Browse Source

Upstream merge (#677)

Hamid Shojanazeri 6 months ago
parent
commit
3470eafbc6

+ 3 - 0
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1451,4 +1451,7 @@ openhathi
 sarvam
 subtask
 acc
+OCRVQA
+OCRVQADataCollator
+ocrvqa
 langchain

File diff suppressed because it is too large
+ 14 - 32
README.md


+ 90 - 0
recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py

@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+from datasets import load_dataset
+import itertools
+import torch
+
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
+def tokenize_dialogs(dialogs, images, processor):
+    text_prompt = processor.apply_chat_template(dialogs)
+    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
+    label_list = []
+    for i in range(len(batch["input_ids"])):
+        dialog_tokens = batch["input_ids"][i].tolist()
+        labels = copy.copy(dialog_tokens)
+        eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+        last_idx = 0
+        # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+        # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+        prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+        for n, idx in enumerate(eot_indices):
+            current_seq = labels[last_idx:idx+1]
+            if check_header(prompt_header_seqs,current_seq):
+                # found prompt header, indicating that this seq should be masked
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+            else:
+                last_idx = idx+1
+            #  Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+        assistant_header_seq = [128006, 78191, 128007]
+        labels = replace_target(assistant_header_seq,labels)
+        # Mask the padding token and image token 128256 
+        for i in range(len(labels)):
+            if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: #  128256 is image token index
+                labels[i] = -100
+        label_list.append(labels)
+    batch["labels"] = torch.tensor(label_list)
+    return batch
+
+
+def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
+    # load_dataset will return DatasetDict that contains all the data in the train set
+    dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
+    dataset = dataset_dict['train']
+    # Comment out the following line to use the full dataset, for quick testing only use 2000 samples
+    dataset = dataset.select(range(2000))
+    dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
+    return dataset
+
+class OCRVQADataCollator:
+    def __init__(self, processor):
+        self.processor = processor
+        self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
+    def __call__(self, samples):
+        dialogs,images = [],[]
+        for sample in samples:
+            image_list,sample_list = sample["images"],sample["texts"]
+            if len(image_list) > 1:
+                raise ValueError("Only support one image per sample")
+            image = image_list[0].convert("RGB") # only use the first image
+            dialog = []
+            for sample_dict in sample_list:
+                if not dialog:
+                    # only append image to the first sentence
+                    dialog += [
+                    {"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+                
+                else:
+                    dialog += [
+                    {"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+            dialogs.append(dialog)
+            images.append([image])
+        return tokenize_dialogs(dialogs,images, self.processor)
+def get_data_collator(processor):
+    return OCRVQADataCollator(processor)

File diff suppressed because it is too large
+ 33 - 0
recipes/quickstart/finetuning/finetune_vision_model.md


+ 8 - 1
recipes/quickstart/inference/local_inference/README.md

@@ -1,5 +1,12 @@
 # Local Inference
 
+For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library
+
+The way to run this would be
+```
+python multi_modal_infer.py --image_path "./resources/image.jpg" --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
+```
+
 For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments.
 To finetune all model parameters the output dir of the training has to be given as --model_name argument.
 In the case of a parameter efficient method like lora the base model has to be given as --model_name and the output dir of the training has to be given as --peft_model argument.
@@ -87,4 +94,4 @@ python inference.py --model_name <training_config.output_dir> --prompt_file <tes
 
 ## Inference on large models like Meta Llama 405B
 The FP8 quantized variants of Meta Llama (i.e. meta-llama/Meta-Llama-3.1-405B-FP8 and meta-llama/Meta-Llama-3.1-405B-Instruct-FP8) can be executed on a single node with 8x80GB H100 using the scripts located in this folder.
-To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md).
+To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md).

+ 66 - 0
recipes/quickstart/inference/local_inference/multi_modal_infer.py

@@ -0,0 +1,66 @@
+import os
+import sys
+import argparse
+from PIL import Image as PIL_Image
+import torch
+from transformers import MllamaForConditionalGeneration, MllamaProcessor
+
+
+# Constants
+DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+
+
+def load_model_and_processor(model_name: str, hf_token: str):
+    """
+    Load the model and processor based on the 11B or 90B model.
+    """
+    model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
+    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
+    return model, processor
+
+
+def process_image(image_path: str) -> PIL_Image.Image:
+    """
+    Open and convert an image from the specified path.
+    """
+    if not os.path.exists(image_path):
+        print(f"The image file '{image_path}' does not exist.")
+        sys.exit(1)
+    with open(image_path, "rb") as f:
+        return PIL_Image.open(f).convert("RGB")
+
+
+def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
+    """
+    Generate text from an image using the model and processor.
+    """
+    conversation = [
+        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
+    ]
+    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+    inputs = processor(prompt, image, return_tensors="pt").to(model.device)
+    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
+    return processor.decode(output[0])[len(prompt):]
+
+
+def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str):
+    """
+    Call all the functions. 
+    """
+    model, processor = load_model_and_processor(model_name, hf_token)
+    image = process_image(image_path)
+    result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
+    print("Generated Text: " + result)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.")
+    parser.add_argument("--image_path", type=str, help="Path to the image file")
+    parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image")
+    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)")
+    parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)")
+    parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')")
+    parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
+
+    args = parser.parse_args()
+    main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)

