generation.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import tqdm
  3. import json
  4. import copy
  5. import math
  6. import torch
  7. import logging
  8. import argparse
  9. import numpy as np
  10. from rouge import Rouge
  11. import dataclasses
  12. from xopen import xopen
  13. import matplotlib.pyplot as plt
  14. from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
  15. from utils_llama import H2OLlamaForCausalLM
  16. def set_seed(args):
  17. np.random.seed(args.seed)
  18. torch.manual_seed(args.seed)
  19. if args.n_gpu > 0:
  20. torch.cuda.manual_seed_all(args.seed)
  21. if __name__ == '__main__':
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument("--input-path", type=str, default="")
  24. parser.add_argument("--output-path", type=str, default="")
  25. parser.add_argument("--model-name", type=str, default="")
  26. parser.add_argument("--enable_h2o_generation", action='store_true')
  27. parser.add_argument("--num_heavy_hitter_tokens", type=int, default=256)
  28. parser.add_argument("--num_local_windows", type=int, default=256)
  29. parser.add_argument("--enable_position_rolling", action='store_true')
  30. parser.add_argument("--sample_num", type=int, default=10)
  31. parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
  32. args = parser.parse_args()
  33. args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  34. args.n_gpu = torch.cuda.device_count()
  35. set_seed(args)
  36. model_name = args.model_name
  37. input_path = args.input_path
  38. output_path = args.output_path
  39. os.makedirs(os.path.dirname(output_path), exist_ok=True)
  40. config = AutoConfig.from_pretrained(model_name)
  41. tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
  42. if args.enable_h2o_generation:
  43. config.hh_size = args.num_heavy_hitter_tokens
  44. config.recent_size = args.num_local_windows
  45. config.enable_position_rolling = args.enable_position_rolling
  46. model = H2OLlamaForCausalLM.from_pretrained(model_name, config)
  47. else:
  48. model = AutoModelForCausalLM.from_pretrained(model_name)
  49. model.half().eval().cuda()
  50. # loading inference data
  51. requests = []
  52. with open(input_path, 'r') as f:
  53. for line in f:
  54. if line.strip() != '':
  55. requests.append(json.loads(line))
  56. if args.sample_num < len(requests):
  57. print('Sample {} Examples from {} samples'.format(args.sample_num, len(requests)))
  58. requests = requests[:args.sample_num]
  59. results = []
  60. rouge = Rouge()
  61. rouge1_score_list = []
  62. rouge2_score_list = []
  63. rougel_score_list = []
  64. with torch.no_grad():
  65. for request in tqdm.tqdm(requests):
  66. result = {'request': request, 'result': {}}
  67. prompt = request['article']
  68. label = request['summary_gt']
  69. temperature = request['temperature']
  70. stop = request['stop']
  71. input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
  72. output_sequences = model.generate(
  73. input_ids=input_ids,
  74. max_length=request['max_tokens'] + len(input_ids[0]),
  75. temperature=temperature,
  76. top_p=request['top_p'],
  77. do_sample=True,
  78. num_return_sequences=request['n'],
  79. return_dict_in_generate=True, output_scores=True,
  80. )
  81. tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
  82. logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
  83. top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
  84. generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):])
  85. generate_text = generate_text[: generate_text.find(stop[0])]
  86. scores = rouge.get_scores(generate_text, label)[0]
  87. rouge1_score_list.append(scores['rouge-1']['f'])
  88. rouge2_score_list.append(scores['rouge-2']['f'])
  89. rougel_score_list.append(scores['rouge-l']['f'])
  90. result['result'] = {
  91. "choices": [
  92. {
  93. "text": generate_text,
  94. "logprobs": {
  95. "tokens": tokens,
  96. "token_logprobs": logprobs,
  97. "top_logprobs": top_logprobs,
  98. "text_offset": []
  99. },
  100. "finish_reason": "length"
  101. }
  102. ],
  103. "request_time": {
  104. "batch_time": 0,
  105. "batch_size": 1}
  106. }
  107. results.append(result)
  108. print('Final Results: {:.6f}, rouge-2: {:.6f}, rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))
  109. with open(output_path, 'w') as f:
  110. for result in results:
  111. f.write(json.dumps(result) + '\n')