فهرست منبع

Implement H2O for long context inference on summarization tasks (#411)

Hamid Shojanazeri 10 ماه پیش
والد
کامیت
4e1466c572

+ 7 - 0
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1351,6 +1351,13 @@ Weaviate
 MediaGen
 SDXL
 SVD
+KV
+KVs
+XSUM
+contrains
+knowlege
+kv
+prefilling
 DataFrame
 DuckDB
 Groq

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 50 - 0
recipes/experimental/long-context/H2O/README.md


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 1000 - 0
recipes/experimental/long-context/H2O/data/summarization/cnn_dailymail.jsonl


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 1000 - 0
recipes/experimental/long-context/H2O/data/summarization/xsum.jsonl


+ 4 - 0
recipes/experimental/long-context/H2O/requirements.txt

@@ -0,0 +1,4 @@
+transformers
+rouge
+xopen
+needlehaystack

+ 91 - 0
recipes/experimental/long-context/H2O/run_streaming.py

@@ -0,0 +1,91 @@
+import torch
+import argparse
+import json
+import os
+import time
+import re
+import sys
+
+from utils.streaming import load, download_url, load_jsonl, greedy_generate
+
+from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+from utils.llama import H2OLlamaForCausalLM
+from utils.cache import Cache, HHCache, StaticCache
+
+
+@torch.no_grad()
+def streaming_inference_h2o(model, tokenizer, config, prompts, max_gen_len=1000, enable_h2o_generation=False):
+    past_key_values = None
+    for idx, prompt in enumerate(prompts):
+        prompt = "USER: " + prompt + "\n\nASSISTANT: "
+        print("\n" + prompt, end="")
+        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+        input_ids = input_ids.to(model.device)
+        seq_len = input_ids.shape[1]
+
+        past_key_values = greedy_generate(
+            model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
+        )
+        if enable_h2o_generation:
+            space_needed = seq_len + max_gen_len
+            past_key_values = HHCache.from_legacy_cache(config.num_window_length, config.num_heavy_hitter_tokens, past_key_values)
+            past_key_values.evict_for_space(space_needed)
+            past_key_values = past_key_values.to_legacy_cache()
+
+
+def main():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument("--input-path", type=str, default="")
+    parser.add_argument("--model-name", type=str, default="lmsys/vicuna-13b-v1.5")
+
+    parser.add_argument("--enable_h2o_generation", action='store_true')
+    parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
+    parser.add_argument("--num_window_length", type=int, default=256)
+
+    parser.add_argument("--enable_position_rolling", action='store_true')
+
+    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
+
+    args = parser.parse_args()
+
+    model_name = args.model_name
+    data_root = args.input_path
+
+    config = AutoConfig.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
+
+    if args.enable_h2o_generation:
+        config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
+        config.num_window_length = args.num_window_length
+        config.enable_position_rolling = args.enable_position_rolling
+        model = H2OLlamaForCausalLM.from_pretrained(model_name,
+            torch_dtype=torch.float16,
+            device_map='auto',
+            low_cpu_mem_usage=True,
+            config=config)
+    else:
+        model = AutoModelForCausalLM.from_pretrained(model_name,
+            torch_dtype=torch.float16,
+            device_map='auto',
+            low_cpu_mem_usage=True,)
+
+    test_filepath = os.path.join(data_root, "mt_bench.jsonl")
+    print(f"Loading data from {test_filepath} ...")
+
+    if not os.path.exists(test_filepath):
+        download_url(
+            "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl",
+            data_root,
+        )
+        os.rename(os.path.join(data_root, "question.jsonl"), test_filepath)
+
+    list_data = load_jsonl(test_filepath)
+    prompts = []
+    for sample in list_data:
+        prompts += sample["turns"]
+
+    streaming_inference_h2o(model, tokenizer, config, prompts, enable_h2o_generation=args.enable_h2o_generation)
+
+if __name__ == "__main__":
+    main()

+ 147 - 0
recipes/experimental/long-context/H2O/run_summarization.py

@@ -0,0 +1,147 @@
+import os
+import tqdm
+import json
+import copy
+import math
+
+import torch
+import logging
+import argparse
+
+import numpy as np
+from rouge import Rouge
+
+import dataclasses
+from xopen import xopen
+
+from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+from utils.llama import H2OLlamaForCausalLM
+
+def set_seed(args):
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    torch.cuda.manual_seed_all(args.seed)
+
+if __name__ == '__main__':
+
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument("--input-path", type=str, default="")
+    parser.add_argument("--output-path", type=str, default="")
+
+    parser.add_argument("--model-name", type=str, default="")
+
+    parser.add_argument("--enable_h2o_generation", action='store_true')
+    parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
+    parser.add_argument("--num_window_length", type=int, default=256)
+
+    parser.add_argument("--enable_position_rolling", action='store_true')
+
+    parser.add_argument("--sample_num", type=int, default=500)
+    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
+
+    args = parser.parse_args()
+
+    set_seed(args)
+
+    model_name = args.model_name
+    input_path = args.input_path
+    output_path = args.output_path
+    os.makedirs(os.path.dirname(output_path), exist_ok=True)
+
+    config = AutoConfig.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
+    if args.num_heavy_hitter_tokens == -1:
+        print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
+        args.num_heavy_hitter_tokens = args.num_window_length // 2
+
+    if args.enable_h2o_generation:
+        config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
+        config.num_window_length = args.num_window_length
+        config.enable_position_rolling = args.enable_position_rolling
+        model = H2OLlamaForCausalLM.from_pretrained(model_name,
+            torch_dtype=torch.float16,
+            device_map='auto',
+            low_cpu_mem_usage=True,
+            config=config)
+    else:
+        model = AutoModelForCausalLM.from_pretrained(model_name,
+            torch_dtype=torch.float16,
+            device_map='auto',
+            low_cpu_mem_usage=True,)
+
+    # loading inference data
+    requests = []
+    with open(input_path, 'r') as f:
+        for line in f:
+            if line.strip() != '':
+                requests.append(json.loads(line))
+
+    if args.sample_num < len(requests):
+        print('Sample {} Examples from {} samples'.format(args.sample_num, len(requests)))
+    requests = requests[:args.sample_num]
+
+    results = []
+    rouge = Rouge()
+    rouge1_score_list = []
+    rouge2_score_list = []
+    rougel_score_list = []
+
+    with torch.no_grad():
+        for request in tqdm.tqdm(requests):
+            result = {'request': request, 'result': {}}
+            prompt = request['article']
+            label = request['summary_gt']
+            temperature = request['temperature']
+            stop = request['stop']
+
+            input_ids = tokenizer(prompt, add_special_tokens=False, return_tensors='pt').input_ids.to(model.device)
+
+            output_sequences = model.generate(
+                input_ids=input_ids,
+                max_length=request['max_tokens'] + len(input_ids[0]),
+                temperature=temperature,
+                top_p=request['top_p'],
+                do_sample=True,
+                num_return_sequences=request['n'],
+                return_dict_in_generate=True, output_scores=True,
+                pad_token_id=tokenizer.eos_token_id
+            )
+
+            tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
+            logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
+            top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
+
+            generate_text = tokenizer.decode(output_sequences['sequences'].squeeze(0)[len(input_ids[0]):])
+            generate_text = generate_text[: generate_text.find(stop[0])]
+
+            scores = rouge.get_scores(generate_text, label)[0]
+            rouge1_score_list.append(scores['rouge-1']['f'])
+            rouge2_score_list.append(scores['rouge-2']['f'])
+            rougel_score_list.append(scores['rouge-l']['f'])
+
+            result['result'] = {
+                "choices": [
+                    {
+                        "text": generate_text,
+                        "logprobs": {
+                            "tokens": tokens, 
+                            "token_logprobs": logprobs, 
+                            "top_logprobs": top_logprobs, 
+                            "text_offset": []
+                        }, 
+                        "finish_reason": "length"
+                    }
+                ], 
+                "request_time": {
+                    "batch_time": 0, 
+                    "batch_size": 1}
+            }
+            
+            results.append(result)
+
+    print('Average Rouge1: {:.6f}, Rouge-2: {:.6f}, Rouge-l: {:.6f}'.format(np.mean(rouge1_score_list), np.mean(rouge2_score_list), np.mean(rougel_score_list)))
+    with open(output_path, 'w') as f:
+        for result in results:
+            f.write(json.dumps(result) + '\n')
+

+ 23 - 0
recipes/experimental/long-context/H2O/src/streaming.sh

@@ -0,0 +1,23 @@
+method=$1
+if [[ ${method} == 'h2o' ]]; then
+    python -u run_streaming.py \
+        --input-path data \
+        --model-name lmsys/vicuna-13b-v1.5 \
+        --enable_h2o_generation \
+        --num_heavy_hitter_tokens 2048 \
+        --num_window_length 4096 \
+        --enable_position_rolling
+elif [[ ${method} == 'full' ]]; then
+    python -u run_streaming.py \
+        --input-path data \
+        --model-name lmsys/vicuna-13b-v1.5
+else
+    echo 'unknown argment for method'
+fi
+
+
+
+
+
+
+

+ 644 - 0
recipes/experimental/long-context/H2O/utils/cache.py

@@ -0,0 +1,644 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+@dataclass
+class Cache:
+    """
+    Base, abstract class for all caches. The actual data structure is specific to each subclass.
+    """
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+                cache to be created.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        raise NotImplementedError("Make sure to implement `update` in a subclass.")
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states, if there is any."""
+        raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
+
+    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
+        """Given the sequence length of the new inputs, returns the usable length of the cache."""
+        # Cache without size limit -> all cache is usable
+        # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
+        #   length, we will need to evict part of the cache (and thus not all cache is usable)
+        max_length = self.get_max_length()
+        previous_seq_length = self.get_seq_length(layer_idx)
+        if max_length is not None and previous_seq_length + new_seq_length > max_length:
+            return max_length - new_seq_length
+        return previous_seq_length
+
+    @property
+    def seen_tokens(self):
+        logger.warning_once(
+            "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
+            "model input instead."
+        )
+        if hasattr(self, "_seen_tokens"):
+            return self._seen_tokens
+        else:
+            return None
+
+
+class DynamicCache(Cache):
+    """
+    A cache that grows dynamically as more tokens are generated. This is the default for generative models.
+
+    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+    `[batch_size, num_heads, seq_len, head_dim]`.
+    """
+
+    def __init__(self) -> None:
+        self.key_cache: List[torch.Tensor] = []
+        self.value_cache: List[torch.Tensor] = []
+        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
+
+    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+        """
+        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
+
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.key_cache)
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # Update the number of seen tokens
+        if layer_idx == 0:
+            self._seen_tokens += key_states.shape[-2]
+
+        # Update the cache
+        if len(self.key_cache) <= layer_idx:
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
+        else:
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+        return None
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        for layer_idx in range(len(self.key_cache)):
+            device = self.key_cache[layer_idx].device
+            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+            device = self.value_cache[layer_idx].device
+            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
+        legacy_cache = ()
+        for layer_idx in range(len(self)):
+            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
+        cache = cls()
+        if past_key_values is not None:
+            for layer_idx in range(len(past_key_values)):
+                key_states, value_states = past_key_values[layer_idx]
+                cache.update(key_states, value_states, layer_idx)
+        return cache
+
+
+class SinkCache(Cache):
+    """
+    A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
+    generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
+    tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
+
+    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+    `[batch_size, num_heads, seq_len, head_dim]`.
+
+    Parameters:
+        window_length (`int`):
+            The length of the context window.
+        num_sink_tokens (`int`):
+            The number of sink tokens. See the original paper for more information.
+    """
+
+    def __init__(self, window_length: int, num_sink_tokens: int) -> None:
+        self.key_cache: List[torch.Tensor] = []
+        self.value_cache: List[torch.Tensor] = []
+        self.window_length = window_length
+        self.num_sink_tokens = num_sink_tokens
+        self.cos_sin_cache = {}
+        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
+
+    @staticmethod
+    def _rotate_half(x):
+        x1 = x[..., : x.shape[-1] // 2]
+        x2 = x[..., x.shape[-1] // 2 :]
+        return torch.cat((-x2, x1), dim=-1)
+
+    def _apply_key_rotary_pos_emb(
+        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+    ) -> torch.Tensor:
+        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
+        return rotated_key_states
+
+    def _get_rerotation_cos_sin(
+        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        if key_states.shape[-2] not in self.cos_sin_cache:
+            # Upcast to float32 temporarily for better accuracy
+            cos = cos.to(torch.float32)
+            sin = sin.to(torch.float32)
+
+            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
+            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
+            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
+            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
+            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
+            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
+            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
+
+            self.cos_sin_cache[key_states.shape[-2]] = (
+                rerotation_cos.to(key_states.dtype).unsqueeze(0),
+                rerotation_sin.to(key_states.dtype).unsqueeze(0),
+            )
+        return self.cos_sin_cache[key_states.shape[-2]]
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states."""
+        return self.window_length
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
+                `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
+                rotation as the tokens are shifted.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
+        # with partially rotated position embeddings, like Phi or Persimmon.
+        sin = cache_kwargs.get("sin")
+        cos = cache_kwargs.get("cos")
+        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
+        using_rope = cos is not None and sin is not None
+
+        # Update the number of seen tokens
+        if layer_idx == 0:
+            self._seen_tokens += key_states.shape[-2]
+
+        # [bsz, num_heads, seq_len, head_dim]
+        if len(self.key_cache) <= layer_idx:
+            # Empty cache
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
+
+        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
+            # Growing cache
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+        else:
+            # Shifting cache
+            keys_to_keep = self.key_cache[layer_idx][
+                :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
+            ]
+
+            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
+            if using_rope:
+                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
+                    key_states, cos[: self.window_length], sin[: self.window_length]
+                )
+                if partial_rotation_size is not None:
+                    keys_to_keep, keys_pass = (
+                        keys_to_keep[..., :partial_rotation_size],
+                        keys_to_keep[..., partial_rotation_size:],
+                    )
+                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
+                if partial_rotation_size is not None:
+                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
+
+            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
+            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
+            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+
+            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
+            values_to_keep = self.value_cache[layer_idx][
+                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
+            ]
+            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        for layer_idx in range(len(self.key_cache)):
+            device = self.key_cache[layer_idx].device
+            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+            device = self.value_cache[layer_idx].device
+            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+
+class HHCache(Cache):
+    """
+    A cache that apply heavy-hitter oracle (https://proceedings.neurips.cc/paper_files/paper/2023/file/6ceefa7b15572587b78ecfcebb2827f8-Paper-Conference.pdf).
+    Only the heavy-hitter and the recent tokens are stored in the cache.
+
+    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+    `[batch_size, num_heads, seq_len, head_dim]`.
+
+    Parameters:
+        window_length (`int`):
+            The length of the context window.
+        num_hh_tokens (`int`):
+            The number of heavy hitter tokens. See the original paper for more information.
+    """
+
+    def __init__(self, window_length: int, num_hh_tokens: int) -> None:
+        self.key_cache: List[torch.Tensor] = []
+        self.value_cache: List[torch.Tensor] = []
+        self.window_length = window_length
+        self.num_hh_tokens = num_hh_tokens
+        self.accumulated_attention_scores: List[torch.Tensor] = []
+        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
+
+    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+        """
+        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
+
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.key_cache)
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states."""
+        return self.window_length
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+        accumulated_attention_scores: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # Update the number of seen tokens
+
+        if accumulated_attention_scores is not None:
+            self.accumulated_attention_scores.append(accumulated_attention_scores)
+
+        if layer_idx == 0:
+            self._seen_tokens += key_states.shape[-2]
+
+        # Update the cache
+        if len(self.key_cache) <= layer_idx:
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
+        else:
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+    def update_slimming(
+        self,
+        attention_scores: torch.Tensor,
+        num_kv_groups: int,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Slimming the cache based on accumulated attention scores, only keep heavy-hitters + local tokens.
+
+        Parameters:
+            attention_scores (`torch.Tensor`):
+                Attention_scores for current steps.
+            num_kv_groups (`int`):
+                The number of kv groups in repeat kv.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+        Return:
+            A tuple containing the updated key and value states.
+        """
+
+        # Update score metrics (Accumulated attention scores)
+        if len(self.accumulated_attention_scores) <= layer_idx:
+            self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
+        else:
+            num_new_tokens = attention_scores.shape[2]
+            updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
+            updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
+            self.accumulated_attention_scores[layer_idx] = updated_attention_scores
+
+        # Update KV Cache
+        if self.get_seq_length(layer_idx) > self.window_length:
+
+            seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens]
+            _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
+            keep_hh_index = keep_hh_index.sort().values
+
+            keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
+            keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
+
+            mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
+            mask = mask.scatter(-1, keep_index, 1)
+
+            bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
+            self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
+            self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
+            self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
+
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        for layer_idx in range(len(self.key_cache)):
+            device = self.key_cache[layer_idx].device
+            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+            device = self.value_cache[layer_idx].device
+            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
+        legacy_cache = ()
+        for layer_idx in range(len(self)):
+            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx],))
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(cls, window_length: int, num_hh_tokens: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
+        cache = cls(window_length, num_hh_tokens)
+        if past_key_values is not None:
+            for layer_idx in range(len(past_key_values) // 3):
+                key_states = past_key_values[layer_idx * 3]
+                value_states = past_key_values[layer_idx * 3 + 1]
+                accumulated_attention_scores = past_key_values[layer_idx * 3 + 2]
+                cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
+        return cache
+
+    def evict_for_space(self, space_needed: int):
+        num_layers = len(self.key_cache)
+
+        # Update score metrics (Accumulated attention scores)
+        if len(self.accumulated_attention_scores) < num_layers:
+            raise ValueError("The accumulated_attention_scores should be updated before evicting the cache.")
+
+        for layer_idx in range(num_layers):
+            # Update KV Cache, Evict for new coming prompts
+            if self.get_seq_length(layer_idx) + space_needed > self.window_length:
+                if self.window_length - self.num_hh_tokens <= space_needed:
+                    raise ValueError("The space_needed should be less than the window_length - num_hh_tokens.")
+
+                seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens + space_needed]
+                _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
+                keep_hh_index = keep_hh_index.sort().values
+
+                keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens + space_needed, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
+                keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
+
+                mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
+                mask = mask.scatter(-1, keep_index, 1)
+
+                bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
+                self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
+                self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
+                self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
+
+
+
+
+class StaticCache(Cache):
+    """
+    Static Cache class to be used with `torch.compile(model)`.
+
+    Parameters:
+        config (`PretrainedConfig):
+            The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
+            required to initialize the static cache.
+        max_batch_size (`int`):
+            The maximum batch size with which the model will be used.
+        max_cache_len (`int`):
+            The maximum sequence length with which the model will be used.
+        device (`torch.device`):
+            The device on which the cache should be initialized. Should be the same as the layer.
+        dtype (*optional*, defaults to `torch.float32`):
+            The default `dtype` to use when initializing the layer.
+    """
+
+    def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
+        super().__init__()
+        self.max_batch_size = max_batch_size
+        self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+        self.head_dim = (
+            config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+        )
+
+        self.dtype = dtype if dtype is not None else torch.float32
+        self.num_key_value_heads = (
+            config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
+        )
+
+        cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
+        self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+        self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for. Kept for backward compatibility
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
+                to know how much of the cache it should overwrite.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        new_cache_positions = cache_kwargs.get("cache_position")
+        k_out = self.key_cache
+        v_out = self.value_cache
+
+        k_out[:, :, new_cache_positions] = key_states
+        v_out[:, :, new_cache_positions] = value_states
+
+        return k_out, v_out
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
+        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+        # limit the check to the first batch member and head dimension.
+        # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
+        # https://github.com/pytorch/pytorch/issues/120248 is fixed
+        return (self.key_cache[0, 0].any(dim=-1)).sum()
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+        return self.max_cache_len
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        device = self.key_cache.device
+        self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
+        device = self.value_cache.device
+        self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
+
+    def to_legacy_cache(self):
+        """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
+        return None
+

+ 453 - 0
recipes/experimental/long-context/H2O/utils/llama.py

@@ -0,0 +1,453 @@
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+import warnings
+warnings.filterwarnings("ignore")
+
+import pdb
+import types
+import torch
+from torch import nn
+import torch.utils.checkpoint
+import torch.nn.functional as F
+
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import (
+    LlamaAttention,
+    rotate_half,
+    apply_rotary_pos_emb,
+    repeat_kv,
+    LlamaRotaryEmbedding,
+    LlamaForCausalLM,
+)
+from utils.cache import Cache, HHCache, StaticCache
+from transformers.utils import logging
+from transformers.modeling_outputs import BaseModelOutputWithPast
+
+logger = logging.get_logger(__name__)
+
+__all__ = ["H2OLlamaForCausalLM"]
+
+def _make_causal_mask(
+    bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
+    """
+    Make causal mask used for bi-directional self-attention.
+    """
+    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+    mask_cond = torch.arange(mask.size(-1), device=device)
+    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+    mask = mask.to(dtype)
+
+    if past_key_values_length > 0:
+        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):
+
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
+    x_embed = (x * cos) + (rotate_half(x) * sin)
+
+    return x_embed
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+class H2OLlamaAttention(nn.Module):
+    """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        if layer_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+                "when creating this class."
+            )
+
+        self.attention_dropout = config.attention_dropout
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.is_causal = True
+        self.positional_rolling = config.enable_position_rolling
+
+        if (self.head_dim * self.num_heads) != self.hidden_size:
+            raise ValueError(
+                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+                f" and `num_heads`: {self.num_heads})."
+            )
+
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+        self._init_rope()
+
+    def _init_rope(self):
+        if self.config.rope_scaling is None:
+            self.rotary_emb = LlamaRotaryEmbedding(
+                self.head_dim,
+                max_position_embeddings=self.max_position_embeddings,
+                base=self.rope_theta,
+            )
+        else:
+            scaling_type = self.config.rope_scaling["type"]
+            scaling_factor = self.config.rope_scaling["factor"]
+            if scaling_type == "linear":
+                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
+                    self.head_dim,
+                    max_position_embeddings=self.max_position_embeddings,
+                    scaling_factor=scaling_factor,
+                    base=self.rope_theta,
+                )
+            elif scaling_type == "dynamic":
+                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
+                    self.head_dim,
+                    max_position_embeddings=self.max_position_embeddings,
+                    scaling_factor=scaling_factor,
+                    base=self.rope_theta,
+                )
+            else:
+                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Cache] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+        cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        if self.config.pretraining_tp > 1:
+            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+            query_slices = self.q_proj.weight.split(
+                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+            )
+            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+            query_states = torch.cat(query_states, dim=-1)
+
+            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+            key_states = torch.cat(key_states, dim=-1)
+
+            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+            value_states = torch.cat(value_states, dim=-1)
+
+        else:
+            query_states = self.q_proj(hidden_states)
+            key_states = self.k_proj(hidden_states)
+            value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        past_key_value = getattr(self, "past_key_value", past_key_value)
+
+        if not self.positional_rolling:
+            cos, sin = self.rotary_emb(value_states, position_ids)
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+            if past_key_value is not None:
+                # sin and cos are specific to RoPE models; cache_position needed for the static cache
+                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+        else:
+            if past_key_value is not None:
+                # sin and cos are specific to RoPE models; cache_position needed for the static cache
+                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
+
+            kv_seq_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else key_states.shape[-2]
+
+            if not position_ids.nelement() > 1:
+                # decoding stage
+                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
+                query_position_ids = key_position_ids[:, -1].unsqueeze(0)
+            elif not kv_seq_len == position_ids.shape[-1]:
+                # prefilling stage with evicting
+                query_position_ids = position_ids
+                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
+            else:
+                # prefilling stage
+                query_position_ids = position_ids
+                key_position_ids = position_ids
+
+            key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)
+            query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)
+
+            query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)
+            key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)
+
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+            attn_weights = attn_weights + causal_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+
+        # Update KV Cache based on Heavy-Hitter Oracle
+        if past_key_value is not None:
+            past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)
+
+        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        if self.config.pretraining_tp > 1:
+            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+        else:
+            attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+        
+        return attn_output, attn_weights, past_key_value
+
+
+def enable_h2ocache_forward(
+    self,
+    input_ids: torch.LongTensor = None,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_values: Optional[List[torch.FloatTensor]] = None,
+    inputs_embeds: Optional[torch.FloatTensor] = None,
+    use_cache: Optional[bool] = None,
+    output_attentions: Optional[bool] = None,
+    output_hidden_states: Optional[bool] = None,
+    return_dict: Optional[bool] = None,
+    cache_position: Optional[torch.LongTensor] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+    output_hidden_states = (
+        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+    )
+    use_cache = use_cache if use_cache is not None else self.config.use_cache
+    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+    if (input_ids is None) ^ (inputs_embeds is not None):
+        raise ValueError(
+            "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+        )
+
+    if self.gradient_checkpointing and self.training and use_cache:
+        logger.warning_once(
+            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+        )
+        use_cache = False
+
+    if inputs_embeds is None:
+        inputs_embeds = self.embed_tokens(input_ids)
+
+    past_seen_tokens = 0
+    if use_cache:  # kept for BC (cache positions)
+        if not isinstance(past_key_values, StaticCache):
+            past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
+            past_seen_tokens = past_key_values.get_seq_length()
+
+    if cache_position is None:
+        if isinstance(past_key_values, StaticCache):
+            raise ValueError("cache_position is a required argument when using StaticCache.")
+        cache_position = torch.arange(
+            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+        )
+
+    if position_ids is None:
+        position_ids = cache_position.unsqueeze(0)
+
+    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
+    # embed positions
+    hidden_states = inputs_embeds
+
+    # decoder layers
+    all_hidden_states = () if output_hidden_states else None
+    all_self_attns = () if output_attentions else None
+    next_decoder_cache = None
+
+    for decoder_layer in self.layers:
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        if self.gradient_checkpointing and self.training:
+            layer_outputs = self._gradient_checkpointing_func(
+                decoder_layer.__call__,
+                hidden_states,
+                causal_mask,
+                position_ids,
+                past_key_values,
+                output_attentions,
+                use_cache,
+                cache_position,
+            )
+        else:
+            layer_outputs = decoder_layer(
+                hidden_states,
+                attention_mask=causal_mask,
+                position_ids=position_ids,
+                past_key_value=past_key_values,
+                output_attentions=output_attentions,
+                use_cache=use_cache,
+                cache_position=cache_position,
+            )
+
+        hidden_states = layer_outputs[0]
+
+        if use_cache:
+            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+        if output_attentions:
+            all_self_attns += (layer_outputs[1],)
+
+    hidden_states = self.norm(hidden_states)
+
+    # add hidden states from the last decoder layer
+    if output_hidden_states:
+        all_hidden_states += (hidden_states,)
+
+    next_cache = None
+    if use_cache:
+        next_cache = (
+            next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
+        )
+    if not return_dict:
+        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+    
+    return BaseModelOutputWithPast(
+        last_hidden_state=hidden_states,
+        past_key_values=next_cache,
+        hidden_states=all_hidden_states,
+        attentions=all_self_attns,
+    )
+
+class H2OLlamaForCausalLM(LlamaForCausalLM):
+    def __init__(self, config):
+        super().__init__(config)
+        num_layers = len(self.model.layers)
+        for layer_idx in range(num_layers):
+            self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
+
+        self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
+        self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
+        self.model.num_window_length = config.num_window_length
+    
+    def prepare_inputs_for_generation(
+        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
+    ):
+        # With static cache, the `past_key_values` is None
+        # TODO joao: standardize interface for the different Cache classes and remove of this if
+
+        has_static_cache = False
+        if past_key_values is None:
+            past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
+            has_static_cache = past_key_values is not None
+
+        past_length = 0
+        if past_key_values is not None:
+            if isinstance(past_key_values, Cache):
+                past_length = cache_position[0]
+                max_cache_length = (
+                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
+                    if past_key_values.get_max_length() is not None
+                    else None
+                )
+                cache_length = past_key_values.get_seq_length()
+
+            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
+            else:
+                past_length = cache_position[0]
+                cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
+                max_cache_length = None
+
+            # Keep only the unprocessed tokens:
+            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+            # input)
+            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+            # input_ids based on the past_length.
+            elif past_length < input_ids.shape[1]:
+                input_ids = input_ids[:, past_length:]
+            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+            if (
+                max_cache_length is not None
+                and attention_mask is not None
+                and cache_length + input_ids.shape[1] > max_cache_length
+            ):
+                attention_mask = attention_mask[:, -max_cache_length:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -input_ids.shape[1] :]
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
+            # TODO: use `next_tokens` directly instead.
+            model_inputs = {"input_ids": input_ids.contiguous()}
+
+        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
+        if cache_position is None:
+            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
+        else:
+            cache_position = cache_position[-input_length:]
+
+        if has_static_cache:
+            past_key_values = None
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "cache_position": cache_position,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+            }
+        )
+        return model_inputs

