|
@@ -523,8 +523,10 @@ class HHCache(Cache):
|
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
|
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
|
|
|
self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
|
|
|
+
|
|
|
+ if layer_idx == 0:
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
|
|
|
- pdb.set_trace()
|
|
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
|
"""Reorders the cache for beam search, given the selected beam indices."""
|