Sanyam Bhutani 5 months ago
parent
commit
76762685c0

+ 1 - 0
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1505,3 +1505,4 @@ locallama
 myshell
 parler
 xTTS
+pydantic

+ 13 - 0
recipes/3p_integrations/togetherai/README.md

@@ -0,0 +1,13 @@
+# Building LLM apps using Llama on Together.ai
+
+This folder contains demos on how to use Llama on [Together.ai](https://www.together.ai/) to quickly build LLM apps.
+
+The first demo is a notebook that converts PDF to podcast using Llama 3.1 70B or 8B hosted by Together.ai. It differs and complements the [Meta's implementation](https://github.com/meta-llama/llama-recipes/tree/main/recipes/quickstart/NotebookLlama) in several ways:
+
+1. You don't need to download the Llama models from HuggingFace and have a GPU to run the notebooks - you can quickly get a free Together API key and run the whole Colab notebook on a browser, in about 10 minutes;
+2. A single system prompt is used to generate the naturally sounding podcast from PDF, with the support of pydantic, scratchpad and JSON response format to make the whole flow simple yet powerful;
+3. A different TTS service, also with an easy-to-get free API key, is used.
+
+The whole Colab notebook can run with a single "Runtime - Run all" click, generating the podcast audio from the Transformer paper that started the GenAI revolution. 
+
+ 

+ 508 - 0
recipes/3p_integrations/togetherai/pdf_to_podcast_using_llama_on_together.ipynb

@@ -0,0 +1,508 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/3p_integrations/togetherai/pdf_to_podcast_using_llama_on_together.ipynb)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "1FXUu7Ydf2p3"
+   },
+   "source": [
+    "# A Quick Implementation of PDF to Podcast Using Llama 3.1 on Together.ai"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "_cuH3nHpkZal"
+   },
+   "source": [
+    "### Introduction\n",
+    "\n",
+    "In this notebook we will see how to easily create a podcast from PDF using [Llama 3.1 70b (or 8b) hosted on Together.ai](https://api.together.ai/models/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo).\n",
+    "\n",
+    "**The quickest way to try the whole notebook is to open the Colab link above, then select Runtime - Run all.**"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "yA6mSWAcf2p6"
+   },
+   "source": [
+    "Inspired by [Notebook LM's](https://notebooklm.google/) podcast generation feature and a recent open source implementation of [Open Notebook LM](https://github.com/gabrielchua/open-notebooklm). In this cookbook we will implement a walkthrough of how you can build a PDF to podcast pipeline.\n",
+    "\n",
+    "Given any PDF we will generate a conversation between a host and a guest discussing and explaining the contents of the PDF.\n",
+    "\n",
+    "In doing so we will learn the following:\n",
+    "1. How we can use JSON mode and structured generation with open models like Llama 3 70b to extract a script for the Podcast given text from the PDF.\n",
+    "2. How we can use TTS models to bring this script to life as a conversation.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "cN0Tpr76ssM1",
+    "outputId": "4a2e1eed-4ce6-4bff-c6f0-c59f4730725b"
+   },
+   "outputs": [],
+   "source": [
+    "!apt install -q libasound2-dev portaudio19-dev libportaudio2 libportaudiocpp0 ffmpeg\n",
+    "!pip install -q ffmpeg-python\n",
+    "!pip install -q PyAudio\n",
+    "!pip install -q pypdf #to read PDF content\n",
+    "!pip install -q together #to access open source LLMs\n",
+    "!pip install -q cartesia #to access TTS model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "iWea6go4r72c"
+   },
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "\n",
+    "# Standard library imports\n",
+    "from pathlib import Path\n",
+    "from tempfile import NamedTemporaryFile\n",
+    "from typing import List, Literal, Tuple, Optional\n",
+    "\n",
+    "# Third-party imports\n",
+    "from pydantic import BaseModel\n",
+    "from pypdf import PdfReader\n",
+    "\n",
+    "from together import Together\n",
+    "from cartesia import Cartesia\n",
+    "from pydantic import ValidationError"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "You can easily get free trial API keys at [Together.ai](https://api.together.ai/settings/api-keys) and [cartesia.ai](https://play.cartesia.ai/keys). After that, replace the keys below."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "7GYTmdx_s6QL"
+   },
+   "outputs": [],
+   "source": [
+    "client_together = Together(api_key=\"xxx\")\n",
+    "client_cartesia = Cartesia(api_key=\"yyy\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "LGWv-oZ2f2p8"
+   },
+   "source": [
+    "### Define Dialogue Schema with Pydantic\n",
+    "\n",
+    "We need a way of telling the LLM what the structure of the podcast script between the guest and host will look like. We will do this using `pydantic` models.\n",
+    "\n",
+    "Below we define the required classes.\n",
+    "\n",
+    "- The overall conversation consists of lines said by either the host or the guest. The `DialogueItem` class specifies the structure of these lines.\n",
+    "- The full script is a combination of multiple lines performed by the speakers, here we also include a scratchpad field to allow the LLM to ideate and brainstorm the overall flow of the script prior to actually generating the lines. The `Dialogue` class specifies this."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "zYOq3bdntLgl"
+   },
+   "outputs": [],
+   "source": [
+    "class DialogueItem(BaseModel):\n",
+    "    \"\"\"A single dialogue item.\"\"\"\n",
+    "\n",
+    "    speaker: Literal[\"Host (Jane)\", \"Guest\"]\n",
+    "    text: str\n",
+    "\n",
+    "\n",
+    "class Dialogue(BaseModel):\n",
+    "    \"\"\"The dialogue between the host and guest.\"\"\"\n",
+    "\n",
+    "    scratchpad: str\n",
+    "    name_of_guest: str\n",
+    "    dialogue: List[DialogueItem]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "6ZzYFsNXuDN0"
+   },
+   "outputs": [],
+   "source": [
+    "# Adapted and modified from https://github.com/gabrielchua/open-notebooklm\n",
+    "SYSTEM_PROMPT = \"\"\"\n",
+    "You are a world-class podcast producer tasked with transforming the provided input text into an engaging and informative podcast script. The input may be unstructured or messy, sourced from PDFs or web pages. Your goal is to extract the most interesting and insightful content for a compelling podcast discussion.\n",
+    "\n",
+    "# Steps to Follow:\n",
+    "\n",
+    "1. **Analyze the Input:**\n",
+    "   Carefully examine the text, identifying key topics, points, and interesting facts or anecdotes that could drive an engaging podcast conversation. Disregard irrelevant information or formatting issues.\n",
+    "\n",
+    "2. **Brainstorm Ideas:**\n",
+    "   In the `<scratchpad>`, creatively brainstorm ways to present the key points engagingly. Consider:\n",
+    "   - Analogies, storytelling techniques, or hypothetical scenarios to make content relatable\n",
+    "   - Ways to make complex topics accessible to a general audience\n",
+    "   - Thought-provoking questions to explore during the podcast\n",
+    "   - Creative approaches to fill any gaps in the information\n",
+    "\n",
+    "3. **Craft the Dialogue:**\n",
+    "   Develop a natural, conversational flow between the host (Jane) and the guest speaker (the author or an expert on the topic). Incorporate:\n",
+    "   - The best ideas from your brainstorming session\n",
+    "   - Clear explanations of complex topics\n",
+    "   - An engaging and lively tone to captivate listeners\n",
+    "   - A balance of information and entertainment\n",
+    "\n",
+    "   Rules for the dialogue:\n",
+    "   - The host (Jane) always initiates the conversation and interviews the guest\n",
+    "   - Include thoughtful questions from the host to guide the discussion\n",
+    "   - Incorporate natural speech patterns, including verbal fillers such as Uhh, Hmmm, um, well\n",
+    "   - Allow for natural interruptions and back-and-forth between host and guest - this is very important to make the conversation feel authentic\n",
+    "   - Ensure the guest's responses are substantiated by the input text, avoiding unsupported claims\n",
+    "   - Maintain a PG-rated conversation appropriate for all audiences\n",
+    "   - Avoid any marketing or self-promotional content from the guest\n",
+    "   - The host concludes the conversation\n",
+    "\n",
+    "4. **Summarize Key Insights:**\n",
+    "   Naturally weave a summary of key points into the closing part of the dialogue. This should feel like a casual conversation rather than a formal recap, reinforcing the main takeaways before signing off.\n",
+    "\n",
+    "5. **Maintain Authenticity:**\n",
+    "   Throughout the script, strive for authenticity in the conversation. Include:\n",
+    "   - Moments of genuine curiosity or surprise from the host\n",
+    "   - Instances where the guest might briefly struggle to articulate a complex idea\n",
+    "   - Light-hearted moments or humor when appropriate\n",
+    "   - Brief personal anecdotes or examples that relate to the topic (within the bounds of the input text)\n",
+    "\n",
+    "6. **Consider Pacing and Structure:**\n",
+    "   Ensure the dialogue has a natural ebb and flow:\n",
+    "   - Start with a strong hook to grab the listener's attention\n",
+    "   - Gradually build complexity as the conversation progresses\n",
+    "   - Include brief \"breather\" moments for listeners to absorb complex information\n",
+    "   - For complicated concepts, reasking similar questions framed from a different perspective is recommended\n",
+    "   - End on a high note, perhaps with a thought-provoking question or a call-to-action for listeners\n",
+    "\n",
+    "IMPORTANT RULE:\n",
+    "1. Each line of dialogue should be no more than 100 characters (e.g., can finish within 5-8 seconds)\n",
+    "2. Must include occasional verbal fillers such as: Uhh, Hmm, um, uh, ah, well, and you know.\n",
+    "\n",
+    "Remember: Always reply in valid JSON format, without code blocks. Begin directly with the JSON output.\n",
+    "\"\"\""
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "sdo7pZvgf2p9"
+   },
+   "source": [
+    "### Call Llama 3.1 to Generate Podcast Script\n",
+    "\n",
+    "Below we call `Llama-3.1-70B` to generate a script for our podcast. We will also be able to read it's `scratchpad` and see how it structured the overall conversation. We can also call `Llama-3.1-8B`, but the output may not be as good as calling 70B - e.g. using 70B with the system prompt above, more natural output with occasional occasional verbal fillers such as Uhh, Hmm, Ah, Well will be generated."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "Y0RtJZ9VtVut"
+   },
+   "outputs": [],
+   "source": [
+    "def call_llm(system_prompt: str, text: str, dialogue_format):\n",
+    "    \"\"\"Call the LLM with the given prompt and dialogue format.\"\"\"\n",
+    "    response = client_together.chat.completions.create(\n",
+    "        messages=[\n",
+    "            {\"role\": \"system\", \"content\": system_prompt},\n",
+    "            {\"role\": \"user\", \"content\": text},\n",
+    "        ],\n",
+    "        model=\"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",  # can also use \"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo\"\n",
+    "        response_format={\n",
+    "            \"type\": \"json_object\",\n",
+    "            \"schema\": dialogue_format.model_json_schema(),\n",
+    "        },\n",
+    "    )\n",
+    "    return response"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "FvW4J7W3tOow"
+   },
+   "outputs": [],
+   "source": [
+    "def generate_script(system_prompt: str, input_text: str, output_model):\n",
+    "    \"\"\"Get the dialogue from the LLM.\"\"\"\n",
+    "    # Load as python object\n",
+    "    try:\n",
+    "        response = call_llm(system_prompt, input_text, output_model)\n",
+    "        dialogue = output_model.model_validate_json(\n",
+    "            response.choices[0].message.content\n",
+    "        )\n",
+    "    except ValidationError as e:\n",
+    "        error_message = f\"Failed to parse dialogue JSON: {e}\"\n",
+    "        system_prompt_with_error = f\"{system_prompt}\\n\\nPlease return a VALID JSON object. This was the earlier error: {error_message}\"\n",
+    "        response = call_llm(system_prompt_with_error, input_text, output_model)\n",
+    "        dialogue = output_model.model_validate_json(\n",
+    "            response.choices[0].message.content\n",
+    "        )\n",
+    "    return dialogue"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "eYLRkNiqf2p-"
+   },
+   "source": [
+    "### Load in PDF of Choice\n",
+    "\n",
+    "Here we will load in an academic paper that proposes the use of many open source language models in a collaborative manner together to outperform proprietary models that are much larger!"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "6c2nbb7Hu2jV",
+    "outputId": "03cb849e-0ef1-4d1a-d274-739752b1d456"
+   },
+   "outputs": [],
+   "source": [
+    "# the transformer paper!\n",
+    "!wget https://arxiv.org/pdf/1706.03762\n",
+    "!mv 1706.03762 attention.pdf"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "Rn-lhgqmueWM"
+   },
+   "outputs": [],
+   "source": [
+    "def get_PDF_text(file : str):\n",
+    "    text = ''\n",
+    "\n",
+    "    # Read the PDF file and extract text\n",
+    "    try:\n",
+    "        with Path(file).open(\"rb\") as f:\n",
+    "            reader = PdfReader(f)\n",
+    "            text = \"\\n\\n\".join([page.extract_text() for page in reader.pages])\n",
+    "    except Exception as e:\n",
+    "        raise f\"Error reading the PDF file: {str(e)}\"\n",
+    "\n",
+    "    if len(text) > 400000:\n",
+    "        raise \"The PDF is too long. Please upload a smaller PDF.\"\n",
+    "\n",
+    "    return text"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "D9BzDxmgvS2V",
+    "outputId": "fb6cc3a5-2a7d-4289-bbcb-2f1d4bf4674d"
+   },
+   "outputs": [],
+   "source": [
+    "text = get_PDF_text('attention.pdf')\n",
+    "len(text), text[:1000]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "vevOUMXJf2p_"
+   },
+   "source": [
+    "### Generate Script\n",
+    "\n",
+    "Below we generate the script and print out the lines."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "f5rBXur8vXnP"
+   },
+   "outputs": [],
+   "source": [
+    "script = generate_script(SYSTEM_PROMPT, text, Dialogue)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/"
+    },
+    "id": "inFgEVeBtCOR",
+    "outputId": "a84fabbc-62d4-43de-b966-72b29979bb9f"
+   },
+   "outputs": [],
+   "source": [
+    "script.dialogue"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "WqsYHpTwf2p_"
+   },
+   "source": [
+    "### Generate Podcast Using TTS\n",
+    "\n",
+    "Below we read through the script and parse choose the TTS voice depending on the speaker. We define a speaker and guest voice id.\n",
+    "\n",
+    "We can loop through the lines in the script and generate them by a call to the TTS model with specific voice and lines configurations. The lines all appended to the same buffer and once the script finishes we write this out to a `wav` file, ready to be played.\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "qKnQnoYNvx3k"
+   },
+   "outputs": [],
+   "source": [
+    "import subprocess\n",
+    "import ffmpeg\n",
+    "\n",
+    "host_id = \"694f9389-aac1-45b6-b726-9d9369183238\" # Jane - host\n",
+    "guest_id = \"a0e99841-438c-4a64-b679-ae501e7d6091\" # Guest\n",
+    "\n",
+    "model_id = \"sonic-english\" # The Sonic Cartesia model for English TTS\n",
+    "\n",
+    "output_format = {\n",
+    "    \"container\": \"raw\",\n",
+    "    \"encoding\": \"pcm_f32le\",\n",
+    "    \"sample_rate\": 44100,\n",
+    "    }\n",
+    "\n",
+    "# Set up a WebSocket connection.\n",
+    "ws = client_cartesia.tts.websocket()\n",
+    "\n",
+    "# Open a file to write the raw PCM audio bytes to.\n",
+    "f = open(\"podcast.pcm\", \"wb\")\n",
+    "\n",
+    "# Generate and stream audio.\n",
+    "for line in script.dialogue:\n",
+    "    if line.speaker == \"Guest\":\n",
+    "        voice_id = guest_id\n",
+    "    else:\n",
+    "        voice_id = host_id\n",
+    "\n",
+    "    for output in ws.send(\n",
+    "        model_id=model_id,\n",
+    "        transcript='-' + line.text, # the \"-\"\" is to add a pause between speakers\n",
+    "        voice_id=voice_id,\n",
+    "        stream=True,\n",
+    "        output_format=output_format,\n",
+    "    ):\n",
+    "        buffer = output[\"audio\"]  # buffer contains raw PCM audio bytes\n",
+    "        f.write(buffer)\n",
+    "\n",
+    "# Close the connection to release resources\n",
+    "ws.close()\n",
+    "f.close()\n",
+    "\n",
+    "# Convert the raw PCM bytes to a WAV file.\n",
+    "ffmpeg.input(\"podcast.pcm\", format=\"f32le\").output(\"podcast.wav\").run()\n",
+    "\n",
+    "# Play the file\n",
+    "subprocess.run([\"ffplay\", \"-autoexit\", \"-nodisp\", \"podcast.wav\"])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 75
+    },
+    "id": "STWaJf_ySctY",
+    "outputId": "63f5c555-2a4a-4d9e-9d3f-f9063289eb1d"
+   },
+   "outputs": [],
+   "source": [
+    "# Play the podcast\n",
+    "\n",
+    "import IPython\n",
+    "\n",
+    "IPython.display.Audio(\"podcast.wav\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "id": "rx8ZV9Jj_AB5"
+   },
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "colab": {
+   "provenance": [],
+   "toc_visible": true
+  },
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.15"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 5 - 5
recipes/quickstart/agents/DeepLearningai_Course_Notebooks/AI_Agentic_Design_Patterns_with_AutoGen_L4_Tool_Use_and_Conversational_Chess.ipynb

