|
@@ -205,6 +205,7 @@ class H2OLlamaAttention(nn.Module):
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
output_attentions: bool = False,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
use_cache: bool = False,
|
|
|
|
+ cache_position: Optional[torch.LongTensor] = None,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
bsz, q_len, _ = hidden_states.size()
|