瀏覽代碼

Merge branch 'main' into eval_reproduce

Kai Wu 7 月之前
父節點
當前提交
c1390e8cd8

+ 38 - 10
recipes/quickstart/Getting_to_know_Llama.ipynb

@@ -15,8 +15,8 @@
     "id": "LERqQn5v8-ak"
    },
    "source": [
-    "# **Getting to know Llama 3: Everything you need to start building**\n",
-    "Our goal in this session is to provide a guided tour of Llama 3 with comparison with Llama 2, including understanding different Llama 3 models, how and where to access them, Generative AI and Chatbot architectures, prompt engineering, RAG (Retrieval Augmented Generation), Fine-tuning and more. All this is implemented with a starter code for you to take it and use it in your Llama 3 projects."
+    "# **Getting to know Llama 3.1: Everything you need to start building**\n",
+    "Our goal in this session is to provide a guided tour of Llama 3.1 with comparison with Llama 2, including understanding different Llama 3.1 models, how and where to access them, Generative AI and Chatbot architectures, prompt engineering, RAG (Retrieval Augmented Generation), Fine-tuning and more. All this is implemented with a starter code for you to take it and use it in your Llama 3.1 projects."
    ]
   },
   {
@@ -113,6 +113,20 @@
     "      llama-3-70b --> llama-3-70b-instruct\n",
     "      classDef default fill:#CCE6FF,stroke:#84BCF5,textColor:#1C2B33,fontFamily:trebuchet ms;\n",
     "  \"\"\")\n",
+    "  \n",
+    "def llama3_1_family():\n",
+    "  mm(\"\"\"\n",
+    "  graph LR;\n",
+    "      llama-3-1 --> llama-3-8b\n",
+    "      llama-3-1 --> llama-3-70b\n",
+    "      llama-3-1 --> llama-3-4050b\n",
+    "      llama-3-1-8b --> llama-3-1-8b\n",
+    "      llama-3-1-8b --> llama-3-1-8b-instruct\n",
+    "      llama-3-1-70b --> llama-3-1-70b\n",
+    "      llama-3-1-70b --> llama-3-1-70b-instruct\n",
+    "      llama-3-1-405b --> llama-3-1-405b-instruct\n",
+    "      classDef default fill:#CCE6FF,stroke:#84BCF5,textColor:#1C2B33,fontFamily:trebuchet ms;\n",
+    "  \"\"\")\n",
     "\n",
     "import ipywidgets as widgets\n",
     "from IPython.display import display, Markdown\n",
@@ -184,7 +198,7 @@
     "id": "i4Np_l_KtIno"
    },
    "source": [
-    "### **1 - Understanding Llama 3**"
+    "### **1 - Understanding Llama 3.1**"
    ]
   },
   {
@@ -193,13 +207,13 @@
     "id": "PGPSI3M5PGTi"
    },
    "source": [
-    "### **1.1 - What is Llama 3?**\n",
+    "### **1.1 - What is Llama 3.1?**\n",
     "\n",
     "* State of the art (SOTA), Open Source LLM\n",
-    "* 8B, 70B - base and instruct models\n",
+    "* 8B, 70B, 405B - base and instruct models\n",
     "* Choosing model: Size, Quality, Cost, Speed\n",
     "* Pretrained + Chat\n",
-    "* [Meta Llama 3 Blog](https://ai.meta.com/blog/meta-llama-3/)\n",
+    "* [Meta Llama 3.1 Blog](https://ai.meta.com/blog/meta-llama-3-1/)\n",
     "* [Getting Started with Meta Llama](https://llama.meta.com/docs/get-started)"
    ]
   },
