run_needle_haystack_test.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import os
  2. import tqdm
  3. import glob
  4. import json
  5. import copy
  6. import math
  7. import torch
  8. import logging
  9. import argparse
  10. import numpy as np
  11. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
  12. from utils.llama import H2OLlamaForCausalLM
  13. def set_seed(args):
  14. np.random.seed(args.seed)
  15. torch.manual_seed(args.seed)
  16. torch.cuda.manual_seed_all(args.seed)
  17. if __name__ == '__main__':
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--input-path", type=str, default="")
  20. parser.add_argument("--output-path", type=str, default="")
  21. parser.add_argument("--model-provider", type=str, default="Huggingface")
  22. parser.add_argument("--model-name", type=str, default="")
  23. parser.add_argument("--enable_h2o_generation", action='store_true')
  24. parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
  25. parser.add_argument("--num_window_length", type=int, default=256)
  26. parser.add_argument("--num_chunk_size", type=int, default=2048)
  27. parser.add_argument("--enable_position_rolling", action='store_true')
  28. parser.add_argument("--max_new_tokens", type=int, default=1024)
  29. parser.add_argument("--temperature", type=float, default=0.1)
  30. parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
  31. args = parser.parse_args()
  32. set_seed(args)
  33. model_name = args.model_name
  34. input_path = args.input_path
  35. output_path = args.output_path
  36. model_provider = args.model_provider
  37. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  38. config = AutoConfig.from_pretrained(model_name)
  39. tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
  40. if args.enable_h2o_generation:
  41. config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
  42. config.num_window_length = args.num_window_length
  43. config.enable_position_rolling = args.enable_position_rolling
  44. model = H2OLlamaForCausalLM.from_pretrained(model_name,
  45. torch_dtype=torch.float16,
  46. device_map='auto',
  47. low_cpu_mem_usage=True,
  48. config=config)
  49. else:
  50. model = AutoModelForCausalLM.from_pretrained(model_name,
  51. torch_dtype=torch.float16,
  52. device_map='auto',
  53. low_cpu_mem_usage=True,)
  54. # load the testing prompts
  55. for filename in tqdm.tqdm(glob.glob(f'{input_path}/{args.model_provider}_*_prompts.json')):
  56. with open(filename, 'r') as f:
  57. input_data = json.load(f)
  58. prompt = input_data[0]['content']+'\n'+input_data[1]['content']
  59. input = tokenizer(prompt, truncation=False, return_tensors="pt").to(model.device)
  60. context_length = input.input_ids.shape[-1]
  61. if context_length > args.num_chunk_size:
  62. # truncate the context to the maximum chunk size
  63. input = {k: v[:, -args.num_chunk_size:] for k, v in input.items()}
  64. output = model.generate(
  65. **input,
  66. max_new_tokens=args.max_new_tokens,
  67. num_beams=1,
  68. temperature=args.temperature,
  69. pad_token_id=tokenizer.eos_token_id,
  70. )
  71. pred = tokenizer.decode(output[0][context_length:], skip_special_tokens=True)
  72. pred = pred.strip()
  73. basename = os.path.basename(filename)
  74. newname = basename.replace('.json', '.txt').replace('_prompts', '')
  75. with open(f'{output_path}/{newname}', 'w') as f:
  76. f.write(pred)