+ 123 - 0
recipes/experimental/long-context/H2O/utils/streaming.py

@@ -0,0 +1,123 @@
+"""
+    Source Code: https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/utils.py
+"""
+
+import torch
+import argparse
+from transformers import (
+    AutoTokenizer,
+    AutoModelForCausalLM,
+)
+import os.path as osp
+import ssl
+import urllib.request
+import os
+import json
+
+
+def load(model_name_or_path):
+    print(f"Loading model from {model_name_or_path} ...")
+    # however, tensor parallel for running falcon will occur bugs
+    tokenizer = AutoTokenizer.from_pretrained(
+        model_name_or_path,
+        trust_remote_code=True,
+    )
+    model = AutoModelForCausalLM.from_pretrained(
+        model_name_or_path,
+        device_map="auto",
+        torch_dtype=torch.float16,
+        trust_remote_code=True,
+    )
+    if tokenizer.pad_token_id is None:
+        if tokenizer.eos_token_id is not None:
+            tokenizer.pad_token_id = tokenizer.eos_token_id
+        else:
+            tokenizer.pad_token_id = 0
+
+    model.eval()
+
+    return model, tokenizer
+
+
+def download_url(url: str, folder="folder"):
+    """
+    Downloads the content of an url to a folder. Modified from \
+    https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
+
+    Args:
+        url (string): The url of target file.
+        folder (string): The target folder.
+
+    Returns:
+        string: File path of downloaded files.
+    """
+
+    file = url.rpartition("/")[2]
+    file = file if file[0] == "?" else file.split("?")[0]
+    path = osp.join(folder, file)
+    if osp.exists(path):
+        print(f"File {file} exists, use existing file.")
+        return path
+
+    print(f"Downloading {url}")
+    os.makedirs(folder, exist_ok=True)
+    ctx = ssl._create_unverified_context()
+    data = urllib.request.urlopen(url, context=ctx)
+    with open(path, "wb") as f:
+        f.write(data.read())
+
+    return path
+
+
+def load_jsonl(
+    file_path,
+):
+    list_data_dict = []
+    with open(file_path, "r") as f:
+        for line in f:
+            list_data_dict.append(json.loads(line))
+    return list_data_dict
+
+
+
+@torch.no_grad()
+def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
+    outputs = model(
+        input_ids=input_ids,
+        past_key_values=past_key_values,
+        use_cache=True,
+    )
+    past_key_values = outputs.past_key_values
+    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
+    generated_ids = [pred_token_idx.item()]
+    pos = 0
+    for _ in range(max_gen_len - 1):
+        outputs = model(
+            input_ids=pred_token_idx,
+            past_key_values=past_key_values,
+            use_cache=True,
+        )
+        past_key_values = outputs.past_key_values
+        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
+        generated_ids.append(pred_token_idx.item())
+        generated_text = (
+            tokenizer.decode(
+                generated_ids,
+                skip_special_tokens=True,
+                clean_up_tokenization_spaces=True,
+                spaces_between_special_tokens=False,
+            )
+            .strip()
+            .split(" ")
+        )
+
+        now = len(generated_text) - 1
+        if now > pos:
+            print(" ".join(generated_text[pos:now]), end=" ", flush=True)
+            pos = now
+
+        if pred_token_idx == tokenizer.eos_token_id:
+            break
+    print(" ".join(generated_text[pos:]), flush=True)
+    return past_key_values
+