@@ -239,12 +253,21 @@
    ]
   },
   {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "llama3_1_family()"
+   ]
+  },
+  {
    "cell_type": "markdown",
    "metadata": {
     "id": "aYeHVVh45bdT"
    },
    "source": [
-    "### **1.2 - Accessing Llama 3**\n",
+    "### **1.2 - Accessing Llama 3.1**\n",
     "* Download + Self Host (i.e. [download Llama](https://ai.meta.com/resources/models-and-libraries/llama-downloads))\n",
     "* Hosted API Platform (e.g. [Groq](https://console.groq.com/), [Replicate](https://replicate.com/meta/meta-llama-3-8b-instruct), [Together](https://api.together.xyz/playground/language/meta-llama/Llama-3-8b-hf), [Anyscale](https://app.endpoints.anyscale.com/playground))\n",
     "\n",
@@ -258,7 +281,7 @@
     "id": "kBuSay8vtzL4"
    },
    "source": [
-    "### **1.3 - Use Cases of Llama 3**\n",
+    "### **1.3 - Use Cases of Llama 3.1**\n",
     "* Content Generation\n",
     "* Summarization\n",
     "* General Chatbots\n",
@@ -943,7 +966,7 @@
     "import bs4\n",
     "\n",
     "# Step 1: Load the document from a web url\n",
-    "loader = WebBaseLoader([\"https://huggingface.co/blog/llama3\"])\n",
+    "loader = WebBaseLoader([\"https://huggingface.co/blog/llama31\"])\n",
     "documents = loader.load()\n",
     "\n",
     "# Step 2: Split the document into chunks with a specified chunk size\n",
@@ -1079,7 +1102,7 @@
    },
    "source": [
     "#### **Resources**\n",
-    "- [Meta Llama 3 Blog](https://ai.meta.com/blog/meta-llama-3/)\n",
+    "- [Meta Llama 3.1 Blog](https://ai.meta.com/blog/meta-llama-3-1/)\n",
     "- [Getting Started with Meta Llama](https://llama.meta.com/docs/get-started)\n",
     "- [Llama 3 repo](https://github.com/meta-llama/llama3)\n",
     "- [Llama 3 model card](https://github.com/meta-llama/llama3/blob/main/MODEL_CARD.md)\n",
@@ -1088,6 +1111,11 @@
     "- [Acceptable Use Policy](https://ai.meta.com/llama/use-policy/)\n",
     "\n"
    ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": []
   }
  ],
  "metadata": {

+ 19 - 9
recipes/quickstart/Prompt_Engineering_with_Llama_3.ipynb

@@ -7,11 +7,11 @@
    "source": [
     "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/Prompt_Engineering_with_Llama_3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
     "\n",
-    "# Prompt Engineering with Llama 3\n",
+    "# Prompt Engineering with Llama 3.1\n",
     "\n",
     "Prompt engineering is using natural language to produce a desired response from a large language model (LLM).\n",
     "\n",
-    "This interactive guide covers prompt engineering & best practices with Llama 3."
+    "This interactive guide covers prompt engineering & best practices with Llama 3.1."
    ]
   },
   {
@@ -45,6 +45,15 @@
     "\n",
     "Llama models come in varying parameter sizes. The smaller models are cheaper to deploy and run; the larger models are more capable.\n",
     "\n",
+    "#### Llama 3.1\n",
+    "1. `llama-3.1-8b` - base pretrained 8 billion parameter model\n",
+    "1. `llama-3.1-70b` - base pretrained 70 billion parameter model\n",
+    "1. `llama-3.1-405b` - base pretrained 405 billion parameter model\n",
+    "1. `llama-3.1-8b-instruct` - instruction fine-tuned 8 billion parameter model\n",
+    "1. `llama-3.1-70b-instruct` - instruction fine-tuned 70 billion parameter model\n",
+    "1. `llama-3.1-405b-instruct` - instruction fine-tuned 405 billion parameter model (flagship)\n",
+    "\n",
+    "\n",
     "#### Llama 3\n",
     "1. `llama-3-8b` - base pretrained 8 billion parameter model\n",
     "1. `llama-3-70b` - base pretrained 70 billion parameter model\n",
@@ -133,7 +142,7 @@
     "\n",
     "Tokens matter most when you consider API pricing and internal behavior (ex. hyperparameters).\n",
     "\n",
-    "Each model has a maximum context length that your prompt cannot exceed. That's 8K tokens for Llama 3, 4K for Llama 2, and 100K for Code Llama. \n"
+    "Each model has a maximum context length that your prompt cannot exceed. That's 128k tokens for Llama 3.1, 4K for Llama 2, and 100K for Code Llama.\n"
    ]
   },
   {
@@ -143,7 +152,7 @@
    "source": [
     "## Notebook Setup\n",
     "\n",
-    "The following APIs will be used to call LLMs throughout the guide. As an example, we'll call Llama 3 chat using [Grok](https://console.groq.com/playground?model=llama3-70b-8192).\n",
+    "The following APIs will be used to call LLMs throughout the guide. As an example, we'll call Llama 3.1 chat using [Grok](https://console.groq.com/playground?model=llama3-70b-8192).\n",
     "\n",
     "To install prerequisites run:"
    ]
@@ -171,8 +180,9 @@
     "# Get a free API key from https://console.groq.com/keys\n",
     "os.environ[\"GROQ_API_KEY\"] = \"YOUR_GROQ_API_KEY\"\n",
     "\n",
-    "LLAMA3_70B_INSTRUCT = \"llama3-70b-8192\"\n",
-    "LLAMA3_8B_INSTRUCT = \"llama3-8b-8192\"\n",
+    "LLAMA3_405B_INSTRUCT = \"llama-3.1-405b-reasoning\" # Note: Groq currently only gives access here to paying customers for 405B model\n",
+    "LLAMA3_70B_INSTRUCT = \"llama-3.1-70b-versatile\"\n",
+    "LLAMA3_8B_INSTRUCT = \"llama3.1-8b-instant\"\n",
     "\n",
     "DEFAULT_MODEL = LLAMA3_70B_INSTRUCT\n",
     "\n",
@@ -225,7 +235,7 @@
    "source": [
     "### Completion APIs\n",
     "\n",
-    "Let's try Llama 3!"
+    "Let's try Llama 3.1!"
    ]
   },
   {
@@ -488,7 +498,7 @@
     "\n",
     "Simply adding a phrase encouraging step-by-step thinking \"significantly improves the ability of large language models to perform complex reasoning\" ([Wei et al. (2022)](https://arxiv.org/abs/2201.11903)). This technique is called \"CoT\" or \"Chain-of-Thought\" prompting.\n",
     "\n",
-    "Llama 3 now reasons step-by-step naturally without the addition of the phrase. This section remains for completeness."
+    "Llama 3.1 now reasons step-by-step naturally without the addition of the phrase. This section remains for completeness."
    ]
   },
   {
@@ -704,7 +714,7 @@
    "source": [
     "### Limiting Extraneous Tokens\n",
     "\n",
-    "A common struggle with Llama 2 is getting output without extraneous tokens (ex. \"Sure! Here's more information on...\"), even if explicit instructions are given to Llama 2 to be concise and no preamble. Llama 3 can better follow instructions.\n",
+    "A common struggle with Llama 2 is getting output without extraneous tokens (ex. \"Sure! Here's more information on...\"), even if explicit instructions are given to Llama 2 to be concise and no preamble. Llama 3.x can better follow instructions.\n",
     "\n",
     "Check out this improvement that combines a role, rules and restrictions, explicit instructions, and an example:"
    ]

文件差異過大導致無法顯示
+ 9 - 9
recipes/quickstart/RAG/hello_llama_cloud.ipynb


+ 8 - 8
recipes/quickstart/Running_Llama3_Anywhere/Running_Llama_on_HF_transformers.ipynb

@@ -4,8 +4,8 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "## Running Meta Llama 3 on Google Colab using Hugging Face transformers library\n",
-    "This notebook goes over how you can set up and run Llama 3 using Hugging Face transformers library\n",
+    "## Running Meta Llama 3.1 on Google Colab using Hugging Face transformers library\n",
+    "This notebook goes over how you can set up and run Llama 3.1 using Hugging Face transformers library\n",
     "<a href=\"https://colab.research.google.com/github/meta-llama/llama-recipes/blob/main/recipes/quickstart/Running_Llama2_Anywhere/Running_Llama_on_HF_transformers.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
   },
