12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- import torch
- import argparse
- import json
- import os
- import time
- import re
- import sys
- from utils.streaming import load, download_url, load_jsonl, greedy_generate
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
- from utils.llama import H2OLlamaForCausalLM
- from utils.cache import Cache, HHCache, StaticCache
- @torch.no_grad()
- def streaming_inference_h2o(model, tokenizer, config, prompts, max_gen_len=1000, enable_h2o_generation=False):
- past_key_values = None
- for idx, prompt in enumerate(prompts):
- prompt = "USER: " + prompt + "\n\nASSISTANT: "
- print("\n" + prompt, end="")
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids
- input_ids = input_ids.to(model.device)
- seq_len = input_ids.shape[1]
- past_key_values = greedy_generate(
- model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
- )
- if enable_h2o_generation:
- space_needed = seq_len + max_gen_len
- past_key_values = HHCache.from_legacy_cache(config.num_window_length, config.num_heavy_hitter_tokens, past_key_values)
- past_key_values.evict_for_space(space_needed)
- past_key_values = past_key_values.to_legacy_cache()
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--input-path", type=str, default="")
- parser.add_argument("--model-name", type=str, default="lmsys/vicuna-13b-v1.5")
- parser.add_argument("--enable_h2o_generation", action='store_true')
- parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
- parser.add_argument("--num_window_length", type=int, default=256)
- parser.add_argument("--enable_position_rolling", action='store_true')
- parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
- args = parser.parse_args()
- model_name = args.model_name
- data_root = args.input_path
- config = AutoConfig.from_pretrained(model_name)
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
- if args.enable_h2o_generation:
- config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
- config.num_window_length = args.num_window_length
- config.enable_position_rolling = args.enable_position_rolling
- model = H2OLlamaForCausalLM.from_pretrained(model_name,
- torch_dtype=torch.float16,
- device_map='auto',
- low_cpu_mem_usage=True,
- config=config)
- else:
- model = AutoModelForCausalLM.from_pretrained(model_name,
- torch_dtype=torch.float16,
- device_map='auto',
- low_cpu_mem_usage=True,)
- test_filepath = os.path.join(data_root, "mt_bench.jsonl")
- print(f"Loading data from {test_filepath} ...")
- if not os.path.exists(test_filepath):
- download_url(
- "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
- data_root,
- )
- os.rename(os.path.join(data_root, "question.jsonl"), test_filepath)
- list_data = load_jsonl(test_filepath)
- prompts = []
- for sample in list_data:
- prompts += sample["turns"]
- streaming_inference_h2o(model, tokenizer, config, prompts, enable_h2o_generation=args.enable_h2o_generation)
- if __name__ == "__main__":
- main()
|