utils.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from transformers import AutoTokenizer
  4. class FakeTokenizer(object):
  5. def __init__(self):
  6. self.pad_token_id = 0
  7. self.bos_token_id = 42
  8. self.eos_token_id = 43
  9. self.sep_token_id = 3
  10. self.vocab_size = 128256
  11. self.pad_token = "<|pad_id|>"
  12. self.bos_token = "<|bos_id|>"
  13. self.eos_token = "<|eos_id|>"
  14. self.sep_token = "<|sep_id|>"
  15. self.tokenizer = self
  16. self.padding_side = "left"
  17. def __call__(self, *args, **kwargs):
  18. print(f"{kwargs=}")
  19. ids = self.encode(*args, **kwargs)
  20. return {"input_ids": ids}
  21. def encode(self, text, *args, **kwargs):
  22. return [self.bos_token_id] + [len(c) for c in text.split(" ")] + [self.eos_token_id]
  23. def __len__(self):
  24. return 128256
  25. def pad(self, *args, **kwargs):
  26. args = args[0]
  27. max_len = max([len(a["input_ids"]) for a in args])
  28. for a in args:
  29. for k in a.keys():
  30. a[k] = a[k] + ([self.pad_token_id if k == "input_ids" else 0] * (max_len - len(a)))
  31. out = {}
  32. for k in args[0].keys():
  33. out[k] = [a[k] for a in args]
  34. return out
  35. def maybe_tokenizer(name):
  36. if name == "fake_llama":
  37. return FakeTokenizer()
  38. try:
  39. return AutoTokenizer.from_pretrained(name)
  40. except OSError:
  41. return None