@@ -14,7 +14,7 @@
    "metadata": {},
    "source": [
     "### Steps at a glance:\n",
-    "This demo showcases how to run the example with already converted Llama 3 weights on [Hugging Face](https://huggingface.co/meta-llama). Please Note: To use the downloads on Hugging Face, you must first request a download as shown in the steps below making sure that you are using the same email address as your Hugging Face account.\n",
+    "This demo showcases how to run the example with already converted Llama 3.1 weights on [Hugging Face](https://huggingface.co/meta-llama). Please Note: To use the downloads on Hugging Face, you must first request a download as shown in the steps below making sure that you are using the same email address as your Hugging Face account.\n",
     "\n",
     "To use already converted weights, start here:\n",
     "1. Request download of model weights from the Llama website\n",
@@ -45,7 +45,7 @@
     "Request download of model weights from the Llama website\n",
     "Before you can run the model locally, you will need to get the model weights. To get the model weights, visit the [Llama website](https://llama.meta.com/) and click on “download models”. \n",
     "\n",
-    "Fill  the required information, select the models “Meta Llama 3” and accept the terms & conditions. You will receive a URL in your email in a short time."
+    "Fill  the required information, select the models “Meta Llama 3.1” and accept the terms & conditions. You will receive a URL in your email in a short time."
    ]
   },
   {
@@ -94,7 +94,7 @@
    "source": [
     "Then, we will set the model variable to a specific model we’d like to use. In this demo, we will use the 8b chat model `meta-llama/Meta-Llama-3.1-8B-Instruct`. Using Meta models from Hugging Face requires you to\n",
     "\n",
-    "1. Accept Terms of Service for Meta Llama 3 on Meta [website](https://llama.meta.com/llama-downloads).\n",
+    "1. Accept Terms of Service for Meta Llama 3.1 on Meta [website](https://llama.meta.com/llama-downloads).\n",
     "2. Use the same email address from Step (1) to login into Hugging Face.\n",
     "\n",
     "Follow the instructions on this Hugging Face page to login from your [terminal](https://huggingface.co/docs/huggingface_hub/en/quick-start). "
@@ -208,7 +208,7 @@
     "#### 2. Clone the llama repo and get the weights\n",
     "Git clone the [Meta Llama 3 repo](https://github.com/meta-llama/llama3). Run the `download.sh` script and follow the instructions. This will download the model checkpoints and tokenizer.\n",
     "\n",
-    "This example demonstrates a Meta Llama 3 model with 8B-instruct parameters, but the steps we follow would be similar for other llama models, as well as for other parameter models."
+    "This example demonstrates a Meta Llama 3.1 model with 8B-instruct parameters, but the steps we follow would be similar for other llama models, as well as for other parameter models."
    ]
   },
   {
@@ -223,7 +223,7 @@
     "* `cd transformers`\n",
     "* `pip install -e .`\n",
     "* `pip install torch tiktoken blobfile accelerate`\n",
-    "* `python3 src/transformers/models/llama/convert_llama_weights_to_hf.py --input_dir ${path_to_meta_downloaded_model} --output_dir ${path_to_save_converted_hf_model} --model_size 8B --llama_version 3`"
+    "* `python3 src/transformers/models/llama/convert_llama_weights_to_hf.py --input_dir ${path_to_meta_downloaded_model} --output_dir ${path_to_save_converted_hf_model} --model_size 8B --llama_version 3.1`"
    ]
   },
   {
@@ -233,7 +233,7 @@
     "\n",
     "#### 4. Prepare the script\n",
     "Import the following necessary modules in your script: \n",
-    "* `AutoModel` is the Llama 2 model class\n",
+    "* `AutoModel` is the Llama 3 model class\n",
     "* `AutoTokenizer` prepares your prompt for the model to process\n",
     "* `pipeline` is an abstraction to generate model outputs"
    ]

