| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 | 
							- import os
 
- import tqdm
 
- import json
 
- import copy
 
- import math
 
- import torch
 
- import logging
 
- import argparse
 
- import numpy as np
 
- from rouge import Rouge
 
- import dataclasses
 
- from xopen import xopen
 
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
 
- from utils.llama import H2OLlamaForCausalLM
 
- def set_seed(args):
 
-     np.random.seed(args.seed)
 
-     torch.manual_seed(args.seed)
 
-     torch.cuda.manual_seed_all(args.seed)
 
- if __name__ == '__main__':
 
-     parser = argparse.ArgumentParser()
 
-     parser.add_argument("--input-path", type=str, default="")
 
-     parser.add_argument("--output-path", type=str, default="")
 
-     parser.add_argument("--model-name", type=str, default="")
 
-     parser.add_argument("--enable_h2o_generation", action='store_true')
 
-     parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
 
-     parser.add_argument("--num_window_length", type=int, default=256)
 
-     parser.add_argument("--enable_position_rolling", action='store_true')
 
-     parser.add_argument("--sample_num", type=int, default=500)
 
-     parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
 
-     args = parser.parse_args()
 
-     set_seed(args)
 
-     model_name = args.model_name
 
-     input_path = args.input_path
 
-     output_path = args.output_path
 
-     os.makedirs(os.path.dirname(output_path), exist_ok=True)
 
-     config = AutoConfig.from_pretrained(model_name)
 
-     tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
 
-     if args.num_heavy_hitter_tokens == -1:
 
-         print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
 
-         args.num_heavy_hitter_tokens = args.num_window_length // 2
 
-     if args.enable_h2o_generation:
 
-         config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
 
-         config.num_window_length = args.num_window_length
 
-         config.enable_position_rolling = args.enable_position_rolling
 
-         model = H2OLlamaForCausalLM.from_pretrained(model_name,
 
-             torch_dtype=torch.float16,
 
-             device_map='auto',
 
-             low_cpu_mem_usage=True,
 
-             config=config)
 
-     else:
 
-         model = AutoModelForCausalLM.from_pretrained(model_name,
 
-             torch_dtype=torch.float16,
 
-             device_map='auto',
 
-             low_cpu_mem_usage=True,)
 
-     # loading inference data
 
-     requests = []
 
-     with open(input_path, 'r') as f:
 
-         for line in f:
 
-             if line.strip() != '':
 
-                 requests.append(json.loads(line))
 
-     if args.sample_num < len(requests):
 
-         print('Sample {} Examples from {} samples'.format(args.sample_num, len(requests)))
 
-     requests = requests[:args.sample_num]
 
-     results = []
 
-     rouge = Rouge()
 
-     rouge1_score_list = []
 
-     rouge2_score_list = []
 
-     rougel_score_list = []
 
-     with torch.no_grad():
 
-         for request in tqdm.tqdm(requests):
 
-             result = {'request': request, 'result': {}}
 
-             prompt = request['article']
 
-             label = request['summary_gt']
 
-             temperature = request['temperature']
 
-             stop = request['stop']
 
-             input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
 
-             output_sequences = model.generate(
 
-                 input_ids=input_ids,
 
-                 max_length=request['max_tokens'] + len(input_ids[0]),
 
-                 temperature=temperature,
 
-                 top_p=request['top_p'],
 
-                 do_sample=True,
 
-                 num_return_sequences=request['n'],
 
-                 return_dict_in_generate=True, output_scores=True,
 
-                 pad_token_id=tokenizer.eos_token_id
 
-             )
 
-             tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
 
-             logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
 
-             top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
 
-             generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):])
 
-             generate_text = generate_text[: generate_text.find(stop[0])]
 
-             scores = rouge.get_scores(generate_text, label)[0]
 
-             rouge1_score_list.append(scores['rouge-1']['f'])
 
-             rouge2_score_list.append(scores['rouge-2']['f'])
 
-             rougel_score_list.append(scores['rouge-l']['f'])
 
-             result['result'] = {
 
-                 "choices": [
 
-                     {
 
-                         "text": generate_text,
 
-                         "logprobs": {
 
-                             "tokens": tokens, 
 
-                             "token_logprobs": logprobs, 
 
-                             "top_logprobs": top_logprobs, 
 
-                             "text_offset": []
 
-                         }, 
 
-                         "finish_reason": "length"
 
-                     }
 
-                 ], 
 
-                 "request_time": {
 
-                     "batch_time": 0, 
 
-                     "batch_size": 1}
 
-             }
 
-             
 
-             results.append(result)
 
-     print('Average Rouge1: {:.6f}, Rouge-2: {:.6f}, Rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))
 
-     with open(output_path, 'w') as f:
 
-         for result in results:
 
-             f.write(json.dumps(result) + '\n')
 
 
  |