Allen 1 rok temu
rodzic
commit
d3932c8185

+ 0 - 33
research/long-context-llama/H2O/cache_utils.py

@@ -408,32 +408,6 @@ class HHCache(Cache):
         """Returns the maximum sequence length of the cached states."""
         return self.window_length
 
-
-    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
-        """
-        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
-        sequence length.
-        """
-        if layer_idx < len(self):
-            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
-        else:
-            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
-
-    def __iter__(self):
-        """
-        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
-        keys and values
-        """
-        for layer_idx in range(len(self)):
-            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
-
-    def __len__(self):
-        """
-        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
-        to the number of layers in the model.
-        """
-        return len(self.key_cache)
-
     def update(
         self,
         key_states: torch.Tensor,
@@ -517,12 +491,6 @@ class HHCache(Cache):
             device = self.value_cache[layer_idx].device
             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
 
-
-
-
-
-
-
     def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
         """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
         legacy_cache = ()
@@ -533,7 +501,6 @@ class HHCache(Cache):
     @classmethod
     def from_legacy_cache(cls, window_length: int, num_hh_tokens: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
         """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
-        import pdb; pdb.set_trace()
         cache = cls(window_length, num_hh_tokens)
         if past_key_values is not None:
             for layer_idx in range(len(past_key_values)):

+ 6 - 5
research/long-context-llama/H2O/utils_llama.py

@@ -257,11 +257,6 @@ class H2OLlamaAttention(nn.Module):
         cos, sin = self.rotary_emb(value_states, position_ids)
         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
-        if past_key_value is not None:
-            # sin and cos are specific to RoPE models; cache_position needed for the static cache
-            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
-            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-
         key_states = repeat_kv(key_states, self.num_key_value_groups)
         value_states = repeat_kv(value_states, self.num_key_value_groups)
 
@@ -273,6 +268,12 @@ class H2OLlamaAttention(nn.Module):
 
         # upcast attention to fp32
         attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+
+        if past_key_value is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, attn_weights, self.layer_idx, cache_kwargs)
+
         attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
         attn_output = torch.matmul(attn_weights, value_states)