File diff suppressed because it is too large
+ 0 - 384
recipes/responsible_ai/Purple_Llama_Anyscale.ipynb


+ 0 - 289
recipes/responsible_ai/Purple_Llama_OctoAI.ipynb

@@ -1,289 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "LERqQn5v8-ak"
-   },
-   "source": [
-    "# **Purple Llama Using OctoAI**\n",
-    "\n",
-    "Drawing inspiration from the cybersecurity concept of \"purple teaming,\" Purple Llama embraces both offensive (red team) and defensive (blue team) strategies. Our goal is to empower developers in deploying generative AI models responsibly, aligning with best practices outlined in our Responsible Use Guide."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "PGPSI3M5PGTi"
-   },
-   "source": [
-    "#### **1 - What is Purple Llama?**\n",
-    "\n",
-    "Purple Llama is a an umbrella project that over time will bring together tools and evals to help the community build responsibly with open generative AI models. The initial release will include tools and evals for Cyber Security and Input/Output safeguards but we plan to contribute more in the near future.\n",
-    "\n",
-    "* Instruction tuned on Llama2-7b model\n",
-    "* [CyberSecurity Evals](https://github.com/facebookresearch/PurpleLlama/tree/main/CybersecurityBenchmarks_)\n",
-    "* [Llama Guard Model](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/)\n",
-    "* [Download Llama Guard](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)\n",
-    "* [Purple Llama Website](https://ai.meta.com/llama/purple-llama/)\n",
-    "* [Purple Llama Github Repo](https://github.com/facebookresearch/PurpleLlama)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "aYeHVVh45bdT"
-   },
-   "source": [
-    "#### **2 - Accessing Purple Llama**\n",
-    "* Download + Self Host (i.e. [download Purple Llama](https://ai.meta.com/resources/models-and-libraries/llama-downloads/))\n",
-    "* Hosted API Platform (e.g. [OctoAI](https://octoai.cloud/), [Anyscale](https://www.anyscale.com/), [Together](https://api.together.xyz/playground/chat/togethercomputer/llama-2-7b-chat), [Replicate](https://replicate.com/meta))\n",
-    "* Hosted Container Platform (e.g. [Azure](https://techcommunity.microsoft.com/t5/ai-machine-learning-blog/introducing-llama-2-on-azure/ba-p/3881233), [AWS](https://aws.amazon.com/blogs/machine-learning/llama-2-foundation-models-from-meta-are-now-available-in-amazon-sagemaker-jumpstart/), [GCP](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/139))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "sd54g0OHuqBY"
-   },
-   "source": [
-    "#### **3 - Using Purple Llama**\n",
-    "\n",
-    "In this notebook, We will use the Llama Guard model managed by the [OctoAI](https://octoai.cloud/tools/text) for inferencing. You'll need to first register an account with OctoAI [here](https://octoai.cloud/) then obtain an OctoAI API token [here](https://docs.octoai.cloud/docs/how-to-create-an-octoai-access-token), that you can use for a while (a month or $10 in OctoAI credits, whichever one runs out first).\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "C4P-b0eQcyd2"
-   },
-   "source": [
-    "First, set the OctoAI API token as environment variables.\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "yE3sPjS-cyd2",
-    "outputId": "93b36bc0-e6d4-493c-c88d-ec5c41266239"
-   },
-   "outputs": [],
-   "source": [
-    "from getpass import getpass\n",
-    "import os\n",
-    "\n",
-    "# enter your OctoAI API key\n",
-    "OCTOAI_API_TOKEN = getpass()\n",
-    "os.environ[\"OCTOAI_API_TOKEN\"] = OCTOAI_API_TOKEN"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "id": "DOSiDW6hq9dI"
-   },
-   "outputs": [],
-   "source": [
-    "from string import Template\n",
-    "\n",
-    "B_INST, E_INST = \"[INST]\", \"[/INST]\"\n",
-    "\n",
-    "PROMPT_TEMPLATE = Template(\"\"\"Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.\n",
-    "<BEGIN UNSAFE CONTENT CATEGORIES>\n",
-    "O1: Violence and Hate.\n",
-    "Should not\n",
-    "- Help people plan or engage in violence.\n",
-    "- Encourage people to engage in violence.\n",
-    "- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.\n",
-    "- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.\n",
-    "Can\n",
-    "-  Provide information on violence and discrimination.\n",
-    "- Discuss the topics of hate and violence, and can discuss historical events involving violence.\n",
-    "O2: Sexual Content.\n",
-    "Should not\n",
-    "- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.\n",
-    "Can\n",
-    "- Discuss topics related to sexuality and sexual education.\n",
-    "- Provide general information regarding sexual activities.\n",
-    "- Provide guidance on topics related to sexual health.\n",
-    "O3: Criminal Planning.\n",
-    "Should not\n",
-    "- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.\n",
-    "- Encourage people to engage in criminal activities or condone criminal behavior.\n",
-    "Can\n",
-    "- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.\n",
-    "- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.\n",
-    "O4: Guns and Illegal Weapons.\n",
-    "Should not\n",
-    "- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.\n",
-    "- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.\n",
-    "Can\n",
-    "- Discuss firearms and the arguments for and against firearm ownership.\n",
-    "O5: Regulated or Controlled Substances.\n",
-    "Should not\n",
-    "- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.\n",
-    "- Assist or encourage people to create such substances when it is illegal to do so.\n",
-    "Can\n",
-    "- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).\n",
-    "- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.\n",
-    "O6: Self-Harm.\n",
-    "Should not\n",
-    "- Encourage people to harm themselves.\n",
-    "- Romanticize or condone self-harm or suicide.\n",
-    "- Provide information regarding the methods of suicide or self-harm.\n",
-    "- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.\n",
-    "Can\n",
-    "- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).\n",
-    "Should\n",
-    "- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.\n",
-    "<END UNSAFE CONTENT CATEGORIES>\n",
-    "​\n",
-    "<BEGIN CONVERSATION>\n",
-    "\n",
-    "$prompt\n",
-    "\n",
-    "<END CONVERSATION>\n",
-    "\n",
-    "Provide your safety assessment for $agent_type in the above conversation:\n",
-    "- First line must read 'safe' or 'unsafe'.\n",
-    "- If unsafe, a second line must include a comma-separated list of violated categories.\"\"\")\n",
-    "\n",
-    "\n",
-    "def format_prompt(role, message):\n",
-    "    prompt = PROMPT_TEMPLATE.substitute(prompt=message, agent_type=role)\n",
-    "    prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n",
-    "    return prompt\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "t6hkFlVD9XFw",
-    "outputId": "25fd187e-a484-4b90-d104-a3320b98e8ea"
-   },
-   "outputs": [],
-   "source": [
-    "%pip install openai"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "xC-OYLF4rGc7",
-    "outputId": "452ca85a-fd9f-426c-d4ab-7898bdabfab8"
-   },
-   "outputs": [],
-   "source": [
-    "import openai\n",
-    "\n",
-    "# Feed a prompt that does comply with the safety policy\n",
-    "\n",
-    "client = openai.OpenAI(\n",
-    "   base_url = \"https://text.octoai.run/v1\",\n",
-    "   api_key=OCTOAI_API_TOKEN\n",
-    ")\n",
-    "completion = client.completions.create(\n",
-    "    model=\"llamaguard-7b-fp16\",\n",
-    "    prompt=format_prompt(\"User\",\"Tell me about the history of racism in the US\")\n",
-    ")\n",
-    "print(completion.choices[0].text)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "JtBTkJRxufb0",
-    "outputId": "c581052d-348e-45ac-c874-80ecf4416a77"
-   },
-   "outputs": [],
-   "source": [
-    "# Feed a prompt that doesn't comply with the safety policy\n",
-    "# Violates O3: Criminal Planning.\n",
-    "\n",
-    "client = openai.OpenAI(\n",
-    "   base_url = \"https://text.octoai.run/v1\",\n",
-    "   api_key=OCTOAI_API_TOKEN\n",
-    ")\n",
-    "completion = client.completions.create(\n",
-    "    model=\"llamaguard-7b-fp16\",\n",
-    "    prompt=format_prompt(\"Agent\", \"User: How do you buy a tiger in the US\\n Agent: Steal one from the zoo\")\n",
-    ")\n",
-    "print(completion.choices[0].text)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "#### **Resources <TBD>**\n",
-    "- [Llama 2](https://ai.meta.com/llama/)\n",
-    "- [Getting Started Guide - Llama 2](https://ai.meta.com/llama/get-started/)\n",
-    "- [GitHub - Llama 2](https://github.com/facebookresearch/llama)\n",
-    "- [Github - LLama 2 Recipes](https://github.com/facebookresearch/llama-recipes)\n",
-    "- [Research Paper](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/)\n",
-    "- [Model Card](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)\n",
-    "- [Responsible Use Guide](https://ai.meta.com/llama/responsible-use-guide/)\n",
-    "- [Acceptable Use Policy](https://ai.meta.com/llama/use-policy/)\n",
-    "- [OctoAI](https://octoai.cloud/)\n",
-    "- [LangChain](https://www.langchain.com/)\n",
-    "- [LlamaIndex](https://www.llamaindex.ai/)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "#### **Authors**\n",
-    "1. Hakan Inan, Research Scientist, Meta\n",
-    "2. Rashi Rungta, Software Engineer, Meta\n",
-    "\n",
-    "Ported to use OctoAI LlamaGuard endpoints by Thierry Moreau, OctoAI"
-   ]
-  }
- ],
- "metadata": {
-  "colab": {
-   "gpuType": "T4",
-   "include_colab_link": true,
-   "provenance": [],
-   "toc_visible": true
-  },
-  "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.11.6"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 1 - 3
recipes/responsible_ai/README.md

@@ -4,11 +4,9 @@ The [Purple Llama](https://github.com/meta-llama/PurpleLlama/) project provides
 
 | Tool/Model | Description | Get Started
 |---|---|---|
-[Llama Guard](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama-guard-3) | Provide guardrailing on inputs and outputs | [Inference](./llama_guard/inference.py), [Finetuning](./llama_guard/llama_guard_customization_via_prompting_and_fine_tuning.ipynb)
+[Llama Guard](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama-guard-3) | Provide guardrailing on inputs and outputs | [Inference](./llama_guard/llama_guard_text_and_vision_inference.ipynb), [Finetuning](./llama_guard/llama_guard_customization_via_prompting_and_fine_tuning.ipynb)
 [Prompt Guard](https://llama.meta.com/docs/model-cards-and-prompt-formats/prompt-guard) | Model to safeguards against jailbreak attempts and embedded prompt injections | [Notebook](./prompt_guard/prompt_guard_tutorial.ipynb)
 [Code Shield](https://github.com/meta-llama/PurpleLlama/tree/main/CodeShield) | Tool to safeguard against insecure code generated by the LLM | [Notebook](https://github.com/meta-llama/PurpleLlama/blob/main/CodeShield/notebook/CodeShieldUsageDemo.ipynb)
 
 
 
-### Running on hosted APIs
-The notebooks [input_output_guardrails.ipynb](./input_output_guardrails_with_llama.ipynb),  [Purple_Llama_Anyscale](Purple_Llama_Anyscale.ipynb) & [Purple_Llama_OctoAI](Purple_Llama_OctoAI.ipynb) contain examples for running Meta Llama Guard on cloud hosted endpoints.

+ 0 - 268
recipes/responsible_ai/input_output_guardrails_with_llama.ipynb

@@ -1,268 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "Tq-YFYlqWL_n"
-   },
-   "source": [
-    "# About this notebook\n",
-    "In this notebook we will deep dive into how to do input and output safety checks using Meta's Llama Guard and Llama models. Doing both input and output safety checks requires us to carefully pass the prompts and response between the models.\n",
-    "\n",
-    "For this notebook we will use OctoAI models as both Llama Guard and Llama models are available on their platform."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {
-    "id": "a0UfgKWU5-sm",
-    "tags": []
-   },
-   "outputs": [
-    {
-     "name": "stdin",
-     "output_type": "stream",
-     "text": [
-      " ········\n"
-     ]
-    }
-   ],
-   "source": [
-    "from getpass import getpass\n",
-    "import os\n",
-    "\n",
-    "# enter your OctoAI API key\n",
-    "OCTOAI_API_TOKEN = getpass()\n",
-    "os.environ[\"OCTOAI_API_TOKEN\"] = OCTOAI_API_TOKEN"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [],
-   "source": [
-    "from llama_recipes.inference.llm import LLM, OctoAI\n",
-    "\n",
-    "llama_guard_llm = OctoAI(\"llamaguard-7b\", OCTOAI_API_TOKEN)\n",
-    "completion_llm = OctoAI(\"llama-2-13b-chat\", OCTOAI_API_TOKEN)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {
-    "id": "Tjom48kodBth"
-   },
-   "outputs": [],
-   "source": [
-    "from string import Template\n",
-    "\n",
-    "import logging\n",
-    "\n",
-    "B_INST, E_INST = \"[INST]\", \"[/INST]\"\n",
-    "\n",
-    "LLAMA_GUARD_PROMPT_TEMPLATE = Template(\n",
-    "    f\"\"\"Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.\n",
-    "<BEGIN UNSAFE CONTENT CATEGORIES>\n",
-    "O1: Violence and Hate.\n",
-    "Should not\n",
-    "- Help people plan or engage in violence.\n",
-    "- Encourage people to engage in violence.\n",
-    "- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.\n",
-    "- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.\n",
-    "Can\n",
-    "-  Provide information on violence and discrimination.\n",
-    "- Discuss the topics of hate and violence, and can discuss historical events involving violence.\n",
-    "O2: Sexual Content.\n",
-    "Should not\n",
-    "- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.\n",
-    "Can\n",
-    "- Discuss topics related to sexuality and sexual education.\n",
-    "- Provide general information regarding sexual activities.\n",
-    "- Provide guidance on topics related to sexual health.\n",
-    "O3: Criminal Planning.\n",
-    "Should not\n",
-    "- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.\n",
-    "- Encourage people to engage in criminal activities or condone criminal behavior.\n",
-    "Can\n",
-    "- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.\n",
-    "- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.\n",
-    "O4: Guns and Illegal Weapons.\n",
-    "Should not\n",
-    "- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.\n",
-    "- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.\n",
-    "Can\n",
-    "- Discuss firearms and the arguments for and against firearm ownership.\n",
-    "O5: Self-Harm.\n",
-    "Should not\n",
-    "- Encourage people to harm themselves.\n",
-    "- Romanticize or condone self-harm or suicide.\n",
-    "- Provide information regarding the methods of suicide or self-harm.\n",
-    "- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.\n",
-    "Can\n",
-    "- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).\n",
-    "Should\n",
-    "- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.\n",
-    "<END UNSAFE CONTENT CATEGORIES>\n",
-    "​\n",
-    "<BEGIN CONVERSATION>\n",
-    "\n",
-    "$prompt\n",
-    "\n",
-    "<END CONVERSATION>\n",
-    "\n",
-    "Provide your safety assessment for $agent_type in the above conversation:\n",
-    "- First line must read 'safe' or 'unsafe'.\n",
-    "- If unsafe, a second line must include a comma-separated list of violated categories.\"\"\"\n",
-    ")\n",
-    "\n",
-    "LLAMA_PROMPT_TEMPLATE = Template(\n",
-    "    f\"\"\"<<SYS>>\n",
-    "You are a helpful agent, answering questions of the user in a short and consise manner.\n",
-    "<</SYS>>\n",
-    "\n",
-    "{{ $user_message }}\"\"\"\n",
-    ")\n",
-    "\n",
-    "def llama_format_prompt(message):\n",
-    "    prompt = LLAMA_PROMPT_TEMPLATE.substitute(user_message=message)\n",
-    "    prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n",
-    "    return prompt\n",
-    "\n",
-    "def llama_guard_format_prompt(role, message):\n",
-    "    prompt = LLAMA_GUARD_PROMPT_TEMPLATE.substitute(prompt=message, agent_type=role)\n",
-    "    prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n",
-    "    return prompt\n",
-    "\n",
-    "def get_category_description(category):\n",
-    "    CATEGORY_MAP = {\n",
-    "      '01': 'Violence and Hate',\n",
-    "      '02': 'Sexual Content',\n",
-    "      '03': 'Criminal Planning',\n",
-    "      '04': 'Guns and Illegal Weapons',\n",
-    "      '05': 'Self-Harm'\n",
-    "    }\n",
-    "\n",
-    "    if not isinstance(category, str) or category is None:\n",
-    "        raise ValueError(\"Invalid input, expected a non-empty string\")\n",
-    "\n",
-    "    if category not in CATEGORY_MAP:\n",
-    "        raise ValueError(f\"Invalid input, unknown category: {category}\")\n",
-    "\n",
-    "    return CATEGORY_MAP[category]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {
-    "id": "KwONCw1oSroO"
-   },
-   "outputs": [],
-   "source": [
-    "def parse_llamaguard_response(message):\n",
-    "    if not isinstance(message, str) or message is None:\n",
-    "        return (False, \"Invalid input, expected a non-empty string\")\n",
-    "    message = message.strip()\n",
-    "\n",
-    "    if message == \"\":\n",
-    "        return (False, \"Invalid input, message cannot be empty\")\n",
-    "\n",
-    "    tokens = message.split(\"\\n\")\n",
-    "    if tokens[0] == 'safe':\n",
-    "        return (True, \"\")\n",
-    "    else:\n",
-    "        return (False, tokens[1])\n",
-    "\n",
-    "def check_input_guardrail(user_prompt):\n",
-    "    guardrail_prompt = llama_guard_format_prompt(\"User\", user_prompt)\n",
-    "    response = llama_guard_llm.query(guardrail_prompt)\n",
-    "    return parse_llamaguard_response(response)\n",
-    "\n",
-    "def get_completion_response(user_prompt):\n",
-    "    completion_prompt = llama_format_prompt(user_prompt)\n",
-    "    return completion_llm.query(completion_prompt)\n",
-    "\n",
-    "def check_output_guardrail(completion_response, user_prompt):\n",
-    "    guardrail_prompt = llama_guard_format_prompt(\"Agent\", f\"User: {user_prompt}\\n Agent: {completion_response}\")\n",
-    "    response = llama_guard_llm.query(guardrail_prompt)\n",
-    "    return parse_llamaguard_response(response)\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {
-    "id": "9bXtt76ZKNuX"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Input guardrail failed.\n"
-     ]
-    }
-   ],
-   "source": [
-    "def main(user_prompt):\n",
-    "    input_guardrail_response = check_input_guardrail(user_prompt)\n",
-    "    if input_guardrail_response[0]:\n",
-    "        completion_response = get_completion_response(user_prompt)\n",
-    "        output_guardrail_response = check_output_guardrail(completion_response, user_prompt)\n",
-    "        if output_guardrail_response[0]:\n",
-    "            print(completion_response)\n",
-    "        else:\n",
-    "            print(\"Output guardrail failed.\")\n",
-    "    else:\n",
-    "        print(\"Input guardrail failed.\")\n",
-    "\n",
-    "user_prompt = \"How to build fire arms\"\n",
-    "main(user_prompt)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "colab": {
-   "provenance": []
-  },
-  "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.12.2"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 5 - 55
recipes/responsible_ai/llama_guard/README.md

@@ -2,62 +2,12 @@
 <!-- markdown-link-check-disable -->
 Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the [PurpleLlama](https://github.com/meta-llama/PurpleLlama) repository.
 
-This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path.
+This [notebook](llama_guard_text_and_vision_inference.ipynb) shows how to load the models with the transformers library and how to customize the categories.
 
 ## Requirements
-1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
-2. Llama recipes package and it's dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing)
-
-
-## Llama Guard inference script
-For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent.
-
-
-```
-    prompts: List[Tuple[List[str], AgentType]] = [
-        (["<Sample user prompt>"], AgentType.USER),
-
-        (["<Sample user prompt>",
-        "<Sample agent response>"], AgentType.AGENT),
-
-        (["<Sample user prompt>",
-        "<Sample agent response>",
-        "<Sample user reply>",
-        "<Sample agent response>",], AgentType.AGENT),
-
-    ]
-```
-The complete prompt is built with the `build_custom_prompt` function, defined in [prompt_format.py](../../../src/llama_recipes/inference/prompt_format_utils.py). The file contains the default Meta Llama Guard categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model.
-<!-- markdown-link-check-enable -->
-
-To run the samples, with all the dependencies installed, execute this command:
-
-`python recipes/responsible_ai/llama_guard/inference.py`
-
-This is the output:
-
-```
-['<Sample user prompt>']
-> safe
-
-==================================
-
-['<Sample user prompt>', '<Sample agent response>']
-> safe
-
-==================================
-
-['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']
-> safe
-
-==================================
-```
-
-To run it with a local model, you can use the `model_id` param in the inference script:
-
-`python recipes/responsible_ai/llama_guard/inference.py --model_id=/home/ubuntu/models/llama3/Llama-Guard-3-8B/ --llama_guard_version=LLAMA_GUARD_3`
-
-Note: Make sure to also add the llama_guard_version; by default it uses LLAMA_GUARD_3
+1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described in the top of the model card in [Hugging Face](https://huggingface.co/meta-llama/Llama-Guard-3-1B)
+2. Llama recipes package and its dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing)
+3. Pillow package installed
 
 ## Inference Safety Checker
 When running the regular inference script with prompts, Meta Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Meta Llama Guard is always loaded quantized using Hugging Face Transformers library with bitsandbytes.
@@ -66,7 +16,7 @@ In this case, the default categories are applied by the tokenizer, using the `ap
 
 Use this command for testing with a quantized Llama model, modifying the values accordingly:
 
-`python examples/inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization 8bit --enable_llamaguard_content_safety`
+`python inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --enable_llamaguard_content_safety`
 
 ## Llama Guard 3 Finetuning & Customization
 The safety categories in Llama Guard 3 can be tuned for specific application needs. Existing categories can be removed and new categories can be added to the taxonomy. The [Llama Guard Customization](./llama_guard_customization_via_prompting_and_fine_tuning.ipynb) notebook walks through the process.

+ 0 - 75
recipes/responsible_ai/llama_guard/inference.py

@@ -1,75 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-import fire
-from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
-
-
-from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
-from typing import List, Tuple
-from enum import Enum
-
-class AgentType(Enum):
-    AGENT = "Agent"
-    USER = "User"
-
-def main(
-    model_id: str = "meta-llama/Llama-Guard-3-8B",
-    llama_guard_version: str = "LLAMA_GUARD_3"
-):
-    """
-    Entry point for Llama Guard inference sample script.
-
-    This function loads Llama Guard from Hugging Face or a local model and 
-    executes the predefined prompts in the script to showcase how to do inference with Llama Guard.
-
-    Args:
-        model_id (str): The ID of the pretrained model to use for generation. This can be either the path to a local folder containing the model files,
-            or the repository ID of a model hosted on the Hugging Face Hub. Defaults to 'meta-llama/LlamaGuard-7b'.
-        llama_guard_version (LlamaGuardVersion): The version of the Llama Guard model to use for formatting prompts. Defaults to LLAMA_GUARD_1.
-    """
-    try:
-        llama_guard_version = LlamaGuardVersion[llama_guard_version]
-    except KeyError as e:
-        raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e
-
-    prompts: List[Tuple[List[str], AgentType]] = [
-        (["<Sample user prompt>"], AgentType.USER),
-
-        (["<Sample user prompt>",
-        "<Sample agent response>"], AgentType.AGENT),
-        
-        (["<Sample user prompt>",
-        "<Sample agent response>",
-        "<Sample user reply>",
-        "<Sample agent response>",], AgentType.AGENT),
-
-    ]
-
-    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
-
-    tokenizer = AutoTokenizer.from_pretrained(model_id)
-    model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
-    
-    for prompt in prompts:
-        formatted_prompt = build_default_prompt(
-                prompt[1], 
-                create_conversation(prompt[0]),
-                llama_guard_version)
-
-
-        input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
-        prompt_len = input["input_ids"].shape[-1]
-        output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
-        results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
-       
-        
-        print(prompt[0])
-        print(f"> {results}")
-        print("\n==================================\n")
-
-if __name__ == "__main__":
-    try:
-        fire.Fire(main)
-    except Exception as e:
-        print(e)

File diff suppressed because it is too large
+ 576 - 0
recipes/responsible_ai/llama_guard/llama_guard_text_and_vision_inference.ipynb


BIN
recipes/responsible_ai/llama_guard/resources/dog.jpg


BIN
recipes/responsible_ai/llama_guard/resources/pasta.jpeg


+ 5 - 3
src/llama_recipes/datasets/__init__.py

@@ -5,14 +5,16 @@ from functools import partial
 
 from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
 from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
-from llama_recipes.datasets.custom_dataset import get_custom_dataset
+from llama_recipes.datasets.custom_dataset import get_custom_dataset,get_data_collator
 from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
 from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
-
 DATASET_PREPROC = {
     "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
     "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
-}
+}
+DATALOADER_COLLATE_FUNC = {
+    "custom_dataset": get_data_collator
+}

+ 20 - 0
src/llama_recipes/datasets/custom_dataset.py

@@ -35,3 +35,23 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
 
+def get_data_collator(dataset_processer,dataset_config):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_data_collator"
+
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_processer)
+    except AttributeError as e:
+        print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
+        print("Using the default data_collator instead.")
+        return None

+ 61 - 22
src/llama_recipes/finetuning.py

@@ -14,16 +14,18 @@ from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
 )
-
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
+    AutoConfig,
     AutoTokenizer,
     BitsAndBytesConfig,
-    LlamaForCausalLM,
-    LlamaConfig,
+    AutoProcessor, 
+    MllamaForConditionalGeneration,
+    AutoModel,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.mllama.modeling_mllama import  MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
 
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
@@ -39,7 +41,7 @@ from llama_recipes.utils.config_utils import (
     get_dataloader_kwargs,
     check_fsdp_config,
 )
-from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
+from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
 
 from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
@@ -118,19 +120,35 @@ def main(**kwargs):
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
-    model = LlamaForCausalLM.from_pretrained(
+    config = AutoConfig.from_pretrained(train_config.model_name)
+    if config.model_type == "mllama":
+        is_vision = True
+        model = MllamaForConditionalGeneration.from_pretrained(
         train_config.model_name,
         quantization_config=bnb_config,
-        use_cache=use_cache,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
     )
-
+        processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
+        processor.tokenizer.padding_side='right'
+    elif config.model_type == "llama":
+        is_vision = False
+        model = AutoModel.from_pretrained(
+            train_config.model_name,
+            quantization_config=bnb_config,
+            use_cache=use_cache,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+            device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
+            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+        )
+    else:
+        raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
-    tokenizer.pad_token_id = tokenizer.eos_token_id
-
+    if not tokenizer.pad_token_id: 
+        tokenizer.pad_token_id = tokenizer.eos_token_id
+        
     # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -169,8 +187,12 @@ def main(**kwargs):
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-
+        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
+        if is_vision:
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
+        else:
+        # Create the FSDP wrapper for LlamaDecoderLayer in text models
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()
@@ -198,12 +220,16 @@ def main(**kwargs):
             model.to("xpu:0")
         elif torch.cuda.is_available():
             model.to("cuda")
-
     dataset_config = generate_dataset_config(train_config, kwargs)
+    if is_vision:
+        dataset_processer = processor
+    else:
+        dataset_processer = tokenizer
+
+    # Load and preprocess the dataset for training and validation
 
-     # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
-        tokenizer,
+        dataset_processer,
         dataset_config,
         split="train",
     )
@@ -211,7 +237,7 @@ def main(**kwargs):
         print(f"--> Training Set Length = {len(dataset_train)}")
 
     dataset_val = get_preprocessed_dataset(
-        tokenizer,
+        dataset_processer,
         dataset_config,
         split="test",
     )
@@ -219,10 +245,17 @@ def main(**kwargs):
         print(f"--> Validation Set Length = {len(dataset_val)}")
 
     if train_config.batching_strategy == "packing":
-        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
-
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-
+        if is_vision:
+            raise ValueError("Packing is not supported for vision datasets")
+        else:
+            dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
+    print("length of dataset_train", len(dataset_train))
+    custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
+    if custom_data_collator:
+        print("custom_data_collator is used")
+        train_dl_kwargs["collate_fn"] = custom_data_collator
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -230,13 +263,19 @@ def main(**kwargs):
         pin_memory=True,
         **train_dl_kwargs,
     )
+    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
 
     eval_dataloader = None
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
-            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+            if is_vision:
+                raise ValueError("Packing is not supported for vision datasets")
+            else:
+                dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
 
-        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
+        if custom_data_collator:
+            val_dl_kwargs["collate_fn"] = custom_data_collator
 
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
@@ -244,6 +283,7 @@ def main(**kwargs):
             pin_memory=True,
             **val_dl_kwargs,
         )
+        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
         if len(eval_dataloader) == 0:
             raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
         else:
@@ -266,7 +306,6 @@ def main(**kwargs):
             weight_decay=train_config.weight_decay,
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
-    # Start the training process
     results = train(
         model,
         train_dataloader,

+ 3 - 3
src/llama_recipes/policies/wrapping.py

@@ -4,6 +4,8 @@
 import functools
 
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.mllama.modeling_mllama import   MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
+
 from torch.distributed.fsdp.wrap import (
     transformer_auto_wrap_policy,
     size_based_auto_wrap_policy,
@@ -25,9 +27,7 @@ def get_llama_wrapper():
 
     llama_auto_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
-        transformer_layer_cls={
-            LlamaDecoderLayer,
-        },
+        transformer_layer_cls=set([LlamaDecoderLayer, MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer,MllamaCrossAttentionDecoderLayer])
     )
 
     return llama_auto_wrap_policy

+ 25 - 27
src/llama_recipes/utils/config_utils.py

@@ -17,8 +17,7 @@ from transformers.data import DataCollatorForSeq2Seq
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
-from llama_recipes.utils.dataset_utils import DATASET_PREPROC
-
+from llama_recipes.datasets import DATASET_PREPROC
 
 def update_config(config, **kwargs):
     if isinstance(config, (tuple, list)):
@@ -76,37 +75,36 @@ def generate_dataset_config(train_config, kwargs):
     return  dataset_config
 
 
-def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
-        kwargs = {}
-        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
-        if train_config.batching_strategy == "padding":
-            if train_config.enable_fsdp:
-                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
-                    dataset,
-                    batch_size=batch_size,
-                    rank=dist.get_rank(),
-                    num_replicas=dist.get_world_size(),
-                    shuffle=mode=="train",
-                )
-            else:
-                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
-            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
-        elif train_config.batching_strategy == "packing":
-            if train_config.enable_fsdp:
-                kwargs["sampler"] = DistributedSampler(
+def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
+    kwargs = {}
+    batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+    if train_config.batching_strategy == "padding":
+        if train_config.enable_fsdp:
+            kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
                 dataset,
+                batch_size=batch_size,
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
-                drop_last=True,
             )
-            kwargs["batch_size"] = batch_size
-            kwargs["drop_last"] = True
-            kwargs["collate_fn"] = default_data_collator
         else:
-            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
-
-        return kwargs
+            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+        kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
+    elif train_config.batching_strategy == "packing":
+        if train_config.enable_fsdp:
+            kwargs["sampler"] = DistributedSampler(
+            dataset,
+            rank=dist.get_rank(),
+            num_replicas=dist.get_world_size(),
+            shuffle=mode=="train",
+            drop_last=True,
+        )
+        kwargs["batch_size"] = batch_size
+        kwargs["drop_last"] = True
+        kwargs["collate_fn"] = default_data_collator
+    else:
+        raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+    return kwargs
 
 
 def check_fsdp_config(fsdp_config):

+ 11 - 1
src/llama_recipes/utils/dataset_utils.py

@@ -4,7 +4,7 @@
 import torch
 
 from llama_recipes.data.concatenator import ConcatDataset
-from llama_recipes.datasets import DATASET_PREPROC, get_custom_dataset
+from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
 from llama_recipes.utils.config_utils import get_dataloader_kwargs
 
 
@@ -27,6 +27,16 @@ def get_preprocessed_dataset(
         get_split(),
     )
 
+def get_custom_data_collator(
+    dataset_processer, dataset_config
+) -> torch.utils.data.Dataset:
+    if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
+        return None
+
+    return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
+        dataset_processer,
+        dataset_config
+    )
 
 def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
     dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)

+ 2 - 4
src/llama_recipes/utils/fsdp_utils.py

@@ -3,7 +3,7 @@
 from torch.distributed._tensor.device_mesh import init_device_mesh
 import os 
 
-def fsdp_auto_wrap_policy(model, transformer_layer_name):
+def fsdp_auto_wrap_policy(model, transformer_layer_names):
     import functools
 
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
@@ -20,9 +20,7 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
     transformer_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
-        transformer_layer_cls=(
-            transformer_layer_name,
-        ),
+        transformer_layer_cls=set(transformer_layer_names)
     )
 
     auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])

+ 3 - 2
src/llama_recipes/utils/train_utils.py

@@ -118,6 +118,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     max_steps_reached = False  # Flag to indicate max training steps reached
     # Start the training loop
     for epoch in range(train_config.num_epochs):
+        print(f"Starting epoch {epoch}/{train_config.num_epochs}")
+        print(f"train_config.max_train_step: {train_config.max_train_step}")
         # stop when the maximum number of training steps is reached
         if max_steps_reached:
             break
@@ -143,10 +145,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             else:
                                 batch[key] = batch[key].to(local_rank)
                         else:
-
                             if is_xpu_available():
                                 batch[key] = batch[key].to('xpu:0')
-                            else:
+                            elif torch.cuda.is_available():
                                 batch[key] = batch[key].to('cuda:0')
                     with autocast():
                         loss = model(**batch).loss