indexer.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import sys
  2. import time
  3. import torch
  4. import torch.distributed as dist
  5. from megatron import get_args, print_rank_0
  6. from megatron import mpu
  7. from megatron.checkpointing import load_biencoder_checkpoint
  8. from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
  9. from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
  10. from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
  11. from megatron.data.realm_index import detach, OpenRetreivalDataStore
  12. from megatron.model.biencoder_model import get_model_provider
  13. from megatron.training import get_model
  14. class IndexBuilder(object):
  15. """
  16. Object for taking one pass over a dataset and creating a BlockData of its
  17. embeddings
  18. """
  19. def __init__(self):
  20. args = get_args()
  21. self.model = None
  22. self.dataloader = None
  23. self.evidence_embedder_obj = None
  24. self.biencoder_shared_query_context_model = \
  25. args.biencoder_shared_query_context_model
  26. # need to know whether we're using a REALM checkpoint (args.load)
  27. # or ICT checkpoint
  28. assert not (args.load and args.ict_load)
  29. self.log_interval = args.indexer_log_interval
  30. self.batch_size = args.indexer_batch_size
  31. self.load_attributes()
  32. self.is_main_builder = mpu.get_data_parallel_rank() == 0
  33. self.num_total_builders = mpu.get_data_parallel_world_size()
  34. self.iteration = self.total_processed = 0
  35. def load_attributes(self):
  36. """
  37. Load the necessary attributes: model, dataloader and empty BlockData
  38. """
  39. only_context_model = True
  40. if self.biencoder_shared_query_context_model:
  41. only_context_model = False
  42. model = get_model(get_model_provider(only_context_model=\
  43. only_context_model, biencoder_shared_query_context_model=\
  44. self.biencoder_shared_query_context_model))
  45. self.model = load_biencoder_checkpoint(model,
  46. only_context_model=only_context_model)
  47. assert len(self.model) == 1
  48. self.model[0].eval()
  49. self.dataset = get_open_retrieval_wiki_dataset()
  50. self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
  51. self.batch_size))
  52. self.evidence_embedder_obj = OpenRetreivalDataStore( \
  53. load_from_path=False)
  54. def track_and_report_progress(self, batch_size):
  55. """
  56. Utility function for tracking progress
  57. """
  58. self.iteration += 1
  59. self.total_processed += batch_size * self.num_total_builders
  60. if self.is_main_builder and self.iteration % self.log_interval == 0:
  61. print('Batch {:10d} | Total {:10d}'.format(self.iteration,
  62. self.total_processed), flush=True)
  63. def build_and_save_index(self):
  64. """
  65. Goes through one epoch of the dataloader and adds all data to this
  66. instance's BlockData.
  67. The copy of BlockData is saved as a shard, which when run in a
  68. distributed setting will be consolidated by the rank 0 process
  69. and saved as a final pickled BlockData.
  70. """
  71. assert len(self.model) == 1
  72. unwrapped_model = self.model[0]
  73. while not hasattr(unwrapped_model, 'embed_text'):
  74. unwrapped_model = unwrapped_model.module
  75. while True:
  76. try:
  77. # batch also has query_tokens and query_pad_data
  78. row_id, context_tokens, context_mask, context_types, \
  79. context_pad_mask = get_open_retrieval_batch( \
  80. self.dataloader)
  81. except (StopIteration, IndexError):
  82. break
  83. # TODO: can we add with torch.no_grad() to reduce memory usage
  84. # detach, separate fields and add to BlockData
  85. assert context_mask.dtype == torch.bool
  86. context_logits = unwrapped_model.embed_text(
  87. unwrapped_model.context_model, context_tokens, context_mask,
  88. context_types)
  89. context_logits = detach(context_logits)
  90. row_id = detach(row_id)
  91. self.evidence_embedder_obj.add_block_data(row_id, context_logits)
  92. self.track_and_report_progress(batch_size=len(row_id))
  93. # This process signals to finalize its shard and then synchronize with
  94. # the other processes
  95. self.evidence_embedder_obj.save_shard()
  96. torch.distributed.barrier()
  97. del self.model
  98. # rank 0 process builds the final copy
  99. if self.is_main_builder:
  100. self.evidence_embedder_obj.merge_shards_and_save()
  101. # make sure that every single piece of data was embedded
  102. assert len(self.evidence_embedder_obj.embed_data) == \
  103. len(self.dataset)
  104. self.evidence_embedder_obj.clear()
  105. # complete building the final copy
  106. torch.distributed.barrier()