123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Detokenization."""
- import re
- def ptb_detokenizer(string):
- string = string.replace(" '", "'")
- string = string.replace(" \n", "\n")
- string = string.replace("\n ", "\n")
- string = string.replace(" n't", "n't")
- string = string.replace(" N ", "1 ")
- string = string.replace("$ 1", "$1")
- string = string.replace("# 1", "#1")
- return string
- def wikitext_detokenizer(string):
- # contractions
- string = string.replace("s '", "s'")
- string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
- # number separators
- string = string.replace(" @-@ ", "-")
- string = string.replace(" @,@ ", ",")
- string = string.replace(" @.@ ", ".")
- # punctuation
- string = string.replace(" : ", ": ")
- string = string.replace(" ; ", "; ")
- string = string.replace(" . ", ". ")
- string = string.replace(" ! ", "! ")
- string = string.replace(" ? ", "? ")
- string = string.replace(" , ", ", ")
- # double brackets
- string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
- string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
- string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
- string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
- string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
- # miscellaneous
- string = string.replace("= = = =", "====")
- string = string.replace("= = =", "===")
- string = string.replace("= =", "==")
- string = string.replace(" " + chr(176) + " ", chr(176))
- string = string.replace(" \n", "\n")
- string = string.replace("\n ", "\n")
- string = string.replace(" N ", " 1 ")
- string = string.replace(" 's", "'s")
- return string
- def lambada_detokenizer(string):
- return string
- _DETOKENIZERS = {
- 'ptb': ptb_detokenizer,
- 'wiki': wikitext_detokenizer,
- 'lambada': lambada_detokenizer,
- }
- def get_detokenizer(path):
- for key in _DETOKENIZERS.keys():
- if key in path:
- return _DETOKENIZERS[key]
|