detokenizer.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Detokenization."""
  16. import re
  17. def ptb_detokenizer(string):
  18. string = string.replace(" '", "'")
  19. string = string.replace(" \n", "\n")
  20. string = string.replace("\n ", "\n")
  21. string = string.replace(" n't", "n't")
  22. string = string.replace(" N ", "1 ")
  23. string = string.replace("$ 1", "$1")
  24. string = string.replace("# 1", "#1")
  25. return string
  26. def wikitext_detokenizer(string):
  27. # contractions
  28. string = string.replace("s '", "s'")
  29. string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
  30. # number separators
  31. string = string.replace(" @-@ ", "-")
  32. string = string.replace(" @,@ ", ",")
  33. string = string.replace(" @.@ ", ".")
  34. # punctuation
  35. string = string.replace(" : ", ": ")
  36. string = string.replace(" ; ", "; ")
  37. string = string.replace(" . ", ". ")
  38. string = string.replace(" ! ", "! ")
  39. string = string.replace(" ? ", "? ")
  40. string = string.replace(" , ", ", ")
  41. # double brackets
  42. string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
  43. string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
  44. string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
  45. string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
  46. string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
  47. # miscellaneous
  48. string = string.replace("= = = =", "====")
  49. string = string.replace("= = =", "===")
  50. string = string.replace("= =", "==")
  51. string = string.replace(" " + chr(176) + " ", chr(176))
  52. string = string.replace(" \n", "\n")
  53. string = string.replace("\n ", "\n")
  54. string = string.replace(" N ", " 1 ")
  55. string = string.replace(" 's", "'s")
  56. return string
  57. def lambada_detokenizer(string):
  58. return string
  59. _DETOKENIZERS = {
  60. 'ptb': ptb_detokenizer,
  61. 'wiki': wikitext_detokenizer,
  62. 'lambada': lambada_detokenizer,
  63. }
  64. def get_detokenizer(path):
  65. for key in _DETOKENIZERS.keys():
  66. if key in path:
  67. return _DETOKENIZERS[key]