|
@@ -502,6 +502,9 @@ class HHCache(Cache):
|
|
|
if len(self.accumulated_attention_scores) <= layer_idx:
|
|
|
self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
|
|
|
else:
|
|
|
+ if layer_idx == 0:
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
+
|
|
|
num_new_tokens = attention_scores.shape[2]
|
|
|
updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
|
|
|
updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
|