@@ -5,7 +5,7 @@
    "id": "7a4b75bb-d60a-41e3-abca-1ca0f0bf1201",
    "metadata": {},
    "source": [
-    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/agents/dlai/AI_Agentic_Design_Patterns_with_AutoGen_L4_Tool_Use_and_Conversational_Chess.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/agents/DeepLearningai_Course_Notebooks/AI_Agentic_Design_Patterns_with_AutoGen_L4_Tool_Use_and_Conversational_Chess.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
   },
   {
@@ -63,7 +63,7 @@
    "outputs": [],
    "source": [
     "def get_legal_moves(\n",
-    "    \n",
+    "\n",
     ") -> Annotated[str, \"A list of legal moves in UCI format\"]:\n",
     "    return \"Possible moves are: \" + \",\".join(\n",
     "        [str(move) for move in board.legal_moves]\n",
@@ -86,7 +86,7 @@
     "    board.push_uci(str(move))\n",
     "    global made_move\n",
     "    made_move = True\n",
-    "    \n",
+    "\n",
     "    svg_str = chess.svg.board(\n",
     "            board,\n",
     "            arrows=[(move.from_square, move.to_square)],\n",
@@ -96,7 +96,7 @@
     "    display(\n",
     "        SVG(data=svg_str)\n",
     "    )\n",
-    "    \n",
+    "\n",
     "    # Get the piece name.\n",
     "    piece = board.piece_at(move.to_square)\n",
     "    piece_symbol = piece.unicode_symbol()\n",
@@ -223,7 +223,7 @@
     "        name=\"get_legal_moves\",\n",
     "        description=\"Call this tool to get all legal moves in UCI format.\",\n",
     "    )\n",
-    "    \n",
+    "\n",
     "    register_function(\n",
     "        make_move,\n",
     "        caller=caller,\n",

+ 7 - 8
recipes/quickstart/agents/DeepLearningai_Course_Notebooks/AI_Agents_in_LangGraph_L1_Build_an_Agent_from_Scratch.ipynb

@@ -5,7 +5,7 @@
    "id": "de56ee05-3b71-43c9-8cbf-6ad9b3233f38",
    "metadata": {},
    "source": [
-    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/agents/dlai/AI_Agents_in_LangGraph_L1_Build_an_Agent_from_Scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/agents/DeepLearningai_Course_Notebooks/AI_Agents_in_LangGraph_L1_Build_an_Agent_from_Scratch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
   },
   {
@@ -35,7 +35,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "import os \n",
+    "import os\n",
     "from groq import Groq\n",
     "\n",
     "os.environ['GROQ_API_KEY'] = 'your_groq_api_key' # get a free key at https://console.groq.com/keys"
@@ -48,7 +48,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# a quick sanity test of calling Llama 3 70b on Groq \n",
+    "# a quick sanity test of calling Llama 3 70b on Groq\n",
     "# see https://console.groq.com/docs/text-chat for more info\n",
     "client = Groq()\n",
     "chat_completion = client.chat.completions.create(\n",
@@ -75,7 +75,7 @@
    "source": [
     "client = Groq()\n",
     "model = \"llama3-8b-8192\" # this model works with the prompt below only for the first simpler example; you'll see how to modify the prompt to make it work for a more complicated question\n",
-    "#model = \"llama3-70b-8192\" # this model works with the prompt below for both example questions \n",
+    "#model = \"llama3-70b-8192\" # this model works with the prompt below for both example questions\n",
     "\n",
     "class Agent:\n",
     "    def __init__(self, system=\"\"):\n",
@@ -95,8 +95,7 @@
     "                        model=model,\n",
     "                        temperature=0,\n",
     "                        messages=self.messages)\n",
-    "        return completion.choices[0].message.content\n",
-    "    "
+    "        return completion.choices[0].message.content\n"
    ]
   },
   {
@@ -151,7 +150,7 @@
     "    return eval(what)\n",
     "\n",
     "def average_dog_weight(name):\n",
-    "    if name in \"Scottish Terrier\": \n",
+    "    if name in \"Scottish Terrier\":\n",
     "        return(\"Scottish Terriers average 20 lbs\")\n",
     "    elif name in \"Border Collie\":\n",
     "        return(\"a Border Collies average weight is 37 lbs\")\n",
@@ -423,7 +422,7 @@
     "\n",
     "            # key to make the agent process fully automated:\n",
     "            # programtically call the external func with arguments, with the info returned by LLM\n",
-    "            observation = known_actions[action](action_input) \n",
+    "            observation = known_actions[action](action_input)\n",
     "\n",
     "            print(\"Observation:\", observation)\n",
     "            next_prompt = \"Observation: {}\".format(observation)\n",

+ 1 - 1
recipes/quickstart/agents/DeepLearningai_Course_Notebooks/Building_Agentic_RAG_with_Llamaindex_L1_Router_Engine.ipynb

@@ -45,7 +45,7 @@
    },
    "outputs": [],
    "source": [
-    "import os \n",
+    "import os\n",
     "os.environ['GROQ_API_KEY'] = 'your_groq_api_key' # get a free key at https://console.groq.com/keys"
    ]
   },

+ 4 - 4
recipes/quickstart/agents/DeepLearningai_Course_Notebooks/Functions_Tools_and_Agents_with_LangChain_L1_Function_Calling.ipynb

@@ -5,7 +5,7 @@
    "id": "2ba1b4ef-3b96-4e7e-b5d0-155b839db73c",
    "metadata": {},
    "source": [
-    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/agents/dlai/Functions_Tools_and_Agents_with_LangChain_L1_Function_Calling.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+    "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/agents/DeepLearningai_Course_Notebooks/Functions_Tools_and_Agents_with_LangChain_L1_Function_Calling.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
   },
   {
@@ -62,7 +62,7 @@
    "outputs": [],
    "source": [
     "# https://console.groq.com/docs/tool-use#models\n",
-    "# Groq API endpoints support tool use for programmatic execution of specified operations through requests with explicitly defined \n",
+    "# Groq API endpoints support tool use for programmatic execution of specified operations through requests with explicitly defined\n",
     "# operations. With tool use, Groq API model endpoints deliver structured JSON output that can be used to directly invoke functions.\n",
     "\n",
     "from groq import Groq\n",
@@ -145,8 +145,8 @@
     "    model=\"llama3-70b-8192\",\n",
     "    messages=messages,\n",
     "    functions=functions,\n",
-    "    #tools=tools, # you can also replace functions with tools, as specified in https://console.groq.com/docs/tool-use \n",
-    "    max_tokens=4096, \n",
+    "    #tools=tools, # you can also replace functions with tools, as specified in https://console.groq.com/docs/tool-use\n",
+    "    max_tokens=4096,\n",
     "    temperature=0\n",
     ")"
    ]

+ 1 - 1
recipes/quickstart/finetuning/LLM_finetuning_overview.md

@@ -61,4 +61,4 @@ To boost the performance of fine-tuning with FSDP, we can make use a number of f
 
 - **Activation Checkpointing**  which is a technique to save memory by discarding the intermediate activation in forward pass instead of keeping it in the memory with the cost recomputing them in the backward pass. FSDP Activation checkpointing is shard aware meaning we need to apply it after wrapping the model with FSDP. In our script we are making use of that.
 
-- **auto_wrap_policy** Which is the way to specify how FSDP would partition the model, there is default support for transformer wrapping policy. This allows FSDP to form each FSDP unit ( partition of the  model ) based on the transformer class in the model. To identify this layer in the model, you need to look at the layer that wraps both the attention layer and  MLP. This helps FSDP have more fine-grained units for communication that help with optimizing the communication cost.
+- **auto_wrap_policy** Which is the way to specify how FSDP would partition the model, there is default support for transformer wrapping policy. This allows FSDP to form each FSDP unit ( partition of the  model ) based on the transformer class in the model. To identify this layer in the model, you need to look at the layer that wraps both the attention layer and  MLP. This helps FSDP have more fine-grained units for communication that help with optimizing the communication cost.

+ 1 - 0
recipes/quickstart/finetuning/README.md

@@ -54,6 +54,7 @@ It lets us specify the training settings for everything from `model_name` to `da
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
+    freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
     quantization: str = None
     one_gpu: bool = False
     save_model: bool = True

File diff suppressed because it is too large
+ 6 - 0
recipes/quickstart/finetuning/finetune_vision_model.md


+ 30 - 13
recipes/quickstart/inference/local_inference/README.md

@@ -3,26 +3,43 @@
 ## Hugging face setup
 **Important Note**: Before running the inference, you'll need your Hugging Face access token, which you can get at your Settings page [here](https://huggingface.co/settings/tokens). Then run `huggingface-cli login` and copy and paste your Hugging Face access token to complete the login to make sure the scripts can download Hugging Face models if needed.
 
-## Multimodal Inference
-For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library.
+## Multimodal Inference and CLI inference with or without PEFT LoRA weights
 
-The way to run this would be:
-```
-python multi_modal_infer.py --image_path PATH_TO_IMAGE --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
-```
----
-## Multi-modal Inferencing Using gradio UI for inferencing
-For multi-modal inferencing using gradio UI we have added [multi_modal_infer_gradio_UI.py](multi_modal_infer_gradio_UI.py) which used gradio and transformers library.
+### Model Overview
+- Base model: `meta-llama/Llama-3.2-11B-Vision-Instruct`
+- Uses PEFT library (v0.13.1) for efficient fine-tuning
+- Supports vision-language tasks with instruction capabilities
 
-### Steps to Run
+### Features in
+`multi_modal_infer.py`
 
-The way to run this would be:
-- Ensure having proper access to llama 3.2 vision models, then run the command given below
+All functionality has been consolidated into a single file with three main modes, use `huggingface-cli login`:
+### Steps to run are given below:
+1. **Basic Inference**
+```bash
+python multi_modal_infer.py \
+    --image_path "path/to/image.jpg" \
+    --prompt_text "Describe this image" \
+    --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
+```
 
+2. **Gradio UI Mode**
+```bash
+python multi_modal_infer.py \
+    --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
+    --gradio_ui
 ```
-python multi_modal_infer_gradio_UI.py --hf_token <your hf_token here>
+
+3. **LoRA Fine-tuning Integration**
+```bash
+python multi_modal_infer.py \
+    --image_path "path/to/image.jpg" \
+    --prompt_text "Describe this image" \
+    --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct" \
+    --finetuning_path "path/to/lora/weights"
 ```
 
+
 ## Text-only Inference
 For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments.
 

+ 176 - 78
recipes/quickstart/inference/local_inference/multi_modal_infer.py

@@ -1,108 +1,206 @@
 import argparse
 import os
 import sys
-
 import torch
 from accelerate import Accelerator
 from PIL import Image as PIL_Image
 from transformers import MllamaForConditionalGeneration, MllamaProcessor
-
+from peft import PeftModel
+import gradio as gr
+from huggingface_hub import HfFolder
+# Initialize accelerator
 accelerator = Accelerator()
-
 device = accelerator.device
 
 # Constants
 DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+MAX_OUTPUT_TOKENS = 2048
+MAX_IMAGE_SIZE = (1120, 1120)
+
+
+def get_hf_token():
+    """Retrieve Hugging Face token from the cache or environment."""
+    # Check if a token is explicitly set in the environment
+    token = os.getenv("HUGGINGFACE_TOKEN")
+    if token:
+        return token
+
+    # Automatically retrieve the token from the Hugging Face cache (set via huggingface-cli login)
+    token = HfFolder.get_token()
+    if token:
+        return token
 
+    print("Hugging Face token not found. Please login using `huggingface-cli login`.")
+    sys.exit(1)
 
-def load_model_and_processor(model_name: str):
-    """
-    Load the model and processor based on the 11B or 90B model.
-    """
+
+def load_model_and_processor(model_name: str, finetuning_path: str = None):
+    """Load model and processor with optional LoRA adapter"""
+    print(f"Loading model: {model_name}")
+    hf_token = get_hf_token()
     model = MllamaForConditionalGeneration.from_pretrained(
         model_name,
         torch_dtype=torch.bfloat16,
         use_safetensors=True,
         device_map=device,
+        token=hf_token
     )
-    processor = MllamaProcessor.from_pretrained(model_name, use_safetensors=True)
-
+    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token, use_safetensors=True)
+
+    if finetuning_path and os.path.exists(finetuning_path):
+        print(f"Loading LoRA adapter from '{finetuning_path}'...")
+        model = PeftModel.from_pretrained(
+            model,
+            finetuning_path,
+            is_adapter=True,
+            torch_dtype=torch.bfloat16
+        )
+        print("LoRA adapter merged successfully")
+    
     model, processor = accelerator.prepare(model, processor)
     return model, processor
 
+def process_image(image_path: str = None, image = None) -> PIL_Image.Image:
+    """Process and validate image input"""
+    if image is not None:
+        return image.convert("RGB")
+    if image_path and os.path.exists(image_path):
+        return PIL_Image.open(image_path).convert("RGB")
+    raise ValueError("No valid image provided")
 
-def process_image(image_path: str) -> PIL_Image.Image:
-    """
-    Open and convert an image from the specified path.
-    """
-    if not os.path.exists(image_path):
-        print(f"The image file '{image_path}' does not exist.")
-        sys.exit(1)
-    with open(image_path, "rb") as f:
-        return PIL_Image.open(f).convert("RGB")
-
-
-def generate_text_from_image(
-    model, processor, image, prompt_text: str, temperature: float, top_p: float
-):
-    """
-    Generate text from an image using the model and processor.
-    """
+def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
+    """Generate text from image using model"""
     conversation = [
-        {
-            "role": "user",
-            "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],
-        }
+        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
     ]
-    prompt = processor.apply_chat_template(
-        conversation, add_generation_prompt=True, tokenize=False
-    )
+    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
     inputs = processor(image, prompt, return_tensors="pt").to(device)
-    output = model.generate(
-        **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
-    )
-    return processor.decode(output[0])[len(prompt) :]
-
-
-def main(
-    image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str
-):
-    """
-    Call all the functions.
-    """
-    model, processor = load_model_and_processor(model_name)
-    image = process_image(image_path)
-    result = generate_text_from_image(
-        model, processor, image, prompt_text, temperature, top_p
-    )
-    print("Generated Text: " + result)
-
+    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS)
+    return processor.decode(output[0])[len(prompt):]
+
+def gradio_interface(model_name: str):
+    """Create Gradio UI with LoRA support"""
+    # Initialize model state
+    current_model = {"model": None, "processor": None}
+    
+    def load_or_reload_model(enable_lora: bool, lora_path: str = None):
+        current_model["model"], current_model["processor"] = load_model_and_processor(
+            model_name, 
+            lora_path if enable_lora else None
+        )
+        return "Model loaded successfully" + (" with LoRA" if enable_lora else "")
+
+    def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
+        if image is not None:
+            try:
+                processed_image = process_image(image=image)
+                result = generate_text_from_image(
+                    current_model["model"],
+                    current_model["processor"],
+                    processed_image,
+                    user_prompt,
+                    temperature,
+                    top_p
+                )
+                history.append((user_prompt, result))
+            except Exception as e:
+                history.append((user_prompt, f"Error: {str(e)}"))
+        return history
+
+    def clear_chat():
+        return []
+
+    with gr.Blocks() as demo:
+        gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")
+        
+        with gr.Row():
+            with gr.Column(scale=1):
+                # Model loading controls
+                with gr.Group():
+                    enable_lora = gr.Checkbox(label="Enable LoRA", value=False)
+                    lora_path = gr.Textbox(
+                        label="LoRA Weights Path",
+                        placeholder="Path to LoRA weights folder",
+                        visible=False
+                    )
+                    load_status = gr.Textbox(label="Load Status", interactive=False)
+                    load_button = gr.Button("Load/Reload Model")
+
+                # Image and parameter controls
+                image_input = gr.Image(label="Image", type="pil", image_mode="RGB", height=512, width=512)
+                temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1)
+                top_k = gr.Slider(label="Top-k", minimum=1, maximum=100, value=50, step=1)
+                top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1)
+                max_tokens = gr.Slider(label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50)
+
+            with gr.Column(scale=2):
+                chat_history = gr.Chatbot(label="Chat", height=512)
+                user_prompt = gr.Textbox(
+                    show_label=False,
+                    placeholder="Enter your prompt",
+                    lines=2
+                )
+                
+                with gr.Row():
+                    generate_button = gr.Button("Generate")
+                    clear_button = gr.Button("Clear")
+
+        # Event handlers
+        enable_lora.change(
+            fn=lambda x: gr.update(visible=x),
+            inputs=[enable_lora],
+            outputs=[lora_path]
+        )
+        
+        load_button.click(
+            fn=load_or_reload_model,
+            inputs=[enable_lora, lora_path],
+            outputs=[load_status]
+        )
+
+        generate_button.click(
+            fn=describe_image,
+            inputs=[
+                image_input, user_prompt, temperature,
+                top_k, top_p, max_tokens, chat_history
+            ],
+            outputs=[chat_history]
+        )
+        
+        clear_button.click(fn=clear_chat, outputs=[chat_history])
+
+    # Initial model load
+    load_or_reload_model(False)
+    return demo
+
+def main(args):
+    """Main execution flow"""
+    if args.gradio_ui:
+        demo = gradio_interface(args.model_name)
+        demo.launch()
+    else:
+        model, processor = load_model_and_processor(
+            args.model_name,
+            args.finetuning_path
+        )
+        image = process_image(image_path=args.image_path)
+        result = generate_text_from_image(
+            model, processor, image,
+            args.prompt_text,
+            args.temperature,
+            args.top_p
+        )
+        print("Generated Text:", result)
 
 if __name__ == "__main__":
-    parser = argparse.ArgumentParser(
-        description="Generate text from an image and prompt using the 3.2 MM Llama model."
-    )
-    parser.add_argument("--image_path", type=str, help="Path to the image file")
-    parser.add_argument(
-        "--prompt_text", type=str, help="Prompt text to describe the image"
-    )
-    parser.add_argument(
-        "--temperature",
-        type=float,
-        default=0.7,
-        help="Temperature for generation (default: 0.7)",
-    )
-    parser.add_argument(
-        "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
-    )
-    parser.add_argument(
-        "--model_name",
-        type=str,
-        default=DEFAULT_MODEL,
-        help=f"Model name (default: '{DEFAULT_MODEL}')",
-    )
-
+    parser = argparse.ArgumentParser(description="Multi-modal inference with optional Gradio UI and LoRA support")
+    parser.add_argument("--image_path", type=str, help="Path to the input image")
+    parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")
+    parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
+    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")
+    parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help="Model name")
+    parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")
+    parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")
+    
     args = parser.parse_args()
-    main(
-        args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name
-    )
+    main(args)

+ 0 - 157
recipes/quickstart/inference/local_inference/multi_modal_infer_gradio_UI.py

@@ -1,157 +0,0 @@
-import gradio as gr
-import torch
-import os
-from PIL import Image
-from accelerate import Accelerator
-from transformers import MllamaForConditionalGeneration, AutoProcessor
-import argparse  # Import argparse
-
-# Parse the command line arguments
-parser = argparse.ArgumentParser(description="Run Gradio app with Hugging Face model")
-parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face authentication token")
-args = parser.parse_args()
-
-# Hugging Face token
-hf_token = args.hf_token
-
-# Initialize Accelerator
-accelerate = Accelerator()
-device = accelerate.device
-
-# Set memory management for PyTorch
-os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'  # or adjust size as needed
-
-# Model ID
-model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
-
-# Load model with the Hugging Face token
-model = MllamaForConditionalGeneration.from_pretrained(
-    model_id,
-    torch_dtype=torch.bfloat16,
-    device_map=device,
-    use_auth_token=hf_token  # Pass the Hugging Face token here
-)
-
-# Load the processor
-processor = AutoProcessor.from_pretrained(model_id, use_auth_token=hf_token)
-
-# Visual theme
-visual_theme = gr.themes.Default()  # Default, Soft or Monochrome
-
-# Constants
-MAX_OUTPUT_TOKENS = 2048
-MAX_IMAGE_SIZE = (1120, 1120)
-
-# Function to process the image and generate a description
-def describe_image(image, user_prompt, temperature, top_k, top_p, max_tokens, history):
-    # Initialize cleaned_output variable
-    cleaned_output = ""
-
-    if image is not None:
-        # Resize image if necessary
-        image = image.resize(MAX_IMAGE_SIZE)
-        prompt = f"<|image|><|begin_of_text|>{user_prompt} Answer:"
-        # Preprocess the image and prompt
-        inputs = processor(image, prompt, return_tensors="pt").to(device)
-    else:
-        # Text-only input if no image is provided
-        prompt = f"<|begin_of_text|>{user_prompt} Answer:"
-        # Preprocess the prompt only (no image)
-        inputs = processor(prompt, return_tensors="pt").to(device)
-
-    # Generate output with model
-    output = model.generate(
-        **inputs,
-        max_new_tokens=min(max_tokens, MAX_OUTPUT_TOKENS),
-        temperature=temperature,
-        top_k=top_k,
-        top_p=top_p
-    )
-
-    # Decode the raw output
-    raw_output = processor.decode(output[0])
-
-    # Clean up the output to remove system tokens
-    cleaned_output = raw_output.replace("<|image|><|begin_of_text|>", "").strip().replace(" Answer:", "")
-
-    # Ensure the prompt is not repeated in the output
-    if cleaned_output.startswith(user_prompt):
-        cleaned_output = cleaned_output[len(user_prompt):].strip()
-
-    # Append the new conversation to the history
-    history.append((user_prompt, cleaned_output))
-
-    return history
-
-
-# Function to clear the chat history
-def clear_chat():
-    return []
-
-# Gradio Interface
-def gradio_interface():
-    with gr.Blocks(visual_theme) as demo:
-        gr.HTML(
-        """
-    <h1 style='text-align: center'>
-    meta-llama/Llama-3.2-11B-Vision-Instruct
-    </h1>
-    """)
-        with gr.Row():
-            # Left column with image and parameter inputs
-            with gr.Column(scale=1):
-                image_input = gr.Image(
-                    label="Image", 
-                    type="pil", 
-                    image_mode="RGB", 
-                    height=512,  # Set the height
-                    width=512   # Set the width
-                )
-
-                # Parameter sliders
-                temperature = gr.Slider(
-                    label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1, interactive=True)
-                top_k = gr.Slider(
-                    label="Top-k", minimum=1, maximum=100, value=50, step=1, interactive=True)
-                top_p = gr.Slider(
-                    label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1, interactive=True)
-                max_tokens = gr.Slider(
-                    label="Max Tokens", minimum=50, maximum=MAX_OUTPUT_TOKENS, value=100, step=50, interactive=True)
-
-            # Right column with the chat interface
-            with gr.Column(scale=2):
-                chat_history = gr.Chatbot(label="Chat", height=512)
-
-                # User input box for prompt
-                user_prompt = gr.Textbox(
-                    show_label=False,
-                    container=False,
-                    placeholder="Enter your prompt", 
-                    lines=2
-                )
-
-                # Generate and Clear buttons
-                with gr.Row():
-                    generate_button = gr.Button("Generate")
-                    clear_button = gr.Button("Clear")
-
-                # Define the action for the generate button
-                generate_button.click(
-                    fn=describe_image, 
-                    inputs=[image_input, user_prompt, temperature, top_k, top_p, max_tokens, chat_history],
-                    outputs=[chat_history]
-                )
-
-                # Define the action for the clear button
-                clear_button.click(
-                    fn=clear_chat,
-                    inputs=[],
-                    outputs=[chat_history]
-                )
-
-    return demo
-
-# Launch the interface
-demo = gradio_interface()
-# demo.launch(server_name="0.0.0.0", server_port=12003)
-demo.launch()

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -35,6 +35,7 @@ class train_config:
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
+    freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
     quantization: str = None
     one_gpu: bool = False
     save_model: bool = True

+ 18 - 3
src/llama_recipes/finetuning.py

@@ -38,8 +38,10 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
     clear_gpu_cache,
     freeze_transformer_layers,
+    freeze_LLM_only,
     get_policies,
     print_model_size,
+    print_frozen_model_status,
     setup,
     setup_environ_flags,
     train,
@@ -194,7 +196,7 @@ def main(**kwargs):
         model.resize_token_embeddings(len(tokenizer))
 
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-
+    
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if (
         train_config.enable_fsdp
@@ -235,7 +237,14 @@ def main(**kwargs):
 
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
-
+            # print model size and frozen layers after freezing layers
+            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
+            
+        if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
+            freeze_LLM_only(model)
+            # print model size and frozen layers after freezing layers
+            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
+        
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
         if is_vision:
@@ -255,6 +264,11 @@ def main(**kwargs):
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
+        
+        if train_config.freeze_LLM_only:
+            use_orig_params = True
+        else:
+            use_orig_params = False
         model = FSDP(
             model,
             auto_wrap_policy=(
@@ -282,6 +296,7 @@ def main(**kwargs):
                 if train_config.low_cpu_fsdp and rank != 0
                 else None
             ),
+            use_orig_params=use_orig_params,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             model.enable_input_require_grads()
@@ -297,7 +312,7 @@ def main(**kwargs):
         dataset_processer = processor
     else:
         dataset_processer = tokenizer
-
+    
     # Load and preprocess the dataset for training and validation
 
     dataset_train = get_preprocessed_dataset(

+ 56 - 2
src/llama_recipes/utils/train_utils.py

@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
             if i < num_layer:
                 for param in layer.parameters():
                     param.requires_grad = False
-
+                    
+def freeze_LLM_only(model):
+    """
+    Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
+    """
+    for name, param in model.language_model.named_parameters():
+                param.requires_grad = False
+    for i, layer in enumerate(model.language_model.model.layers):
+        if i in model.language_model.model.cross_attention_layers:
+            for param in layer.parameters():
+                param.requires_grad = True
 
 def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
@@ -476,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None:
         total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
         print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
 
+def print_frozen_model_status(model, config, rank: int = 0) -> None:
+    """
+    Print the frozen status of the model's and the number of trainable parameters after frozen.
 
-
+    Args:
+        model: The PyTorch model.
+        model_name (str): Name of the model.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        print("After freezing the model:")
+        print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n")
+
+        module_states = {}
+        # Iterate over all parameters
+        for name, param in model.named_parameters():
+            # Extract the top-level module name (e.g., "vision_model", "language_model")
+            top_module = name.split(".")[0]
+
+            # Initialize a record for the top-level module
+            if top_module not in module_states:
+                module_states[top_module] = {"frozen": [], "unfrozen": []}
+
+            # Group parameters into frozen or unfrozen
+            if param.requires_grad:
+                module_states[top_module]["unfrozen"].append(name)
+            else:
+                module_states[top_module]["frozen"].append(name)
+
+        print("--> Model state after freezing:")
+        # Analyze and print the results
+        for module, states in module_states.items():
+            frozen_params = states["frozen"]
+            unfrozen_params = states["unfrozen"]
+
+            if frozen_params and unfrozen_params:
+                # Mixed state: both frozen and unfrozen parameters
+                print(f"    {module}: Mixed")
+            elif frozen_params:
+                # All parameters are frozen
+                print(f"    {module}: Frozen")
+            else:
+                # All parameters are unfrozen
+                print(f"    {module}: Unfrozen")
+        print("")
 
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""