Prechádzať zdrojové kódy

Update cache_utils.py

Allen 1 rok pred
rodič
commit
f5404d89c8
1 zmenil súbory, kde vykonal 11 pridanie a 10 odobranie
  1. 11 10
      research/long-context-llama/H2O/cache_utils.py

+ 11 - 10
research/long-context-llama/H2O/cache_utils.py

@@ -432,22 +432,23 @@ class HHCache(Cache):
         Return:
             A tuple containing the updated key and value states.
         """
-        import pdb; pdb.set_trace()
         # Update the number of seen tokens
         if layer_idx == 0:
             self._seen_tokens += key_states.shape[-2]
 
+        import pdb; pdb.set_trace()
+
         # Update the cache
         # [bsz, num_heads, seq_len, head_dim]
-        if len(self.key_cache) <= layer_idx:
-            # Empty cache
-            self.key_cache.append(key_states)
-            self.value_cache.append(value_states)
-
-        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
-            # Growing cache
-            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
-            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+        if key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
+            if len(self.key_cache) <= layer_idx:
+                # Empty cache
+                self.key_cache.append(key_states)
+                self.value_cache.append(value_states)
+            else:
+                # Growing cache
+                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 
         else:
             # Shifting cache