+ 12 - 12
recipes/quickstart/Running_Llama3_Anywhere/Running_Llama_on_Mac_Windows_Linux.ipynb

@@ -5,7 +5,7 @@
    "metadata": {},
    "source": [
     "## Running Llama 3 on Mac, Windows or Linux\n",
-    "This notebook goes over how you can set up and run Llama 3 locally on a Mac, Windows or Linux using [Ollama](https://ollama.com/)."
+    "This notebook goes over how you can set up and run Llama 3.1 locally on a Mac, Windows or Linux using [Ollama](https://ollama.com/)."
    ]
   },
   {
@@ -14,9 +14,9 @@
    "source": [
     "### Steps at a glance:\n",
     "1. Download and install Ollama.\n",
-    "2. Download and test run Llama 3.\n",
-    "3. Use local Llama 3 via Python.\n",
-    "4. Use local Llama 3 via LangChain.\n"
+    "2. Download and test run Llama 3.1\n",
+    "3. Use local Llama 3.1 via Python.\n",
+    "4. Use local Llama 3.1 via LangChain.\n"
    ]
   },
   {
@@ -36,16 +36,16 @@
    "source": [
     "#### 2. Download and test run Llama 3\n",
     "\n",
-    "On a terminal or console, run `ollama pull llama3` to download the Llama 3 8b chat model, in the 4-bit quantized format with size about 4.7 GB.\n",
+    "On a terminal or console, run `ollama pull llama3.1` to download the Llama 3.1 8b chat model, in the 4-bit quantized format with size about 4.7 GB.\n",
     "\n",
-    "Run `ollama pull llama3:70b` to download the Llama 3 70b chat model, also in the 4-bit quantized format with size 39GB.\n",
+    "Run `ollama pull llama3.1:70b` to download the Llama 3.1 70b chat model, also in the 4-bit quantized format with size 39GB.\n",
     "\n",
-    "Then you can run `ollama run llama3` and ask Llama 3 questions such as \"who wrote the book godfather?\" or \"who wrote the book godfather? answer in one sentence.\" You can also try `ollama run llama3:70b`, but the inference speed will most likely be too slow - for example, on an Apple M1 Pro with 32GB RAM, it takes over 10 seconds to generate one token using Llama 3 70b chat (vs over 10 tokens per second with Llama 3 8b chat).\n",
+    "Then you can run `ollama run llama3.1` and ask Llama 3.1 questions such as \"who wrote the book godfather?\" or \"who wrote the book godfather? answer in one sentence.\" You can also try `ollama run llama3.1:70b`, but the inference speed will most likely be too slow - for example, on an Apple M1 Pro with 32GB RAM, it takes over 10 seconds to generate one token using Llama 3.1 70b chat (vs over 10 tokens per second with Llama 3.1 8b chat).\n",
     "\n",
-    "You can also run the following command to test Llama 3 8b chat:\n",
+    "You can also run the following command to test Llama 3.1 8b chat:\n",
     "```\n",
     " curl http://localhost:11434/api/chat -d '{\n",
-    "  \"model\": \"llama3\",\n",
+    "  \"model\": \"llama3.1\",\n",
     "  \"messages\": [\n",
     "    {\n",
     "      \"role\": \"user\",\n",
@@ -63,7 +63,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "#### 3. Use local Llama 3 via Python\n",
+    "#### 3. Use local Llama 3.1 via Python\n",
     "\n",
     "The Python code below is the port of the curl command above."
    ]
@@ -114,7 +114,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "#### 4. Use local Llama 3 via LangChain\n",
+    "#### 4. Use local Llama 3.1 via LangChain\n",
     "\n",
     "Code below use LangChain with Ollama to query Llama 3 running locally. For a more advanced example of using local Llama 3 with LangChain and agent-powered RAG, see [this](https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_rag_agent_llama3_local.ipynb)."
    ]
