Browse Source

Merging latest from main

Beto 1 year ago
parent
commit
b974c87035
45 changed files with 4763 additions and 326 deletions
  1. 26 6
      README.md
  2. 433 0
      demo_apps/HelloLlamaCloud.ipynb
  3. 347 0
      demo_apps/HelloLlamaLocal.ipynb
  4. 306 0
      demo_apps/LiveData.ipynb
  5. 120 0
      demo_apps/Llama2_Gradio.ipynb
  6. 101 0
      demo_apps/README.md
  7. 559 0
      demo_apps/StructuredLlama.ipynb
  8. 698 0
      demo_apps/VideoSummary.ipynb
  9. 38 0
      demo_apps/csv2db.py
  10. BIN
      demo_apps/llama2-gradio.png
  11. BIN
      demo_apps/llama2-streamlit.png
  12. BIN
      demo_apps/llama2-streamlit2.png
  13. BIN
      demo_apps/llama2.pdf
  14. 1294 0
      demo_apps/nba.txt
  15. 22 0
      demo_apps/streamlit_llama2.py
  16. 53 0
      demo_apps/txt2csv.py
  17. 15 3
      docs/Dataset.md
  18. 21 7
      docs/FAQ.md
  19. 1 0
      examples/Getting_to_know_Llama.ipynb
  20. 1 1
      examples/README.md
  21. 23 30
      examples/custom_dataset.py
  22. 3 6
      examples/quickstart.ipynb
  23. 29 1
      scripts/spellcheck_conf/wordlist.txt
  24. 0 2
      src/llama_recipes/configs/datasets.py
  25. 2 4
      src/llama_recipes/configs/training.py
  26. 2 0
      src/llama_recipes/data/__init__.py
  27. 34 0
      src/llama_recipes/data/concatenator.py
  28. 57 0
      src/llama_recipes/data/sampler.py
  29. 4 14
      src/llama_recipes/datasets/alpaca_dataset.py
  30. 13 18
      src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py
  31. 3 3
      src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb
  32. 19 13
      src/llama_recipes/datasets/samsum_dataset.py
  33. 0 66
      src/llama_recipes/datasets/utils.py
  34. 25 37
      src/llama_recipes/finetuning.py
  35. 49 11
      src/llama_recipes/utils/config_utils.py
  36. 6 6
      src/llama_recipes/utils/dataset_utils.py
  37. 37 33
      src/llama_recipes/utils/train_utils.py
  38. 18 0
      tests/conftest.py
  39. 40 13
      tests/datasets/test_custom_dataset.py
  40. 54 0
      tests/datasets/test_grammar_datasets.py
  41. 30 14
      tests/datasets/test_samsum_datasets.py
  42. 94 0
      tests/test_batching.py
  43. 82 33
      tests/test_finetuning.py
  44. 86 0
      tests/test_sampler.py
  45. 18 5
      tests/test_train_utils.py

+ 26 - 6
README.md

@@ -1,7 +1,13 @@
-# Llama 2 Fine-tuning / Inference Recipes and Examples
+# Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
+
+**[Update Oct. 20, 2023] We have just released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama 2 locally and in the cloud to chat about data (PDF, DB, or live) and generate video summary.**
+
 
 
 The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).
 The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).
 
 
