tokenizers.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. #!/usr/bin/env python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # The following code has been taken from
  6. # https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
  7. # licensed as of now. More details on the license can be found
  8. # at https://github.com/facebookresearch/DPR/blob/master/LICENSE
  9. """
  10. Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
  11. """
  12. import copy
  13. import logging
  14. import regex
  15. import spacy
  16. logger = logging.getLogger(__name__)
  17. class Tokens(object):
  18. """A class to represent a list of tokenized text."""
  19. TEXT = 0
  20. TEXT_WS = 1
  21. SPAN = 2
  22. POS = 3
  23. LEMMA = 4
  24. NER = 5
  25. def __init__(self, data, annotators, opts=None):
  26. self.data = data
  27. self.annotators = annotators
  28. self.opts = opts or {}
  29. def __len__(self):
  30. """The number of tokens."""
  31. return len(self.data)
  32. def slice(self, i=None, j=None):
  33. """Return a view of the list of tokens from [i, j)."""
  34. new_tokens = copy.copy(self)
  35. new_tokens.data = self.data[i: j]
  36. return new_tokens
  37. def untokenize(self):
  38. """Returns the original text (with whitespace reinserted)."""
  39. return ''.join([t[self.TEXT_WS] for t in self.data]).strip()
  40. def words(self, uncased=False):
  41. """Returns a list of the text of each token
  42. Args:
  43. uncased: lower cases text
  44. """
  45. if uncased:
  46. return [t[self.TEXT].lower() for t in self.data]
  47. else:
  48. return [t[self.TEXT] for t in self.data]
  49. def offsets(self):
  50. """Returns a list of [start, end) character offsets of each token."""
  51. return [t[self.SPAN] for t in self.data]
  52. def pos(self):
  53. """Returns a list of part-of-speech tags of each token.
  54. Returns None if this annotation was not included.
  55. """
  56. if 'pos' not in self.annotators:
  57. return None
  58. return [t[self.POS] for t in self.data]
  59. def lemmas(self):
  60. """Returns a list of the lemmatized text of each token.
  61. Returns None if this annotation was not included.
  62. """
  63. if 'lemma' not in self.annotators:
  64. return None
  65. return [t[self.LEMMA] for t in self.data]
  66. def entities(self):
  67. """Returns a list of named-entity-recognition tags of each token.
  68. Returns None if this annotation was not included.
  69. """
  70. if 'ner' not in self.annotators:
  71. return None
  72. return [t[self.NER] for t in self.data]
  73. def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True):
  74. """Returns a list of all ngrams from length 1 to n.
  75. Args:
  76. n: upper limit of ngram length
  77. uncased: lower cases text
  78. filter_fn: user function that takes in an ngram list and returns
  79. True or False to keep or not keep the ngram
  80. as_string: return the ngram as a string vs list
  81. """
  82. def _skip(gram):
  83. if not filter_fn:
  84. return False
  85. return filter_fn(gram)
  86. words = self.words(uncased)
  87. ngrams = [(s, e + 1)
  88. for s in range(len(words))
  89. for e in range(s, min(s + n, len(words)))
  90. if not _skip(words[s:e + 1])]
  91. # Concatenate into strings
  92. if as_strings:
  93. ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams]
  94. return ngrams
  95. def entity_groups(self):
  96. """Group consecutive entity tokens with the same NER tag."""
  97. entities = self.entities()
  98. if not entities:
  99. return None
  100. non_ent = self.opts.get('non_ent', 'O')
  101. groups = []
  102. idx = 0
  103. while idx < len(entities):
  104. ner_tag = entities[idx]
  105. # Check for entity tag
  106. if ner_tag != non_ent:
  107. # Chomp the sequence
  108. start = idx
  109. while (idx < len(entities) and entities[idx] == ner_tag):
  110. idx += 1
  111. groups.append((self.slice(start, idx).untokenize(), ner_tag))
  112. else:
  113. idx += 1
  114. return groups
  115. class Tokenizer(object):
  116. """Base tokenizer class.
  117. Tokenizers implement tokenize, which should return a Tokens class.
  118. """
  119. def tokenize(self, text):
  120. raise NotImplementedError
  121. def shutdown(self):
  122. pass
  123. def __del__(self):
  124. self.shutdown()
  125. class SimpleTokenizer(Tokenizer):
  126. ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
  127. NON_WS = r'[^\p{Z}\p{C}]'
  128. def __init__(self, **kwargs):
  129. """
  130. Args:
  131. annotators: None or empty set (only tokenizes).
  132. """
  133. self._regexp = regex.compile(
  134. '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
  135. flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
  136. )
  137. if len(kwargs.get('annotators', {})) > 0:
  138. logger.warning('%s only tokenizes! Skipping annotators: %s' %
  139. (type(self).__name__, kwargs.get('annotators')))
  140. self.annotators = set()
  141. def tokenize(self, text):
  142. data = []
  143. matches = [m for m in self._regexp.finditer(text)]
  144. for i in range(len(matches)):
  145. # Get text
  146. token = matches[i].group()
  147. # Get whitespace
  148. span = matches[i].span()
  149. start_ws = span[0]
  150. if i + 1 < len(matches):
  151. end_ws = matches[i + 1].span()[0]
  152. else:
  153. end_ws = span[1]
  154. # Format data
  155. data.append((
  156. token,
  157. text[start_ws: end_ws],
  158. span,
  159. ))
  160. return Tokens(data, self.annotators)
  161. class SpacyTokenizer(Tokenizer):
  162. def __init__(self, **kwargs):
  163. """
  164. Args:
  165. annotators: set that can include pos, lemma, and ner.
  166. model: spaCy model to use (either path, or keyword like 'en').
  167. """
  168. model = kwargs.get('model', 'en')
  169. self.annotators = copy.deepcopy(kwargs.get('annotators', set()))
  170. nlp_kwargs = {'parser': False}
  171. if not any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
  172. nlp_kwargs['tagger'] = False
  173. if 'ner' not in self.annotators:
  174. nlp_kwargs['entity'] = False
  175. self.nlp = spacy.load(model, **nlp_kwargs)
  176. def tokenize(self, text):
  177. # We don't treat new lines as tokens.
  178. clean_text = text.replace('\n', ' ')
  179. tokens = self.nlp.tokenizer(clean_text)
  180. if any([p in self.annotators for p in ['lemma', 'pos', 'ner']]):
  181. self.nlp.tagger(tokens)
  182. if 'ner' in self.annotators:
  183. self.nlp.entity(tokens)
  184. data = []
  185. for i in range(len(tokens)):
  186. # Get whitespace
  187. start_ws = tokens[i].idx
  188. if i + 1 < len(tokens):
  189. end_ws = tokens[i + 1].idx
  190. else:
  191. end_ws = tokens[i].idx + len(tokens[i].text)
  192. data.append((
  193. tokens[i].text,
  194. text[start_ws: end_ws],
  195. (tokens[i].idx, tokens[i].idx + len(tokens[i].text)),
  196. tokens[i].tag_,
  197. tokens[i].lemma_,
  198. tokens[i].ent_type_,
  199. ))
  200. # Set special option for non-entity tag: '' vs 'O' in spaCy
  201. return Tokens(data, self.annotators, opts={'non_ent': ''})