Explorar o código

Support converting fine-tuned llama 3.2 vision model to HF format and then local inference (#737)

Sanyam Bhutani hai 6 meses
pai
achega
799e90eb95

+ 2 - 0
recipes/quickstart/finetuning/finetune_vision_model.md

@@ -22,6 +22,8 @@ For **LoRA finetuning with FSDP**, we can run the following code:
 
 For more details about the finetuning configurations, please read the [finetuning readme](./README.md).
 
+For more details about local inference with the fine-tuned checkpoint, please read [Inference with FSDP checkpoints section](https://github.com/meta-llama/llama-recipes/tree/main/recipes/quickstart/inference/local_inference#inference-with-fsdp-checkpoints) to learn how to convert the FSDP weights into a consolidated Hugging Face formatted model for local inference.
+
 ### How to use a custom dataset to fine-tune vision model
 
 In order to use a custom dataset, please follow the steps below:

+ 6 - 3
recipes/quickstart/inference/local_inference/README.md

@@ -1,11 +1,14 @@
 # Local Inference
 
+## Hugging face setup
+**Important Note**: Before running the inference, you'll need your Hugging Face access token, which you can get at your Settings page [here](https://huggingface.co/settings/tokens). Then run `huggingface-cli login` and copy and paste your Hugging Face access token to complete the login to make sure the scripts can download Hugging Face models if needed.
+
 ## Multimodal Inference
-For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library
+For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library.
 
-The way to run this would be
+The way to run this would be:
 ```
-python multi_modal_infer.py --image_path "./resources/image.jpg" --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
+python multi_modal_infer.py --image_path PATH_TO_IMAGE --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
 ```
 
 ## Text-only Inference

+ 59 - 24
recipes/quickstart/inference/local_inference/multi_modal_infer.py

@@ -1,10 +1,11 @@
+import argparse
 import os
 import sys
-import argparse
-from PIL import Image as PIL_Image
+
 import torch
+from accelerate import Accelerator
+from PIL import Image as PIL_Image
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
-from accelerate import  Accelerator
 
 accelerator = Accelerator()
 
@@ -14,15 +15,19 @@ device = accelerator.device
 DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
 
 
-def load_model_and_processor(model_name: str, hf_token: str):
+def load_model_and_processor(model_name: str):
     """
     Load the model and processor based on the 11B or 90B model.
     """
-    model = MllamaForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16,use_safetensors=True, device_map=device,
-                                                            token=hf_token)
-    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token,use_safetensors=True)
+    model = MllamaForConditionalGeneration.from_pretrained(
+        model_name,
+        torch_dtype=torch.bfloat16,
+        use_safetensors=True,
+        device_map=device,
+    )
+    processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True)
 
-    model, processor=accelerator.prepare(model, processor)
+    model, processor = accelerator.prepare(model, processor)
     return model, processor
 
 
@@ -37,37 +42,67 @@ def process_image(image_path: str) -> PIL_Image.Image:
         return PIL_Image.open(f).convert("RGB")
 
 
-def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
+def generate_text_from_image(
+    model, processor, image, prompt_text: str, temperature: float, top_p: float
+):
     """
     Generate text from an image using the model and processor.
     """
     conversation = [
-        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
+        {
+            "role": "user",
+            "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
+        }
     ]
-    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+    prompt = processor.apply_chat_template(
+        conversation, add_generation_prompt=True, tokenize=False
+    )
     inputs = processor(image, prompt, return_tensors="pt").to(device)
-    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
-    return processor.decode(output[0])[len(prompt):]
+    output = model.generate(
+        **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
+    )
+    return processor.decode(output[0])[len(prompt) :]
 
 
-def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str):
+def main(
+    image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str
+):
     """
-    Call all the functions. 
+    Call all the functions.
     """
-    model, processor = load_model_and_processor(model_name, hf_token)
+    model, processor = load_model_and_processor(model_name)
     image = process_image(image_path)
-    result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
+    result = generate_text_from_image(
+        model, processor, image, prompt_text, temperature, top_p
+    )
     print("Generated Text: " + result)
 
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.")
+    parser = argparse.ArgumentParser(
+        description="Generate text from an image and prompt using the 3.2 MM Llama model."
+    )
     parser.add_argument("--image_path", type=str, help="Path to the image file")
-    parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image")
-    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)")
-    parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)")
-    parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')")
-    parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
+    parser.add_argument(
+        "--prompt_text", type=str, help="Prompt text to describe the image"
+    )
+    parser.add_argument(
+        "--temperature",
+        type=float,
+        default=0.7,
+        help="Temperature for generation (default: 0.7)",
+    )
+    parser.add_argument(
+        "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
+    )
+    parser.add_argument(
+        "--model_name",
+        type=str,
+        default=DEFAULT_MODEL,
+        help=f"Model name (default: '{DEFAULT_MODEL}')",
+    )
 
     args = parser.parse_args()
-    main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)
+    main(
+        args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name
+    )

+ 31 - 13
src/llama_recipes/inference/model_utils.py

@@ -1,17 +1,29 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
+from warnings import warn
+
+from llama_recipes.configs import quantization_config as QUANT_CONFIG
 from llama_recipes.utils.config_utils import update_config
-from llama_recipes.configs import quantization_config  as QUANT_CONFIG
 from peft import PeftModel
-from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
-from warnings import warn
+from transformers import (
+    AutoConfig,
+    AutoModelForCausalLM,
+    LlamaConfig,
+    LlamaForCausalLM,
+    MllamaConfig,
+    MllamaForConditionalGeneration,
+)
+
 
 # Function to load the main model for text generation
 def load_model(model_name, quantization, use_fast_kernels, **kwargs):
     if type(quantization) == type(True):
-            warn("Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.", FutureWarning)
-            quantization = "8bit"
+        warn(
+            "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
+            FutureWarning,
+        )
+        quantization = "8bit"
 
     bnb_config = None
     if quantization:
@@ -23,10 +35,10 @@ def load_model(model_name, quantization, use_fast_kernels, **kwargs):
 
     kwargs = {}
     if bnb_config:
-        kwargs["quantization_config"]=bnb_config
-    kwargs["device_map"]="auto"
-    kwargs["low_cpu_mem_usage"]=True
-    kwargs["attn_implementation"]="sdpa" if use_fast_kernels else None
+        kwargs["quantization_config"] = bnb_config
+    kwargs["device_map"] = "auto"
+    kwargs["low_cpu_mem_usage"] = True
+    kwargs["attn_implementation"] = "sdpa" if use_fast_kernels else None
     model = AutoModelForCausalLM.from_pretrained(
         model_name,
         return_dict=True,
@@ -40,10 +52,16 @@ def load_peft_model(model, peft_model):
     peft_model = PeftModel.from_pretrained(model, peft_model)
     return peft_model
 
+
 # Loading the model from config to load FSDP checkpoints into that
 def load_llama_from_config(config_path):
-    model_config = LlamaConfig.from_pretrained(config_path) 
-    model = LlamaForCausalLM(config=model_config)
+    config = AutoConfig.from_pretrained(config_path)
+    if config.model_type == "mllama":
+        model = MllamaForConditionalGeneration(config=config)
+    elif config.model_type == "llama":
+        model = LlamaForCausalLM(config=config)
+    else:
+        raise ValueError(
+            f"Unsupported model type: {config.model_type}, Please use llama or mllama model."
+        )
     return model
-    
-