+In addition, we also provide a number of demo apps, to showcase the Llama2 usage along with other ecosystem solutions to run Llama2 locally on your mac and on cloud.
+
+
 Llama 2 is a new technology that carries potential risks with use. Testing conducted to date has not — and could not — cover all scenarios. In order to help developers address these risks, we have created the [Responsible Use Guide](https://github.com/facebookresearch/llama/blob/main/Responsible-Use-Guide.pdf). More details can be found in our research paper as well. For downloading the models, follow the instructions on [Llama 2 repo](https://github.com/facebookresearch/llama).
 Llama 2 is a new technology that carries potential risks with use. Testing conducted to date has not — and could not — cover all scenarios. In order to help developers address these risks, we have created the [Responsible Use Guide](https://github.com/facebookresearch/llama/blob/main/Responsible-Use-Guide.pdf). More details can be found in our research paper as well. For downloading the models, follow the instructions on [Llama 2 repo](https://github.com/facebookresearch/llama).
 
 
 
 
@@ -13,8 +19,9 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
     - [Multi GPU One Node](#multiple-gpus-one-node)
     - [Multi GPU One Node](#multiple-gpus-one-node)
     - [Multi GPU Multi Node](#multi-gpu-multi-node)
     - [Multi GPU Multi Node](#multi-gpu-multi-node)
 4. [Inference](./docs/inference.md)
 4. [Inference](./docs/inference.md)
-5. [Repository Organization](#repository-organization)
-6. [License and Acceptable Use Policy](#license)
+5. [Demo Apps](#demo-apps)
+6. [Repository Organization](#repository-organization)
+7. [License and Acceptable Use Policy](#license)
 
 
 
 
 
 
@@ -130,7 +137,7 @@ Here we make use of Parameter Efficient Methods (PEFT) as described in the next
 
 
 ```bash
 ```bash
 
 
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
 
 
 ```
 ```
 
 
@@ -141,7 +148,7 @@ Here we use FSDP as discussed in the next section which can be used along with P
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 
 ```bash
 ```bash
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
 ```
 ```
 
 
 ### Fine-tuning using FSDP Only
 ### Fine-tuning using FSDP Only
@@ -160,7 +167,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 
 ```bash
 ```bash
 
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 
 ```
 ```
 
 
@@ -174,6 +181,17 @@ sbatch multi_node.slurm
 ```
 ```
 You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md).
 You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md).
 
 
+# Demo Apps
+This folder contains a series of Llama2-powered apps:
+* Quickstart Llama deployments and basic interactions with Llama
+  1. Llama on your Mac and ask Llama general questions
+  2. Llama on Google Colab
+  3. Llama on Cloud and ask Llama questions about unstructured data in a PDF
+
+* Specialized Llama use cases:
+  1. Ask Llama to summarize a video content
+  2. Ask Llama questions about structured data in a DB
+  3. Ask Llama questions about live data on the web
 
 
 # Repository Organization
 # Repository Organization
 This repository is organized in the following way:
 This repository is organized in the following way:
@@ -184,6 +202,8 @@ This repository is organized in the following way:
 
 
 [datasets](src/llama_recipes/datasets/): Contains individual scripts for each dataset to download and process. Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses)
 [datasets](src/llama_recipes/datasets/): Contains individual scripts for each dataset to download and process. Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses)
 
 
+[demo_apps](./demo_apps) contains a series of Llama2-powered apps, from quickstart deployments to how to ask Llama questions about unstructured data, structured data, live data, and video summary.
+
 [examples](./examples/): Contains examples script for finetuning and inference of the Llama 2 model as well as how to use them safely.
 [examples](./examples/): Contains examples script for finetuning and inference of the Llama 2 model as well as how to use them safely.
 
 
 [inference](src/llama_recipes/inference/): Includes modules for inference for the fine-tuned models.
 [inference](src/llama_recipes/inference/): Includes modules for inference for the fine-tuned models.

File diff suppressed because it is too large
+ 433 - 0
demo_apps/HelloLlamaCloud.ipynb


File diff suppressed because it is too large
+ 347 - 0
demo_apps/HelloLlamaLocal.ipynb


+ 306 - 0
demo_apps/LiveData.ipynb

@@ -0,0 +1,306 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "30eb1704-8d76-4bc9-9308-93243aeb69cb",
+   "metadata": {},
+   "source": [
+    "## This demo app shows:\n",
+    "* How to use LlamaIndex, an open source library to help you build custom data augmented LLM applications\n",
+    "* How to ask Llama questions about recent live data via the You.com live search API and LlamaIndex\n",
+    "\n",
+    "The LangChain package is used to facilitate the call to Llama2 hosted on Replicate\n",
+    "\n",
+    "**Note** We will be using Replicate to run the examples here. You will need to first sign in with Replicate with your github account, then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while. \n",
+    "After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "68cf076e",
+   "metadata": {},
+   "source": [
+    "We start by installing the necessary packages:\n",
+    "- [langchain](https://python.langchain.com/docs/get_started/introduction) which provides RAG capabilities\n",
+    "- [llama-index](https://docs.llamaindex.ai/en/stable/) for data augmentation."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1d0005d6-e928-4d1a-981b-534a40e19e56",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!pip install llama-index langchain"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "21fe3849",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# use ServiceContext to configure the LLM used and the custom embeddings \n",
+    "from llama_index import ServiceContext\n",
+    "\n",
+    "# VectorStoreIndex is used to index custom data \n",
+    "from llama_index import VectorStoreIndex\n",
+    "\n",
+    "from langchain.llms import Replicate"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "73e8e661",
+   "metadata": {},
+   "source": [
+    "Next we set up the Replicate token."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d9d76e33",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from getpass import getpass\n",
+    "import os\n",
+    "\n",
+    "REPLICATE_API_TOKEN = getpass()\n",
+    "os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f8ff812b",
+   "metadata": {},
+   "source": [
+    "In this example we will use the [YOU.com](https://you.com/) search engine to augment the LLM's responses.\n",
+    "To use the You.com Search API, you can email api@you.com to request an API key. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "75275628-5235-4b55-8033-601c76107528",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "YOUCOM_API_KEY = getpass()\n",
+    "os.environ[\"YOUCOM_API_KEY\"] = YOUCOM_API_KEY"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "cb210c7c",
+   "metadata": {},
+   "source": [
+    "We then call the Llama 2 model from replicate. \n",
+    "\n",
+    "We will use the llama 2 13b chat model. You can find more Llama 2 models by searching for them on the [Replicate model explore page](https://replicate.com/explore?query=llama).\n",
+    "You can add them here in the format: model_name/version"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c12fc2cb",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# set llm to be using Llama2 hosted on Replicate\n",
+    "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
+    "\n",
+    "llm = Replicate(\n",
+    "    model=llama2_13b_chat,\n",
+    "    model_kwargs={\"temperature\": 0.01, \"top_p\": 1, \"max_new_tokens\":500}\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "476d72da",
+   "metadata": {},
+   "source": [
+    "Using our api key we set up earlier, we make a request from YOU.com for live data on a particular topic."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "effc9656-b18d-4d24-a80b-6066564a838b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "import requests\n",
+    "\n",
+    "query = \"Meta Connect\" # you can try other live data query about sports score, stock market and weather info \n",
+    "headers = {\"X-API-Key\": os.environ[\"YOUCOM_API_KEY\"]}\n",
+    "data = requests.get(\n",
+    "    f\"https://api.ydc-index.io/search?query={query}\",\n",
+    "    headers=headers,\n",
+    ").json()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8bed3baf-742e-473c-ada1-4459012a8a2c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# check the query result in JSON\n",
+    "import json\n",
+    "\n",
+    "print(json.dumps(data, indent=2))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b196e697",
+   "metadata": {},
+   "source": [
+    "We then use the [`JSONLoader`](https://llamahub.ai/l/file-json) to extract the text from the returned data. The `JSONLoader` gives us the ability to load the data into LamaIndex.\n",
+    "In the next cell we show how to load the JSON result with key info stored as \"snippets\".\n",
+    "\n",
+    "However, you can also add the snippets in the query result to documents like below:\n",
+    "```python \n",
+    "from llama_index import Document\n",
+    "snippets = [snippet for hit in data[\"hits\"] for snippet in hit[\"snippets\"]]\n",
+    "documents = [Document(text=s) for s in snippets]\n",
+    "```\n",
+    "This can be handy if you just need to add a list of text strings to doc"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7c40e73f-ca13-4f4a-a753-e613df3d389e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# one way to load the JSON result with key info stored as \"snippets\"\n",
+    "from llama_index import download_loader\n",
+    "\n",
+    "JsonDataReader = download_loader(\"JsonDataReader\")\n",
+    "loader = JsonDataReader()\n",
+    "documents = loader.load_data([hit[\"snippets\"] for hit in data[\"hits\"]])\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8e5e3b4e",
+   "metadata": {},
+   "source": [
+    "With the data set up, we create a vector store for the data and a query engine for it.\n",
+    "\n",
+    "For our embeddings we will use `HuggingFaceEmbeddings` whose default embedding model is sentence-transformers/all-mpnet-base-v2. This model provides a good balance between speed and performance.\n",
+    "To change the default model, call `HuggingFaceEmbeddings(model_name=<another_embedding_model>)`. \n",
+    "\n",
+    "For more info see https://huggingface.co/blog/mteb. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a5de3080-2c4b-479c-baba-793b3bee36ed",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# use HuggingFace embeddings \n",
+    "from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
+    "from llama_index import LangchainEmbedding\n",
+    "\n",
+    "\n",
+    "embeddings = LangchainEmbedding(HuggingFaceEmbeddings())\n",
+    "print(embeddings)\n",
+    "\n",
+    "# create a ServiceContext instance to use Llama2 and custom embeddings\n",
+    "service_context = ServiceContext.from_defaults(llm=llm, chunk_size=800, chunk_overlap=20, embed_model=embeddings)\n",
+    "\n",
+    "# create vector store index from the documents created above\n",
+    "index = VectorStoreIndex.from_documents(documents, service_context=service_context)\n",
+    "\n",
+    "# create query engine from the index\n",
+    "query_engine = index.as_query_engine(streaming=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2c4ea012",
+   "metadata": {},
+   "source": [
+    "We are now ready to ask Llama 2 a question about the live data using our query engine."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "de91a191-d0f2-498e-88dc-b2b43423e0e5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ask Llama2 a summary question about the search result\n",
+    "response = query_engine.query(\"give me a summary\")\n",
+    "response.print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "72814b20-06aa-4da8-b4dd-f0b0d74a2ea0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# more questions\n",
+    "query_engine.query(\"what products were announced\").print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a65bc037-a689-476d-b529-0059a27bc949",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "query_engine.query(\"tell me more about Meta AI assistant\").print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "16a56542",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "query_engine.query(\"what are Generative AI stickers\").print_response_stream()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.18"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

+ 120 - 0
demo_apps/Llama2_Gradio.ipynb

@@ -0,0 +1,120 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "47a9adb3",
+   "metadata": {},
+   "source": [
+    "## This demo app shows how to query Llama 2 using the Gradio UI.\n",
+    "\n",
+    "Since we are using Replicate in this example, you will need to replace `<your replicate api token>` with your API token.\n",
+    "\n",
+    "To get the Replicate token: \n",
+    "\n",
+    "- You will need to first sign in with Replicate with your github account\n",
+    "- Then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while \n",
+    "\n",
+    "**Note** After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate.\n",
+    "\n",
+    "To run this example:\n",
+    "- Set up your Replicate API token and enter it in place of `<your replicate api token>`\n",
+    "- Run the notebook\n",
+    "- Enter your question and click Submit\n",
+    "\n",
+    "In the notebook or a browser with URL http://127.0.0.1:7860 you should see a UI with your answer."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "928041cc",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Init param `input` is deprecated, please use `model_kwargs` instead.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Running on local URL:  http://127.0.0.1:7860\n",
+      "\n",
+      "To create a public link, set `share=True` in `launch()`.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": []
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from langchain.schema import AIMessage, HumanMessage\n",
+    "import gradio as gr\n",
+    "from langchain.llms import Replicate\n",
+    "import os\n",
+    "\n",
+    "os.environ[\"REPLICATE_API_TOKEN\"] = \"<your replicate api token>\"\n",
+    "\n",
+    "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
+    "\n",
+    "llm = Replicate(\n",
+    "    model=llama2_13b_chat,\n",
+    "    model_kwargs={\"temperature\": 0.01, \"top_p\": 1, \"max_new_tokens\":500}\n",
+    ")\n",
+    "\n",
+    "\n",
+    "def predict(message, history):\n",
+    "    history_langchain_format = []\n",
+    "    for human, ai in history:\n",
+    "        history_langchain_format.append(HumanMessage(content=human))\n",
+    "        history_langchain_format.append(AIMessage(content=ai))\n",
+    "    history_langchain_format.append(HumanMessage(content=message))\n",
+    "    gpt_response = llm(message) #history_langchain_format)\n",
+    "    return gpt_response#.content\n",
+    "\n",
+    "gr.ChatInterface(predict).launch()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.18"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

File diff suppressed because it is too large
+ 101 - 0
demo_apps/README.md


File diff suppressed because it is too large
+ 559 - 0
demo_apps/StructuredLlama.ipynb


File diff suppressed because it is too large
+ 698 - 0
demo_apps/VideoSummary.ipynb


+ 38 - 0
demo_apps/csv2db.py

@@ -0,0 +1,38 @@
+import sqlite3
+import csv
+
+# Define the input CSV file and the SQLite database file
+input_csv = 'nba_roster.csv'
+database_file = 'nba_roster.db'
+
+# Connect to the SQLite database
+conn = sqlite3.connect(database_file)
+cursor = conn.cursor()
+
+# Create a table to store the data
+cursor.execute('''CREATE TABLE IF NOT EXISTS nba_roster (
+                    Team TEXT,
+                    NAME TEXT,
+                    Jersey TEXT,
+                    POS TEXT,
+                    AGE INT,
+                    HT TEXT,
+                    WT TEXT,
+                    COLLEGE TEXT,
+                    SALARY TEXT
+                )''')
+
+# Read data from the CSV file and insert it into the SQLite table
+with open(input_csv, 'r', newline='') as csvfile:
+    csv_reader = csv.reader(csvfile)
+    next(csv_reader)  # Skip the header row
+    
+    for row in csv_reader:
+        cursor.execute('INSERT INTO nba_roster VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', row)
+
+# Commit the changes and close the database connection
+conn.commit()
+conn.close()
+
+print(f'Data from {input_csv} has been successfully imported into {database_file}')
+

BIN
demo_apps/llama2-gradio.png


BIN
demo_apps/llama2-streamlit.png


BIN
demo_apps/llama2-streamlit2.png


BIN
demo_apps/llama2.pdf


File diff suppressed because it is too large
+ 1294 - 0
demo_apps/nba.txt


+ 22 - 0
demo_apps/streamlit_llama2.py

@@ -0,0 +1,22 @@
+import streamlit as st
+from langchain.llms import Replicate
+import os
+
+st.title("Llama2-powered Streamlit App")
+
+with st.sidebar:
+    os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"
+
+def generate_response(input_text):
+    llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+    llm = Replicate(
+        model=llama2_13b_chat,
+        model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+    )
+    st.info(llm(input_text))
+
+with st.form("my_form"):
+    text = st.text_area("Enter text:", "What is Generative AI?")
+    submitted = st.form_submit_button("Submit")
+    generate_response(text)

+ 53 - 0
demo_apps/txt2csv.py

@@ -0,0 +1,53 @@
+import csv
+
+# Define the input and output file names
+input_file = 'nba.txt'
+output_file = 'nba_roster.csv'
+
+# Initialize lists to store data
+roster_data = []
+current_team = None
+
+# Open the input file
+with open(input_file, 'r') as file:
+    for line in file:
+        # Remove leading and trailing whitespaces from the line
+        line = line.strip()
+        
+        # Check if the line starts with 'https', skip it
+        if line.startswith('https'):
+            continue
+        
+        # Check if the line contains the team name
+        if 'Roster' in line:
+            current_team = line.split(' Roster ')[0]
+        elif line and "NAME" not in line:  # Skip empty lines and header lines
+            # Split the line using tabs as the delimiter
+            player_info = line.split('\t')
+            
+            # Remove any numbers from the player's name and set Jersey accordingly
+            name = ''.join([c for c in player_info[0] if not c.isdigit()])
+            jersey = ''.join([c for c in player_info[0] if c.isdigit()])
+            
+            # If no number found, set Jersey to "NA"
+            if not jersey:
+                jersey = "NA"
+            
+            # Append the team name, name, and jersey to the player's data
+            player_info = [current_team, name, jersey] + player_info[1:]
+            
+            # Append the player's data to the roster_data list
+            roster_data.append(player_info)
+
+# Write the data to a CSV file
+with open(output_file, 'w', newline='') as csvfile:
+    writer = csv.writer(csvfile)
+    
+    # Write the header row
+    writer.writerow(['Team', 'NAME', 'Jersey', 'POS', 'AGE', 'HT', 'WT', 'COLLEGE', 'SALARY'])
+    
+    # Write the player data
+    writer.writerows(roster_data)
+
+print(f'Conversion completed. Data saved to {output_file}')
+

+ 15 - 3
docs/Dataset.md

@@ -7,6 +7,18 @@ The provided fine tuning script allows you to select between three datasets by p
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
 * [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
 
 
+## Batching Strategies
+Llama-recipes support two strategies to batch requests together.
+The default setting is `packing` which concatenates the tokenized samples into long sequences filling up the context length of the model.
+This is the most compute efficient variant as it avoids any padding and all sequences have the same length.
+Samples at the boundary of the context length are truncated and the remainder of the cut sequence it used as the start of the next long sequence.
+
+If the amount of training data is small this procedure might introduce a lot of noise into the training data which can hurt the prediction performance of the fine-tune model.
+Therefore, we also support a `padding` strategy which does not introduce the addition noise due to truncated sequences.
+The strategy tries to minimize the efficiency loss by batching samples of similar length together so only minimal padding is necessary.
+
+The batching strategy can be selected though the command line parameter `--batching_strategy [packing]/[padding]`.
+
 ## Using custom datasets
 ## Using custom datasets
 
 
 The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
 The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
@@ -23,9 +35,9 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
 For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [examples/custom_dataset.py](../examples/custom_dataset.py).
 For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [examples/custom_dataset.py](../examples/custom_dataset.py).
 The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line.
 The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line.
 The split signals wether to return the training or validation dataset.
 The split signals wether to return the training or validation dataset.
-The default function name is `get_custom_dataset` but this can be changes as described below.
+The default function name is `get_custom_dataset` but this can be changed as described below.
 
 
-In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter. 
+In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter.
 ```
 ```
 python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
 python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
 ```
 ```
@@ -35,7 +47,7 @@ python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.f
 ```
 ```
 This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
 This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
 
 
-### Adding new dataset 
+### Adding new dataset
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 
 
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.

File diff suppressed because it is too large
+ 21 - 7
docs/FAQ.md


File diff suppressed because it is too large
+ 1 - 0
examples/Getting_to_know_Llama.ipynb


+ 1 - 1
examples/README.md

@@ -10,7 +10,7 @@ After installing the llama-recipes package through [pip](../README.md#installati
 ```
 ```
 python -m llama_recipes.finetuning <parameters>
 python -m llama_recipes.finetuning <parameters>
 
 
-python examnples/finetuning.py <parameters>
+python examples/finetuning.py <parameters>
 ```
 ```
 Please see [README.md](../README.md) for details.
 Please see [README.md](../README.md) for details.
 
 

+ 23 - 30
examples/custom_dataset.py

@@ -7,33 +7,27 @@ import copy
 import datasets
 import datasets
 import itertools
 import itertools
 
 
-from llama_recipes.datasets.utils import Concatenator
-
 
 
 B_INST, E_INST = "[INST]", "[/INST]"
 B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
 
 
 def tokenize_dialog(dialog, tokenizer):
 def tokenize_dialog(dialog, tokenizer):
-    dialog_tokens = [
-            tokenizer(
-                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-            )
-            for prompt, answer in zip(dialog[::2], dialog[1::2])
-        ]
-    if len(dialog) % 2:    
-        dialog_tokens += [tokenizer(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )]
-    
-    combined_tokens = {}  
-    for k in dialog_tokens[0].keys():
-        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
-    return combined_tokens
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
 
 
 
 def get_custom_dataset(dataset_config, tokenizer, split):
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
-    
+
     dataset = dataset.map(lambda sample: {
     dataset = dataset.map(lambda sample: {
         "message_id": sample["message_id"],
         "message_id": sample["message_id"],
         "parent_id": sample["parent_id"],
         "parent_id": sample["parent_id"],
@@ -41,19 +35,19 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         },
         batched=True,
         batched=True,
         remove_columns=list(dataset.features),)
         remove_columns=list(dataset.features),)
-    
+
     nodes = {}
     nodes = {}
-    
+
     messages = {}
     messages = {}
     root_ids = []
     root_ids = []
-    
+
     for data in dataset:
     for data in dataset:
         if data["parent_id"]:
         if data["parent_id"]:
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
         else:
             root_ids.append(data["message_id"])
             root_ids.append(data["message_id"])
         messages[data["message_id"]]=data["text"]
         messages[data["message_id"]]=data["text"]
-           
+
     def follow(thread, current_id):
     def follow(thread, current_id):
         thread = copy.copy(thread) + [messages[current_id]]
         thread = copy.copy(thread) + [messages[current_id]]
         if current_id in nodes:
         if current_id in nodes:
@@ -63,18 +57,18 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             return new_threads
             return new_threads
         else:
         else:
             return [thread]
             return [thread]
-        
+
     def get_threads_from_root(root_id):
     def get_threads_from_root(root_id):
         all_threads = []
         all_threads = []
         thread = [messages[root_id]]
         thread = [messages[root_id]]
         for cid in nodes[root_id]:
         for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
             all_threads += follow(thread, cid)
         return all_threads
         return all_threads
-            
+
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
-    
+
     def to_dialog(thread):
     def to_dialog(thread):
         dialog = []
         dialog = []
         for i, content in enumerate(thread):
         for i, content in enumerate(thread):
@@ -83,9 +77,8 @@ def get_custom_dataset(dataset_config, tokenizer, split):
                 "content": content,
                 "content": content,
             })
             })
         return {"dialog": dialog}
         return {"dialog": dialog}
-            
+
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(Concatenator(), batched=True)
-    
-    return dataset
+
+    return dataset

+ 3 - 6
examples/quickstart.ipynb

@@ -32,7 +32,7 @@
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "# %%bash\n",
     "# %%bash\n",
-    "# pip install transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
+    "# pip install llama-recipes transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
     "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
     "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
     "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
     "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
    ]
    ]
@@ -130,11 +130,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "from pathlib import Path\n",
-    "import os\n",
-    "import sys\n",
-    "from utils.dataset_utils import get_preprocessed_dataset\n",
-    "from configs.datasets import samsum_dataset\n",
+    "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n",
+    "from llama_recipes.configs.datasets import samsum_dataset\n",
     "\n",
     "\n",
     "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
     "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
    ]
    ]

+ 29 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1156,4 +1156,32 @@ Autocast
 FN
 FN
 GBs
 GBs
 MLP
 MLP
-learnable
+learnable
+tokenized
+Colab
+GenAI
+Gradio
+HelloLlama
+HelloLlamaCloud
+HelloLlamaLocal
+LLM's
+LangChain
+LangChain's
+LiveData
+LlamaIndex
+MBP
+MLC
+Replicate's
+StructuredLlama
+VideoSummary
+cpp
+envinronment
+ggml
+gguf
+gradio
+minnutes
+pdf
+quantized
+serarch
+streamlit
+

+ 0 - 2
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     train_split: str = "train"
     test_split: str = "validation"
     test_split: str = "validation"
-    input_length: int = 2048
     
     
     
     
 @dataclass
 @dataclass
@@ -17,7 +16,6 @@ class grammar_dataset:
     dataset: str = "grammar_dataset"
     dataset: str = "grammar_dataset"
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
-    input_length: int = 2048
 
 
     
     
 @dataclass
 @dataclass

+ 2 - 4
src/llama_recipes/configs/training.py

@@ -11,6 +11,8 @@ class train_config:
     low_cpu_fsdp: bool=False
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     run_validation: bool=True
     batch_size_training: int=4
     batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
     gradient_accumulation_steps: int=1
     gradient_accumulation_steps: int=1
     num_epochs: int=3
     num_epochs: int=3
     num_workers_dataloader: int=1
     num_workers_dataloader: int=1
@@ -34,7 +36,3 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-
-    
-    
-    

+ 2 - 0
src/llama_recipes/data/__init__.py

@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

+ 34 - 0
src/llama_recipes/data/concatenator.py

@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from tqdm import tqdm
+from itertools import chain
+
+from torch.utils.data import Dataset
+
+
+class ConcatDataset(Dataset):
+    def __init__(self, dataset, chunk_size=4096):
+        self.dataset = dataset
+        self.chunk_size = chunk_size
+
+        self.samples = []
+
+        buffer = {
+            "input_ids": [],
+            "attention_mask": [],
+            "labels": [],
+            }
+
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
+            buffer = {k: v + sample[k] for k,v in buffer.items()}
+
+            while len(next(iter(buffer.values()))) > self.chunk_size:
+                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
+                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
+
+    def __getitem__(self, idx):
+        return self.samples[idx]
+
+    def __len__(self):
+        return len(self.samples)

+ 57 - 0
src/llama_recipes/data/sampler.py

@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+from itertools import islice
+
+import numpy as np
+import torch
+
+
+class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.shuffle = shuffle
+
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+
+        if self.shuffle:
+            random.shuffle(batches)
+
+        for b in batches:
+            yield b
+
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+
+
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
+        random.seed(seed)
+        self.batch_sampler = LengthBasedBatchSampler(
+            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+            )
+        self.num_replicas = num_replicas
+        self.rank = rank
+        
+    def __iter__(self):
+        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
+        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
+         
+    def __len__(self):
+        return len(self.batch_sampler) // self.num_replicas
+            

+ 4 - 14
src/llama_recipes/datasets/alpaca_dataset.py

@@ -24,17 +24,14 @@ PROMPT_DICT = {
 }
 }
 
 
 class InstructionDataset(Dataset):
 class InstructionDataset(Dataset):
-    def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
+    def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
         if partition == "train":
             self.ann = self.ann
             self.ann = self.ann
         else:
         else:
             self.ann = self.ann[:200]
             self.ann = self.ann[:200]
 
 
-        self.max_words = max_words
-        # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
         self.tokenizer = tokenizer
         self.tokenizer = tokenizer
-        # self.tokenizer1 = tokenizer
 
 
     def __len__(self):
     def __len__(self):
         return len(self.ann)
         return len(self.ann)
@@ -57,22 +54,15 @@ class InstructionDataset(Dataset):
         example = torch.tensor(
         example = torch.tensor(
             example, dtype=torch.int64
             example, dtype=torch.int64
         )
         )
-        padding = self.max_words - example.shape[0]
-        if padding > 0:
-            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
-        elif padding < 0:
-            example = example[: self.max_words]
         labels = copy.deepcopy(example)
         labels = copy.deepcopy(example)
         labels[: len(prompt)] = -1
         labels[: len(prompt)] = -1
         example_mask = example.ge(0)
         example_mask = example.ge(0)
         label_mask = labels.ge(0)
         label_mask = labels.ge(0)
         example[~example_mask] = 0
         example[~example_mask] = 0
         labels[~label_mask] = IGNORE_INDEX
         labels[~label_mask] = IGNORE_INDEX
-        example_mask = example_mask.float()
-        label_mask = label_mask.float()
 
 
         return {
         return {
-            "input_ids": example,
-            "labels": labels,
-            "attention_mask":example_mask,
+            "input_ids": example.tolist(),
+            "labels": labels.tolist(),
+            "attention_mask":example_mask.tolist(),
         }
         }

+ 13 - 18
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -10,8 +10,6 @@ from pathlib import Path
 
 
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
 
 
-from llama_recipes.datasets.utils import ConcatDataset
-
 
 
 class grammar(Dataset):
 class grammar(Dataset):
     def __init__(
     def __init__(
@@ -48,24 +46,22 @@ class grammar(Dataset):
 
 
         input_ = example_batch["input"]
         input_ = example_batch["input"]
         target_ = example_batch["target"]
         target_ = example_batch["target"]
-        
-        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
-        sample = self.tokenizer(prompt)
-        
-        return sample
-
-    def __getitem__(self, index):
-        sample = self.convert_to_features(self.dataset["train"][index])
-        source_ids = sample["input_ids"]
 
 
-        src_mask = sample["attention_mask"]
+        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
+        prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
+        label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
 
 
-        return {
-            "input_ids": source_ids,
-            "attention_mask": src_mask,
-            "labels": source_ids.copy(),
+        sample = {
+            "input_ids": prompt_ids + label_ids,
+            "attention_mask": [1] * len(prompt_ids + label_ids),
+            "labels": [-100] * len(prompt_ids) + label_ids
         }
         }
 
 
+        return sample
+
+    def __getitem__(self, index):
+        return self.convert_to_features(self.dataset["train"][int(index)])
+
 
 
 def get_dataset(
 def get_dataset(
     dataset_config, tokenizer, csv_name=None
     dataset_config, tokenizer, csv_name=None
@@ -80,6 +76,5 @@ def get_dataset(
         tokenizer=tokenizer,
         tokenizer=tokenizer,
         csv_name=csv_name,
         csv_name=csv_name,
     )
     )
-    
-    return ConcatDataset(dataset, chunk_size=dataset_config.input_length)
 
 
+    return dataset

+ 3 - 3
src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb

@@ -35,10 +35,10 @@
     "  (\" '\", \"'\"),\n",
     "  (\" '\", \"'\"),\n",
     "  (\" ?\", \"?\"),\n",
     "  (\" ?\", \"?\"),\n",
     "  (\" !\", \"!\"),\n",
     "  (\" !\", \"!\"),\n",
-    "  (\" :\", \"!\"),\n",
-    "  (\" ;\", \"!\"),\n",
+    "  (\" :\", \":\"),\n",
+    "  (\" ;\", \";\"),\n",
     "  (\" n't\", \"n't\"),\n",
     "  (\" n't\", \"n't\"),\n",
-    "  (\" v\", \"n't\"),\n",
+    "  (\" v\", \"v\"),\n",
     "  (\"2 0 0 6\", \"2006\"),\n",
     "  (\"2 0 0 6\", \"2006\"),\n",
     "  (\"5 5\", \"55\"),\n",
     "  (\"5 5\", \"55\"),\n",
     "  (\"4 0 0\", \"400\"),\n",
     "  (\"4 0 0\", \"400\"),\n",

+ 19 - 13
src/llama_recipes/datasets/samsum_dataset.py

@@ -3,31 +3,37 @@
 
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
 
+import copy
 import datasets
 import datasets
 
 
-from llama_recipes.datasets.utils import Concatenator
 
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
     dataset = datasets.load_dataset("samsum", split=split)
 
 
     prompt = (
     prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
     )
     )
 
 
     def apply_prompt_template(sample):
     def apply_prompt_template(sample):
         return {
         return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
-            )
+            "prompt": prompt.format(dialog=sample["dialogue"]),
+            "summary": sample["summary"],
         }
         }
 
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
-        batched=True,
-        remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
+
+    def tokenize_add_label(sample):
+        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
+        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+
+        sample = {
+            "input_ids": prompt + summary,
+            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "labels": [-100] * len(prompt) + summary,
+            }
+
+        return sample
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+
     return dataset
     return dataset

+ 0 - 66
src/llama_recipes/datasets/utils.py

@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from tqdm import tqdm
-from itertools import chain
-
-from torch.utils.data import Dataset
-
-class Concatenator(object):
-    def __init__(self, chunk_size=2048):
-        self.chunk_size=chunk_size
-        self.residual = {"input_ids": [], "attention_mask": []}
-        
-    def __call__(self, batch):
-        concatenated_samples = {
-            k: v + list(chain(*batch[k])) for k, v in self.residual.items()
-        }
-
-        total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
-
-        if total_length >= self.chunk_size:
-            chunk_num = total_length // self.chunk_size
-            result = {
-                k: [
-                    v[i : i + self.chunk_size]
-                    for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
-                ]
-                for k, v in concatenated_samples.items()
-            }
-            self.residual = {
-                k: v[(chunk_num * self.chunk_size) :]
-                for k, v in concatenated_samples.items()
-            }
-        else:
-            result = concatenated_samples
-            self.residual = {k: [] for k in concatenated_samples.keys()}
-
-        result["labels"] = result["input_ids"].copy()
-
-        return result
-
-class ConcatDataset(Dataset):
-    def __init__(self, dataset, chunk_size=4096):
-        self.dataset = dataset
-        self.chunk_size = chunk_size
-        
-        self.samples = []
-        
-        buffer = {
-            "input_ids": [],
-            "attention_mask": [],
-            "labels": [],
-            }
-        
-        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
-            buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
-            while len(next(iter(buffer.values()))) > self.chunk_size:
-                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
-                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
-    def __getitem__(self, idx):
-        return self.samples[idx]
-    
-    def __len__(self):
-        return len(self.samples)

+ 25 - 37
src/llama_recipes/finetuning.py

@@ -5,8 +5,8 @@ import os
 from pkg_resources import packaging
 from pkg_resources import packaging
 
 
 import fire
 import fire
+import random
 import torch
 import torch
-import torch.distributed as dist
 import torch.optim as optim
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
 from torch.distributed.fsdp import (
@@ -14,16 +14,16 @@ from torch.distributed.fsdp import (
 )
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from torch.optim.lr_scheduler import StepLR
-from torch.utils.data import DistributedSampler
 from transformers import (
 from transformers import (
     LlamaForCausalLM,
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaTokenizer,
     LlamaConfig,
     LlamaConfig,
-    default_data_collator,
 )
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
+from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 
 from llama_recipes.utils import fsdp_auto_wrap_policy
 from llama_recipes.utils import fsdp_auto_wrap_policy
@@ -31,6 +31,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     update_config,
     generate_peft_config,
     generate_peft_config,
     generate_dataset_config,
     generate_dataset_config,
+    get_dataloader_kwargs,
 )
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
 
@@ -47,11 +48,13 @@ from llama_recipes.utils.train_utils import (
 
 
 def main(**kwargs):
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
     update_config((train_config, fsdp_config), **kwargs)
 
 
     # Set the seeds for reproducibility
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(train_config.seed)
     torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
+    random.seed(train_config.seed)
 
 
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         setup()
         setup()
@@ -102,14 +105,19 @@ def main(**kwargs):
     if train_config.enable_fsdp and train_config.use_fast_kernels:
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
+        using of Flash Attention or Xformer memory-efficient kernels
         based on the hardware being used. This would speed up fine-tuning.
         based on the hardware being used. This would speed up fine-tuning.
         """
         """
         try:
         try:
             from optimum.bettertransformer import BetterTransformer
             from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model) 
+            model = BetterTransformer.transform(model)
         except ImportError:
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
+    # Load the tokenizer and add special tokens
+    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer.pad_token_id = tokenizer.eos_token_id
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
 
     # Prepare the model for int8 training if quantization is enabled
     # Prepare the model for int8 training if quantization is enabled
@@ -120,14 +128,6 @@ def main(**kwargs):
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
         model.to(torch.bfloat16)
 
 
-    # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
-    tokenizer.add_special_tokens(
-            {
-
-                "pad_token": "<PAD>",
-            }
-        )
     if train_config.use_peft:
     if train_config.use_peft:
         peft_config = generate_peft_config(train_config, kwargs)
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model = get_peft_model(model, peft_config)
@@ -179,43 +179,31 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
 
-    train_sampler = None
-    val_sampler = None
-    if train_config.enable_fsdp:
-        train_sampler = DistributedSampler(
-            dataset_train,
-            rank=dist.get_rank(),
-            num_replicas=dist.get_world_size(),
-            shuffle=True,
-        )
-        if train_config.run_validation:
-            val_sampler = DistributedSampler(
-                dataset_val,
-                rank=dist.get_rank(),
-                num_replicas=dist.get_world_size(),
-            )
+    if train_config.batching_strategy == "packing":
+        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
 
 
     # Create DataLoaders for the training and validation dataset
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
         dataset_train,
-        batch_size=train_config.batch_size_training,
         num_workers=train_config.num_workers_dataloader,
         num_workers=train_config.num_workers_dataloader,
         pin_memory=True,
         pin_memory=True,
-        sampler=train_sampler if train_sampler else None,
-        drop_last=True,
-        collate_fn=default_data_collator,
+        **train_dl_kwargs,
     )
     )
 
 
     eval_dataloader = None
     eval_dataloader = None
     if train_config.run_validation:
     if train_config.run_validation:
+        if train_config.batching_strategy == "packing":
+            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             dataset_val,
-            batch_size=train_config.val_batch_size,
             num_workers=train_config.num_workers_dataloader,
             num_workers=train_config.num_workers_dataloader,
             pin_memory=True,
             pin_memory=True,
-            sampler=val_sampler if val_sampler else None,
-            drop_last=True,
-            collate_fn=default_data_collator,
+            **val_dl_kwargs,
         )
         )
 
 
     # Initialize the optimizer and learning rate scheduler
     # Initialize the optimizer and learning rate scheduler

+ 49 - 11
src/llama_recipes/utils/config_utils.py

@@ -3,13 +3,19 @@
 
 
 import inspect
 import inspect
 from dataclasses import asdict
 from dataclasses import asdict
+
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
 from peft import (
 from peft import (
     LoraConfig,
     LoraConfig,
     AdaptionPromptConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
     PrefixTuningConfig,
 )
 )
+from transformers import default_data_collator
+from transformers.data import DataCollatorForSeq2Seq
 
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
 
 
 
@@ -32,31 +38,63 @@ def update_config(config, **kwargs):
                         print(f"Warning: {config_name} does not accept parameter: {k}")
                         print(f"Warning: {config_name} does not accept parameter: {k}")
             elif isinstance(config, train_config):
             elif isinstance(config, train_config):
                 print(f"Warning: unknown parameter {k}")
                 print(f"Warning: unknown parameter {k}")
-                        
-                        
+
+
 def generate_peft_config(train_config, kwargs):
 def generate_peft_config(train_config, kwargs):
     configs = (lora_config, llama_adapter_config, prefix_config)
     configs = (lora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
-    
+
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-    
+
     config = configs[names.index(train_config.peft_method)]()
     config = configs[names.index(train_config.peft_method)]()
-    
+
     update_config(config, **kwargs)
     update_config(config, **kwargs)
     params = asdict(config)
     params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-    
+
     return peft_config
     return peft_config
 
 
 
 
 def generate_dataset_config(train_config, kwargs):
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
     names = tuple(DATASET_PREPROC.keys())
-        
+
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-    
+
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-        
+
     update_config(dataset_config, **kwargs)
     update_config(dataset_config, **kwargs)
-    
-    return  dataset_config
+
+    return  dataset_config
+
+
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
+        kwargs = {}
+        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp:
+                kwargs["sampler"] = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
+        else:
+            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+
+        return kwargs

+ 6 - 6
src/llama_recipes/utils/dataset_utils.py

@@ -33,24 +33,24 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         module_path, func_name = dataset_config.file.split(":")
         module_path, func_name = dataset_config.file.split(":")
     else:
     else:
         module_path, func_name = dataset_config.file, "get_custom_dataset"
         module_path, func_name = dataset_config.file, "get_custom_dataset"
-        
+
     if not module_path.endswith(".py"):
     if not module_path.endswith(".py"):
         raise ValueError(f"Dataset file {module_path} is not a .py file.")
         raise ValueError(f"Dataset file {module_path} is not a .py file.")
-    
+
     module_path = Path(module_path)
     module_path = Path(module_path)
     if not module_path.is_file():
     if not module_path.is_file():
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
-    
+
     module = load_module_from_py_file(module_path.as_posix())
     module = load_module_from_py_file(module_path.as_posix())
     try:
     try:
         return getattr(module, func_name)(dataset_config, tokenizer, split)
         return getattr(module, func_name)(dataset_config, tokenizer, split)
     except AttributeError as e:
     except AttributeError as e:
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
         raise e
-    
+
 
 
 DATASET_PREPROC = {
 DATASET_PREPROC = {
-    "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
+    "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
     "custom_dataset": get_custom_dataset,
@@ -69,7 +69,7 @@ def get_preprocessed_dataset(
             if split == "train"
             if split == "train"
             else dataset_config.test_split
             else dataset_config.test_split
         )
         )
-    
+
     return DATASET_PREPROC[dataset_config.dataset](
     return DATASET_PREPROC[dataset_config.dataset](
         dataset_config,
         dataset_config,
         tokenizer,
         tokenizer,

+ 37 - 33
src/llama_recipes/utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import os
 import time
 import time
 import yaml
 import yaml
+from contextlib import nullcontext
 from pathlib import Path
 from pathlib import Path
 from pkg_resources import packaging
 from pkg_resources import packaging
 from datetime import datetime
 from datetime import datetime
@@ -20,14 +21,14 @@ import json
 
 
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from llama_recipes.utils.memory_utils import MemoryTrace
 
 
 
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
     tokenizer.padding_side = "left"
-    
+
 # Converting Bytes to Megabytes
 # Converting Bytes to Megabytes
 def byte2mb(x):
 def byte2mb(x):
     return int(x / 2**20)
     return int(x / 2**20)
@@ -35,7 +36,7 @@ def byte2mb(x):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
     """
     """
     Trains the model on the given dataloader
     Trains the model on the given dataloader
-    
+
     Args:
     Args:
         model: The model to be trained
         model: The model to be trained
         train_dataloader: The dataloader containing the training data
         train_dataloader: The dataloader containing the training data
@@ -47,18 +48,20 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_config: The training configuration
         train_config: The training configuration
         eval_dataloader: The dataloader containing the eval data
         eval_dataloader: The dataloader containing the eval data
         tokenizer: tokenizer used in the eval for decoding the predicitons
         tokenizer: tokenizer used in the eval for decoding the predicitons
-    
+
     Returns: results dictionary containing average training and validation perplexity and loss
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     """
     # Create a gradient scaler for fp16
     # Create a gradient scaler for fp16
     if train_config.use_fp16 and train_config.enable_fsdp:
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
     elif train_config.use_fp16 and not train_config.enable_fsdp:
-        scaler = torch.cuda.amp.GradScaler() 
+        scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"]) 
         world_size = int(os.environ["WORLD_SIZE"]) 
 
 
     metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
     metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+
     train_prep = []
     train_prep = []
     train_loss = []
     train_loss = []
     train_step_perplexity = []
     train_step_perplexity = []
@@ -85,8 +88,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                         batch[key] = batch[key].to(local_rank)
                     else:
                     else:
-                        batch[key] = batch[key].to('cuda:0')              
-                loss = model(**batch).loss
+                        batch[key] = batch[key].to('cuda:0')
+                with autocast():
+                    loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 loss = loss / gradient_accumulation_steps
                 train_step_loss.append(loss.detach().float().item())
                 train_step_loss.append(loss.detach().float().item())
                 train_step_perplexity.append(float(torch.exp(loss.detach().float())))
                 train_step_perplexity.append(float(torch.exp(loss.detach().float())))
@@ -109,9 +113,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
             pbar.close()
-                
+
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_end_time = time.perf_counter()-epoch_start_time
-        epoch_times.append(epoch_end_time)    
+        epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -136,10 +140,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
             print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
             print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        
+
         # Update the learning rate as needed
         # Update the learning rate as needed
         lr_scheduler.step()
         lr_scheduler.step()
-          
+
         if train_config.run_validation:
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             val_step_loss.extend(temp_val_loss)
             val_step_loss.extend(temp_val_loss)
@@ -154,23 +158,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                             print(f"we are about to save the PEFT modules")
                     else:
                     else:
                         print(f"we are about to save the PEFT modules")
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)  
+                    model.save_pretrained(train_config.output_dir)
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
-                        if rank==0: 
+                        if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                     else:
                     else:
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                        
+
                 else:
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                        
+
                         save_model_checkpoint(
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
                         print("=====================================================")
-                        
+
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                         if train_config.save_optimizer:
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
@@ -182,7 +186,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")                     
+                        print("=====================================================")
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     dist.barrier()
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
@@ -210,8 +214,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep) 
-        avg_eval_loss = sum(val_loss)/len(val_loss) 
+        avg_eval_prep = sum(val_prep)/len(val_prep)
+        avg_eval_loss = sum(val_loss)/len(val_loss)
 
 
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_loss'] = avg_train_loss
     results['avg_train_loss'] = avg_train_loss
@@ -220,27 +224,27 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
-    
+
     #saving the training params including fsdp setting for reference.
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
         save_train_params(train_config, fsdp_config, rank)
-        
+
     return results
     return results
 
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     """
     """
     Evaluates the model on the given dataloader
     Evaluates the model on the given dataloader
-    
+
     Args:
     Args:
         model: The model to evaluate
         model: The model to evaluate
         eval_dataloader: The dataloader containing the evaluation data
         eval_dataloader: The dataloader containing the evaluation data
         local_rank: The rank of the current node in a distributed setting
         local_rank: The rank of the current node in a distributed setting
         tokenizer: The tokenizer used to decode predictions
         tokenizer: The tokenizer used to decode predictions
-    
+
     Returns: eval_ppl, eval_epoch_loss
     Returns: eval_ppl, eval_epoch_loss
     """
     """
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     model.eval()
     eval_preds = []
     eval_preds = []
     val_step_loss = []
     val_step_loss = []
@@ -266,17 +270,17 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             eval_preds.extend(
             eval_preds.extend(
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
             )
             )
-    
+
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
-    
+
     # Compute average loss and perplexity
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         eval_epoch_loss = eval_epoch_loss/world_size
         eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     eval_ppl = torch.exp(eval_epoch_loss)
-    
+
     # Print evaluation metrics
     # Print evaluation metrics
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         if local_rank==0:
         if local_rank==0:
@@ -297,8 +301,8 @@ def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
      for i, layer in enumerate(model.base_model.model.model.layers):
             for name, param in layer.named_parameters():
             for name, param in layer.named_parameters():
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
-                
-                
+
+
 def setup():
 def setup():
     """Initialize the process group for distributed training"""
     """Initialize the process group for distributed training"""
     dist.init_process_group("nccl")
     dist.init_process_group("nccl")
@@ -311,7 +315,7 @@ def setup_environ_flags(rank):
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
-    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
     if rank == 0:
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
         print(f"--> Running with torch dist debug set to detail")
 
 
@@ -356,7 +360,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
 
 
 def get_policies(cfg, rank):
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
     """Get the policies for mixed precision and fsdp wrapping"""
-    
+
     verify_bfloat_support = (
     verify_bfloat_support = (
     torch.version.cuda
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
     and torch.cuda.is_bf16_supported()
@@ -374,7 +378,7 @@ def get_policies(cfg, rank):
         bf16_ready = verify_bfloat_support
         bf16_ready = verify_bfloat_support
 
 
         if bf16_ready and not cfg.use_fp16:
         if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen_mixed
+            mixed_precision_policy = bfSixteen
             if rank == 0:
             if rank == 0:
                 print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
                 print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
         elif cfg.use_fp16:
         elif cfg.use_fp16:
@@ -392,7 +396,7 @@ def save_train_params(train_config, fsdp_config, rank):
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     It also would be hepful as a log for future references.
     It also would be hepful as a log for future references.
     """
     """
-    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
     # converting all values to strings to ensure they can be serialized into a YAML file
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}

+ 18 - 0
tests/conftest.py

@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+
+from transformers import LlamaTokenizer
+
+
+@pytest.fixture
+def setup_tokenizer():
+    def _helper(tokenizer):
+        #Align with Llama 2 tokenizer
+        tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+        tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
+        tokenizer.from_pretrained.return_value.bos_token_id = 1
+        tokenizer.from_pretrained.return_value.eos_token_id = 2
+
+    return _helper

+ 40 - 13
tests/datasets/test_custom_dataset.py

@@ -4,21 +4,38 @@
 import pytest
 import pytest
 from unittest.mock import patch
 from unittest.mock import patch
 
 
+from transformers import LlamaTokenizer
+
+def check_padded_entry(batch):
+    seq_len = sum(batch["attention_mask"][0])
+    assert seq_len < len(batch["attention_mask"][0])
+
+    assert batch["labels"][0][0] == -100
+    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][-1] == -100
+    assert batch["input_ids"][0][0] == 1
+    assert batch["input_ids"][0][-1] == 2
+
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
 
 
+    setup_tokenizer(tokenizer)
+
     kwargs = {
     kwargs = {
         "dataset": "custom_dataset",
         "dataset": "custom_dataset",
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
         "batch_size_training": 2,
+        "val_batch_size": 4,
         "use_peft": False,
         "use_peft": False,
+        "batching_strategy": "padding"
         }
         }
 
 
     main(**kwargs)
     main(**kwargs)
@@ -30,24 +47,34 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     eval_dataloader = args[2]
     tokenizer = args[3]
     tokenizer = args[3]
 
 
-    assert len(train_dataloader) == 226
-    assert len(eval_dataloader) == 2*226
+    assert len(train_dataloader) == 1120
+    assert len(eval_dataloader) == 1120 //2
+
+    it = iter(eval_dataloader)
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
+    assert STRING.startswith(EXPECTED_STRING)
+
+    assert batch["input_ids"].size(0) == 4
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
 
     it = iter(train_dataloader)
     it = iter(train_dataloader)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+    for _ in range(5):
+        next(it)
 
 
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
     assert STRING.startswith(EXPECTED_STRING)
     assert STRING.startswith(EXPECTED_STRING)
 
 
-    next(it)
-    next(it)
-    next(it)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
-    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
+    assert batch["input_ids"].size(0) == 2
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
 
-    assert EXPECTED_SUBSTRING_1 in STRING
-    assert EXPECTED_SUBSTRING_2 in STRING
 
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')

+ 54 - 0
tests/datasets/test_grammar_datasets.py

@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+from transformers import LlamaTokenizer
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "grammar_dataset",
+        "batching_strategy": "padding",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    VAL_SAMPLES = 2988
+    TRAIN_SAMPLES = 13016
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][29] == -100
+    assert batch["labels"][0][30] == 29871
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 30 - 14
tests/datasets/test_samsum_datasets.py

@@ -1,37 +1,53 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+from functools import partial
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
-        
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
-    
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
     kwargs = {
     kwargs = {
-        "batch_size_training": 1,
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
         "use_peft": False,
         "use_peft": False,
         "dataset": "samsum_dataset",
         "dataset": "samsum_dataset",
+        "batching_strategy": "padding",
         }
         }
-    
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     train_dataloader = args[1]
     train_dataloader = args[1]
     eval_dataloader = args[2]
     eval_dataloader = args[2]
-    
+
     VAL_SAMPLES = 818
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
     TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
     assert len(eval_dataloader) == VAL_SAMPLES
-    
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][268] == -100
+    assert batch["labels"][0][269] == 22291
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 94 - 0
tests/test_batching.py

@@ -0,0 +1,94 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96
+    assert len(eval_dataloader) == 42
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0].size(0) == 4096
+    assert batch["input_ids"][0].size(0) == 4096
+    assert batch["attention_mask"][0].size(0) == 4096
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+@patch('llama_recipes.finetuning.setup')
+@patch('llama_recipes.finetuning.FSDP')
+@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
+@patch('llama_recipes.utils.config_utils.dist')
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+    import os
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    rank = 0
+    os.environ['LOCAL_RANK'] = f'{rank}'
+    os.environ['RANK'] = f'{rank}'
+    os.environ['WORLD_SIZE'] = '2'
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12345'
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        "enable_fsdp": True
+        }
+
+    is_initialized.return_value = True
+    dist.get_rank.return_value = rank
+    dist.get_world_size.return_value = 2
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96 //2
+    assert len(eval_dataloader) == 42 //2

+ 82 - 33
tests/test_finetuning.py

@@ -1,14 +1,26 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+import pytest
 from pytest import approx
 from pytest import approx
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 from torch.nn import Linear
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
 from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.sampler import BatchSampler
 
 
 from llama_recipes.finetuning import main
 from llama_recipes.finetuning import main
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+
+
+def get_fake_dataset():
+    return [{
+        "input_ids":[1],
+        "attention_mask":[1],
+        "labels":[1],
+        }]
+
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,23 +30,23 @@ from llama_recipes.finetuning import main
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
     kwargs = {"run_validation": False}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     train_dataloader = args[1]
     train_dataloader = args[1]
     eval_dataloader = args[2]
     eval_dataloader = args[2]
-    
+
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
     assert eval_dataloader is None
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -43,21 +55,22 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     train_dataloader = args[1]
     train_dataloader = args[1]
     eval_dataloader = args[2]
     eval_dataloader = args[2]
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -68,37 +81,73 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
     kwargs = {"use_peft": True}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
+def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
     kwargs = {"weight_decay": 0.01}
     kwargs = {"weight_decay": 0.01}
-    
-    get_dataset.return_value = [1]
-    
-    get_peft_model.return_value = Linear(1,1)
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+
+    get_dataset.return_value = get_fake_dataset()
+
+    get_model.return_value = Linear(1,1)
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     optimizer = args[4]
     optimizer = args[4]
-    
+
     print(optimizer.state_dict())
     print(optimizer.state_dict())
-    
+
     assert isinstance(optimizer, AdamW)
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
-    
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+    kwargs = {"batching_strategy": "packing"}
+
+    get_dataset.return_value = get_fake_dataset()
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+
+    kwargs["batching_strategy"] = "padding"
+    train.reset_mock()
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
+
+    kwargs["batching_strategy"] = "none"
+
+    with pytest.raises(ValueError):
+        main(**kwargs)

+ 86 - 0
tests/test_sampler.py

@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+import pytest
+
+import torch
+
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+from llama_recipes.data.sampler import DistributedLengthBasedBatchSampler
+
+SAMPLES = 33
+
+@pytest.fixture
+def dataset():
+    random.seed(42)
+    dataset = []
+    def add_samples(ds, n, a, b):
+        for _ in range(n):
+            ds.append(random.randint(a,b) * [1,])
+    add_samples(dataset, SAMPLES // 2, 1,9)
+    add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
+    
+    return random.sample(dataset, len(dataset))
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_array(dataset, batch_size, drop_last):
+    
+    sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    all_ids = [i for b in sampler for i in b]
+    assert len(set(all_ids)) == EXPECTED_LENGTH * batch_size if drop_last else len(dataset)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_dict(dataset, batch_size, drop_last):
+    
+    dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
+    
+    sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size", [2, 8])
+def test_dist_batch_sampling(dataset, batch_size):
+    sampler_1 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=0,
+        num_replicas=2,
+        shuffle=False,
+    )
+    sampler_2 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=1,
+        num_replicas=2,
+        shuffle=False,
+    )
+    
+    ids_1 = set(i for b in sampler_1 for i in b)
+    ids_2 = set(i for b in sampler_2 for i in b)
+    
+    assert ids_1.isdisjoint(ids_2)
+    assert len(ids_1)+len(ids_2) > 0
+    assert len(ids_1)+len(ids_2) == len(dataset) // batch_size  *  batch_size 

+ 18 - 5
tests/test_train_utils.py

@@ -1,17 +1,22 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+from unittest.mock import patch
+
 import torch
 import torch
 
 
 from llama_recipes.utils.train_utils import train
 from llama_recipes.utils.train_utils import train
 
 
-def test_gradient_accumulation(mocker):
-    # import sys
-    # sys.path.append('/home/ubuntu/llama-recipes/')
+@patch("llama_recipes.utils.train_utils.MemoryTrace")
+@patch("llama_recipes.utils.train_utils.nullcontext")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
+def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
     
     
     model = mocker.MagicMock(name="model")
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
-    batch = {"input": torch.zeros(1)}
+    mock_tensor = mocker.MagicMock(name="tensor")
+    batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
     train_dataloader = [batch, batch, batch, batch, batch]
     eval_dataloader = None
     eval_dataloader = None
     tokenizer = mocker.MagicMock()
     tokenizer = mocker.MagicMock()
@@ -37,7 +42,13 @@ def test_gradient_accumulation(mocker):
     assert optimizer.zero_grad.call_count == 5
     assert optimizer.zero_grad.call_count == 5
     optimizer.zero_grad.reset_mock()
     optimizer.zero_grad.reset_mock()
     
     
+    assert nullcontext.call_count == 5
+    nullcontext.reset_mock()
+    
+    assert autocast.call_count == 0
+    
     gradient_accumulation_steps = 2
     gradient_accumulation_steps = 2
+    train_config.use_fp16 = True
     train(
     train(
         model,
         model,
         train_dataloader,
         train_dataloader,
@@ -48,4 +59,6 @@ def test_gradient_accumulation(mocker):
         gradient_accumulation_steps,
         gradient_accumulation_steps,
         train_config,
         train_config,
     )
     )
-    assert optimizer.zero_grad.call_count == 3
+    assert optimizer.zero_grad.call_count == 3
+    assert nullcontext.call_count == 0
+    assert autocast.call_count == 5