@@ -136,7 +136,7 @@
    "source": [
     "from langchain_community.chat_models import ChatOllama\n",
     "\n",
-    "llm = ChatOllama(model=\"llama3\", temperature=0)\n",
+    "llm = ChatOllama(model=\"llama3.1\", temperature=0)\n",
     "response = llm.invoke(\"who wrote the book godfather?\")\n",
     "print(response.content)\n"
    ]

+ 54 - 12
recipes/responsible_ai/prompt_guard/inference.py

@@ -31,7 +31,45 @@ def load_model_and_tokenizer(model_name='meta-llama/Prompt-Guard-86M'):
     return model, tokenizer
 
 
-def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
+def preprocess_text_for_promptguard(text: str, tokenizer) -> str:
+    """
+    Preprocess the text by removing spaces that break apart larger tokens.
+    This hotfixes a workaround to PromptGuard, where spaces can be inserted into a string
+    to allow the string to be classified as benign.
+
+    Args:
+        text (str): The input text to preprocess.
+        tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
+
+    Returns:
+        str: The preprocessed text.
+    """
+
+    try:
+        cleaned_text = ''
+        index_map = []
+        for i, char in enumerate(text):
+            if not char.isspace():
+                cleaned_text += char
+                index_map.append(i)
+        tokens = tokenizer.tokenize(cleaned_text)
+        result = []
+        last_end = 0
+        for token in tokens:
+            token_str = tokenizer.convert_tokens_to_string([token])
+            start = cleaned_text.index(token_str, last_end)
+            end = start + len(token_str)
+            original_start = index_map[start]
+            if original_start > 0 and text[original_start - 1].isspace():
+                result.append(' ')
+            result.append(token_str)
+            last_end = end
+        return ''.join(result)
+    except Exception:
+        return text
+
+
+def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     Evaluate the model on the given text with temperature-adjusted softmax.
     Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
