Browse Source

Update cache_utils.py

Allen 1 year ago
parent
commit
769e93e770
1 changed files with 3 additions and 0 deletions
  1. 3 0
      research/long-context-llama/H2O/cache_utils.py

+ 3 - 0
research/long-context-llama/H2O/cache_utils.py

@@ -502,6 +502,9 @@ class HHCache(Cache):
         if len(self.accumulated_attention_scores) <= layer_idx:
         if len(self.accumulated_attention_scores) <= layer_idx:
             self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
             self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
         else:
         else:
+            if layer_idx == 0:
+                import pdb; pdb.set_trace()
+
             num_new_tokens = attention_scores.shape[2]
             num_new_tokens = attention_scores.shape[2]
             updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
             updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
             updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
             updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]