Преглед изворни кода

Fix quantization for inference

Matthias Reso пре 9 месеци
родитељ
комит
0920b1a415

+ 3 - 4
recipes/quickstart/inference/local_inference/README.md

@@ -27,8 +27,8 @@ samsum_prompt.txt
 ...
 ```
 
-**Note**
-Currently pad token by default in [HuggingFace Tokenizer is `None`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L110). We add the padding token as a special token to the tokenizer, which in this case requires to resize the token_embeddings as shown below:
+**Note on Llama version < 3.1**
+The default padding token in [HuggingFace Tokenizer is `None`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L110). To use padding the padding token needs to be added as a special token to the tokenizer, which in this case requires to resize the token_embeddings as shown below:
 
 ```python
 tokenizer.add_special_tokens(
@@ -39,8 +39,7 @@ tokenizer.add_special_tokens(
     )
 model.resize_token_embeddings(model.config.vocab_size + 1)
 ```
-Padding would be required for batch inference. In this this [example](inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference.
-
+Padding would be required for batched inference. In this [example](inference.py), batch size = 1 so essentially padding is not required. However, we added the code pointer as an example in case of batch inference. For Llama version 3.1 use the special token `<|finetune_right_pad_id|> (128004)` for padding.
 
 ## Chat completion
 The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:

+ 3 - 2
recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py

@@ -19,7 +19,7 @@ from accelerate.utils import is_xpu_available
 def main(
     model_name,
     peft_model: str=None,
-    quantization: bool=False,
+    quantization: str = None, # Options: 4bit, 8bit
     max_new_tokens =256, #The maximum numbers of tokens to generate
     min_new_tokens:int=0, #The minimum numbers of tokens to generate
     prompt_file: str=None,
@@ -66,7 +66,8 @@ def main(
     else:
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
-    model = load_model(model_name, quantization, use_fast_kernels)
+
+    model = load_model(model_name, quantization, use_fast_kernels, **kwargs)
     if peft_model:
         model = load_peft_model(model, peft_model)
 

+ 2 - 2
recipes/quickstart/inference/local_inference/inference.py

@@ -20,7 +20,7 @@ from transformers import AutoTokenizer
 def main(
     model_name,
     peft_model: str = None,
-    quantization: bool = False,
+    quantization: str = None, # Options: 4bit, 8bit
     max_new_tokens=100,  # The maximum numbers of tokens to generate
     prompt_file: str = None,
     seed: int = 42,  # seed value for reproducibility
@@ -48,7 +48,7 @@ def main(
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
 
-    model = load_model(model_name, quantization, use_fast_kernels)
+    model = load_model(model_name, quantization, use_fast_kernels, **kwargs)
     if peft_model:
         model = load_peft_model(model, peft_model)
 

+ 15 - 2
src/llama_recipes/inference/model_utils.py

@@ -1,16 +1,29 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
+from llama_recipes.utils.config_utils import update_config
+from llama_recipes.configs import quantization_config  as QUANT_CONFIG
 from peft import PeftModel
 from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
+from warnings import warn
 
 # Function to load the main model for text generation
-def load_model(model_name, quantization, use_fast_kernels):
+def load_model(model_name, quantization, use_fast_kernels, **kwargs):
+    if type(quantization) == type(True):
+            warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
+            quantization = "8bit"
+
+    bnb_config = None
+    if quantization:
+        quant_config = QUANT_CONFIG()
+        update_config(quant_config, **kwargs)
+        bnb_config = quant_config.create_bnb_config(quantization)
+
     print(f"use_fast_kernels{use_fast_kernels}")
     model = AutoModelForCausalLM.from_pretrained(
         model_name,
         return_dict=True,
-        load_in_8bit=quantization,
+        quantization_config=bnb_config,
         device_map="auto",
         low_cpu_mem_usage=True,
         attn_implementation="sdpa" if use_fast_kernels else None,