@@ -44,6 +82,8 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
     Returns:
         torch.Tensor: The probability of each class adjusted by the temperature.
     """
+    if preprocess:
+        text = preprocess_text_for_promptguard(text, tokenizer)
     # Encode the text
     inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
     inputs = inputs.to(device)
@@ -57,7 +97,7 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
     return probabilities
 
 
-def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
+def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
     Appropriate for filtering dialogue between a user and an LLM.
@@ -70,11 +110,11 @@ def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
     Returns:
         float: The probability of the text containing malicious content.
     """
-    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
+    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
     return probabilities[0, 2].item()
 
 
-def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
+def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
     """
     Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
     Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
@@ -87,11 +127,11 @@ def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device
     Returns:
         float: The combined probability of the text containing malicious or embedded instructions.
     """
-    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
+    probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
     return (probabilities[0, 1] + probabilities[0, 2]).item()
 
 
-def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
+def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu', preprocess=True):
     """
     Process a batch of texts and return their class probabilities.
     Args:
@@ -104,6 +144,8 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
     Returns:
         torch.Tensor: A tensor containing the class probabilities for each text in the batch.
     """
+    if preprocess:
+        texts = [preprocess_text_for_promptguard(text, tokenizer) for text in texts]
     inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
     inputs = inputs.to(device)
     with torch.no_grad():
@@ -113,7 +155,7 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
     return probabilities
 
 
-def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16):
+def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
     Args:
@@ -138,7 +180,7 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
     for i in range(0, len(all_chunks), max_batch_size):
         batch_chunks = all_chunks[i:i+max_batch_size]
         batch_indices = text_indices[i:i+max_batch_size]
-        probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device)
+        probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device, preprocess)
         scores = probabilities[:, score_indices].sum(dim=1).tolist()
         
         for idx, score in zip(batch_indices, scores):
@@ -146,7 +188,7 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
     return all_scores
 
 
-def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
+def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     Compute jailbreak scores for a list of texts.
     Args:
