streaming.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. """
  2. Source Code: https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/utils.py
  3. """
  4. import torch
  5. import argparse
  6. from transformers import (
  7. AutoTokenizer,
  8. AutoModelForCausalLM,
  9. )
  10. import os.path as osp
  11. import ssl
  12. import urllib.request
  13. import os
  14. import json
  15. def load(model_name_or_path):
  16. print(f"Loading model from {model_name_or_path} ...")
  17. # however, tensor parallel for running falcon will occur bugs
  18. tokenizer = AutoTokenizer.from_pretrained(
  19. model_name_or_path,
  20. trust_remote_code=True,
  21. )
  22. model = AutoModelForCausalLM.from_pretrained(
  23. model_name_or_path,
  24. device_map="auto",
  25. torch_dtype=torch.float16,
  26. trust_remote_code=True,
  27. )
  28. if tokenizer.pad_token_id is None:
  29. if tokenizer.eos_token_id is not None:
  30. tokenizer.pad_token_id = tokenizer.eos_token_id
  31. else:
  32. tokenizer.pad_token_id = 0
  33. model.eval()
  34. return model, tokenizer
  35. def download_url(url: str, folder="folder"):
  36. """
  37. Downloads the content of an url to a folder. Modified from \
  38. https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
  39. Args:
  40. url (string): The url of target file.
  41. folder (string): The target folder.
  42. Returns:
  43. string: File path of downloaded files.
  44. """
  45. file = url.rpartition("/")[2]
  46. file = file if file[0] == "?" else file.split("?")[0]
  47. path = osp.join(folder, file)
  48. if osp.exists(path):
  49. print(f"File {file} exists, use existing file.")
  50. return path
  51. print(f"Downloading {url}")
  52. os.makedirs(folder, exist_ok=True)
  53. ctx = ssl._create_unverified_context()
  54. data = urllib.request.urlopen(url, context=ctx)
  55. with open(path, "wb") as f:
  56. f.write(data.read())
  57. return path
  58. def load_jsonl(
  59. file_path,
  60. ):
  61. list_data_dict = []
  62. with open(file_path, "r") as f:
  63. for line in f:
  64. list_data_dict.append(json.loads(line))
  65. return list_data_dict
  66. @torch.no_grad()
  67. def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
  68. outputs = model(
  69. input_ids=input_ids,
  70. past_key_values=past_key_values,
  71. use_cache=True,
  72. )
  73. past_key_values = outputs.past_key_values
  74. pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
  75. generated_ids = [pred_token_idx.item()]
  76. pos = 0
  77. for _ in range(max_gen_len - 1):
  78. outputs = model(
  79. input_ids=pred_token_idx,
  80. past_key_values=past_key_values,
  81. use_cache=True,
  82. )
  83. past_key_values = outputs.past_key_values
  84. pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
  85. generated_ids.append(pred_token_idx.item())
  86. generated_text = (
  87. tokenizer.decode(
  88. generated_ids,
  89. skip_special_tokens=True,
  90. clean_up_tokenization_spaces=True,
  91. spaces_between_special_tokens=False,
  92. )
  93. .strip()
  94. .split(" ")
  95. )
  96. now = len(generated_text) - 1
  97. if now > pos:
  98. print(" ".join(generated_text[pos:now]), end=" ", flush=True)
  99. pos = now
  100. if pred_token_idx == tokenizer.eos_token_id:
  101. break
  102. print(" ".join(generated_text[pos:]), flush=True)
  103. return past_key_values