gpt2_tokenization.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # coding=utf-8
  2. # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization classes for OpenAI GPT."""
  16. from __future__ import (absolute_import, division, print_function,
  17. unicode_literals)
  18. import sys
  19. import json
  20. import logging
  21. import os
  22. import regex as re
  23. from io import open
  24. try:
  25. from functools import lru_cache
  26. except ImportError:
  27. # Just a dummy decorator to get the checks to run on python2
  28. # because honestly I don't want to support a byte-level unicode BPE
  29. # tokenizer on python 2 right now.
  30. def lru_cache():
  31. return lambda func: func
  32. logger = logging.getLogger(__name__)
  33. PRETRAINED_VOCAB_ARCHIVE_MAP = {
  34. 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
  35. }
  36. PRETRAINED_MERGES_ARCHIVE_MAP = {
  37. 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
  38. }
  39. PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
  40. 'gpt2': 1024,
  41. }
  42. VOCAB_NAME = 'vocab.json'
  43. MERGES_NAME = 'merges.txt'
  44. SPECIAL_TOKENS_NAME = 'special_tokens.txt'
  45. @lru_cache()
  46. def bytes_to_unicode():
  47. """
  48. Returns list of utf-8 byte and a corresponding list of unicode strings.
  49. The reversible bpe codes work on unicode strings.
  50. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  51. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  52. This is a signficant percentage of your normal, say, 32K bpe vocab.
  53. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  54. And avoids mapping to whitespace/control characters the bpe code barfs on.
  55. """
  56. _chr = unichr if sys.version_info[0] == 2 else chr
  57. bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
  58. list(range(ord("®"), ord("ÿ") + 1))
  59. cs = bs[:]
  60. n = 0
  61. for b in range(2**8):
  62. if b not in bs:
  63. bs.append(b)
  64. cs.append(2**8 + n)
  65. n += 1
  66. cs = [_chr(n) for n in cs]
  67. return dict(zip(bs, cs))
  68. def get_pairs(word):
  69. """Return set of symbol pairs in a word.
  70. Word is represented as tuple of symbols (symbols being variable-length strings).
  71. """
  72. pairs = set()
  73. prev_char = word[0]
  74. for char in word[1:]:
  75. pairs.add((prev_char, char))
  76. prev_char = char
  77. return pairs
  78. class GPT2Tokenizer(object):
  79. """
  80. GPT-2 BPE tokenizer. Peculiarities:
  81. - Byte-level BPE
  82. """
  83. @classmethod
  84. def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
  85. """
  86. Instantiate a PreTrainedBertModel from a pre-trained model file.
  87. Download and cache the pre-trained model file if needed.
  88. """
  89. if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
  90. vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
  91. merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
  92. special_tokens_file = None
  93. else:
  94. vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
  95. merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
  96. special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
  97. if not os.path.exists(special_tokens_file):
  98. special_tokens_file = None
  99. else:
  100. logger.info("loading special tokens file {}".format(special_tokens_file))
  101. # redirect to the cache, if necessary
  102. try:
  103. from .file_utils import cached_path
  104. resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
  105. resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
  106. except EnvironmentError:
  107. logger.error(
  108. "Model name '{}' was not found in model name list ({}). "
  109. "We assumed '{}' was a path or url but couldn't find files {} and {} "
  110. "at this path or url.".format(
  111. pretrained_model_name_or_path,
  112. ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
  113. pretrained_model_name_or_path,
  114. vocab_file, merges_file))
  115. return None
  116. if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
  117. logger.info("loading vocabulary file {}".format(vocab_file))
  118. logger.info("loading merges file {}".format(merges_file))
  119. else:
  120. logger.info("loading vocabulary file {} from cache at {}".format(
  121. vocab_file, resolved_vocab_file))
  122. logger.info("loading merges file {} from cache at {}".format(
  123. merges_file, resolved_merges_file))
  124. if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
  125. # if we're using a pretrained model, ensure the tokenizer wont index sequences longer
  126. # than the number of positional embeddings
  127. max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
  128. kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
  129. # Instantiate tokenizer.
  130. if special_tokens_file and 'special_tokens' not in kwargs:
  131. special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
  132. else:
  133. special_tokens = kwargs.pop('special_tokens', [])
  134. tokenizer = cls(
  135. resolved_vocab_file,
  136. resolved_merges_file,
  137. special_tokens=special_tokens,
  138. *inputs,
  139. **kwargs)
  140. return tokenizer
  141. def __init__(self, vocab_file, merges_file, errors='replace',
  142. special_tokens=None, max_len=None):
  143. self.max_len = max_len if max_len is not None else int(1e12)
  144. self.encoder = json.load(open(vocab_file))
  145. self.decoder = {v: k for k, v in self.encoder.items()}
  146. self.errors = errors # how to handle errors in decoding
  147. self.byte_encoder = bytes_to_unicode()
  148. self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
  149. bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
  150. bpe_merges = [tuple(merge.split()) for merge in bpe_data]
  151. self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
  152. self.cache = {}
  153. # Should haved added re.IGNORECASE so BPE merges can happen for
  154. # capitalized versions of contractions
  155. self.pat = re.compile(
  156. r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
  157. self.special_tokens = {}
  158. self.special_tokens_decoder = {}
  159. self.set_special_tokens(special_tokens)
  160. def __len__(self):
  161. return len(self.encoder) + len(self.special_tokens)
  162. def set_special_tokens(self, special_tokens):
  163. """ Add a list of additional tokens to the encoder.
  164. The additional tokens are indexed starting from the last index of the
  165. current vocabulary in the order of the `special_tokens` list.
  166. """
  167. if not special_tokens:
  168. self.special_tokens = {}
  169. self.special_tokens_decoder = {}
  170. return
  171. self.special_tokens = dict((tok, len(self.encoder) + i)
  172. for i, tok in enumerate(special_tokens))
  173. self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
  174. logger.info("Special tokens {}".format(self.special_tokens))
  175. def bpe(self, token):
  176. if token in self.cache:
  177. return self.cache[token]
  178. word = tuple(token)
  179. pairs = get_pairs(word)
  180. if not pairs:
  181. return token
  182. while True:
  183. bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
  184. if bigram not in self.bpe_ranks:
  185. break
  186. first, second = bigram
  187. new_word = []
  188. i = 0
  189. while i < len(word):
  190. try:
  191. j = word.index(first, i)
  192. new_word.extend(word[i:j])
  193. i = j
  194. except BaseException:
  195. new_word.extend(word[i:])
  196. break
  197. if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
  198. new_word.append(first + second)
  199. i += 2
  200. else:
  201. new_word.append(word[i])
  202. i += 1
  203. new_word = tuple(new_word)
  204. word = new_word
  205. if len(word) == 1:
  206. break
  207. else:
  208. pairs = get_pairs(word)
  209. word = ' '.join(word)
  210. self.cache[token] = word
  211. return word
  212. def tokenize(self, text):
  213. """ Tokenize a string. """
  214. bpe_tokens = []
  215. for token in re.findall(self.pat, text):
  216. if sys.version_info[0] == 2:
  217. token = ''.join(self.byte_encoder[ord(b)] for b in token)
  218. else:
  219. token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
  220. bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
  221. return bpe_tokens
  222. def convert_tokens_to_ids(self, tokens):
  223. """ Converts a sequence of tokens into ids using the vocab. """
  224. ids = []
  225. if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
  226. if tokens in self.special_tokens:
  227. return self.special_tokens[tokens]
  228. else:
  229. return self.encoder.get(tokens, 0)
  230. for token in tokens:
  231. if token in self.special_tokens:
  232. ids.append(self.special_tokens[token])
  233. else:
  234. ids.append(self.encoder.get(token, 0))
  235. if len(ids) > self.max_len:
  236. logger.warning(
  237. "Token indices sequence length is longer than the specified maximum "
  238. " sequence length for this OpenAI GPT model ({} > {}). Running this"
  239. " sequence through the model will result in indexing errors".format(
  240. len(ids), self.max_len)
  241. )
  242. return ids
  243. def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
  244. """Converts a sequence of ids in BPE tokens using the vocab."""
  245. tokens = []
  246. for i in ids:
  247. if i in self.special_tokens_decoder:
  248. if not skip_special_tokens:
  249. tokens.append(self.special_tokens_decoder[i])
  250. else:
  251. tokens.append(self.decoder[i])
  252. return tokens
  253. def encode(self, text):
  254. return self.convert_tokens_to_ids(self.tokenize(text))
  255. def decode(self, tokens):
  256. text = ''.join([self.decoder[token] for token in tokens])
  257. text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
  258. return text
  259. def save_vocabulary(self, vocab_path):
  260. """Save the tokenizer vocabulary and merge files to a directory."""
  261. if not os.path.isdir(vocab_path):
  262. logger.error("Vocabulary path ({}) should be a directory".format(vocab_path))
  263. return
  264. vocab_file = os.path.join(vocab_path, VOCAB_NAME)
  265. merge_file = os.path.join(vocab_path, MERGES_NAME)
  266. special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
  267. with open(vocab_file, 'w', encoding='utf-8') as f:
  268. f.write(json.dumps(self.encoder, ensure_ascii=False))
  269. index = 0
  270. with open(merge_file, "w", encoding="utf-8") as writer:
  271. writer.write(u'#version: 0.2\n')
  272. for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
  273. if index != token_index:
  274. logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
  275. " Please check that the tokenizer is not corrupted!".format(merge_file))
  276. index = token_index
  277. writer.write(' '.join(bpe_tokens) + u'\n')
  278. index += 1
  279. index = len(self.encoder)
  280. with open(special_tokens_file, 'w', encoding='utf-8') as writer:
  281. for token, token_index in sorted(self.special_tokens.items(), key=lambda kv: kv[1]):
  282. if index != token_index:
  283. logger.warning("Saving special tokens vocabulary to {}: BPE indices are not consecutive."
  284. " Please check that the tokenizer is not corrupted!".format(special_tokens_file))
  285. index = token_index
  286. writer.write(token + u'\n')
  287. index += 1
  288. return vocab_file, merge_file, special_tokens_file