utils.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from transformers import AutoTokenizer
  4. class FakeTokenizer(object):
  5. def __init__(self):
  6. self.pad_token_id = 0
  7. self.bos_token_id = 42
  8. self.eos_token_id = 43
  9. self.sep_token_id = 3
  10. self.vocab_size = 128256
  11. self.pad_token = "<|pad_id|>"
  12. self.bos_token = "<|bos_id|>"
  13. self.eos_token = "<|eos_id|>"
  14. self.sep_token = "<|sep_id|>"
  15. self.tokenizer = self
  16. self.padding_side = "left"
  17. def __call__(self, *args, **kwargs):
  18. ids = self.encode(*args, **kwargs)
  19. return {"input_ids": ids}
  20. def encode(self, text, *args, **kwargs):
  21. return [self.bos_token_id] + [len(c) for c in text.split(" ")] + [self.eos_token_id]
  22. def __len__(self):
  23. return 128256
  24. def pad(self, *args, **kwargs):
  25. args = args[0]
  26. max_len = max([len(a["input_ids"]) for a in args])
  27. for a in args:
  28. for k in a.keys():
  29. a[k] = a[k] + ([self.pad_token_id if k == "input_ids" else 0] * (max_len - len(a)))
  30. out = {}
  31. for k in args[0].keys():
  32. out[k] = [a[k] for a in args]
  33. return out
  34. def maybe_tokenizer(name):
  35. if name == "fake_llama":
  36. return FakeTokenizer()
  37. try:
  38. return AutoTokenizer.from_pretrained(name)
  39. except OSError:
  40. return None