浏览代码

Merge pull request #2 from meta-llama/inference_changes

Inference/Finetuning changes
Suraj Subramanian 8 月之前
父节点
当前提交
3516a5cea7

+ 8 - 1
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1406,5 +1406,12 @@ DLAI
 agentic
 containts
 dlai
+Prerequirements
+tp
+QLoRA
+ntasks
+srun
+xH
+unquantized
 eom
-ipython
+ipython

+ 1 - 3
docs/LLM_finetuning.md

@@ -1,6 +1,6 @@
 ## LLM Fine-Tuning
 
-Here we discuss fine-tuning Meta Llama 3 with a couple of different recipes. We will cover two scenarios here:
+Here we discuss fine-tuning Meta Llama with a couple of different recipes. We will cover two scenarios here:
 
 
 ## 1. **Parameter Efficient Model Fine-Tuning**
@@ -18,8 +18,6 @@ These methods will address three aspects:
 
 HF [PEFT](https://github.com/huggingface/peft) library provides an easy way of using these methods which we make use of here. Please read more [here](https://huggingface.co/blog/peft).
 
-
-
 ## 2. **Full/ Partial Parameter Fine-Tuning**
 
 Full parameter fine-tuning has its own advantages, in this method there are multiple strategies that can help:

+ 3 - 4
docs/multi_gpu.md

@@ -6,13 +6,12 @@ To run fine-tuning on multi-GPUs, we will  make use of two packages:
 
 2. [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html) which helps us parallelize the training over multiple GPUs. [More details](LLM_finetuning.md/#2-full-partial-parameter-finetuning).
 
-Given the combination of PEFT and FSDP, we would be able to fine tune a Meta Llama 3 8B model on multiple GPUs in one node or multi-node.
+Given the combination of PEFT and FSDP, we would be able to fine tune a Meta Llama 8B model on multiple GPUs in one node.
+For big models like 405B we will need to fine-tune in a multi-node setup even if 4bit quantization is enabled.
 
 ## Requirements
 To run the examples, make sure to install the llama-recipes package and clone the github repository in order to use the provided [`finetuning.py`](../recipes/quickstart/finetuning/finetuning.py) script with torchrun (See [README.md](../README.md) for details).
 
-**Please note that the llama_recipes package will install PyTorch 2.0.1 version, in case you want to run FSDP + PEFT, please make sure to install PyTorch nightlies.**
-
 ## How to run it
 
 Get access to a machine with multiple GPUs ( in this case we tested with 4 A100 and A10s).
@@ -61,7 +60,7 @@ torchrun --nnodes 1 --nproc_per_node 8  recipes/quickstart/finetuning/finetuning
 This has been tested on 4 H100s GPUs.
 
 ```bash
- FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 torchrun --nnodes 1 --nproc_per_node 4  finetuning.py --enable_fsdp  --quantization int4 --model_name /path_of_model_folder/70B  --mixed_precision False --low_cpu_fsdp --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model
+ FSDP_CPU_RAM_EFFICIENT_LOADING=1 ACCELERATE_USE_FSDP=1 torchrun --nnodes 1 --nproc_per_node 4  finetuning.py --enable_fsdp  --quantization 4bit --model_name /path_of_model_folder/70B  --mixed_precision False --low_cpu_fsdp --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model
 ```
 
 ### Fine-tuning using FSDP on 70B Model

+ 75 - 0
recipes/3p_integrations/vllm/README.md

@@ -0,0 +1,75 @@
+# Llama inference with vLLM
+
+This folder contains an example for running Llama inference on multiple-gpus in single- as well as multi-node scenarios using vLLM.
+
+## Prerequirements
+
+To run this example we will need to install vLLM as well as ray in case multi-node inference is the goal.
+
+```bash
+pip install vllm
+
+# For multi-node inference we also need to install ray
+pip install ray[default]
+```
+
+For the following examples we will assume that we fine-tuned a base model using the LoRA method and we have setup the following environment variables pointing to the base model as well as LoRA adapter:
+
+```bash
+export MODEL_PATH=/path/to/out/base/model
+export PEFT_MODEL_PATH=/path/to/out/peft/model
+```
+
+## Single-node multi-gpu inference
+To launch the inference simply execute the following command changing the tp_size parameter to the numbers of GPUs you have available:
+
+``` bash
+python inference.py --model_name $MODEL_PATH --peft_model_name $PEFT_MODEL_PATH --tp_size 8 --user_prompt "Hello my name is"
+```
+The script will ask for another prompt ina loop after completing the generation which you can exit by simply pressing enter and leaving the prompt empty.
+When using multiple gpus the model will automatically be split accross the available GPUs using tensor parallelism.
+
+## Multi-node multi-gpu inference
+The FP8 quantized variants of Meta Llama (i.e. meta-llama/Meta-Llama-3.1-405B-FP8 and meta-llama/Meta-Llama-3.1-405B-Instruct-FP8) can be executed on a single node with 8x80GB H100 using the script located in this folder.
+To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need multi-node inference.
+vLLM allows this by leveraging pipeline parallelism accros nodes while still applying tensor parallelism insid each node.
+To start a multi-node inference we first need to set up a ray serves which well be leveraged by vLLM to execute the model across node boundaries.
+
+```bash
+# On the head node we start the clustr as follows
+ray start --head
+
+# After the server starts it prints out a couple of lines including the command to add nodes to the cluster e.g.:
+# To add another node to this Ray cluster, run
+#   ray start --address='<head-node-ip-address>:6379'
+# Where the head node ip address will depend on your environment
+
+# We can then add the worker nodes by executing the command in a shell on the worker node
+ray start --address='<head-node-ip-address>:6379'
+
+# We can check if the cluster was launched successfully by executing this on any node
+ray status
+
+# It should show the number of nodes we have added as well as the head node
+# Node status
+# ---------------------------------------------------------------
+# Active:
+#  1 node_82143b740a25228c24dc8bb3a280b328910b2fcb1987eee52efb838b
+#  1 node_3f2c673530de5de86f953771538f35437ab60e3cacd7730dbca41719
+```
+
+To launch the inference we can then execute the inference script while we adapt pp_size and tp_size to our environment.
+
+```
+pp_size - number of worker + head nodes
+
+tp_size - number of GPUs per node
+```
+
+If our environment consists of two nodes with 8 GPUs each we would execute:
+```bash
+python inference.py --model_name $MODEL_PATH --peft_model_name $PEFT_MODEL_PATH --pp_size 2 --tp_size 8 --user_prompt "Hello my name is"
+```
+
+The launch of the vLLM engine will take some time depending on your environment as each worker will need to load the checkpoint files to extract its fraction of the weights.
+and even if it seem to hang

+ 35 - 12
recipes/3p_integrations/vllm/inference.py

@@ -1,11 +1,13 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import uuid
+import asyncio
 import fire
 
 import torch
-from vllm import LLM
-from vllm import LLM, SamplingParams
+from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
+from vllm.lora.request import LoRARequest
 from accelerate.utils import is_xpu_available
 
 if is_xpu_available():
@@ -15,13 +17,24 @@ else:
 
 torch.manual_seed(42)
 
-def load_model(model_name, tp_size=1):
+def load_model(model_name, peft_model=None, pp_size=1, tp_size=1):
+    additional_configs = {}
+    if peft_model:
+        additional_configs["enable_lora"] = True
+        
+    engine_config = AsyncEngineArgs(
+        model=model_name,
+        pipeline_parallel_size=pp_size,
+        tensor_parallel_size=tp_size,
+        max_loras=1,
+        **additional_configs)
 
-    llm = LLM(model_name, tensor_parallel_size=tp_size)
+    llm = AsyncLLMEngine.from_engine_args(engine_config)
     return llm
 
-def main(
+async def main(
     model,
+    peft_model_name=None,
     max_new_tokens=100,
     user_prompt=None,
     top_p=0.9,
@@ -35,26 +48,36 @@ def main(
 
         print(f"sampling params: top_p {top_p} and temperature {temperature} for this inference request")
         sampling_param = SamplingParams(top_p=top_p, temperature=temperature, max_tokens=max_new_tokens)
-        
 
-        outputs = model.generate(user_prompt, sampling_params=sampling_param)
+        lora_request = None
+        if peft_model_name:
+            lora_request = LoRARequest("lora",0,peft_model_name)
+
+        req_id = str(uuid.uuid4())
+
+        generator = model.generate(user_prompt, sampling_param, req_id, lora_request=lora_request)
+        output = None
+        async for request_output in generator:
+            output = request_output
    
-        print(f"model output:\n {user_prompt} {outputs[0].outputs[0].text}")
+        print(f"model output:\n {user_prompt} {output.outputs[0].text}")
         user_prompt = input("Enter next prompt (press Enter to exit): ")
         if not user_prompt:
             break
 
 def run_script(
     model_name: str,
-    peft_model=None,
-    tp_size=1,
+    peft_model_name=None,
+    pp_size : int = 1,
+    tp_size : int = 1,
     max_new_tokens=100,
     user_prompt=None,
     top_p=0.9,
     temperature=0.8
 ):
-    model = load_model(model_name, tp_size)
-    main(model, max_new_tokens, user_prompt, top_p, temperature)
+    model = load_model(model_name, peft_model_name, pp_size, tp_size)
+
+    asyncio.get_event_loop().run_until_complete(main(model, peft_model_name, max_new_tokens, user_prompt, top_p, temperature))
 
 if __name__ == "__main__":
     fire.Fire(run_script)

+ 1 - 3
recipes/quickstart/finetuning/LLM_finetuning_overview.md

@@ -1,6 +1,6 @@
 ## LLM Fine-Tuning
 
-Here we discuss fine-tuning Meta Llama 3 with a couple of different recipes. We will cover two scenarios here:
+Here we discuss fine-tuning Meta Llama with a couple of different recipes. We will cover two scenarios here:
 
 
 ## 1. **Parameter Efficient Model Fine-Tuning**
@@ -18,8 +18,6 @@ These methods will address three aspects:
 
 HF [PEFT](https://github.com/huggingface/peft) library provides an easy way of using these methods which we make use of here. Please read more [here](https://huggingface.co/blog/peft).
 
-
-
 ## 2. **Full/ Partial Parameter Fine-Tuning**
 
 Full parameter fine-tuning has its own advantages, in this method there are multiple strategies that can help:

+ 25 - 0
recipes/quickstart/finetuning/multigpu_finetuning.md

@@ -68,7 +68,32 @@ If you are running full parameter fine-tuning on the 70B model, you can enable `
 torchrun --nnodes 1 --nproc_per_node 8 finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /path_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 ```
 
+**Multi GPU multi node**:
 
+Here we use a slurm script to schedule a job with slurm over multiple nodes.
+
+```bash
+
+sbatch recipes/quickstart/finetuning/multi_node.slurm
+# Change the num nodes and GPU per nodes in the script before running.
+
+```
+
+To fine-tune the Meta Llama 405B model with LoRA on 32xH100, 80 GB GPUs we need to combine 4bit quantization (QLoRA) and FSDP.
+We can achieve this by adding the following environment variables to the slurm script (before the srun command in the bottom).
+
+```bash
+export FSDP_CPU_RAM_EFFICIENT_LOADING=1
+export ACCELERATE_USE_FSDP=1 
+```
+
+Then we need to replace the bottom srun command with the following:
+
+```bash
+srun  torchrun --nproc_per_node 8 --rdzv_id $RANDOM --rdzv_backend c10d --rdzv_endpoint $head_node_ip:29500 ./finetuning.py  --enable_fsdp --use_peft --peft_method lora --quantization 4bit  --quantization_config.quant_type nf4 --mixed_precision False --low_cpu_fsdp
+```
+
+Do not forget to adjust the number of nodes, ntasks and gpus-per-task in the top.
 
 ## Running with different datasets
 Currently 3 open source datasets are supported that can be found in [Datasets config file](../../../src/llama_recipes/configs/datasets.py). You can also use your custom dataset (more info [here](./datasets/README.md)).

+ 7 - 4
recipes/quickstart/inference/local_inference/README.md

@@ -27,8 +27,8 @@ samsum_prompt.txt
 ...
 ```
 
-**Note**
-Currently pad token by default in [HuggingFace Tokenizer is `None`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L110). We add the padding token as a special token to the tokenizer, which in this case requires to resize the token_embeddings as shown below:
+**Note on Llama version < 3.1**
+The default padding token in [HuggingFace Tokenizer is `None`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L110). To use padding the padding token needs to be added as a special token to the tokenizer, which in this case requires to resize the token_embeddings as shown below:
 
 ```python
 tokenizer.add_special_tokens(
@@ -39,8 +39,7 @@ tokenizer.add_special_tokens(
     )
 model.resize_token_embeddings(model.config.vocab_size + 1)
 ```
-Padding would be required for batch inference. In this this [example](inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference.
-
+Padding would be required for batched inference. In this [example](inference.py), batch size = 1 so essentially padding is not required. However, we added the code pointer as an example in case of batch inference. For Llama version 3.1 use the special token `<|finetune_right_pad_id|> (128004)` for padding.
 
 ## Chat completion
 The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
@@ -85,3 +84,7 @@ Then run inference using:
 python inference.py --model_name <training_config.output_dir> --prompt_file <test_prompt_file>
 
 ```
+
+## Inference on large models like Meta Llama 405B
+The FP8 quantized variants of Meta Llama (i.e. meta-llama/Meta-Llama-3.1-405B-FP8 and meta-llama/Meta-Llama-3.1-405B-Instruct-FP8) can be executed on a single node with 8x80GB H100 using the scripts located in this folder.
+To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md).

+ 13 - 11
recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py

@@ -4,6 +4,7 @@
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
 import fire
+import json
 import os
 import sys
 
@@ -18,7 +19,7 @@ from accelerate.utils import is_xpu_available
 def main(
     model_name,
     peft_model: str=None,
-    quantization: bool=False,
+    quantization: str = None, # Options: 4bit, 8bit
     max_new_tokens =256, #The maximum numbers of tokens to generate
     min_new_tokens:int=0, #The minimum numbers of tokens to generate
     prompt_file: str=None,
@@ -47,33 +48,32 @@ def main(
 
     elif not sys.stdin.isatty():
         dialogs = "\n".join(sys.stdin.readlines())
+        try:
+            dialogs = json.loads(dialogs)
+        except:
+            print("Could not parse json from stdin. Please provide a json file with the user prompts. Exiting.")
+            sys.exit(1)
     else:
         print("No user prompt provided. Exiting.")
         sys.exit(1)
 
     print(f"User dialogs:\n{dialogs}")
     print("\n==================================\n")
-
-
+    
     # Set the seeds for reproducibility
     if is_xpu_available():
         torch.xpu.manual_seed(seed)
     else:
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
-    model = load_model(model_name, quantization, use_fast_kernels)
+
+    model = load_model(model_name, quantization, use_fast_kernels, **kwargs)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     tokenizer = AutoTokenizer.from_pretrained(model_name)
-    tokenizer.add_special_tokens(
-        {
-
-            "pad_token": "<PAD>",
-        }
-    )
 
-    chats = tokenizer.apply_chat_template(dialogs)
+    chats = [tokenizer.apply_chat_template(dialog) for dialog in dialogs]
 
     with torch.no_grad():
         for idx, chat in enumerate(chats):
@@ -99,12 +99,14 @@ def main(
                 sys.exit(1)  # Exit the program with an error status
             tokens= torch.tensor(chat).long()
             tokens= tokens.unsqueeze(0)
+            attention_mask = torch.ones_like(tokens)
             if is_xpu_available():
                 tokens= tokens.to("xpu:0")
             else:
                 tokens= tokens.to("cuda:0")
             outputs = model.generate(
                 input_ids=tokens,
+                attention_mask=attention_mask,
                 max_new_tokens=max_new_tokens,
                 do_sample=do_sample,
                 top_p=top_p,

+ 145 - 128
recipes/quickstart/inference/local_inference/inference.py

@@ -1,68 +1,46 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
-
-import fire
 import os
 import sys
 import time
+
+import fire
 import gradio as gr
 
 import torch
-from transformers import AutoTokenizer
 
-from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
+from accelerate.utils import is_xpu_available
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 
-from accelerate.utils import is_xpu_available
+from llama_recipes.inference.safety_utils import AgentType, get_safety_checker
+from transformers import AutoTokenizer
+
 
 def main(
     model_name,
-    peft_model: str=None,
-    quantization: bool=False,
-    max_new_tokens =100, #The maximum numbers of tokens to generate
-    prompt_file: str=None,
-    seed: int=42, #seed value for reproducibility
-    do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
-    min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
-    use_cache: bool=True,  #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
-    top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
-    temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
-    top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
-    repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
-    length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
-    enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
-    enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
-    enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
-    enable_llamaguard_content_safety: bool=False,
-    max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
-    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-    **kwargs
+    peft_model: str = None,
+    quantization: str = None, # Options: 4bit, 8bit
+    max_new_tokens=100,  # The maximum numbers of tokens to generate
+    prompt_file: str = None,
+    seed: int = 42,  # seed value for reproducibility
+    do_sample: bool = True,  # Whether or not to use sampling ; use greedy decoding otherwise.
+    min_length: int = None,  # The minimum length of the sequence to be generated, input prompt + min_new_tokens
+    use_cache: bool = True,  # [optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
+    top_p: float = 1.0,  # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
+    temperature: float = 1.0,  # [optional] The value used to modulate the next token probabilities.
+    top_k: int = 50,  # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
+    repetition_penalty: float = 1.0,  # The parameter for repetition penalty. 1.0 means no penalty.
+    length_penalty: int = 1,  # [optional] Exponential penalty to the length that is used with beam-based generation.
+    enable_azure_content_safety: bool = False,  # Enable safety check with Azure content safety api
+    enable_sensitive_topics: bool = False,  # Enable check for sensitive topics using AuditNLG APIs
+    enable_salesforce_content_safety: bool = True,  # Enable safety check with Salesforce safety flan t5
+    enable_llamaguard_content_safety: bool = False,
+    max_padding_length: int = None,  # the max padding length to be used with tokenizer padding the prompts.
+    use_fast_kernels: bool = False,  # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    share_gradio: bool = False,  # Enable endpoint creation for gradio.live
+    **kwargs,
 ):
-
-  def inference(user_prompt, temperature, top_p, top_k, max_new_tokens, **kwargs,):
-    safety_checker = get_safety_checker(enable_azure_content_safety,
-                                        enable_sensitive_topics,
-                                        enable_salesforce_content_safety,
-                                        enable_llamaguard_content_safety
-                                        )
-
-    # Safety check of the user prompt
-    safety_results = [check(user_prompt) for check in safety_checker]
-    are_safe = all([r[1] for r in safety_results])
-    if are_safe:
-        print("User prompt deemed safe.")
-        print(f"User prompt:\n{user_prompt}")
-    else:
-        print("User prompt deemed unsafe.")
-        for method, is_safe, report in safety_results:
-            if not is_safe:
-                print(method)
-                print(report)
-        print("Skipping the inference as the prompt is not safe.")
-        sys.exit(1)  # Exit the program with an error status
-
     # Set the seeds for reproducibility
     if is_xpu_available():
         torch.xpu.manual_seed(seed)
@@ -70,7 +48,7 @@ def main(
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
 
-    model = load_model(model_name, quantization, use_fast_kernels)
+    model = load_model(model_name, quantization, use_fast_kernels, **kwargs)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
@@ -79,86 +57,125 @@ def main(
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token
 
-    batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
-    if is_xpu_available():
-        batch = {k: v.to("xpu") for k, v in batch.items()}
-    else:
-        batch = {k: v.to("cuda") for k, v in batch.items()}
-
-    start = time.perf_counter()
-    with torch.no_grad():
-        outputs = model.generate(
-            **batch,
-            max_new_tokens=max_new_tokens,
-            do_sample=do_sample,
-            top_p=top_p,
-            temperature=temperature,
-            min_length=min_length,
-            use_cache=use_cache,
-            top_k=top_k,
-            repetition_penalty=repetition_penalty,
-            length_penalty=length_penalty,
-            **kwargs
+    def inference(
+        user_prompt,
+        temperature,
+        top_p,
+        top_k,
+        max_new_tokens,
+        **kwargs,
+    ):
+        safety_checker = get_safety_checker(
+            enable_azure_content_safety,
+            enable_sensitive_topics,
+            enable_salesforce_content_safety,
+            enable_llamaguard_content_safety,
         )
-    e2e_inference_time = (time.perf_counter()-start)*1000
-    print(f"the inference time is {e2e_inference_time} ms")
-    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
-
-    # Safety check of the model output
-    safety_results = [check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker]
-    are_safe = all([r[1] for r in safety_results])
-    if are_safe:
-        print("User input and model output deemed safe.")
-        print(f"Model output:\n{output_text}")
-    else:
-        print("Model output deemed unsafe.")
-        for method, is_safe, report in safety_results:
-            if not is_safe:
-                print(method)
-                print(report)
-    return output_text
-
-  if prompt_file is not None:
-      assert os.path.exists(
-          prompt_file
-      ), f"Provided Prompt file does not exist {prompt_file}"
-      with open(prompt_file, "r") as f:
-          user_prompt = "\n".join(f.readlines())
-      inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
-  elif not sys.stdin.isatty():
-      user_prompt = "\n".join(sys.stdin.readlines())
-      inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
-  else:
-      gr.Interface(
-        fn=inference,
-        inputs=[
-            gr.components.Textbox(
-                lines=9,
-                label="User Prompt",
-                placeholder="none",
-            ),
-            gr.components.Slider(
-                minimum=0, maximum=1, value=1.0, label="Temperature"
-            ),
-            gr.components.Slider(
-                minimum=0, maximum=1, value=1.0, label="Top p"
-            ),
-            gr.components.Slider(
-                minimum=0, maximum=100, step=1, value=50, label="Top k"
-            ),
-            gr.components.Slider(
-                minimum=1, maximum=2000, step=1, value=200, label="Max tokens"
-            ),
-        ],
-        outputs=[
-            gr.components.Textbox(
-                lines=5,
-                label="Output",
+
+        # Safety check of the user prompt
+        safety_results = [check(user_prompt) for check in safety_checker]
+        are_safe = all([r[1] for r in safety_results])
+        if are_safe:
+            print("User prompt deemed safe.")
+            print(f"User prompt:\n{user_prompt}")
+        else:
+            print("User prompt deemed unsafe.")
+            for method, is_safe, report in safety_results:
+                if not is_safe:
+                    print(method)
+                    print(report)
+            print("Skipping the inference as the prompt is not safe.")
+            return  # Exit the program with an error status
+
+        batch = tokenizer(
+            user_prompt,
+            padding="max_length",
+            truncation=True,
+            max_length=max_padding_length,
+            return_tensors="pt",
+        )
+        if is_xpu_available():
+            batch = {k: v.to("xpu") for k, v in batch.items()}
+        else:
+            batch = {k: v.to("cuda") for k, v in batch.items()}
+
+        start = time.perf_counter()
+        with torch.no_grad():
+            outputs = model.generate(
+                **batch,
+                max_new_tokens=max_new_tokens,
+                do_sample=do_sample,
+                top_p=top_p,
+                temperature=temperature,
+                min_length=min_length,
+                use_cache=use_cache,
+                top_k=top_k,
+                repetition_penalty=repetition_penalty,
+                length_penalty=length_penalty,
+                **kwargs,
             )
-        ],
-        title="Meta Llama3 Playground",
-        description="https://github.com/facebookresearch/llama-recipes",
-      ).queue().launch(server_name="0.0.0.0", share=True)
+        e2e_inference_time = (time.perf_counter() - start) * 1000
+        print(f"the inference time is {e2e_inference_time} ms")
+        output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
+
+        # Safety check of the model output
+        safety_results = [
+            check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt)
+            for check in safety_checker
+        ]
+        are_safe = all([r[1] for r in safety_results])
+        if are_safe:
+            print("User input and model output deemed safe.")
+            print(f"Model output:\n{output_text}")
+            return output_text
+        else:
+            print("Model output deemed unsafe.")
+            for method, is_safe, report in safety_results:
+                if not is_safe:
+                    print(method)
+                    print(report)
+            return None
+
+    if prompt_file is not None:
+        assert os.path.exists(
+            prompt_file
+        ), f"Provided Prompt file does not exist {prompt_file}"
+        with open(prompt_file, "r") as f:
+            user_prompt = "\n".join(f.readlines())
+        inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
+    elif not sys.stdin.isatty():
+        user_prompt = "\n".join(sys.stdin.readlines())
+        inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
+    else:
+        gr.Interface(
+            fn=inference,
+            inputs=[
+                gr.components.Textbox(
+                    lines=9,
+                    label="User Prompt",
+                    placeholder="none",
+                ),
+                gr.components.Slider(
+                    minimum=0, maximum=1, value=1.0, label="Temperature"
+                ),
+                gr.components.Slider(minimum=0, maximum=1, value=1.0, label="Top p"),
+                gr.components.Slider(
+                    minimum=0, maximum=100, step=1, value=50, label="Top k"
+                ),
+                gr.components.Slider(
+                    minimum=1, maximum=2000, step=1, value=200, label="Max tokens"
+                ),
+            ],
+            outputs=[
+                gr.components.Textbox(
+                    lines=5,
+                    label="Output",
+                )
+            ],
+            title="Meta Llama3 Playground",
+            description="https://github.com/meta-llama/llama-recipes",
+        ).queue().launch(server_name="0.0.0.0", share=share_gradio)
+
 
 if __name__ == "__main__":
     fire.Fire(main)

+ 22 - 5
src/llama_recipes/inference/model_utils.py

@@ -1,19 +1,36 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
+from llama_recipes.utils.config_utils import update_config
+from llama_recipes.configs import quantization_config  as QUANT_CONFIG
 from peft import PeftModel
 from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
+from warnings import warn
 
 # Function to load the main model for text generation
-def load_model(model_name, quantization, use_fast_kernels):
+def load_model(model_name, quantization, use_fast_kernels, **kwargs):
+    if type(quantization) == type(True):
+            warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
+            quantization = "8bit"
+
+    bnb_config = None
+    if quantization:
+        quant_config = QUANT_CONFIG()
+        update_config(quant_config, **kwargs)
+        bnb_config = quant_config.create_bnb_config(quantization)
+
     print(f"use_fast_kernels{use_fast_kernels}")
+
+    kwargs = {}
+    if bnb_config:
+        kwargs["quantization_config"]=bnb_config
+    kwargs["device_map"]="auto"
+    kwargs["low_cpu_mem_usage"]=True
+    kwargs["attn_implementation"]="sdpa" if use_fast_kernels else None
     model = AutoModelForCausalLM.from_pretrained(
         model_name,
         return_dict=True,
-        load_in_8bit=quantization,
-        device_map="auto",
-        low_cpu_mem_usage=True,
-        attn_implementation="sdpa" if use_fast_kernels else None,
+        **kwargs,
     )
     return model