123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- import itertools
- import os
- import pickle
- import shutil
- import numpy as np
- import torch
- from megatron import get_args
- from megatron import mpu
- def detach(tensor):
- return tensor.detach().cpu().numpy()
- class OpenRetreivalDataStore(object):
- """
- Serializable data structure for holding data for blocks --
- embeddings and necessary metadata for Retriever
- """
- def __init__(self, embedding_path=None, load_from_path=True, rank=None):
- self.embed_data = dict()
- if embedding_path is None:
- args = get_args()
- embedding_path = args.embedding_path
- rank = args.rank
- self.embedding_path = embedding_path
- self.rank = rank
- if load_from_path:
- self.load_from_file()
- block_data_name = os.path.splitext(self.embedding_path)[0]
- self.temp_dir_name = block_data_name + '_tmp'
- def state(self):
- return {
- 'embed_data': self.embed_data,
- }
- def clear(self):
- """
- Clear the embedding data structures to save memory.
- The metadata ends up getting used, and is also much smaller in
- dimensionality so it isn't really worth clearing.
- """
- self.embed_data = dict()
- def load_from_file(self):
- """Populate members from instance saved to file"""
- if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
- print("\n> Unpickling BlockData", flush=True)
- state_dict = pickle.load(open(self.embedding_path, 'rb'))
- if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
- print(">> Finished unpickling BlockData\n", flush=True)
- self.embed_data = state_dict['embed_data']
- def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
- """
- Add data for set of blocks
- :param row_id: 1D array of unique int ids for the blocks
- :param block_embeds: 2D array of embeddings of the blocks
- In the case of retriever this will be [start_idx, end_idx, doc_idx]
- """
- for idx, embed in zip(row_id, block_embeds):
- if not allow_overwrite and idx in self.embed_data:
- raise ValueError("Unexpectedly tried to overwrite block data")
- self.embed_data[idx] = np.float16(embed)
- def save_shard(self):
- """
- Save the block data that was created this in this process
- """
- if not os.path.isdir(self.temp_dir_name):
- os.makedirs(self.temp_dir_name, exist_ok=True)
- # save the data for each shard
- with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
- as writer:
- pickle.dump(self.state(), writer)
- def merge_shards_and_save(self):
- #Combine all the shards made using save_shard
- shard_names = os.listdir(self.temp_dir_name)
- seen_own_shard = False
- for fname in os.listdir(self.temp_dir_name):
- shard_rank = int(os.path.splitext(fname)[0])
- if shard_rank == self.rank:
- seen_own_shard = True
- continue
- with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
- data = pickle.load(f)
- old_size = len(self.embed_data)
- shard_size = len(data['embed_data'])
- # add the shard's data and check to make sure there
- # is no overlap
- self.embed_data.update(data['embed_data'])
- assert len(self.embed_data) == old_size + shard_size
- assert seen_own_shard
- # save the consolidated shards and remove temporary directory
- with open(self.embedding_path, 'wb') as final_file:
- pickle.dump(self.state(), final_file)
- shutil.rmtree(self.temp_dir_name, ignore_errors=True)
- print("Finished merging {} shards for a total of {} embeds".format(
- len(shard_names), len(self.embed_data)), flush=True)
- class FaissMIPSIndex(object):
- """
- Wrapper object for a BlockData which similarity search via FAISS under the hood
- """
- def __init__(self, embed_size, embed_data=None, use_gpu=False):
- self.embed_size = embed_size
- self.embed_data = embed_data
- self.use_gpu = use_gpu
- self.mips_index = None
- self._set_mips_index()
- def _set_mips_index(self):
- """
- Create a Faiss Flat index with inner product as the metric
- to search against
- """
- try:
- import faiss
- except ImportError:
- raise Exception("Error: Please install faiss to use FaissMIPSIndex")
- if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
- print("\n> Building index", flush=True)
- cpu_index = faiss.IndexFlatIP(self.embed_size)
- if self.use_gpu:
- # create resources and config for GpuIndex
- config = faiss.GpuMultipleClonerOptions()
- config.shard = True
- config.useFloat16 = True
- gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
- self.mips_index = faiss.IndexIDMap(gpu_index)
- if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
- print(">> Initialized index on GPU", flush=True)
- else:
- # CPU index supports IDs so wrap with IDMap
- self.mips_index = faiss.IndexIDMap(cpu_index)
- if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
- print(">> Initialized index on CPU", flush=True)
- # if we were constructed with a BlockData, then automatically load it
- # when the FAISS structure is built
- if self.embed_data is not None:
- self.add_embed_data(self.embed_data)
- def reset_index(self):
- """Delete existing index and create a new"""
- del self.mips_index
- # reset the block data so that _set_block_index will reload it as well
- if self.embed_data is not None:
- embed_data_path = self.embed_data.embedding_path
- del self.embed_data
- self.embed_data = OpenRetreivalDataStore(embed_data_path)
- self._set_mips_index()
- def update_index(self):
- """Delete existing index and create a new"""
- del self.mips_index
- # reset the block data so that _set_mips_index will reload it as well
- if self.embed_data is not None:
- self.embed_data.load_from_file()
- self._set_mips_index()
- def add_embed_data(self, all_embed_data):
- """Add the embedding of each block to the underlying FAISS index"""
- # this assumes the embed_data is a dict : {int: np.array<float>}
- block_indices, block_embeds = zip(*all_embed_data.embed_data.items())
- # the embeddings have to be entered in as float32 even though the math
- # internally is done with float16.
- embeds_arr = np.float32(np.array(block_embeds))
- indices_arr = np.array(block_indices)
- # we no longer need the embedding data since it's in the index now
- all_embed_data.clear()
- self.mips_index.add_with_ids(embeds_arr, indices_arr)
- if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0:
- print(">>> Finished adding block data to index", flush=True)
- def search_mips_index(self, query_embeds, top_k, reconstruct=True):
- """
- Get the top-k blocks by the index distance metric.
- :param reconstruct: if True: return a [num_queries x k x embed_dim]
- array of blocks
- if False: return [num_queries x k] array of
- distances, and another for indices
- """
- query_embeds = np.float32(detach(query_embeds))
- if reconstruct:
- # get the vectors themselves
- top_k_block_embeds = self.mips_index.search_and_reconstruct(\
- query_embeds, top_k)
- return top_k_block_embeds
- else:
- # get distances and indices of closest vectors
- distances, block_indices = self.mips_index.search(query_embeds, top_k)
- return distances, block_indices
|