indexed_dataset.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. # copied from fairseq/fairseq/data/indexed_dataset.py
  6. # Removed IndexedRawTextDataset since it relied on Fairseq dictionary
  7. # other slight modifications to remove fairseq dependencies
  8. # Added document index to index file and made it accessible.
  9. # An empty sentence no longer separates documents.
  10. from functools import lru_cache
  11. import os
  12. import shutil
  13. import struct
  14. from itertools import accumulate
  15. import numpy as np
  16. import torch
  17. from megatron import print_rank_0
  18. def __best_fitting_dtype(vocab_size=None):
  19. if vocab_size is not None and vocab_size < 65500:
  20. return np.uint16
  21. else:
  22. return np.int32
  23. def get_available_dataset_impl():
  24. return ['lazy', 'cached', 'mmap']
  25. def infer_dataset_impl(path):
  26. if IndexedDataset.exists(path):
  27. with open(index_file_path(path), 'rb') as f:
  28. magic = f.read(8)
  29. if magic == IndexedDataset._HDR_MAGIC:
  30. return 'cached'
  31. elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
  32. return 'mmap'
  33. else:
  34. return None
  35. else:
  36. print(f"Dataset does not exist: {path}")
  37. print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
  38. return None
  39. def make_builder(out_file, impl, vocab_size=None):
  40. if impl == 'mmap':
  41. return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
  42. else:
  43. return IndexedDatasetBuilder(out_file)
  44. def make_dataset(path, impl, skip_warmup=False):
  45. if not IndexedDataset.exists(path):
  46. print(f"Dataset does not exist: {path}")
  47. print("Path should be a basename that both .idx and .bin can be appended to get full filenames.")
  48. return None
  49. if impl == 'infer':
  50. impl = infer_dataset_impl(path)
  51. if impl == 'lazy' and IndexedDataset.exists(path):
  52. return IndexedDataset(path)
  53. elif impl == 'cached' and IndexedDataset.exists(path):
  54. return IndexedCachedDataset(path)
  55. elif impl == 'mmap' and MMapIndexedDataset.exists(path):
  56. return MMapIndexedDataset(path, skip_warmup)
  57. print(f"Unknown dataset implementation: {impl}")
  58. return None
  59. def dataset_exists(path, impl):
  60. if impl == 'mmap':
  61. return MMapIndexedDataset.exists(path)
  62. else:
  63. return IndexedDataset.exists(path)
  64. def read_longs(f, n):
  65. a = np.empty(n, dtype=np.int64)
  66. f.readinto(a)
  67. return a
  68. def write_longs(f, a):
  69. f.write(np.array(a, dtype=np.int64))
  70. dtypes = {
  71. 1: np.uint8,
  72. 2: np.int8,
  73. 3: np.int16,
  74. 4: np.int32,
  75. 5: np.int64,
  76. 6: np.float,
  77. 7: np.double,
  78. 8: np.uint16
  79. }
  80. def code(dtype):
  81. for k in dtypes.keys():
  82. if dtypes[k] == dtype:
  83. return k
  84. raise ValueError(dtype)
  85. def index_file_path(prefix_path):
  86. return prefix_path + '.idx'
  87. def data_file_path(prefix_path):
  88. return prefix_path + '.bin'
  89. def create_doc_idx(sizes):
  90. doc_idx = [0]
  91. for i, s in enumerate(sizes):
  92. if s == 0:
  93. doc_idx.append(i + 1)
  94. return doc_idx
  95. class IndexedDataset(torch.utils.data.Dataset):
  96. """Loader for IndexedDataset"""
  97. _HDR_MAGIC = b'TNTIDX\x00\x00'
  98. def __init__(self, path):
  99. super().__init__()
  100. self.path = path
  101. self.data_file = None
  102. self.read_index(path)
  103. def read_index(self, path):
  104. with open(index_file_path(path), 'rb') as f:
  105. magic = f.read(8)
  106. assert magic == self._HDR_MAGIC, (
  107. 'Index file doesn\'t match expected format. '
  108. 'Make sure that --dataset-impl is configured properly.'
  109. )
  110. version = f.read(8)
  111. assert struct.unpack('<Q', version) == (1,)
  112. code, self.element_size = struct.unpack('<QQ', f.read(16))
  113. self.dtype = dtypes[code]
  114. self._len, self.s = struct.unpack('<QQ', f.read(16))
  115. self.doc_count = struct.unpack('<Q', f.read(8))
  116. self.dim_offsets = read_longs(f, self._len + 1)
  117. self.data_offsets = read_longs(f, self._len + 1)
  118. self.sizes = read_longs(f, self.s)
  119. self.doc_idx = read_longs(f, self.doc_count)
  120. def read_data(self, path):
  121. self.data_file = open(data_file_path(path), 'rb', buffering=0)
  122. def check_index(self, i):
  123. if i < 0 or i >= self._len:
  124. raise IndexError('index out of range')
  125. def __del__(self):
  126. if self.data_file:
  127. self.data_file.close()
  128. # @lru_cache(maxsize=8)
  129. def __getitem__(self, idx):
  130. if not self.data_file:
  131. self.read_data(self.path)
  132. if isinstance(idx, int):
  133. i = idx
  134. self.check_index(i)
  135. tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
  136. a = np.empty(tensor_size, dtype=self.dtype)
  137. self.data_file.seek(self.data_offsets[i] * self.element_size)
  138. self.data_file.readinto(a)
  139. return a
  140. elif isinstance(idx, slice):
  141. start, stop, step = idx.indices(len(self))
  142. if step != 1:
  143. raise ValueError("Slices into indexed_dataset must be contiguous")
  144. sizes = self.sizes[self.dim_offsets[start]:self.dim_offsets[stop]]
  145. size = sum(sizes)
  146. a = np.empty(size, dtype=self.dtype)
  147. self.data_file.seek(self.data_offsets[start] * self.element_size)
  148. self.data_file.readinto(a)
  149. offsets = list(accumulate(sizes))
  150. sents = np.split(a, offsets[:-1])
  151. return sents
  152. def __len__(self):
  153. return self._len
  154. def num_tokens(self, index):
  155. return self.sizes[index]
  156. def size(self, index):
  157. return self.sizes[index]
  158. @staticmethod
  159. def exists(path):
  160. return (
  161. os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
  162. )
  163. @property
  164. def supports_prefetch(self):
  165. return False # avoid prefetching to save memory
  166. class IndexedCachedDataset(IndexedDataset):
  167. def __init__(self, path):
  168. super().__init__(path)
  169. self.cache = None
  170. self.cache_index = {}
  171. @property
  172. def supports_prefetch(self):
  173. return True
  174. def prefetch(self, indices):
  175. if all(i in self.cache_index for i in indices):
  176. return
  177. if not self.data_file:
  178. self.read_data(self.path)
  179. indices = sorted(set(indices))
  180. total_size = 0
  181. for i in indices:
  182. total_size += self.data_offsets[i + 1] - self.data_offsets[i]
  183. self.cache = np.empty(total_size, dtype=self.dtype)
  184. ptx = 0
  185. self.cache_index.clear()
  186. for i in indices:
  187. self.cache_index[i] = ptx
  188. size = self.data_offsets[i + 1] - self.data_offsets[i]
  189. a = self.cache[ptx: ptx + size]
  190. self.data_file.seek(self.data_offsets[i] * self.element_size)
  191. self.data_file.readinto(a)
  192. ptx += size
  193. if self.data_file:
  194. # close and delete data file after prefetch so we can pickle
  195. self.data_file.close()
  196. self.data_file = None
  197. # @lru_cache(maxsize=8)
  198. def __getitem__(self, idx):
  199. if isinstance(idx, int):
  200. i = idx
  201. self.check_index(i)
  202. tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
  203. a = np.empty(tensor_size, dtype=self.dtype)
  204. ptx = self.cache_index[i]
  205. np.copyto(a, self.cache[ptx: ptx + a.size])
  206. return a
  207. elif isinstance(idx, slice):
  208. # Hack just to make this work, can optimizer later if necessary
  209. sents = []
  210. for i in range(*idx.indices(len(self))):
  211. sents.append(self[i])
  212. return sents
  213. class IndexedDatasetBuilder(object):
  214. element_sizes = {
  215. np.uint8: 1,
  216. np.int8: 1,
  217. np.int16: 2,
  218. np.int32: 4,
  219. np.int64: 8,
  220. np.float: 4,
  221. np.double: 8
  222. }
  223. def __init__(self, out_file, dtype=np.int32):
  224. self.out_file = open(out_file, 'wb')
  225. self.dtype = dtype
  226. self.data_offsets = [0]
  227. self.dim_offsets = [0]
  228. self.sizes = []
  229. self.element_size = self.element_sizes[self.dtype]
  230. self.doc_idx = [0]
  231. def add_item(self, tensor):
  232. bytes = self.out_file.write(np.array(tensor.numpy(), dtype=self.dtype))
  233. self.data_offsets.append(self.data_offsets[-1] + bytes / self.element_size)
  234. for s in tensor.size():
  235. self.sizes.append(s)
  236. self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
  237. def end_document(self):
  238. self.doc_idx.append(len(self.sizes))
  239. def merge_file_(self, another_file):
  240. index = IndexedDataset(another_file)
  241. assert index.dtype == self.dtype
  242. begin = self.data_offsets[-1]
  243. for offset in index.data_offsets[1:]:
  244. self.data_offsets.append(begin + offset)
  245. self.sizes.extend(index.sizes)
  246. begin = self.dim_offsets[-1]
  247. for dim_offset in index.dim_offsets[1:]:
  248. self.dim_offsets.append(begin + dim_offset)
  249. with open(data_file_path(another_file), 'rb') as f:
  250. while True:
  251. data = f.read(1024)
  252. if data:
  253. self.out_file.write(data)
  254. else:
  255. break
  256. def finalize(self, index_file):
  257. self.out_file.close()
  258. index = open(index_file, 'wb')
  259. index.write(b'TNTIDX\x00\x00')
  260. index.write(struct.pack('<Q', 1))
  261. index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
  262. index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
  263. index.write(struct.pack('<Q', len(self.doc_idx)))
  264. write_longs(index, self.dim_offsets)
  265. write_longs(index, self.data_offsets)
  266. write_longs(index, self.sizes)
  267. write_longs(index, self.doc_idx)
  268. index.close()
  269. def _warmup_mmap_file(path):
  270. with open(path, 'rb') as stream:
  271. while stream.read(100 * 1024 * 1024):
  272. pass
  273. class MMapIndexedDataset(torch.utils.data.Dataset):
  274. class Index(object):
  275. _HDR_MAGIC = b'MMIDIDX\x00\x00'
  276. @classmethod
  277. def writer(cls, path, dtype):
  278. class _Writer(object):
  279. def __enter__(self):
  280. self._file = open(path, 'wb')
  281. self._file.write(cls._HDR_MAGIC)
  282. self._file.write(struct.pack('<Q', 1))
  283. self._file.write(struct.pack('<B', code(dtype)))
  284. return self
  285. @staticmethod
  286. def _get_pointers(sizes):
  287. dtype_size = dtype().itemsize
  288. address = 0
  289. pointers = []
  290. for size in sizes:
  291. pointers.append(address)
  292. address += size * dtype_size
  293. return pointers
  294. def write(self, sizes, doc_idx):
  295. pointers = self._get_pointers(sizes)
  296. self._file.write(struct.pack('<Q', len(sizes)))
  297. self._file.write(struct.pack('<Q', len(doc_idx)))
  298. sizes = np.array(sizes, dtype=np.int32)
  299. self._file.write(sizes.tobytes(order='C'))
  300. del sizes
  301. pointers = np.array(pointers, dtype=np.int64)
  302. self._file.write(pointers.tobytes(order='C'))
  303. del pointers
  304. doc_idx = np.array(doc_idx, dtype=np.int64)
  305. self._file.write(doc_idx.tobytes(order='C'))
  306. def __exit__(self, exc_type, exc_val, exc_tb):
  307. self._file.close()
  308. return _Writer()
  309. def __init__(self, path, skip_warmup=False):
  310. with open(path, 'rb') as stream:
  311. magic_test = stream.read(9)
  312. assert self._HDR_MAGIC == magic_test, (
  313. 'Index file doesn\'t match expected format. '
  314. 'Make sure that --dataset-impl is configured properly.'
  315. )
  316. version = struct.unpack('<Q', stream.read(8))
  317. assert (1,) == version
  318. dtype_code, = struct.unpack('<B', stream.read(1))
  319. self._dtype = dtypes[dtype_code]
  320. self._dtype_size = self._dtype().itemsize
  321. self._len = struct.unpack('<Q', stream.read(8))[0]
  322. self._doc_count = struct.unpack('<Q', stream.read(8))[0]
  323. offset = stream.tell()
  324. if not skip_warmup:
  325. print_rank_0(" warming up index mmap file...")
  326. _warmup_mmap_file(path)
  327. self._bin_buffer_mmap = np.memmap(path, mode='r', order='C')
  328. self._bin_buffer = memoryview(self._bin_buffer_mmap)
  329. print_rank_0(" reading sizes...")
  330. self._sizes = np.frombuffer(
  331. self._bin_buffer,
  332. dtype=np.int32,
  333. count=self._len,
  334. offset=offset)
  335. print_rank_0(" reading pointers...")
  336. self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
  337. offset=offset + self._sizes.nbytes)
  338. print_rank_0(" reading document index...")
  339. self._doc_idx = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._doc_count,
  340. offset=offset + self._sizes.nbytes + self._pointers.nbytes)
  341. def __del__(self):
  342. self._bin_buffer_mmap._mmap.close()
  343. del self._bin_buffer_mmap
  344. @property
  345. def dtype(self):
  346. return self._dtype
  347. @property
  348. def sizes(self):
  349. return self._sizes
  350. @property
  351. def doc_idx(self):
  352. return self._doc_idx
  353. @lru_cache(maxsize=8)
  354. def __getitem__(self, i):
  355. return self._pointers[i], self._sizes[i]
  356. def __len__(self):
  357. return self._len
  358. def __init__(self, path, skip_warmup=False):
  359. super().__init__()
  360. self._path = None
  361. self._index = None
  362. self._bin_buffer = None
  363. self._do_init(path, skip_warmup)
  364. def __getstate__(self):
  365. return self._path
  366. def __setstate__(self, state):
  367. self._do_init(state)
  368. def _do_init(self, path, skip_warmup):
  369. self._path = path
  370. self._index = self.Index(index_file_path(self._path), skip_warmup)
  371. if not skip_warmup:
  372. print_rank_0(" warming up data mmap file...")
  373. _warmup_mmap_file(data_file_path(self._path))
  374. print_rank_0(" creating numpy buffer of mmap...")
  375. self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode='r', order='C')
  376. print_rank_0(" creating memory view of numpy buffer...")
  377. self._bin_buffer = memoryview(self._bin_buffer_mmap)
  378. def __del__(self):
  379. self._bin_buffer_mmap._mmap.close()
  380. del self._bin_buffer_mmap
  381. del self._index
  382. def __len__(self):
  383. return len(self._index)
  384. # @lru_cache(maxsize=8)
  385. def __getitem__(self, idx):
  386. if isinstance(idx, int):
  387. ptr, size = self._index[idx]
  388. np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
  389. count=size, offset=ptr)
  390. return np_array
  391. elif isinstance(idx, slice):
  392. start, stop, step = idx.indices(len(self))
  393. if step != 1:
  394. raise ValueError("Slices into indexed_dataset must be contiguous")
  395. ptr = self._index._pointers[start]
  396. sizes = self._index._sizes[idx]
  397. offsets = list(accumulate(sizes))
  398. total_size = sum(sizes)
  399. np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
  400. count=total_size, offset=ptr)
  401. sents = np.split(np_array, offsets[:-1])
  402. return sents
  403. def get(self, idx, offset=0, length=None):
  404. """ Retrieves a single item from the dataset with the option to only
  405. return a portion of the item.
  406. get(idx) is the same as [idx] but get() does not support slicing.
  407. """
  408. ptr, size = self._index[idx]
  409. if length is None:
  410. length = size - offset
  411. ptr += offset * np.dtype(self._index.dtype).itemsize
  412. np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype,
  413. count=length, offset=ptr)
  414. return np_array
  415. @property
  416. def sizes(self):
  417. return self._index.sizes
  418. @property
  419. def doc_idx(self):
  420. return self._index.doc_idx
  421. def get_doc_idx(self):
  422. return self._index._doc_idx
  423. def set_doc_idx(self, doc_idx_):
  424. self._index._doc_idx = doc_idx_
  425. @property
  426. def supports_prefetch(self):
  427. return False
  428. @staticmethod
  429. def exists(path):
  430. return (
  431. os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
  432. )
  433. class MMapIndexedDatasetBuilder(object):
  434. def __init__(self, out_file, dtype=np.int64):
  435. self._data_file = open(out_file, 'wb')
  436. self._dtype = dtype
  437. self._sizes = []
  438. self._doc_idx = [0]
  439. def add_item(self, tensor):
  440. np_array = np.array(tensor.numpy(), dtype=self._dtype)
  441. self._data_file.write(np_array.tobytes(order='C'))
  442. self._sizes.append(np_array.size)
  443. def end_document(self):
  444. self._doc_idx.append(len(self._sizes))
  445. def merge_file_(self, another_file):
  446. # Concatenate index
  447. index = MMapIndexedDataset.Index(index_file_path(another_file))
  448. assert index.dtype == self._dtype
  449. for size in index.sizes:
  450. self._sizes.append(size)
  451. # Concatenate data
  452. with open(data_file_path(another_file), 'rb') as f:
  453. shutil.copyfileobj(f, self._data_file)
  454. def finalize(self, index_file):
  455. self._data_file.close()
  456. with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
  457. index.write(self._sizes, self._doc_idx)