@@ -160,10 +202,10 @@ def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, dev
     Returns:
         list[float]: A list of jailbreak scores for each text.
     """
-    return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size)
+    return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size, preprocess)
 
 
-def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
+def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
     """
     Compute indirect injection scores for a list of texts.
     Args:
@@ -177,4 +219,4 @@ def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature
     Returns:
         list[float]: A list of indirect injection scores for each text.
     """
-    return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size)
+    return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size, preprocess)

+ 1 - 0
recipes/use_cases/multilingual/README.md

@@ -2,6 +2,7 @@
 Authored by : Sarvam team
 In this recipe, we will see how to add a new language to the Llama family of models. The steps are quite general and can be easily adapted to other models as well. Using this recipe, you should be able to replicate the findings of [OpenHathi](https://huggingface.co/sarvamai/OpenHathi-7B-Hi-v0.1-Base).
 Please read more about OpenHathi [here](https://web.archive.org/web/20240418103408/https://www.sarvam.ai/blog/announcing-openhathi-series)
+
 ## Data
 The original OpenHathi model uses a combination of [Sangraha](https://huggingface.co/datasets/ai4bharat/sangraha) and Wikipedia as its primary data sources. If the reader is interested in using these sources, they would also have to preprocess the data: clean, filter, and deduplicate. See [Setu](https://github.com/AI4Bharat/setu) for an easy way to do this at scale.
 

+ 2 - 1
src/llama_recipes/configs/datasets.py

@@ -9,6 +9,7 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
+    trust_remote_code: bool = False
 
 
 @dataclass
@@ -37,4 +38,4 @@ class custom_dataset:
 class llamaguard_toxicchat_dataset:
     dataset: str = "llamaguard_toxicchat_dataset"
     train_split: str = "train"
-    test_split: str = "test"
+    test_split: str = "test"

+ 3 - 1
src/llama_recipes/datasets/samsum_dataset.py

@@ -8,7 +8,9 @@ import datasets
 
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
-    dataset = datasets.load_dataset("samsum", split=split)
+    if not hasattr(dataset_config, "trust_remote_code") or not dataset_config.trust_remote_code:
+        raise ValueError("The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum. To activate `trust_remote_code` option use this config: --samsum_dataset.trust_remote_code=True")
+    dataset = datasets.load_dataset("samsum", split=split, trust_remote_code=dataset_config.trust_remote_code)
 
     prompt = (
         f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"

+ 1 - 0
src/llama_recipes/model_checkpointing/__init__.py

@@ -4,6 +4,7 @@
 from llama_recipes.model_checkpointing.checkpoint_handler import (
     load_model_checkpoint,
     save_model_checkpoint,
+    save_peft_checkpoint,
     load_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,

+ 10 - 1
src/llama_recipes/model_checkpointing/checkpoint_handler.py

@@ -26,6 +26,7 @@ from torch.distributed.checkpoint.default_planner import (
 )
 
 
+from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 import torch.distributed._shard.checkpoint as dist_cp
 import torch.distributed as dist
@@ -264,4 +265,12 @@ def load_sharded_model_single_gpu(model,model_path):
     model.load_state_dict(state_dict["model"])
     
     print(f"Sharded state checkpoint loaded from {model_path}")
-    return model
+    return model
+
+def save_peft_checkpoint(model, model_path):
+    """save_pretrained peft model"""
+
+    options = StateDictOptions(full_state_dict=True, cpu_offload=True)
+
+    state_dict = get_model_state_dict(model, options=options)
+    model.save_pretrained(model_path, state_dict=state_dict)

+ 2 - 2
src/llama_recipes/utils/train_utils.py

@@ -20,7 +20,7 @@ from transformers import LlamaTokenizer
 import json
 
 
-from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
+from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from accelerate.utils import is_xpu_available, is_ccl_available
@@ -235,7 +235,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                     else:
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)
+                    save_peft_checkpoint(model, train_config.output_dir)
                     if train_config.enable_fsdp:
                         if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")