realm_index.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import itertools
  2. import os
  3. import pickle
  4. import shutil
  5. import numpy as np
  6. import torch
  7. from megatron import get_args
  8. from megatron import mpu
  9. def detach(tensor):
  10. return tensor.detach().cpu().numpy()
  11. class OpenRetreivalDataStore(object):
  12. """
  13. Serializable data structure for holding data for blocks --
  14. embeddings and necessary metadata for Retriever
  15. """
  16. def __init__(self, embedding_path=None, load_from_path=True, rank=None):
  17. self.embed_data = dict()
  18. if embedding_path is None:
  19. args = get_args()
  20. embedding_path = args.embedding_path
  21. rank = args.rank
  22. self.embedding_path = embedding_path
  23. self.rank = rank
  24. if load_from_path:
  25. self.load_from_file()
  26. block_data_name = os.path.splitext(self.embedding_path)[0]
  27. self.temp_dir_name = block_data_name + '_tmp'
  28. def state(self):
  29. return {
  30. 'embed_data': self.embed_data,
  31. }
  32. def clear(self):
  33. """
  34. Clear the embedding data structures to save memory.
  35. The metadata ends up getting used, and is also much smaller in
  36. dimensionality so it isn't really worth clearing.
  37. """
  38. self.embed_data = dict()
  39. def load_from_file(self):
  40. """Populate members from instance saved to file"""
  41. if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
  42. print("\n> Unpickling BlockData", flush=True)
  43. state_dict = pickle.load(open(self.embedding_path, 'rb'))
  44. if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
  45. print(">> Finished unpickling BlockData\n", flush=True)
  46. self.embed_data = state_dict['embed_data']
  47. def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
  48. """
  49. Add data for set of blocks
  50. :param row_id: 1D array of unique int ids for the blocks
  51. :param block_embeds: 2D array of embeddings of the blocks
  52. In the case of retriever this will be [start_idx, end_idx, doc_idx]
  53. """
  54. for idx, embed in zip(row_id, block_embeds):
  55. if not allow_overwrite and idx in self.embed_data:
  56. raise ValueError("Unexpectedly tried to overwrite block data")
  57. self.embed_data[idx] = np.float16(embed)
  58. def save_shard(self):
  59. """
  60. Save the block data that was created this in this process
  61. """
  62. if not os.path.isdir(self.temp_dir_name):
  63. os.makedirs(self.temp_dir_name, exist_ok=True)
  64. # save the data for each shard
  65. with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
  66. as writer:
  67. pickle.dump(self.state(), writer)
  68. def merge_shards_and_save(self):
  69. #Combine all the shards made using save_shard
  70. shard_names = os.listdir(self.temp_dir_name)
  71. seen_own_shard = False
  72. for fname in os.listdir(self.temp_dir_name):
  73. shard_rank = int(os.path.splitext(fname)[0])
  74. if shard_rank == self.rank:
  75. seen_own_shard = True
  76. continue
  77. with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
  78. data = pickle.load(f)
  79. old_size = len(self.embed_data)
  80. shard_size = len(data['embed_data'])
  81. # add the shard's data and check to make sure there
  82. # is no overlap
  83. self.embed_data.update(data['embed_data'])
  84. assert len(self.embed_data) == old_size + shard_size
  85. assert seen_own_shard
  86. # save the consolidated shards and remove temporary directory
  87. with open(self.embedding_path, 'wb') as final_file:
  88. pickle.dump(self.state(), final_file)
  89. shutil.rmtree(self.temp_dir_name, ignore_errors=True)
  90. print("Finished merging {} shards for a total of {} embeds".format(
  91. len(shard_names), len(self.embed_data)), flush=True)
  92. class FaissMIPSIndex(object):
  93. """
  94. Wrapper object for a BlockData which similarity search via FAISS under the hood
  95. """
  96. def __init__(self, embed_size, embed_data=None, use_gpu=False):
  97. self.embed_size = embed_size
  98. self.embed_data = embed_data
  99. self.use_gpu = use_gpu
  100. self.mips_index = None
  101. self._set_mips_index()
  102. def _set_mips_index(self):
  103. """
  104. Create a Faiss Flat index with inner product as the metric
  105. to search against
  106. """
  107. try:
  108. import faiss
  109. except ImportError:
  110. raise Exception("Error: Please install faiss to use FaissMIPSIndex")
  111. if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
  112. print("\n> Building index", flush=True)
  113. cpu_index = faiss.IndexFlatIP(self.embed_size)
  114. if self.use_gpu:
  115. # create resources and config for GpuIndex
  116. config = faiss.GpuMultipleClonerOptions()
  117. config.shard = True
  118. config.useFloat16 = True
  119. gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
  120. self.mips_index = faiss.IndexIDMap(gpu_index)
  121. if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
  122. print(">> Initialized index on GPU", flush=True)
  123. else:
  124. # CPU index supports IDs so wrap with IDMap
  125. self.mips_index = faiss.IndexIDMap(cpu_index)
  126. if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
  127. print(">> Initialized index on CPU", flush=True)
  128. # if we were constructed with a BlockData, then automatically load it
  129. # when the FAISS structure is built
  130. if self.embed_data is not None:
  131. self.add_embed_data(self.embed_data)
  132. def reset_index(self):
  133. """Delete existing index and create a new"""
  134. del self.mips_index
  135. # reset the block data so that _set_block_index will reload it as well
  136. if self.embed_data is not None:
  137. embed_data_path = self.embed_data.embedding_path
  138. del self.embed_data
  139. self.embed_data = OpenRetreivalDataStore(embed_data_path)
  140. self._set_mips_index()
  141. def update_index(self):
  142. """Delete existing index and create a new"""
  143. del self.mips_index
  144. # reset the block data so that _set_mips_index will reload it as well
  145. if self.embed_data is not None:
  146. self.embed_data.load_from_file()
  147. self._set_mips_index()
  148. def add_embed_data(self, all_embed_data):
  149. """Add the embedding of each block to the underlying FAISS index"""
  150. # this assumes the embed_data is a dict : {int: np.array<float>}
  151. block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
  152. # the embeddings have to be entered in as float32 even though the math
  153. # internally is done with float16.
  154. embeds_arr = np.float32(np.array(block_embeds))
  155. indices_arr = np.array(block_indices)
  156. # we no longer need the embedding data since it's in the index now
  157. all_embed_data.clear()
  158. self.mips_index.add_with_ids(embeds_arr, indices_arr)
  159. if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
  160. print(">>> Finished adding block data to index", flush=True)
  161. def search_mips_index(self, query_embeds, top_k, reconstruct=True):
  162. """
  163. Get the top-k blocks by the index distance metric.
  164. :param reconstruct: if True: return a [num_queries x k x embed_dim]
  165. array of blocks
  166. if False: return [num_queries x k] array of
  167. distances, and another for indices
  168. """
  169. query_embeds = np.float32(detach(query_embeds))
  170. if reconstruct:
  171. # get the vectors themselves
  172. top_k_block_embeds = self.mips_index.search_and_reconstruct(\
  173. query_embeds, top_k)
  174. return top_k_block_embeds
  175. else:
  176. # get distances and indices of closest vectors
  177. distances, block_indices = self.mips_index.search(query_embeds, top_k)
  178. return distances, block_indices