run_streaming.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import torch
  2. import argparse
  3. import json
  4. import os
  5. import time
  6. import re
  7. import sys
  8. from utils.streaming import load, download_url, load_jsonl, greedy_generate
  9. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
  10. from utils.llama import H2OLlamaForCausalLM
  11. from utils.cache import Cache, HHCache, StaticCache
  12. @torch.no_grad()
  13. def streaming_inference_h2o(model, tokenizer, config, prompts, max_gen_len=1000, enable_h2o_generation=False):
  14. past_key_values = None
  15. for idx, prompt in enumerate(prompts):
  16. prompt = "USER: " + prompt + "\n\nASSISTANT: "
  17. print("\n" + prompt, end="")
  18. input_ids = tokenizer(prompt, return_tensors="pt").input_ids
  19. input_ids = input_ids.to(model.device)
  20. seq_len = input_ids.shape[1]
  21. past_key_values = greedy_generate(
  22. model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
  23. )
  24. if enable_h2o_generation:
  25. space_needed = seq_len + max_gen_len
  26. past_key_values = HHCache.from_legacy_cache(config.num_window_length, config.num_heavy_hitter_tokens, past_key_values)
  27. past_key_values.evict_for_space(space_needed)
  28. past_key_values = past_key_values.to_legacy_cache()
  29. def main():
  30. parser = argparse.ArgumentParser()
  31. parser.add_argument("--input-path", type=str, default="")
  32. parser.add_argument("--model-name", type=str, default="lmsys/vicuna-13b-v1.5")
  33. parser.add_argument("--enable_h2o_generation", action='store_true')
  34. parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
  35. parser.add_argument("--num_window_length", type=int, default=256)
  36. parser.add_argument("--enable_position_rolling", action='store_true')
  37. parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
  38. args = parser.parse_args()
  39. model_name = args.model_name
  40. data_root = args.input_path
  41. config = AutoConfig.from_pretrained(model_name)
  42. tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
  43. if args.enable_h2o_generation:
  44. config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
  45. config.num_window_length = args.num_window_length
  46. config.enable_position_rolling = args.enable_position_rolling
  47. model = H2OLlamaForCausalLM.from_pretrained(model_name,
  48. torch_dtype=torch.float16,
  49. device_map='auto',
  50. low_cpu_mem_usage=True,
  51. config=config)
  52. else:
  53. model = AutoModelForCausalLM.from_pretrained(model_name,
  54. torch_dtype=torch.float16,
  55. device_map='auto',
  56. low_cpu_mem_usage=True,)
  57. test_filepath = os.path.join(data_root, "mt_bench.jsonl")
  58. print(f"Loading data from {test_filepath} ...")
  59. if not os.path.exists(test_filepath):
  60. download_url(
  61. "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
  62. data_root,
  63. )
  64. os.rename(os.path.join(data_root, "question.jsonl"), test_filepath)
  65. list_data = load_jsonl(test_filepath)
  66. prompts = []
  67. for sample in list_data:
  68. prompts += sample["turns"]
  69. streaming_inference_h2o(model, tokenizer, config, prompts, enable_h2o_generation=args.enable_h2o_generation)
  70. if __name__ == "__main__":
  71. main()