浏览代码

Update utils_llama.py

Allen 1 年之前
父节点
当前提交
036620e6d7
共有 1 个文件被更改,包括 2 次插入1 次删除
  1. 2 1
      research/long-context-llama/H2O/utils_llama.py

+ 2 - 1
research/long-context-llama/H2O/utils_llama.py

@@ -187,7 +187,8 @@ class H2OLlamaAttention(nn.Module):
                 query_position_ids = position_ids
                 key_position_ids = position_ids
 
-            cos, sin = self.rotary_emb(value_states, key_position_ids)
+            key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)
+            query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)
 
             if self.layer_idx == 0:
                 print(kv_seq_len, query_position_ids, key_position_ids)