소스 검색

benchmark_summarization

Allen 1 년 전
부모
커밋
864f0d9a67

파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
+ 1000 - 0
research/long-context-llama/H2O/data/summarization/cnn_dailymail.jsonl


파일 크기가 너무 크기때문에 변경 상태를 표시하지 않습니다.
+ 1000 - 0
research/long-context-llama/H2O/data/summarization/xsum.jsonl


+ 13 - 0
research/long-context-llama/H2O/exp.sh

@@ -0,0 +1,13 @@
+CUDA_VISIBLE_DEVICES=$1 python -u generate.py \
+--input-path data/summarization/xsum.jsonl \
+--output-path xsum_baseline.jsonl \
+--model-name meta-llama/Llama-2-7b-hf 
+
+
+CUDA_VISIBLE_DEVICES=$1 python -u generate.py \
+--input-path data/summarization/xsum.jsonl \
+--output-path xsum_h2o.jsonl \
+--model-name meta-llama/Llama-2-7b-hf \
+--enable_h2o_generation \
+--num_heavy_hitter_tokens 256 \
+--num_local_windows 256

+ 142 - 0
research/long-context-llama/H2O/generation.py

@@ -0,0 +1,142 @@
+import os
+import tqdm
+import json
+import copy
+import math
+
+import torch
+import logging
+import argparse
+
+import numpy as np
+from rouge import Rouge
+
+import dataclasses
+from xopen import xopen
+import matplotlib.pyplot as plt 
+
+from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+from utils_llama import H2OLlamaForCausalLM
+
+def set_seed(args):
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    if args.n_gpu > 0:
+        torch.cuda.manual_seed_all(args.seed)
+
+if __name__ == '__main__':
+
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument("--input-path", type=str, default="")
+    parser.add_argument("--output-path", type=str, default="")
+
+    parser.add_argument("--model-name", type=str, default="")
+
+    parser.add_argument("--enable_h2o_generation", action='store_true')
+    parser.add_argument("--num_heavy_hitter_tokens", type=int, default=256)
+    parser.add_argument("--num_local_windows", type=int, default=256)
+
+    parser.add_argument("--enable_position_rolling", action='store_true')
+
+    parser.add_argument("--sample_num", type=int, default=10)
+    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
+
+    args = parser.parse_args()
+
+    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    args.n_gpu = torch.cuda.device_count()
+    set_seed(args)
+
+    model_name = args.model_name
+    input_path = args.input_path
+    output_path = args.output_path
+    os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+    config = AutoConfig.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
+
+    if args.enable_h2o_generation:
+        config.hh_size = args.num_heavy_hitter_tokens
+        config.recent_size = args.num_local_windows
+        config.enable_position_rolling = args.enable_position_rolling
+        model = H2OLlamaForCausalLM.from_pretrained(model_name, config)
+    else:
+        model = AutoModelForCausalLM.from_pretrained(model_name)
+
+    model.half().eval().cuda()
+
+    # loading inference data
+    requests = []
+    with open(input_path, 'r') as f:
+        for line in f:
+            if line.strip() != '':
+                requests.append(json.loads(line))
+
+    if args.sample_num < len(requests):
+        print('Sample {} Examples from {} samples'.format(args.sample_num, len(requests)))
+    requests = requests[:args.sample_num]
+
+    results = []
+    rouge = Rouge()
+    rouge1_score_list = []
+    rouge2_score_list = []
+    rougel_score_list = []
+
+    with torch.no_grad():
+        for request in tqdm.tqdm(requests):
+            result = {'request': request, 'result': {}}
+            prompt = request['article']
+            label = request['summary_gt']
+            temperature = request['temperature']
+            stop = request['stop']
+
+            input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
+
+            output_sequences = model.generate(
+                input_ids=input_ids,
+                max_length=request['max_tokens'] + len(input_ids[0]),
+                temperature=temperature,
+                top_p=request['top_p'],
+                do_sample=True,
+                num_return_sequences=request['n'],
+                return_dict_in_generate=True, output_scores=True,
+            )
+
+            tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
+            logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
+            top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
+
+            generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):])
+            generate_text = generate_text[: generate_text.find(stop[0])]
+
+            scores = rouge.get_scores(generate_text, label)[0]
+            rouge1_score_list.append(scores['rouge-1']['f'])
+            rouge2_score_list.append(scores['rouge-2']['f'])
+            rougel_score_list.append(scores['rouge-l']['f'])
+
+            result['result'] = {
+                "choices": [
+                    {
+                        "text": generate_text,
+                        "logprobs": {
+                            "tokens": tokens, 
+                            "token_logprobs": logprobs, 
+                            "top_logprobs": top_logprobs, 
+                            "text_offset": []
+                        }, 
+                        "finish_reason": "length"
+                    }
+                ], 
+                "request_time": {
+                    "batch_time": 0, 
+                    "batch_size": 1}
+            }
+            
+            results.append(result)
+
+    print('Final Results: {:.6f}, rouge-2: {:.6f}, rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))
+    with open(output_path, 'w') as f:
+        for result in results:
+            f.write(json.dumps(result) + '\n')
+

+ 356 - 0
research/long-context-llama/H2O/utils_llama.py

@@ -0,0 +1,356 @@
+import math
+from typing import Optional, Tuple
+
+import pdb
+import types
+import torch
+from torch import nn
+import torch.utils.checkpoint
+import torch.nn.functional as F
+
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+    LlamaAttention,
+    rotate_half,
+    apply_rotary_pos_emb,
+    repeat_kv,
+    LlamaRotaryEmbedding,
+    apply_rotary_pos_emb,
+    LlamaForCausalLM,
+)
+
+__all__ = ["H2OLlamaForCausalLM"]
+
+def _make_causal_mask(
+    bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
+    """
+    Make causal mask used for bi-directional self-attention.
+    """
+    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+    mask_cond = torch.arange(mask.size(-1), device=device)
+    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+    mask = mask.to(dtype)
+
+    if past_key_values_length > 0:
+        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
+    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
+    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
+    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
+    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
+    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
+    x_embed = (x * cos) + (rotate_half(x) * sin)
+    return x_embed
+
+class H2OKVCache_LayerWise:
+    def __init__(
+        self,
+        hh_size=4,
+        recent_size=512,
+        k_seq_dim=2,
+        v_seq_dim=2,
+    ):
+        self.hh_size = hh_size
+        self.recent_size = recent_size
+        self.cache_size = hh_size + recent_size
+        self.k_seq_dim = k_seq_dim
+        self.v_seq_dim = v_seq_dim
+        self.k_slice = DIM_TO_SLICE[k_seq_dim]
+        self.v_slice = DIM_TO_SLICE[v_seq_dim]
+        self.hh_score = None
+
+    def __call__(self, past_key_values, attn_score_cache):
+
+        self._update_hh_score(attn_score_cache)
+
+        if past_key_values is None:
+            return None
+        seq_len = past_key_values[0].size(self.k_seq_dim)
+        if seq_len <= self.cache_size:
+            return past_key_values
+
+        # hh-selection
+        bsz, num_heads, _, head_dim = past_key_values[0].shape
+
+        select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
+        _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
+        keep_topk = keep_topk.sort().values
+
+        # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
+        keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
+        keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
+
+        mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
+        mask = mask.scatter(-1, keep_idx, 1)
+
+        k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
+        v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
+
+        self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
+
+        return (k_hh_recent, v_hh_recent)
+
+    def evict_for_space(self, past_key_values, num_coming):
+        if past_key_values is None:
+            return None
+        seq_len = past_key_values[0][0].size(self.k_seq_dim)
+        if seq_len + num_coming <= self.cache_size:
+            return past_key_values
+
+        # hh-selection
+        bsz, num_heads, _, head_dim = past_key_values[0].shape
+
+        select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
+        _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
+        keep_topk = keep_topk.sort().values
+
+        # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
+        keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
+        keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
+
+        mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
+        mask = mask.scatter(-1, keep_idx, 1)
+
+        k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
+        v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
+
+        self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
+
+        return (k_hh_recent, v_hh_recent)
+
+    def _update_hh_score(self, attn_score_cache):
+
+        num_new_tokens = attn_score_cache.shape[2]
+
+        if self.hh_score is None:
+            self.hh_score = attn_score_cache.sum(0).sum(1)
+        else:
+            attn_score_cache = attn_score_cache.sum(0).sum(1)
+            attn_score_cache[:, :-num_new_tokens] += self.hh_score
+            self.hh_score = attn_score_cache
+
+    def _clean_scores(self):
+        self.hh_score = None
+
+class H2OLlamaAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: LlamaConfig):
+        super().__init__()
+        self.config = config
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+
+        if (self.head_dim * self.num_heads) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads})."
+            )
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+        self._init_rope()
+
+        self.kv_cache = H2OKVCache_LayerWise(
+            hh_size=config.hh_size,
+            recent_size=config.recent_size,
+            k_seq_dim=2,
+            v_seq_dim=2,
+        )
+
+    def _init_rope(self):
+        if self.config.rope_scaling is None:
+            self.rotary_emb = LlamaRotaryEmbedding(
+                self.head_dim,
+                max_position_embeddings=self.max_position_embeddings,
+                base=self.rope_theta,
+            )
+        else:
+            scaling_type = self.config.rope_scaling["type"]
+            scaling_factor = self.config.rope_scaling["factor"]
+            if scaling_type == "linear":
+                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
+                    self.head_dim,
+                    max_position_embeddings=self.max_position_embeddings,
+                    scaling_factor=scaling_factor,
+                    base=self.rope_theta,
+                )
+            elif scaling_type == "dynamic":
+                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
+                    self.head_dim,
+                    max_position_embeddings=self.max_position_embeddings,
+                    scaling_factor=scaling_factor,
+                    base=self.rope_theta,
+                )
+            else:
+                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+    def _clean_cache(self):
+        self.kv_cache._clean_scores()
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
+        bsz, q_len, _ = hidden_states.size()
+
+        if self.config.pretraining_tp > 1:
+            key_value_slicing = (
+                self.num_key_value_heads * self.head_dim
+            ) // self.config.pretraining_tp
+            query_slices = self.q_proj.weight.split(
+                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+            )
+            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+            query_states = [
+                F.linear(hidden_states, query_slices[i])
+                for i in range(self.config.pretraining_tp)
+            ]
+            query_states = torch.cat(query_states, dim=-1)
+
+            key_states = [
+                F.linear(hidden_states, key_slices[i])
+                for i in range(self.config.pretraining_tp)
+            ]
+            key_states = torch.cat(key_states, dim=-1)
+
+            value_states = [
+                F.linear(hidden_states, value_slices[i])
+                for i in range(self.config.pretraining_tp)
+            ]
+            value_states = torch.cat(value_states, dim=-1)
+
+        else:
+            query_states = self.q_proj(hidden_states)
+            key_states = self.k_proj(hidden_states)
+            value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(
+            bsz, q_len, self.num_heads, self.head_dim
+        ).transpose(1, 2)
+        key_states = key_states.view(
+            bsz, q_len, self.num_key_value_heads, self.head_dim
+        ).transpose(1, 2)
+        value_states = value_states.view(
+            bsz, q_len, self.num_key_value_heads, self.head_dim
+        ).transpose(1, 2)
+
+        # remake causal mask
+        attention_mask = _make_causal_mask(
+            bsz=bsz,
+            tgt_len=q_len,
+            past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
+            dtype=query_states.dtype,
+            device=query_states.device,
+        )
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+
+        if not position_ids.nelement() > 1:
+            position_ids[0][0] = kv_seq_len - 1
+
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        ### Shift Pos: query pos is min(cache_size, idx)
+        # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+        query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
+        ###
+
+        if past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
+        key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
+        key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
+        ###
+
+        # repeat k/v heads if n_kv_heads < n_heads
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
+            self.head_dim
+        )
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+            query_states.dtype
+        )
+
+        past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
+
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        if self.config.pretraining_tp > 1:
+            attn_output = attn_output.split(
+                self.hidden_size // self.config.pretraining_tp, dim=2
+            )
+            o_proj_slices = self.o_proj.weight.split(
+                self.hidden_size // self.config.pretraining_tp, dim=1
+            )
+            attn_output = sum(
+                [
+                    F.linear(attn_output[i], o_proj_slices[i])
+                    for i in range(self.config.pretraining_tp)
+                ]
+            )
+        else:
+            attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+class H2OLlamaForCausalLM(LlamaForCausalLM):
+    def __init__(self, config):
+        super().__init__(config)
+        num_layers = len(self.model.layers)
+        for layer_idx in range(num_layers):
+            self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config)