run_summarization.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import os
  2. import tqdm
  3. import json
  4. import copy
  5. import math
  6. import torch
  7. import logging
  8. import argparse
  9. import numpy as np
  10. from rouge import Rouge
  11. import dataclasses
  12. from xopen import xopen
  13. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
  14. from utils.llama import H2OLlamaForCausalLM
  15. def set_seed(args):
  16. np.random.seed(args.seed)
  17. torch.manual_seed(args.seed)
  18. torch.cuda.manual_seed_all(args.seed)
  19. if __name__ == '__main__':
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument("--input-path", type=str, default="")
  22. parser.add_argument("--output-path", type=str, default="")
  23. parser.add_argument("--model-name", type=str, default="")
  24. parser.add_argument("--enable_h2o_generation", action='store_true')
  25. parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
  26. parser.add_argument("--num_window_length", type=int, default=256)
  27. parser.add_argument("--enable_position_rolling", action='store_true')
  28. parser.add_argument("--sample_num", type=int, default=500)
  29. parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
  30. args = parser.parse_args()
  31. set_seed(args)
  32. model_name = args.model_name
  33. input_path = args.input_path
  34. output_path = args.output_path
  35. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  36. config = AutoConfig.from_pretrained(model_name)
  37. tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
  38. if args.enable_h2o_generation:
  39. config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
  40. config.num_window_length = args.num_window_length
  41. config.enable_position_rolling = args.enable_position_rolling
  42. model = H2OLlamaForCausalLM.from_pretrained(model_name,
  43. torch_dtype=torch.float16,
  44. device_map='auto',
  45. low_cpu_mem_usage=True,
  46. config=config)
  47. else:
  48. model = AutoModelForCausalLM.from_pretrained(model_name,
  49. torch_dtype=torch.float16,
  50. device_map='auto',
  51. low_cpu_mem_usage=True,)
  52. # loading inference data
  53. requests = []
  54. with open(input_path, 'r') as f:
  55. for line in f:
  56. if line.strip() != '':
  57. requests.append(json.loads(line))
  58. if args.sample_num < len(requests):
  59. print('Sample {} Examples from {} samples'.format(args.sample_num, len(requests)))
  60. requests = requests[:args.sample_num]
  61. results = []
  62. rouge = Rouge()
  63. rouge1_score_list = []
  64. rouge2_score_list = []
  65. rougel_score_list = []
  66. with torch.no_grad():
  67. for request in tqdm.tqdm(requests):
  68. result = {'request': request, 'result': {}}
  69. prompt = request['article']
  70. label = request['summary_gt']
  71. temperature = request['temperature']
  72. stop = request['stop']
  73. input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
  74. output_sequences = model.generate(
  75. input_ids=input_ids,
  76. max_length=request['max_tokens'] + len(input_ids[0]),
  77. temperature=temperature,
  78. top_p=request['top_p'],
  79. do_sample=True,
  80. num_return_sequences=request['n'],
  81. return_dict_in_generate=True, output_scores=True,
  82. pad_token_id=tokenizer.eos_token_id
  83. )
  84. tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
  85. logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
  86. top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
  87. generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):])
  88. generate_text = generate_text[: generate_text.find(stop[0])]
  89. scores = rouge.get_scores(generate_text, label)[0]
  90. rouge1_score_list.append(scores['rouge-1']['f'])
  91. rouge2_score_list.append(scores['rouge-2']['f'])
  92. rougel_score_list.append(scores['rouge-l']['f'])
  93. result['result'] = {
  94. "choices": [
  95. {
  96. "text": generate_text,
  97. "logprobs": {
  98. "tokens": tokens,
  99. "token_logprobs": logprobs,
  100. "top_logprobs": top_logprobs,
  101. "text_offset": []
  102. },
  103. "finish_reason": "length"
  104. }
  105. ],
  106. "request_time": {
  107. "batch_time": 0,
  108. "batch_size": 1}
  109. }
  110. results.append(result)
  111. print('Average Rouge1: {:.6f}, Rouge-2: {:.6f}, Rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))
  112. with open(output_path, 'w') as f:
  113. for result in results:
  114. f.write(json.dumps(result) + '\n')