run_needle_haystack_test.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import os
  2. import tqdm
  3. import glob
  4. import json
  5. import copy
  6. import math
  7. import torch
  8. import logging
  9. import argparse
  10. import numpy as np
  11. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
  12. from utils.llama import H2OLlamaForCausalLM
  13. def set_seed(args):
  14. np.random.seed(args.seed)
  15. torch.manual_seed(args.seed)
  16. torch.cuda.manual_seed_all(args.seed)
  17. if __name__ == '__main__':
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--input-path", type=str, default="")
  20. parser.add_argument("--output-path", type=str, default="")
  21. parser.add_argument("--model-provider", type=str, default="Huggingface")
  22. parser.add_argument("--model-name", type=str, default="")
  23. parser.add_argument("--enable_h2o_generation", action='store_true')
  24. parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
  25. parser.add_argument("--num_window_length", type=int, default=256)
  26. parser.add_argument("--enable_position_rolling", action='store_true')
  27. parser.add_argument("--max_new_tokens", type=int, default=1024)
  28. parser.add_argument("--temperature", type=float, default=0.1)
  29. parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
  30. args = parser.parse_args()
  31. set_seed(args)
  32. model_name = args.model_name
  33. input_path = args.input_path
  34. output_path = args.output_path
  35. model_provider = args.model_provider
  36. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  37. config = AutoConfig.from_pretrained(model_name)
  38. tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
  39. if args.enable_h2o_generation:
  40. config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
  41. config.num_window_length = args.num_window_length
  42. config.enable_position_rolling = args.enable_position_rolling
  43. model = H2OLlamaForCausalLM.from_pretrained(model_name,
  44. torch_dtype=torch.float16,
  45. device_map='auto',
  46. low_cpu_mem_usage=True,
  47. config=config)
  48. else:
  49. model = AutoModelForCausalLM.from_pretrained(model_name,
  50. torch_dtype=torch.float16,
  51. device_map='auto',
  52. low_cpu_mem_usage=True,)
  53. # load the testing prompts
  54. for filename in tqdm.tqdm(glob.glob(f'{input_path}/{args.model_provider}_*_prompts.json')):
  55. with open(filename, 'r') as f:
  56. input_data = json.load(f)
  57. prompt = input_data[0]['content']+'\n'+input_data[1]['content']
  58. input = tokenizer(prompt, truncation=False, return_tensors="pt").to(model.device)
  59. context_length = input.input_ids.shape[-1]
  60. output = model.generate(
  61. **input,
  62. max_new_tokens=args.max_new_tokens,
  63. num_beams=1,
  64. temperature=args.temperature,
  65. pad_token_id=tokenizer.eos_token_id,
  66. )
  67. pred = tokenizer.decode(output[0][context_length:], skip_special_tokens=True)
  68. pred = pred.strip()
  69. basename = os.path.basename(filename)
  70. newname = basename.replace('.json', '.txt').replace('_prompts', '')
  71. with open(f'{output_path}/{newname}', 'w') as f:
  72. f.write(pred)