run_needle_haystack_test.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import os
  2. import tqdm
  3. import glob
  4. import json
  5. import copy
  6. import math
  7. import torch
  8. import logging
  9. import argparse
  10. import numpy as np
  11. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
  12. from utils.llama import H2OLlamaForCausalLM
  13. def set_seed(args):
  14. np.random.seed(args.seed)
  15. torch.manual_seed(args.seed)
  16. torch.cuda.manual_seed_all(args.seed)
  17. if __name__ == '__main__':
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--input-path", type=str, default="")
  20. parser.add_argument("--output-path", type=str, default="")
  21. parser.add_argument("--model-provider", type=str, default="Huggingface")
  22. parser.add_argument("--model-name", type=str, default="")
  23. parser.add_argument("--enable_h2o_generation", action='store_true')
  24. parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
  25. parser.add_argument("--num_window_length", type=int, default=256)
  26. parser.add_argument("--num_chunk_size", type=int, default=2048)
  27. parser.add_argument("--enable_position_rolling", action='store_true')
  28. parser.add_argument("--max_new_tokens", type=int, default=1024)
  29. parser.add_argument("--temperature", type=float, default=0.1)
  30. parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
  31. args = parser.parse_args()
  32. set_seed(args)
  33. model_name = args.model_name
  34. input_path = args.input_path
  35. output_path = args.output_path
  36. model_provider = args.model_provider
  37. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  38. config = AutoConfig.from_pretrained(model_name)
  39. tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
  40. if args.num_heavy_hitter_tokens == -1:
  41. print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
  42. args.num_heavy_hitter_tokens = args.num_window_length // 2
  43. if args.enable_h2o_generation:
  44. config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
  45. config.num_window_length = args.num_window_length
  46. config.enable_position_rolling = args.enable_position_rolling
  47. model = H2OLlamaForCausalLM.from_pretrained(model_name,
  48. torch_dtype=torch.float16,
  49. device_map='auto',
  50. low_cpu_mem_usage=True,
  51. config=config)
  52. else:
  53. model = AutoModelForCausalLM.from_pretrained(model_name,
  54. torch_dtype=torch.float16,
  55. device_map='auto',
  56. low_cpu_mem_usage=True,)
  57. # load the testing prompts
  58. for filename in tqdm.tqdm(glob.glob(f'{input_path}/{args.model_provider}_*_prompts.json')):
  59. with open(filename, 'r') as f:
  60. input_data = json.load(f)
  61. prompt = input_data[0]['content']+'\n'+input_data[1]['content']
  62. input = tokenizer(prompt, truncation=False, return_tensors="pt").to(model.device)
  63. context_length = input.input_ids.shape[-1]
  64. if context_length > args.num_chunk_size:
  65. # truncate the context to the maximum chunk size
  66. input = {k: v[:, -args.num_chunk_size:] for k, v in input.items()}
  67. output = model.generate(
  68. **input,
  69. max_new_tokens=args.max_new_tokens,
  70. num_beams=1,
  71. temperature=args.temperature,
  72. pad_token_id=tokenizer.eos_token_id,
  73. )
  74. pred = tokenizer.decode(output[0][context_length:], skip_special_tokens=True)
  75. pred = pred.strip()
  76. basename = os.path.basename(filename)
  77. newname = basename.replace('.json', '.txt').replace('_prompts', '')
  78. with open(f'{output_path}/{newname}', 'w') as f:
  79. f.write(pred)