run_summarization.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import os
  2. import tqdm
  3. import json
  4. import copy
  5. import math
  6. import torch
  7. import logging
  8. import argparse
  9. import numpy as np
  10. from rouge import Rouge
  11. import dataclasses
  12. from xopen import xopen
  13. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
  14. from utils.llama import H2OLlamaForCausalLM
  15. def set_seed(args):
  16. np.random.seed(args.seed)
  17. torch.manual_seed(args.seed)
  18. torch.cuda.manual_seed_all(args.seed)
  19. if __name__ == '__main__':
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument("--input-path", type=str, default="")
  22. parser.add_argument("--output-path", type=str, default="")
  23. parser.add_argument("--model-name", type=str, default="")
  24. parser.add_argument("--enable_h2o_generation", action='store_true')
  25. parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
  26. parser.add_argument("--num_window_length", type=int, default=256)
  27. parser.add_argument("--enable_position_rolling", action='store_true')
  28. parser.add_argument("--sample_num", type=int, default=500)
  29. parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
  30. args = parser.parse_args()
  31. set_seed(args)
  32. model_name = args.model_name
  33. input_path = args.input_path
  34. output_path = args.output_path
  35. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  36. config = AutoConfig.from_pretrained(model_name)
  37. tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
  38. if args.num_heavy_hitter_tokens == -1:
  39. print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
  40. args.num_heavy_hitter_tokens = args.num_window_length // 2
  41. if args.enable_h2o_generation:
  42. config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
  43. config.num_window_length = args.num_window_length
  44. config.enable_position_rolling = args.enable_position_rolling
  45. model = H2OLlamaForCausalLM.from_pretrained(model_name,
  46. torch_dtype=torch.float16,
  47. device_map='auto',
  48. low_cpu_mem_usage=True,
  49. config=config)
  50. else:
  51. model = AutoModelForCausalLM.from_pretrained(model_name,
  52. torch_dtype=torch.float16,
  53. device_map='auto',
  54. low_cpu_mem_usage=True,)
  55. # loading inference data
  56. requests = []
  57. with open(input_path, 'r') as f:
  58. for line in f:
  59. if line.strip() != '':
  60. requests.append(json.loads(line))
  61. if args.sample_num < len(requests):
  62. print('Sample {} Examples from {} samples'.format(args.sample_num, len(requests)))
  63. requests = requests[:args.sample_num]
  64. results = []
  65. rouge = Rouge()
  66. rouge1_score_list = []
  67. rouge2_score_list = []
  68. rougel_score_list = []
  69. with torch.no_grad():
  70. for request in tqdm.tqdm(requests):
  71. result = {'request': request, 'result': {}}
  72. prompt = request['article']
  73. label = request['summary_gt']
  74. temperature = request['temperature']
  75. stop = request['stop']
  76. input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
  77. output_sequences = model.generate(
  78. input_ids=input_ids,
  79. max_length=request['max_tokens'] + len(input_ids[0]),
  80. temperature=temperature,
  81. top_p=request['top_p'],
  82. do_sample=True,
  83. num_return_sequences=request['n'],
  84. return_dict_in_generate=True, output_scores=True,
  85. pad_token_id=tokenizer.eos_token_id
  86. )
  87. tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
  88. logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
  89. top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
  90. generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):])
  91. generate_text = generate_text[: generate_text.find(stop[0])]
  92. scores = rouge.get_scores(generate_text, label)[0]
  93. rouge1_score_list.append(scores['rouge-1']['f'])
  94. rouge2_score_list.append(scores['rouge-2']['f'])
  95. rougel_score_list.append(scores['rouge-l']['f'])
  96. result['result'] = {
  97. "choices": [
  98. {
  99. "text": generate_text,
  100. "logprobs": {
  101. "tokens": tokens,
  102. "token_logprobs": logprobs,
  103. "top_logprobs": top_logprobs,
  104. "text_offset": []
  105. },
  106. "finish_reason": "length"
  107. }
  108. ],
  109. "request_time": {
  110. "batch_time": 0,
  111. "batch_size": 1}
  112. }
  113. results.append(result)
  114. print('Average Rouge1: {:.6f}, Rouge-2: {:.6f}, Rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))
  115. with open(output_path, 'w') as f:
  116. for result in results:
  117. f.write(json.dumps(result) + '\n')