Allen 1 年之前
父節點
當前提交
77427ea66c
共有 2 個文件被更改,包括 3 次插入2 次删除
  1. 3 0
      research/long-context-llama/H2O/generation.py
  2. 0 2
      research/long-context-llama/H2O/utils_llama.py

+ 3 - 0
research/long-context-llama/H2O/generation.py

@@ -102,6 +102,9 @@ if __name__ == '__main__':
                 return_dict_in_generate=True, output_scores=True,
             )
 
+            if enable_h2o_generation:
+                self._clean_cache()
+
             tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
             logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
             top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]

+ 0 - 2
research/long-context-llama/H2O/utils_llama.py

@@ -57,8 +57,6 @@ class H2OKVCache_LayerWise:
         self.cache_size = hh_size + recent_size
         self.k_seq_dim = k_seq_dim
         self.v_seq_dim = v_seq_dim
-        self.k_slice = DIM_TO_SLICE[k_seq_dim]
-        self.v_slice = DIM_TO_SLICE[v_seq_dim]
         self.hh_score = None
 
     def __call__(self, past_key_values, attn_score_cache):