|
@@ -102,6 +102,9 @@ if __name__ == '__main__':
|
|
return_dict_in_generate=True, output_scores=True,
|
|
return_dict_in_generate=True, output_scores=True,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ if enable_h2o_generation:
|
|
|
|
+ self._clean_cache()
|
|
|
|
+
|
|
tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
|
|
tokens = tokenizer.convert_ids_to_tokens(output_sequences['sequences'].squeeze(0))[len(input_ids[0]):]
|
|
logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
|
|
logprobs = [logits.log_softmax(dim=-1).max().item() for logits in output_sequences['scores']]
|
|
top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
|
|
top_logprobs = [{i: v for i, v in zip(tokens, logprobs)}]
|