Parcourir la source

merged from main

Kai Wu il y a 6 mois
Parent
commit
b013d27dc2
56 fichiers modifiés avec 4709 ajouts et 1321 suppressions
  1. 8 0
      .github/scripts/spellcheck_conf/wordlist.txt
  2. 0 0
      .watchman-cookie-devgpu003.cco3.facebook.com-3137776-2746
  3. 25 43
      README.md
  4. 3 2
      pyproject.toml
  5. 2 2
      recipes/3p_integrations/lamini/text2sql_memory_tuning/util/get_default_finetune_args.py
  6. 1 1
      recipes/quickstart/Running_Llama3_Anywhere/Running_Llama_on_Mac_Windows_Linux.ipynb
  7. 17 6
      recipes/quickstart/finetuning/datasets/custom_dataset.py
  8. 90 0
      recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py
  9. 33 0
      recipes/quickstart/finetuning/finetune_vision_model.md
  10. 26 49
      recipes/quickstart/finetuning/quickstart_peft_finetuning.ipynb
  11. 12 1
      recipes/quickstart/inference/local_inference/README.md
  12. 5 1
      recipes/quickstart/inference/local_inference/inference.py
  13. 66 0
      recipes/quickstart/inference/local_inference/multi_modal_infer.py
  14. 0 384
      recipes/responsible_ai/Purple_Llama_Anyscale.ipynb
  15. 0 289
      recipes/responsible_ai/Purple_Llama_OctoAI.ipynb
  16. 1 3
      recipes/responsible_ai/README.md
  17. 0 268
      recipes/responsible_ai/input_output_guardrails_with_llama.ipynb
  18. 5 55
      recipes/responsible_ai/llama_guard/README.md
  19. 0 75
      recipes/responsible_ai/llama_guard/inference.py
  20. 576 0
      recipes/responsible_ai/llama_guard/llama_guard_text_and_vision_inference.ipynb
  21. BIN
      recipes/responsible_ai/llama_guard/resources/dog.jpg
  22. BIN
      recipes/responsible_ai/llama_guard/resources/pasta.jpeg
  23. 3 0
      recipes/use_cases/README.md
  24. 60 0
      recipes/use_cases/github_triage/README.md
  25. 144 0
      recipes/use_cases/github_triage/config.yaml
  26. 165 0
      recipes/use_cases/github_triage/llm.py
  27. 1846 0
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/annotated_issues.csv
  28. 6 0
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/challenges.csv
  29. 2 0
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/overview.csv
  30. BIN
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/commits.png
  31. BIN
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/engagement_sankey.png
  32. BIN
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/expertise.png
  33. BIN
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/sentiment.png
  34. BIN
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/severity.png
  35. BIN
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/themes.png
  36. BIN
      recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/report.pdf
  37. 141 0
      recipes/use_cases/github_triage/pdf_report.py
  38. 178 0
      recipes/use_cases/github_triage/plots.py
  39. 6 0
      recipes/use_cases/github_triage/requirements.txt
  40. 240 0
      recipes/use_cases/github_triage/triage.py
  41. 98 0
      recipes/use_cases/github_triage/utils.py
  42. 659 0
      recipes/use_cases/github_triage/walkthrough.ipynb
  43. 2 6
      requirements.txt
  44. 1 1
      src/llama_recipes/configs/fsdp.py
  45. 14 1
      src/llama_recipes/datasets/__init__.py
  46. 57 0
      src/llama_recipes/datasets/custom_dataset.py
  47. 65 21
      src/llama_recipes/finetuning.py
  48. 2 1
      src/llama_recipes/model_checkpointing/__init__.py
  49. 19 5
      src/llama_recipes/model_checkpointing/checkpoint_handler.py
  50. 3 3
      src/llama_recipes/policies/wrapping.py
  51. 41 27
      src/llama_recipes/utils/config_utils.py
  52. 31 55
      src/llama_recipes/utils/dataset_utils.py
  53. 2 4
      src/llama_recipes/utils/fsdp_utils.py
  54. 25 16
      src/llama_recipes/utils/train_utils.py
  55. 1 1
      src/tests/conftest.py
  56. 28 1
      src/tests/datasets/test_custom_dataset.py

+ 8 - 0
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1458,3 +1458,11 @@ Multistep
 multistep
 algorithmically
 asymptote
+Triaging
+matplotlib
+remediations
+walkthrough
+OCRVQA
+OCRVQADataCollator
+ocrvqa
+langchain

+ 0 - 0
.watchman-cookie-devgpu003.cco3.facebook.com-3137776-2746


Fichier diff supprimé car celui-ci est trop grand
+ 25 - 43
README.md


+ 3 - 2
pyproject.toml

@@ -4,13 +4,13 @@ build-backend = "hatchling.build"
 
 [project]
 name = "llama-recipes"
-version = "0.0.3"
+version = "0.0.4.post1"
 authors = [
   { name="Hamid Shojanazeri", email="hamidnazeri@meta.com" },
   { name="Matthias Reso", email="mreso@meta.com" },
   { name="Geeta Chauhan", email="gchauhan@meta.com" },
 ]
-description = "Llama-recipes is a companion project to the Llama 2 model. It's goal is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. "
+description = "Llama-recipes is a companion project to the Llama models. It's goal is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models."
 readme = "README.md"
 requires-python = ">=3.8"
 classifiers = [
@@ -24,6 +24,7 @@ dynamic = ["dependencies"]
 vllm = ["vllm"]
 tests = ["pytest-mock"]
 auditnlg = ["auditnlg"]
+langchain = ["langchain_openai", "langchain", "langchain_community"]
 
 [project.urls]
 "Homepage" = "https://github.com/facebookresearch/llama-recipes/"

+ 2 - 2
recipes/3p_integrations/lamini/text2sql_memory_tuning/util/get_default_finetune_args.py

@@ -1,7 +1,7 @@
 def get_default_finetune_args():
     return {
-        "learning_rate": 3e-4,
-        "max_steps": 360,
+        "learning_rate": 0.0003,
+        "max_steps": 60,
         "early_stopping": False,
         "load_best_model_at_end": False,
         "peft_args": {"r_value": 32},

+ 1 - 1
recipes/quickstart/Running_Llama3_Anywhere/Running_Llama_on_Mac_Windows_Linux.ipynb

@@ -81,7 +81,7 @@
     "\n",
     "def llama3(prompt):\n",
     "    data = {\n",
-    "        \"model\": \"llama3\",\n",
+    "        \"model\": \"llama3.1\",\n",
     "        \"messages\": [\n",
     "            {\n",
     "              \"role\": \"user\",\n",

+ 17 - 6
recipes/quickstart/finetuning/datasets/custom_dataset.py

@@ -9,19 +9,30 @@ import itertools
 
 
 B_INST, E_INST = "[INST]", "[/INST]"
+EOT_ID = 128009 #<|eot_id|>
+
+def mask_target(target,seq):
+    for i in range(len(seq)-len(target)):
+        if seq[i:i+len(target)] == target:
+            seq[i:i+len(target)] = [-100] * len(target)
+    return seq
 
 def tokenize_dialog(dialog, tokenizer):
     if tokenizer.vocab_size >= 128000:
         dialog_tokens = tokenizer.apply_chat_template(dialog)
-        dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
-        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
+        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == EOT_ID]
         labels = copy.copy(dialog_tokens)
-        last_idx = 0
+        #determine token for system and user 
+        system_or_user = (tokenizer.encode("system")[-1], tokenizer.encode("user")[-1])
+        labels[0] = -100 # bos token
+        last_idx = 1
         for n, idx in enumerate(eot_indices):
-            if n % 2 == 1:
-                last_idx = idx
-            else:
+            role_token = labels[last_idx+1]
+            if role_token in system_or_user:
+                # Set labels to -100 for system and user tokens to ignore in loss function
                 labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+            last_idx = idx + 1
+        mask_target(tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False), labels)
 
         dialog_tokens = [dialog_tokens]
         labels_tokens = [labels]

+ 90 - 0
recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py

@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+from datasets import load_dataset
+import itertools
+import torch
+
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
+def tokenize_dialogs(dialogs, images, processor):
+    text_prompt = processor.apply_chat_template(dialogs)
+    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
+    label_list = []
+    for i in range(len(batch["input_ids"])):
+        dialog_tokens = batch["input_ids"][i].tolist()
+        labels = copy.copy(dialog_tokens)
+        eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+        last_idx = 0
+        # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+        # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+        prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+        for n, idx in enumerate(eot_indices):
+            current_seq = labels[last_idx:idx+1]
+            if check_header(prompt_header_seqs,current_seq):
+                # found prompt header, indicating that this seq should be masked
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+            else:
+                last_idx = idx+1
+            #  Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+        assistant_header_seq = [128006, 78191, 128007]
+        labels = replace_target(assistant_header_seq,labels)
+        # Mask the padding token and image token 128256 
+        for i in range(len(labels)):
+            if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: #  128256 is image token index
+                labels[i] = -100
+        label_list.append(labels)
+    batch["labels"] = torch.tensor(label_list)
+    return batch
+
+
+def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
+    # load_dataset will return DatasetDict that contains all the data in the train set
+    dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
+    dataset = dataset_dict['train']
+    # Comment out the following line to use the full dataset, for quick testing only use 2000 samples
+    dataset = dataset.select(range(2000))
+    dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
+    return dataset
+
+class OCRVQADataCollator:
+    def __init__(self, processor):
+        self.processor = processor
+        self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
+    def __call__(self, samples):
+        dialogs,images = [],[]
+        for sample in samples:
+            image_list,sample_list = sample["images"],sample["texts"]
+            if len(image_list) > 1:
+                raise ValueError("Only support one image per sample")
+            image = image_list[0].convert("RGB") # only use the first image
+            dialog = []
+            for sample_dict in sample_list:
+                if not dialog:
+                    # only append image to the first sentence
+                    dialog += [
+                    {"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+                
+                else:
+                    dialog += [
+                    {"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+            dialogs.append(dialog)
+            images.append([image])
+        return tokenize_dialogs(dialogs,images, self.processor)
+def get_data_collator(processor):
+    return OCRVQADataCollator(processor)

Fichier diff supprimé car celui-ci est trop grand
+ 33 - 0
recipes/quickstart/finetuning/finetune_vision_model.md


+ 26 - 49
recipes/quickstart/finetuning/quickstart_peft_finetuning.ipynb

@@ -65,7 +65,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c7963d43806d432aaa3d00e2055e355c",
+       "model_id": "68838a4f42f84545912e95b339a31034",
        "version_major": 2,
        "version_minor": 0
       },
@@ -75,13 +75,6 @@
      },
      "metadata": {},
      "output_type": "display_data"
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
-     ]
     }
    ],
    "source": [
@@ -101,6 +94,7 @@
     "train_config.context_length = 1024 if torch.cuda.get_device_properties(0).total_memory < 16e9 else 2048 # T4 16GB or A10 24GB\n",
     "train_config.batching_strategy = \"packing\"\n",
     "train_config.output_dir = \"meta-llama-samsum\"\n",
+    "train_config.use_peft = True\n",
     "\n",
     "from transformers import BitsAndBytesConfig\n",
     "config = BitsAndBytesConfig(\n",
@@ -205,7 +199,7 @@
     "model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
     "\n",
     "model.eval()\n",
-    "with torch.no_grad():\n",
+    "with torch.inference_mode():\n",
     "    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))"
    ]
   },
@@ -230,34 +224,20 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/datasets/load.py:1486: FutureWarning: The repository for samsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/samsum\n",
-      "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
-      "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
-      "  warnings.warn(\n",
-      "Preprocessing dataset: 100%|██████████| 14732/14732 [00:02<00:00, 6124.69it/s]\n"
+      "/home/ubuntu/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: `torch.distributed._shard.checkpoint` will be deprecated, use `torch.distributed.checkpoint` instead\n",
+      "  from torch.distributed._shard.checkpoint import (\n",
+      "Preprocessing dataset: 100%|██████████| 14732/14732 [00:02<00:00, 5872.02it/s]\n"
      ]
     }
    ],
    "source": [
     "from llama_recipes.configs.datasets import samsum_dataset\n",
-    "from llama_recipes.data.concatenator import ConcatDataset\n",
-    "from llama_recipes.utils.config_utils import get_dataloader_kwargs\n",
-    "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n",
-    "\n",
-    "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')\n",
-    "\n",
-    "train_dl_kwargs = get_dataloader_kwargs(train_config, train_dataset, tokenizer, \"train\")\n",
+    "from llama_recipes.utils.dataset_utils import get_dataloader\n",
     "\n",
-    "if train_config.batching_strategy == \"packing\":\n",
-    "        train_dataset = ConcatDataset(train_dataset, chunk_size=train_config.context_length)\n",
+    "samsum_dataset.trust_remote_code = True\n",
     "\n",
-    "# Create DataLoaders for the training and validation dataset\n",
-    "train_dataloader = torch.utils.data.DataLoader(\n",
-    "    train_dataset,\n",
-    "    num_workers=train_config.num_workers_dataloader,\n",
-    "    pin_memory=True,\n",
-    "    **train_dl_kwargs,\n",
-    ")"
+    "train_dataloader = get_dataloader(tokenizer, samsum_dataset, train_config)\n",
+    "eval_dataloader = get_dataloader(tokenizer, samsum_dataset, train_config, \"val\")"
    ]
   },
   {
@@ -310,17 +290,23 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
+      "/home/ubuntu/llama-recipes/src/llama_recipes/utils/train_utils.py:92: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
+      "  scaler = torch.cuda.amp.GradScaler()\n",
+      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.\n",
       "  warnings.warn(\n",
       "Training Epoch: 1:   0%|\u001b[34m          \u001b[0m| 0/319 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
       "To disable this warning, you can either:\n",
       "\t- Avoid using `tokenizers` before the fork if possible\n",
       "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
-      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/utils/checkpoint.py:464: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
-      "  warnings.warn(\n",
+      "/home/ubuntu/llama-recipes/src/llama_recipes/utils/train_utils.py:151: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
+      "  with autocast():\n",
+      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:600: UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.\n",
+      "  return fn(*args, **kwargs)\n",
       "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
       "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
-      "Training Epoch: 1/1, step 1278/1279 completed (loss: 0.27870458364486694): : 320it [2:07:09, 23.84s/it]                      3.94s/it]  \n"
+      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
+      "  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]\n",
+      "Training Epoch: 1/1, step 1278/1279 completed (loss: 0.28094857931137085): : 320it [2:08:50, 24.16s/it]                      4.21s/it]  \n"
      ]
     },
     {
@@ -332,7 +318,7 @@
       "Peak active CUDA memory was 15 GB\n",
       "CUDA Malloc retries : 0\n",
       "CPU Total Peak Memory consumed during the train (max): 2 GB\n",
-      "Epoch 1: train_perplexity=1.3403, train_epoch_loss=0.2929, epoch time 7630.169942979002s\n"
+      "Epoch 1: train_perplexity=1.3404, train_epoch_loss=0.2930, epoch time 7730.981359725998s\n"
      ]
     }
    ],
@@ -354,7 +340,7 @@
     "results = train(\n",
     "    model,\n",
     "    train_dataloader,\n",
-    "    None,\n",
+    "    eval_dataloader,\n",
     "    tokenizer,\n",
     "    optimizer,\n",
     "    scheduler,\n",
@@ -380,16 +366,7 @@
    "cell_type": "code",
    "execution_count": 7,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/home/ubuntu/miniconda3/envs/llama/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
-      "  warnings.warn(\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "model.save_pretrained(train_config.output_dir)"
    ]
@@ -440,13 +417,13 @@
       "A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))\n",
       "---\n",
       "Summary:\n",
-      "A wants to get a puppy for her son. She will take him to the animal shelter tomorrow. B is not sure if he can go with her, but he's willing to.\n"
+      "A wants to get a puppy for his son. A took him to the animal shelter last Monday and he showed A one he really liked. A wants to get him one of those little dogs. A and B agree that raising a dog is a tough issue.\n"
      ]
     }
    ],
    "source": [
     "model.eval()\n",
-    "with torch.no_grad():\n",
+    "with torch.inference_mode():\n",
     "    print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))\n"
    ]
   }
@@ -467,7 +444,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.10.14"
+   "version": "3.11.9"
   },
   "vscode": {
    "interpreter": {

+ 12 - 1
recipes/quickstart/inference/local_inference/README.md

@@ -1,6 +1,17 @@
 # Local Inference
 
+## Multimodal Inference
+For Multi-Modal inference we have added [multi_modal_infer.py](multi_modal_infer.py) which uses the transformers library
+
+The way to run this would be
+```
+python multi_modal_infer.py --image_path "./resources/image.jpg" --prompt_text "Describe this image" --temperature 0.5 --top_p 0.8 --model_name "meta-llama/Llama-3.2-11B-Vision-Instruct"
+```
+
+## Text-only Inference
 For local inference we have provided an [inference script](inference.py). Depending on the type of finetuning performed during training the [inference script](inference.py) takes different arguments.
+
+
 To finetune all model parameters the output dir of the training has to be given as --model_name argument.
 In the case of a parameter efficient method like lora the base model has to be given as --model_name and the output dir of the training has to be given as --peft_model argument.
 Additionally, a prompt for the model in the form of a text file has to be provided. The prompt file can either be piped through standard input or given as --prompt_file parameter.
@@ -87,4 +98,4 @@ python inference.py --model_name <training_config.output_dir> --prompt_file <tes
 
 ## Inference on large models like Meta Llama 405B
 The FP8 quantized variants of Meta Llama (i.e. meta-llama/Meta-Llama-3.1-405B-FP8 and meta-llama/Meta-Llama-3.1-405B-Instruct-FP8) can be executed on a single node with 8x80GB H100 using the scripts located in this folder.
-To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md).
+To run the unquantized Meta Llama 405B variants (i.e. meta-llama/Meta-Llama-3.1-405B and meta-llama/Meta-Llama-3.1-405B-Instruct) we need to use a multi-node setup for inference. The llama-recipes inference script currently does not allow multi-node inference. To run this model you can use vLLM with pipeline and tensor parallelism as showed in [this example](../../../3p_integrations/vllm/README.md).

+ 5 - 1
recipes/quickstart/inference/local_inference/inference.py

@@ -6,7 +6,6 @@ import sys
 import time
 
 import fire
-import gradio as gr
 
 import torch
 
@@ -146,6 +145,11 @@ def main(
         user_prompt = "\n".join(sys.stdin.readlines())
         inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
     else:
+        try:
+            import gradio as gr
+        except ImportError:
+            raise ImportError("This part of the recipe requires gradio. Please run `pip install gradio`")
+            
         gr.Interface(
             fn=inference,
             inputs=[

+ 66 - 0
recipes/quickstart/inference/local_inference/multi_modal_infer.py

@@ -0,0 +1,66 @@
+import os
+import sys
+import argparse
+from PIL import Image as PIL_Image
+import torch
+from transformers import MllamaForConditionalGeneration, MllamaProcessor
+
+
+# Constants
+DEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
+
+
+def load_model_and_processor(model_name: str, hf_token: str):
+    """
+    Load the model and processor based on the 11B or 90B model.
+    """
+    model = MllamaForConditionalGeneration.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, token=hf_token)
+    processor = MllamaProcessor.from_pretrained(model_name, token=hf_token)
+    return model, processor
+
+
+def process_image(image_path: str) -> PIL_Image.Image:
+    """
+    Open and convert an image from the specified path.
+    """
+    if not os.path.exists(image_path):
+        print(f"The image file '{image_path}' does not exist.")
+        sys.exit(1)
+    with open(image_path, "rb") as f:
+        return PIL_Image.open(f).convert("RGB")
+
+
+def generate_text_from_image(model, processor, image, prompt_text: str, temperature: float, top_p: float):
+    """
+    Generate text from an image using the model and processor.
+    """
+    conversation = [
+        {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt_text}]}
+    ]
+    prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+    inputs = processor(image, prompt, return_tensors="pt").to(model.device)
+    output = model.generate(**inputs, temperature=temperature, top_p=top_p, max_new_tokens=512)
+    return processor.decode(output[0])[len(prompt):]
+
+
+def main(image_path: str, prompt_text: str, temperature: float, top_p: float, model_name: str, hf_token: str):
+    """
+    Call all the functions. 
+    """
+    model, processor = load_model_and_processor(model_name, hf_token)
+    image = process_image(image_path)
+    result = generate_text_from_image(model, processor, image, prompt_text, temperature, top_p)
+    print("Generated Text: " + result)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Generate text from an image and prompt using the 3.2 MM Llama model.")
+    parser.add_argument("--image_path", type=str, help="Path to the image file")
+    parser.add_argument("--prompt_text", type=str, help="Prompt text to describe the image")
+    parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation (default: 0.7)")
+    parser.add_argument("--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)")
+    parser.add_argument("--model_name", type=str, default=DEFAULT_MODEL, help=f"Model name (default: '{DEFAULT_MODEL}')")
+    parser.add_argument("--hf_token", type=str, required=True, help="Hugging Face token for authentication")
+
+    args = parser.parse_args()
+    main(args.image_path, args.prompt_text, args.temperature, args.top_p, args.model_name, args.hf_token)

Fichier diff supprimé car celui-ci est trop grand
+ 0 - 384
recipes/responsible_ai/Purple_Llama_Anyscale.ipynb


+ 0 - 289
recipes/responsible_ai/Purple_Llama_OctoAI.ipynb

@@ -1,289 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "LERqQn5v8-ak"
-   },
-   "source": [
-    "# **Purple Llama Using OctoAI**\n",
-    "\n",
-    "Drawing inspiration from the cybersecurity concept of \"purple teaming,\" Purple Llama embraces both offensive (red team) and defensive (blue team) strategies. Our goal is to empower developers in deploying generative AI models responsibly, aligning with best practices outlined in our Responsible Use Guide."
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "PGPSI3M5PGTi"
-   },
-   "source": [
-    "#### **1 - What is Purple Llama?**\n",
-    "\n",
-    "Purple Llama is a an umbrella project that over time will bring together tools and evals to help the community build responsibly with open generative AI models. The initial release will include tools and evals for Cyber Security and Input/Output safeguards but we plan to contribute more in the near future.\n",
-    "\n",
-    "* Instruction tuned on Llama2-7b model\n",
-    "* [CyberSecurity Evals](https://github.com/facebookresearch/PurpleLlama/tree/main/CybersecurityBenchmarks_)\n",
-    "* [Llama Guard Model](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/)\n",
-    "* [Download Llama Guard](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)\n",
-    "* [Purple Llama Website](https://ai.meta.com/llama/purple-llama/)\n",
-    "* [Purple Llama Github Repo](https://github.com/facebookresearch/PurpleLlama)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "aYeHVVh45bdT"
-   },
-   "source": [
-    "#### **2 - Accessing Purple Llama**\n",
-    "* Download + Self Host (i.e. [download Purple Llama](https://ai.meta.com/resources/models-and-libraries/llama-downloads/))\n",
-    "* Hosted API Platform (e.g. [OctoAI](https://octoai.cloud/), [Anyscale](https://www.anyscale.com/), [Together](https://api.together.xyz/playground/chat/togethercomputer/llama-2-7b-chat), [Replicate](https://replicate.com/meta))\n",
-    "* Hosted Container Platform (e.g. [Azure](https://techcommunity.microsoft.com/t5/ai-machine-learning-blog/introducing-llama-2-on-azure/ba-p/3881233), [AWS](https://aws.amazon.com/blogs/machine-learning/llama-2-foundation-models-from-meta-are-now-available-in-amazon-sagemaker-jumpstart/), [GCP](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/139))"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "sd54g0OHuqBY"
-   },
-   "source": [
-    "#### **3 - Using Purple Llama**\n",
-    "\n",
-    "In this notebook, We will use the Llama Guard model managed by the [OctoAI](https://octoai.cloud/tools/text) for inferencing. You'll need to first register an account with OctoAI [here](https://octoai.cloud/) then obtain an OctoAI API token [here](https://docs.octoai.cloud/docs/how-to-create-an-octoai-access-token), that you can use for a while (a month or $10 in OctoAI credits, whichever one runs out first).\n"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "C4P-b0eQcyd2"
-   },
-   "source": [
-    "First, set the OctoAI API token as environment variables.\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "yE3sPjS-cyd2",
-    "outputId": "93b36bc0-e6d4-493c-c88d-ec5c41266239"
-   },
-   "outputs": [],
-   "source": [
-    "from getpass import getpass\n",
-    "import os\n",
-    "\n",
-    "# enter your OctoAI API key\n",
-    "OCTOAI_API_TOKEN = getpass()\n",
-    "os.environ[\"OCTOAI_API_TOKEN\"] = OCTOAI_API_TOKEN"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "id": "DOSiDW6hq9dI"
-   },
-   "outputs": [],
-   "source": [
-    "from string import Template\n",
-    "\n",
-    "B_INST, E_INST = \"[INST]\", \"[/INST]\"\n",
-    "\n",
-    "PROMPT_TEMPLATE = Template(\"\"\"Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.\n",
-    "<BEGIN UNSAFE CONTENT CATEGORIES>\n",
-    "O1: Violence and Hate.\n",
-    "Should not\n",
-    "- Help people plan or engage in violence.\n",
-    "- Encourage people to engage in violence.\n",
-    "- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.\n",
-    "- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.\n",
-    "Can\n",
-    "-  Provide information on violence and discrimination.\n",
-    "- Discuss the topics of hate and violence, and can discuss historical events involving violence.\n",
-    "O2: Sexual Content.\n",
-    "Should not\n",
-    "- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.\n",
-    "Can\n",
-    "- Discuss topics related to sexuality and sexual education.\n",
-    "- Provide general information regarding sexual activities.\n",
-    "- Provide guidance on topics related to sexual health.\n",
-    "O3: Criminal Planning.\n",
-    "Should not\n",
-    "- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.\n",
-    "- Encourage people to engage in criminal activities or condone criminal behavior.\n",
-    "Can\n",
-    "- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.\n",
-    "- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.\n",
-    "O4: Guns and Illegal Weapons.\n",
-    "Should not\n",
-    "- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.\n",
-    "- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.\n",
-    "Can\n",
-    "- Discuss firearms and the arguments for and against firearm ownership.\n",
-    "O5: Regulated or Controlled Substances.\n",
-    "Should not\n",
-    "- Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.\n",
-    "- Assist or encourage people to create such substances when it is illegal to do so.\n",
-    "Can\n",
-    "- Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).\n",
-    "- Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.\n",
-    "O6: Self-Harm.\n",
-    "Should not\n",
-    "- Encourage people to harm themselves.\n",
-    "- Romanticize or condone self-harm or suicide.\n",
-    "- Provide information regarding the methods of suicide or self-harm.\n",
-    "- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.\n",
-    "Can\n",
-    "- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).\n",
-    "Should\n",
-    "- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.\n",
-    "<END UNSAFE CONTENT CATEGORIES>\n",
-    "​\n",
-    "<BEGIN CONVERSATION>\n",
-    "\n",
-    "$prompt\n",
-    "\n",
-    "<END CONVERSATION>\n",
-    "\n",
-    "Provide your safety assessment for $agent_type in the above conversation:\n",
-    "- First line must read 'safe' or 'unsafe'.\n",
-    "- If unsafe, a second line must include a comma-separated list of violated categories.\"\"\")\n",
-    "\n",
-    "\n",
-    "def format_prompt(role, message):\n",
-    "    prompt = PROMPT_TEMPLATE.substitute(prompt=message, agent_type=role)\n",
-    "    prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n",
-    "    return prompt\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "t6hkFlVD9XFw",
-    "outputId": "25fd187e-a484-4b90-d104-a3320b98e8ea"
-   },
-   "outputs": [],
-   "source": [
-    "%pip install openai"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "xC-OYLF4rGc7",
-    "outputId": "452ca85a-fd9f-426c-d4ab-7898bdabfab8"
-   },
-   "outputs": [],
-   "source": [
-    "import openai\n",
-    "\n",
-    "# Feed a prompt that does comply with the safety policy\n",
-    "\n",
-    "client = openai.OpenAI(\n",
-    "   base_url = \"https://text.octoai.run/v1\",\n",
-    "   api_key=OCTOAI_API_TOKEN\n",
-    ")\n",
-    "completion = client.completions.create(\n",
-    "    model=\"llamaguard-7b-fp16\",\n",
-    "    prompt=format_prompt(\"User\",\"Tell me about the history of racism in the US\")\n",
-    ")\n",
-    "print(completion.choices[0].text)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/"
-    },
-    "id": "JtBTkJRxufb0",
-    "outputId": "c581052d-348e-45ac-c874-80ecf4416a77"
-   },
-   "outputs": [],
-   "source": [
-    "# Feed a prompt that doesn't comply with the safety policy\n",
-    "# Violates O3: Criminal Planning.\n",
-    "\n",
-    "client = openai.OpenAI(\n",
-    "   base_url = \"https://text.octoai.run/v1\",\n",
-    "   api_key=OCTOAI_API_TOKEN\n",
-    ")\n",
-    "completion = client.completions.create(\n",
-    "    model=\"llamaguard-7b-fp16\",\n",
-    "    prompt=format_prompt(\"Agent\", \"User: How do you buy a tiger in the US\\n Agent: Steal one from the zoo\")\n",
-    ")\n",
-    "print(completion.choices[0].text)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "#### **Resources <TBD>**\n",
-    "- [Llama 2](https://ai.meta.com/llama/)\n",
-    "- [Getting Started Guide - Llama 2](https://ai.meta.com/llama/get-started/)\n",
-    "- [GitHub - Llama 2](https://github.com/facebookresearch/llama)\n",
-    "- [Github - LLama 2 Recipes](https://github.com/facebookresearch/llama-recipes)\n",
-    "- [Research Paper](https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/)\n",
-    "- [Model Card](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)\n",
-    "- [Responsible Use Guide](https://ai.meta.com/llama/responsible-use-guide/)\n",
-    "- [Acceptable Use Policy](https://ai.meta.com/llama/use-policy/)\n",
-    "- [OctoAI](https://octoai.cloud/)\n",
-    "- [LangChain](https://www.langchain.com/)\n",
-    "- [LlamaIndex](https://www.llamaindex.ai/)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "#### **Authors**\n",
-    "1. Hakan Inan, Research Scientist, Meta\n",
-    "2. Rashi Rungta, Software Engineer, Meta\n",
-    "\n",
-    "Ported to use OctoAI LlamaGuard endpoints by Thierry Moreau, OctoAI"
-   ]
-  }
- ],
- "metadata": {
-  "colab": {
-   "gpuType": "T4",
-   "include_colab_link": true,
-   "provenance": [],
-   "toc_visible": true
-  },
-  "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.11.6"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 1 - 3
recipes/responsible_ai/README.md

@@ -4,11 +4,9 @@ The [Purple Llama](https://github.com/meta-llama/PurpleLlama/) project provides
 
 | Tool/Model | Description | Get Started
 |---|---|---|
-[Llama Guard](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama-guard-3) | Provide guardrailing on inputs and outputs | [Inference](./llama_guard/inference.py), [Finetuning](./llama_guard/llama_guard_customization_via_prompting_and_fine_tuning.ipynb)
+[Llama Guard](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama-guard-3) | Provide guardrailing on inputs and outputs | [Inference](./llama_guard/llama_guard_text_and_vision_inference.ipynb), [Finetuning](./llama_guard/llama_guard_customization_via_prompting_and_fine_tuning.ipynb)
 [Prompt Guard](https://llama.meta.com/docs/model-cards-and-prompt-formats/prompt-guard) | Model to safeguards against jailbreak attempts and embedded prompt injections | [Notebook](./prompt_guard/prompt_guard_tutorial.ipynb)
 [Code Shield](https://github.com/meta-llama/PurpleLlama/tree/main/CodeShield) | Tool to safeguard against insecure code generated by the LLM | [Notebook](https://github.com/meta-llama/PurpleLlama/blob/main/CodeShield/notebook/CodeShieldUsageDemo.ipynb)
 
 
 
-### Running on hosted APIs
-The notebooks [input_output_guardrails.ipynb](./input_output_guardrails_with_llama.ipynb),  [Purple_Llama_Anyscale](Purple_Llama_Anyscale.ipynb) & [Purple_Llama_OctoAI](Purple_Llama_OctoAI.ipynb) contain examples for running Meta Llama Guard on cloud hosted endpoints.

+ 0 - 268
recipes/responsible_ai/input_output_guardrails_with_llama.ipynb

@@ -1,268 +0,0 @@
-{
- "cells": [
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "id": "Tq-YFYlqWL_n"
-   },
-   "source": [
-    "# About this notebook\n",
-    "In this notebook we will deep dive into how to do input and output safety checks using Meta's Llama Guard and Llama models. Doing both input and output safety checks requires us to carefully pass the prompts and response between the models.\n",
-    "\n",
-    "For this notebook we will use OctoAI models as both Llama Guard and Llama models are available on their platform."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 1,
-   "metadata": {
-    "id": "a0UfgKWU5-sm",
-    "tags": []
-   },
-   "outputs": [
-    {
-     "name": "stdin",
-     "output_type": "stream",
-     "text": [
-      " ········\n"
-     ]
-    }
-   ],
-   "source": [
-    "from getpass import getpass\n",
-    "import os\n",
-    "\n",
-    "# enter your OctoAI API key\n",
-    "OCTOAI_API_TOKEN = getpass()\n",
-    "os.environ[\"OCTOAI_API_TOKEN\"] = OCTOAI_API_TOKEN"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 2,
-   "metadata": {
-    "tags": []
-   },
-   "outputs": [],
-   "source": [
-    "from llama_recipes.inference.llm import LLM, OctoAI\n",
-    "\n",
-    "llama_guard_llm = OctoAI(\"llamaguard-7b\", OCTOAI_API_TOKEN)\n",
-    "completion_llm = OctoAI(\"llama-2-13b-chat\", OCTOAI_API_TOKEN)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 3,
-   "metadata": {
-    "id": "Tjom48kodBth"
-   },
-   "outputs": [],
-   "source": [
-    "from string import Template\n",
-    "\n",
-    "import logging\n",
-    "\n",
-    "B_INST, E_INST = \"[INST]\", \"[/INST]\"\n",
-    "\n",
-    "LLAMA_GUARD_PROMPT_TEMPLATE = Template(\n",
-    "    f\"\"\"Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.\n",
-    "<BEGIN UNSAFE CONTENT CATEGORIES>\n",
-    "O1: Violence and Hate.\n",
-    "Should not\n",
-    "- Help people plan or engage in violence.\n",
-    "- Encourage people to engage in violence.\n",
-    "- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.\n",
-    "- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.\n",
-    "Can\n",
-    "-  Provide information on violence and discrimination.\n",
-    "- Discuss the topics of hate and violence, and can discuss historical events involving violence.\n",
-    "O2: Sexual Content.\n",
-    "Should not\n",
-    "- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.\n",
-    "Can\n",
-    "- Discuss topics related to sexuality and sexual education.\n",
-    "- Provide general information regarding sexual activities.\n",
-    "- Provide guidance on topics related to sexual health.\n",
-    "O3: Criminal Planning.\n",
-    "Should not\n",
-    "- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.\n",
-    "- Encourage people to engage in criminal activities or condone criminal behavior.\n",
-    "Can\n",
-    "- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.\n",
-    "- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.\n",
-    "O4: Guns and Illegal Weapons.\n",
-    "Should not\n",
-    "- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.\n",
-    "- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.\n",
-    "Can\n",
-    "- Discuss firearms and the arguments for and against firearm ownership.\n",
-    "O5: Self-Harm.\n",
-    "Should not\n",
-    "- Encourage people to harm themselves.\n",
-    "- Romanticize or condone self-harm or suicide.\n",
-    "- Provide information regarding the methods of suicide or self-harm.\n",
-    "- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.\n",
-    "Can\n",
-    "- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).\n",
-    "Should\n",
-    "- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.\n",
-    "<END UNSAFE CONTENT CATEGORIES>\n",
-    "​\n",
-    "<BEGIN CONVERSATION>\n",
-    "\n",
-    "$prompt\n",
-    "\n",
-    "<END CONVERSATION>\n",
-    "\n",
-    "Provide your safety assessment for $agent_type in the above conversation:\n",
-    "- First line must read 'safe' or 'unsafe'.\n",
-    "- If unsafe, a second line must include a comma-separated list of violated categories.\"\"\"\n",
-    ")\n",
-    "\n",
-    "LLAMA_PROMPT_TEMPLATE = Template(\n",
-    "    f\"\"\"<<SYS>>\n",
-    "You are a helpful agent, answering questions of the user in a short and consise manner.\n",
-    "<</SYS>>\n",
-    "\n",
-    "{{ $user_message }}\"\"\"\n",
-    ")\n",
-    "\n",
-    "def llama_format_prompt(message):\n",
-    "    prompt = LLAMA_PROMPT_TEMPLATE.substitute(user_message=message)\n",
-    "    prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n",
-    "    return prompt\n",
-    "\n",
-    "def llama_guard_format_prompt(role, message):\n",
-    "    prompt = LLAMA_GUARD_PROMPT_TEMPLATE.substitute(prompt=message, agent_type=role)\n",
-    "    prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n",
-    "    return prompt\n",
-    "\n",
-    "def get_category_description(category):\n",
-    "    CATEGORY_MAP = {\n",
-    "      '01': 'Violence and Hate',\n",
-    "      '02': 'Sexual Content',\n",
-    "      '03': 'Criminal Planning',\n",
-    "      '04': 'Guns and Illegal Weapons',\n",
-    "      '05': 'Self-Harm'\n",
-    "    }\n",
-    "\n",
-    "    if not isinstance(category, str) or category is None:\n",
-    "        raise ValueError(\"Invalid input, expected a non-empty string\")\n",
-    "\n",
-    "    if category not in CATEGORY_MAP:\n",
-    "        raise ValueError(f\"Invalid input, unknown category: {category}\")\n",
-    "\n",
-    "    return CATEGORY_MAP[category]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 4,
-   "metadata": {
-    "id": "KwONCw1oSroO"
-   },
-   "outputs": [],
-   "source": [
-    "def parse_llamaguard_response(message):\n",
-    "    if not isinstance(message, str) or message is None:\n",
-    "        return (False, \"Invalid input, expected a non-empty string\")\n",
-    "    message = message.strip()\n",
-    "\n",
-    "    if message == \"\":\n",
-    "        return (False, \"Invalid input, message cannot be empty\")\n",
-    "\n",
-    "    tokens = message.split(\"\\n\")\n",
-    "    if tokens[0] == 'safe':\n",
-    "        return (True, \"\")\n",
-    "    else:\n",
-    "        return (False, tokens[1])\n",
-    "\n",
-    "def check_input_guardrail(user_prompt):\n",
-    "    guardrail_prompt = llama_guard_format_prompt(\"User\", user_prompt)\n",
-    "    response = llama_guard_llm.query(guardrail_prompt)\n",
-    "    return parse_llamaguard_response(response)\n",
-    "\n",
-    "def get_completion_response(user_prompt):\n",
-    "    completion_prompt = llama_format_prompt(user_prompt)\n",
-    "    return completion_llm.query(completion_prompt)\n",
-    "\n",
-    "def check_output_guardrail(completion_response, user_prompt):\n",
-    "    guardrail_prompt = llama_guard_format_prompt(\"Agent\", f\"User: {user_prompt}\\n Agent: {completion_response}\")\n",
-    "    response = llama_guard_llm.query(guardrail_prompt)\n",
-    "    return parse_llamaguard_response(response)\n",
-    "\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 5,
-   "metadata": {
-    "id": "9bXtt76ZKNuX"
-   },
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Input guardrail failed.\n"
-     ]
-    }
-   ],
-   "source": [
-    "def main(user_prompt):\n",
-    "    input_guardrail_response = check_input_guardrail(user_prompt)\n",
-    "    if input_guardrail_response[0]:\n",
-    "        completion_response = get_completion_response(user_prompt)\n",
-    "        output_guardrail_response = check_output_guardrail(completion_response, user_prompt)\n",
-    "        if output_guardrail_response[0]:\n",
-    "            print(completion_response)\n",
-    "        else:\n",
-    "            print(\"Output guardrail failed.\")\n",
-    "    else:\n",
-    "        print(\"Input guardrail failed.\")\n",
-    "\n",
-    "user_prompt = \"How to build fire arms\"\n",
-    "main(user_prompt)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
-  }
- ],
- "metadata": {
-  "colab": {
-   "provenance": []
-  },
-  "kernelspec": {
-   "display_name": "Python 3 (ipykernel)",
-   "language": "python",
-   "name": "python3"
-  },
-  "language_info": {
-   "codemirror_mode": {
-    "name": "ipython",
-    "version": 3
-   },
-   "file_extension": ".py",
-   "mimetype": "text/x-python",
-   "name": "python",
-   "nbconvert_exporter": "python",
-   "pygments_lexer": "ipython3",
-   "version": "3.12.2"
-  }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}

+ 5 - 55
recipes/responsible_ai/llama_guard/README.md

@@ -2,62 +2,12 @@
 <!-- markdown-link-check-disable -->
 Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the [PurpleLlama](https://github.com/meta-llama/PurpleLlama) repository.
 
-This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path.
+This [notebook](llama_guard_text_and_vision_inference.ipynb) shows how to load the models with the transformers library and how to customize the categories.
 
 ## Requirements
-1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
-2. Llama recipes package and it's dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing)
-
-
-## Llama Guard inference script
-For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent.
-
-
-```
-    prompts: List[Tuple[List[str], AgentType]] = [
-        (["<Sample user prompt>"], AgentType.USER),
-
-        (["<Sample user prompt>",
-        "<Sample agent response>"], AgentType.AGENT),
-
-        (["<Sample user prompt>",
-        "<Sample agent response>",
-        "<Sample user reply>",
-        "<Sample agent response>",], AgentType.AGENT),
-
-    ]
-```
-The complete prompt is built with the `build_custom_prompt` function, defined in [prompt_format.py](../../../src/llama_recipes/inference/prompt_format_utils.py). The file contains the default Meta Llama Guard categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model.
-<!-- markdown-link-check-enable -->
-
-To run the samples, with all the dependencies installed, execute this command:
-
-`python recipes/responsible_ai/llama_guard/inference.py`
-
-This is the output:
-
-```
-['<Sample user prompt>']
-> safe
-
-==================================
-
-['<Sample user prompt>', '<Sample agent response>']
-> safe
-
-==================================
-
-['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']
-> safe
-
-==================================
-```
-
-To run it with a local model, you can use the `model_id` param in the inference script:
-
-`python recipes/responsible_ai/llama_guard/inference.py --model_id=/home/ubuntu/models/llama3/Llama-Guard-3-8B/ --llama_guard_version=LLAMA_GUARD_3`
-
-Note: Make sure to also add the llama_guard_version; by default it uses LLAMA_GUARD_3
+1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described in the top of the model card in [Hugging Face](https://huggingface.co/meta-llama/Llama-Guard-3-1B)
+2. Llama recipes package and its dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing)
+3. Pillow package installed
 
 ## Inference Safety Checker
 When running the regular inference script with prompts, Meta Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Meta Llama Guard is always loaded quantized using Hugging Face Transformers library with bitsandbytes.
@@ -66,7 +16,7 @@ In this case, the default categories are applied by the tokenizer, using the `ap
 
 Use this command for testing with a quantized Llama model, modifying the values accordingly:
 
-`python examples/inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization 8bit --enable_llamaguard_content_safety`
+`python inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --enable_llamaguard_content_safety`
 
 ## Llama Guard 3 Finetuning & Customization
 The safety categories in Llama Guard 3 can be tuned for specific application needs. Existing categories can be removed and new categories can be added to the taxonomy. The [Llama Guard Customization](./llama_guard_customization_via_prompting_and_fine_tuning.ipynb) notebook walks through the process.

+ 0 - 75
recipes/responsible_ai/llama_guard/inference.py

@@ -1,75 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-import fire
-from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
-
-
-from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
-from typing import List, Tuple
-from enum import Enum
-
-class AgentType(Enum):
-    AGENT = "Agent"
-    USER = "User"
-
-def main(
-    model_id: str = "meta-llama/Llama-Guard-3-8B",
-    llama_guard_version: str = "LLAMA_GUARD_3"
-):
-    """
-    Entry point for Llama Guard inference sample script.
-
-    This function loads Llama Guard from Hugging Face or a local model and 
-    executes the predefined prompts in the script to showcase how to do inference with Llama Guard.
-
-    Args:
-        model_id (str): The ID of the pretrained model to use for generation. This can be either the path to a local folder containing the model files,
-            or the repository ID of a model hosted on the Hugging Face Hub. Defaults to 'meta-llama/LlamaGuard-7b'.
-        llama_guard_version (LlamaGuardVersion): The version of the Llama Guard model to use for formatting prompts. Defaults to LLAMA_GUARD_1.
-    """
-    try:
-        llama_guard_version = LlamaGuardVersion[llama_guard_version]
-    except KeyError as e:
-        raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e
-
-    prompts: List[Tuple[List[str], AgentType]] = [
-        (["<Sample user prompt>"], AgentType.USER),
-
-        (["<Sample user prompt>",
-        "<Sample agent response>"], AgentType.AGENT),
-        
-        (["<Sample user prompt>",
-        "<Sample agent response>",
-        "<Sample user reply>",
-        "<Sample agent response>",], AgentType.AGENT),
-
-    ]
-
-    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
-
-    tokenizer = AutoTokenizer.from_pretrained(model_id)
-    model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
-    
-    for prompt in prompts:
-        formatted_prompt = build_default_prompt(
-                prompt[1], 
-                create_conversation(prompt[0]),
-                llama_guard_version)
-
-
-        input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
-        prompt_len = input["input_ids"].shape[-1]
-        output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
-        results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
-       
-        
-        print(prompt[0])
-        print(f"> {results}")
-        print("\n==================================\n")
-
-if __name__ == "__main__":
-    try:
-        fire.Fire(main)
-    except Exception as e:
-        print(e)

Fichier diff supprimé car celui-ci est trop grand
+ 576 - 0
recipes/responsible_ai/llama_guard/llama_guard_text_and_vision_inference.ipynb


BIN
recipes/responsible_ai/llama_guard/resources/dog.jpg


BIN
recipes/responsible_ai/llama_guard/resources/pasta.jpeg


+ 3 - 0
recipes/use_cases/README.md

@@ -1,3 +1,6 @@
+## [Automatic Triaging of Github Repositories](./github_triage/walkthrough.ipynb): Use Llama to automatically triage issues in an OSS repository and generate insights to improve community experience
+This tool utilizes an off-the-shelf Llama model to analyze, generate insights, and create a report for better understanding of the state of a repository. It serves as a reference implementation for using Llama to develop custom reporting and data analytics applications.
+
 ## [VideoSummary](video_summary.ipynb): Ask Llama 3 to Summarize a Long YouTube Video (using Replicate or [OctoAI](../3p_integrations/octoai/video_summary.ipynb))
 This demo app uses Llama 3 to return a text summary of a YouTube video. It shows how to retrieve the caption of a YouTube video and how to ask Llama to summarize the content in different ways, from the simplest naive way that works for short text to more advanced methods of using LangChain's map_reduce and refine to overcome the 8K context length limit of Llama 3.
 

+ 60 - 0
recipes/use_cases/github_triage/README.md

@@ -0,0 +1,60 @@
+# Automatic Issues Triaging with Llama
+
+This tool utilizes an off-the-shelf Llama model to analyze, generate insights, and create a report for better understanding of the state of a repository. It serves as a reference implementation for using Llama to develop custom reporting and data analytics applications.
+
+## Features
+
+The tool performs the following tasks:
+
+* Fetches issue threads from a specified repository
+* Analyzes issue discussions and generates annotations such as category, severity, component affected, etc.
+* Categorizes all issues by theme
+* Synthesizes key challenges faced by users, along with probable causes and remediations
+* Generates a high-level executive summary providing insights on diagnosing and improving the developer experience
+
+For a step-by-step look, check out the [walkthrough notebook](walkthrough.ipynb).
+
+## Getting Started
+
+
+### Installation
+
+```bash
+pip install -r requirements.txt
+```
+
+### Setup
+
+1. **API Keys and Model Service**: Set your GitHub token for API calls. Some privileged information may not be available if you don't have push-access to the target repository.
+2. **Model Configuration**: Set the appropriate values in the `model` section of [config.yaml](config.yaml) for using Llama via VLLM or Groq.
+3. **JSON Schemas**: Edit the output JSON schemas in [config.yaml](config.yaml) to ensure consistency in outputs. VLLM supports JSON-decoding via the `guided_json` generation argument, while Groq requires passing the schema in the system prompt.
+
+### Running the Tool
+
+```bash
+python triage.py --repo_name='meta-llama/llama-recipes' --start_date='2024-08-14' --end_date='2024-08-27'
+```
+
+### Output
+
+The tool generates:
+
+* CSV files with `annotations`, `challenges`, and `overview` data, which can be persisted in SQL tables for downstream analyses and reporting.
+* Graphical matplotlib plots of repository traffic, maintenance activity, and issue attributes.
+* A PDF report for easier reading and sharing.
+
+## Config
+
+The tool's configuration is stored in [config.yaml](config.yaml). The following sections can be edited:
+
+* **Github Token**: Use a token that has push-access on the target repo.
+* **model**: Specify the model service (`vllm` or `groq`) and set the endpoints and API keys as applicable.
+* **prompts**: For each of the 3 tasks Llama does in this tool, we specify a prompt and an output JSON schema:
+  * `parse_issue`: Parsing and generating annotations for the issues 
+  * `assign_category`: Assigns each issue to a category specified in an enum in the corresponding JSON schema
+  * `get_overview`: Generates a high-level executive summary and analysis of all the parsed and generated data
+
+## Troubleshooting
+
+* If you encounter issues with API calls, ensure that your GitHub token is set correctly and that you have the necessary permissions.
+* If you encounter issues with the model service, check the configuration values in [config.yaml](config.yaml).

Fichier diff supprimé car celui-ci est trop grand
+ 144 - 0
recipes/use_cases/github_triage/config.yaml


+ 165 - 0
recipes/use_cases/github_triage/llm.py

@@ -0,0 +1,165 @@
+import logging 
+from typing import Any, Dict, List, Optional, Union
+import yaml
+import time
+import json
+
+from tqdm import tqdm
+from openai import OpenAI
+import groq
+
+logger = logging.getLogger(__name__)
+logger.addHandler(logging.StreamHandler())
+CFG = yaml.safe_load(open("config.yaml", "r"))
+
+class LlamaVLLM():
+    def __init__(self, endpoint, model_id):
+        self.model_id = model_id
+        self.client = OpenAI(base_url=endpoint, api_key='token')
+
+    def chat(
+        self,
+        inputs: List[Dict[str, str]],
+        generation_kwargs: Optional[Dict[str, Any]] = None,
+        guided_decode_json_schema: Optional[str] = None
+    ) -> List[str]:
+
+        if generation_kwargs is None:
+            generation_kwargs = {}
+            
+        try:
+            response = self.client.chat.completions.create(
+                model=self.model,
+                messages=inputs,
+                extra_body={
+                    "guided_json": guided_decode_json_schema
+                },
+                **generation_kwargs,
+            )
+            output = response.choices[0].message
+        except Exception as e:
+            logger.error(
+                f"FAILED to generate inference for input {inputs}\nError: {str(e)}"
+            )
+            output = None
+        return output
+    
+
+class LlamaGroq():
+    def __init__(self, key, model_id):
+        self.model_id = model_id
+        self.client = groq.Groq(api_key=key)
+        logger.debug(f"Using Groq:{self.model_id} for inference")
+
+    def chat(
+        self, 
+        inputs: List[Dict[str, str]], 
+        generation_kwargs: Optional[Dict[str, Any]] = None,
+        guided_decode_json_schema: Optional[str] = None
+    ) -> str:
+        
+        if generation_kwargs is None:
+            generation_kwargs = {}
+            
+        # Currently Groq doesn't support guided JSON decoding. Workaround:
+        if guided_decode_json_schema is not None:
+            inputs[0]['content'] += f"\n\nEnsure your response aligns with the following JSON schema:\n{guided_decode_json_schema}\n\n"
+        
+        output = None
+        
+        while True:
+            try:
+                response = self.client.chat.completions.with_raw_response.create(
+                    model=self.model_id,
+                    messages=inputs,
+                    stream=False,
+                    **generation_kwargs,
+                    response_format={"type": 'json_object' if guided_decode_json_schema is not None else 'text'}
+                )
+                completion = response.parse()
+                output = completion.choices[0].message.content
+                break
+            except groq.RateLimitError as e:
+                wait = e.response.headers['X-Ratelimit-Reset']
+                response = e.response
+                print(e)
+                print(f"[groq] waiting for {wait} to prevent ratelimiting")
+                time.sleep(wait)
+            except Exception as e:
+                logger.error(f"INFERENCE FAILED with Error: {e.response.status_code} for input:\n{inputs[-1]['content'][:300]}")
+                break
+
+        return output
+
+
+def run_llm_inference(
+    prompt_name: str,
+    inputs: Union[str, List[str]],
+    generation_kwargs: Optional[Dict] = None,
+    guided_decode_json_schema=None,
+) -> Union[List[str], List[Dict[str, Any]]]:
+    """
+    Run the LLM inference on the given inputs.
+
+    Args:
+    - prompt_name (str): The name of the prompt to use.
+    - inputs (str or List[str]): The input(s) to the LLM.
+    - generation_kwargs (Dict): Additional keyword arguments to pass to the LLM.
+    - guided_decode_json_schema (str): The JSON schema to use for guided decoding.
+
+    Returns:
+    - Union[str, List[str]]: The response(s) from the LLM.
+    """
+    
+    # initialize appropriate LLM accessor
+    if CFG['model']['use'] == 'vllm':
+        LLM = LlamaVLLM(**CFG['model']['vllm'])
+    elif CFG['model']['use'] == 'groq':
+        LLM = LlamaGroq(**CFG['model']['groq'])
+    else:
+        raise ValueError("Invalid model type in config.yaml")
+    
+    logger.debug(f"Running `{prompt_name}` inference with {CFG['model']['use']}")
+    
+    _batch = True
+    if isinstance(inputs, str):
+        _batch = False
+        inputs = [inputs]
+
+    inputs = [
+        [
+            {"role": "system", "content": CFG["prompts"][prompt_name]["system"]},
+            {"role": "user", "content": i},
+        ]
+        for i in inputs
+    ]
+
+    if (
+        guided_decode_json_schema is None
+        and "json_schema" in CFG["prompts"][prompt_name]
+    ):
+        guided_decode_json_schema = " ".join(
+            CFG["prompts"][prompt_name]["json_schema"].split()
+        )
+
+    responses = [
+        LLM.chat(i, generation_kwargs, guided_decode_json_schema) 
+        for i in tqdm(inputs, desc=f"Inference[{prompt_name}]")
+    ]
+
+    if guided_decode_json_schema is not None:
+        responses_json = []
+        for r in responses:
+            if r is not None:
+                try:
+                    responses_json.append(json.loads(r, strict=False))
+                    continue
+                except json.JSONDecodeError:
+                    logger.error(f"Error decoding JSON: {r}")
+            responses_json.append(None)
+        responses = responses_json
+
+    if not _batch:
+        responses = responses[0]
+
+    return responses

Fichier diff supprimé car celui-ci est trop grand
+ 1846 - 0
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/annotated_issues.csv


Fichier diff supprimé car celui-ci est trop grand
+ 6 - 0
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/challenges.csv


Fichier diff supprimé car celui-ci est trop grand
+ 2 - 0
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/overview.csv


BIN
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/commits.png


BIN
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/engagement_sankey.png


BIN
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/expertise.png


BIN
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/sentiment.png


BIN
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/severity.png


BIN
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/plots/themes.png


BIN
recipes/use_cases/github_triage/output/pytorch/pytorch/2024-08-28_2024-08-28/report.pdf


+ 141 - 0
recipes/use_cases/github_triage/pdf_report.py

@@ -0,0 +1,141 @@
+from fpdf import FPDF
+import os
+from datetime import datetime
+import logging
+
+logger = logging.getLogger(__name__)
+logger.addHandler(logging.StreamHandler())
+
+class ReportPDF(FPDF):
+    def __init__(self, repository_name, start_date, end_date):
+        FPDF.__init__(self,orientation='P',unit='mm',format='A4')
+        self.repo = repository_name
+        self.start_end = f"{datetime.strptime(start_date, '%Y-%m-%d').strftime('%b %d, %Y')} to {datetime.strptime(end_date, '%Y-%m-%d').strftime('%b %d, %Y')}"
+        
+    def header(self):
+        self.set_font('Arial', 'B', 12)
+        self.cell(100, 10, f'AutoTriage Report: {self.repo}', 0, 0)
+        self.cell(90, 10, self.start_end, 0, 0, 'R')
+        self.ln(20)
+
+    def footer(self):
+        self.set_y(-15)
+        self.set_font('Arial', 'I', 8)
+        self.cell(0, 10, f'Page {self.page_no()}', 0, 0, 'C')
+        
+    def exec_summary(self, text):
+        self.set_font('Arial', 'B', 16)
+        self.cell(0, 8, 'Executive Summary', 'B', 0, 'L')
+        self.ln(10)
+        self.set_font('Arial', '', 10)
+        self.multi_cell(0, 5, text)
+        self.ln(10)
+    
+    def add_challenge(self, challenge_data):
+        # title
+        self.set_font('Arial', '', 14)
+        self.cell(0, 10, f"{challenge_data['key_challenge']}", 0, 0, 'L')
+        self.ln(8)
+        
+        # psosible causes
+        self.set_font('Arial', 'B', 10)
+        self.cell(0, 10, "Possible Causes", 0, 0, 'L')
+        self.ln(5)
+        self.set_font('Arial', '', 10)
+
+        x_list = challenge_data['possible_causes']
+        if isinstance(x_list, str):
+            x_list = x_list.split(',')
+
+        for x in x_list:
+            self.cell(0, 10, "* " + x, 0, 0, 'L')
+            self.ln(5)
+        self.ln(3)
+            
+        # remediations
+        self.set_font('Arial', 'B', 10)
+        self.cell(0, 10, "Remediations", 0, 0, 'L')
+        self.ln(5)
+        self.set_font('Arial', '', 10)
+
+        x_list = challenge_data['remediations']
+        if isinstance(x_list, str):
+            x_list = x_list.split(',')
+
+        for x in x_list:
+            self.cell(0, 10, "* " + x, 0, 0, 'L')
+            self.ln(5)
+        self.ln(3)
+        
+        # affected issues
+        self.set_font('Arial', 'B', 10)
+        self.cell(30, 10, f"Affected issues: ", 0, 0, 'L')
+        
+        x_list = challenge_data['affected_issues']
+        if isinstance(x_list, str):
+            x_list = x_list.split(',')
+            
+        for iss in x_list:
+            self.set_text_color(0,0,255)
+            self.cell(12, 10, str(iss), 0, 0, 'L', link=f"https://github.com/{self.repo}/issues/{iss}")
+            
+        self.set_text_color(0,0,0)
+        self.ln(15)
+
+    def challenges_section(self, key_challenges_data):
+        self.set_font('Arial', 'B', 16)
+        self.cell(0, 8, 'Key Challenges', 'B', 0, 'L')
+        self.ln(10)
+        for cd in key_challenges_data:
+            self.add_challenge(cd)
+        self.ln(20)
+    
+    def open_ques_section(self, open_questions):
+        self.set_font('Arial', 'B', 16)
+        self.cell(0, 8, 'Open Questions', 'B', 0, 'L')
+        self.ln(10)
+        self.set_font('Arial', '', 10)
+
+        if isinstance(open_questions, str):
+            open_questions = open_questions.split(',')
+                    
+        for qq in open_questions:
+            self.multi_cell(0, 5, "* " + qq, 0, 0, 'L')
+            self.ln(5)
+        self.ln(5)
+    
+    def add_graphs_section(self, title, plot_paths):
+        self.set_font('Arial', 'B', 16)
+        self.cell(0, 8, f'[Viz] {title}', 'B', 0, 'L')
+        self.ln(10)
+        for path in plot_paths:
+            if os.path.exists(path):
+                self.add_plot(path)
+            else:
+                self.set_font('Arial', 'BI', 10)
+                self.cell(0, 8, '< Plot not found, make sure you have push-acces to this repo >', 0, 0)
+        self.ln(10)
+            
+    def add_plot(self, img):
+        self.image(img, x=30, w=150)
+        self.ln(5)
+        
+    
+    
+def create_report_pdf(repo_name, start_date, end_date, key_challenges_data, executive_summary, open_questions, out_folder):#, image1, image2):
+    out_path = f'{out_folder}/report.pdf'
+    logger.info(f"Creating PDF report at {out_path}")
+    
+    pdf = ReportPDF(repo_name, start_date, end_date)
+    pdf.add_page()
+    pdf.exec_summary(executive_summary)
+    pdf.open_ques_section(open_questions)
+    pdf.challenges_section(key_challenges_data)
+    pdf.add_page()
+    pdf.add_graphs_section("Repo Maintenance", [f'{out_folder}/plots/engagement_sankey.png'])
+    pdf.add_page()
+    pdf.add_graphs_section("Traffic in the last 2 weeks", [f'{out_folder}/plots/{x}.png' for x in ['views_clones','resources', 'referrers']])
+    pdf.add_page()
+    pdf.add_graphs_section("New issues in the last 2 weeks", [f'{out_folder}/plots/{x}.png' for x in ['themes', 'severity', 'sentiment', 'expertise']])
+    pdf.output(out_path, 'F')
+

+ 178 - 0
recipes/use_cases/github_triage/plots.py

@@ -0,0 +1,178 @@
+import matplotlib.pyplot as plt
+import pandas as pd
+import plotly.graph_objects as go
+from utils import fetch_github_endpoint
+import logging
+
+logger = logging.getLogger(__name__)
+logger.addHandler(logging.StreamHandler())
+
+def plot_views_clones(repo_name, out_folder):
+    def json_to_df(json_data, key):
+        df = pd.DataFrame(json_data[key])
+        df['timestamp'] = df['timestamp'].apply(lambda x: x[5:10])
+        if key in ['clones', 'views']:
+            df.rename(columns={'uniques': key}, inplace=True)
+            df.drop(columns=['count'], inplace=True)
+        return df
+
+    unique_clones_2w = fetch_github_endpoint(f"https://api.github.com/repos/{repo_name}/traffic/clones").json()
+    unique_views_2w = fetch_github_endpoint(f"https://api.github.com/repos/{repo_name}/traffic/views").json()
+
+    df1 = json_to_df(unique_clones_2w, 'clones')
+    df2 = json_to_df(unique_views_2w, 'views')
+
+    df = df1.merge(df2, on='timestamp', how='inner')
+
+    fig, ax1 = plt.subplots(figsize=(10, 6))
+    ax1.plot(df['timestamp'], df['views'], color='blue')
+    ax1.set_xlabel('Day', fontsize=18)
+    ax1.set_ylabel('Unique Views', color='blue', fontsize=18)
+    ax1.tick_params(axis='y', labelcolor='blue')
+
+    ax2 = ax1.twinx()
+    ax2.bar(df['timestamp'], df['clones'], color='red')
+    ax2.set_ylabel('Unique Clones', color='red', fontsize=18)
+    ax2.tick_params(axis='y', labelcolor='red')
+
+    plt.title('Views & Clones in the last 2 weeks', fontsize=24)
+    plt.savefig(f'{out_folder}/views_clones.png', dpi=120)  
+    plt.close()
+
+def plot_high_traffic_resources(repo_name, out_folder):
+    popular_paths_2w = fetch_github_endpoint(f"https://api.github.com/repos/{repo_name}/traffic/popular/paths").json()
+    df = pd.DataFrame(popular_paths_2w)
+    df['path'] = df['path'].apply(lambda x: '/'.join(x.split('/')[-2:]))
+    df = df.sort_values(by='uniques', ascending=False).head(10)
+
+    plt.figure(figsize=(10, 6))
+    plt.barh(df['path'], df['uniques'])
+    plt.xlabel('Unique traffic in the last 2 weeks', fontsize=18)
+    # plt.ylabel('Resource', fontsize=18, labelpad=15)
+    plt.title("Popular Resources on the Repository", fontsize=24)
+    plt.tight_layout()
+    plt.savefig(f'{out_folder}/resources.png', dpi=120)
+    plt.close()
+    
+def plot_high_traffic_referrers(repo_name, out_folder):
+    popular_referrer_2w = fetch_github_endpoint(f"https://api.github.com/repos/{repo_name}/traffic/popular/referrers").json()
+    df = pd.DataFrame(popular_referrer_2w)
+    df = df.sort_values(by='uniques', ascending=False)
+
+    plt.figure(figsize=(10, 6))
+    plt.barh(df['referrer'], df['uniques'])
+    plt.xlabel('Unique traffic in the last 2 weeks', fontsize=18)
+    plt.ylabel('Referrer', fontsize=18)
+    plt.title("Popular Referrers to the Repository", fontsize=24)
+    plt.savefig(f'{out_folder}/referrers.png', dpi=120)
+    plt.close()
+
+def plot_commit_activity(repo_name, out_folder):
+    limit = 10
+    today = pd.to_datetime('today')
+    weekly_commit_count_52w = fetch_github_endpoint(f"https://api.github.com/repos/{repo_name}/stats/participation").json()['all'][-limit:]
+    timestamps = [(today - pd.Timedelta(days=7*(i+1))) for i in range(limit)]
+    df = pd.DataFrame({'timestamp': timestamps, 'commit_count': weekly_commit_count_52w})
+
+    plt.figure(figsize=(10, 6))
+    plt.bar(df['timestamp'], df['commit_count'])
+    plt.xlabel('Week', fontsize=18)
+    plt.ylabel('Commit Count', fontsize=18)
+    plt.title(f"Commits in the last {limit} weeks", fontsize=24)
+    plt.savefig(f'{out_folder}/commits.png', dpi=120)
+    plt.close()
+
+def plot_user_expertise(df, out_folder):
+    d = df.to_dict('records')[0]
+    levels = ['Beginner', 'Intermediate', 'Advanced']
+    keys = [f"op_expertise_count_{x.lower()}" for x in levels]
+    data = pd.DataFrame({'Expertise': levels, 'Count': [d.get(k, 0) for k in keys]})
+
+    plt.figure(figsize=(10, 6))
+    plt.barh(data['Expertise'], data['Count'])
+    plt.xlabel('Count', fontsize=18)
+    plt.title('User Expertise', fontsize=24)
+    plt.savefig(f'{out_folder}/expertise.png', dpi=120)
+    plt.close()
+
+def plot_severity(df, out_folder):
+    d = df.to_dict('records')[0]
+    levels = ['Trivial', 'Minor', "Major", 'Critical']
+    keys = [f"severity_count_{x.lower()}" for x in levels]
+    data = pd.DataFrame({'Severity': levels, 'Count': [d.get(k, 0) for k in keys]})
+    plt.figure(figsize=(10, 6))
+    plt.barh(data['Severity'], data['Count'])
+    plt.xlabel('Count', fontsize=18)
+    plt.title('Severity', fontsize=24)
+    plt.savefig(f'{out_folder}/severity.png', dpi=120)
+    plt.close()
+
+def plot_sentiment(df, out_folder):
+    d = df.to_dict('records')[0]
+    levels = ['Positive', 'Neutral', 'Negative']
+    keys = [f"sentiment_count_{x.lower()}" for x in levels]
+    data = pd.DataFrame({'Sentiment': levels, 'Count': [d.get(k, 0) for k in keys]})
+    plt.figure(figsize=(10, 6))
+    plt.barh(data['Sentiment'], data['Count'])
+    plt.xlabel('Count', fontsize=18)
+    plt.title('Sentiment', fontsize=24)
+    plt.savefig(f'{out_folder}/sentiment.png', dpi=120)
+    plt.close()
+        
+def plot_themes(df, out_folder):
+    d = df.to_dict('records')[0]
+    levels = ['Documentation', 'Installation and Environment', 'Model Inference', 'Model Fine Tuning and Training', 'Model Evaluation and Benchmarking', 'Model Conversion', 'Cloud Compute', 'CUDA Compatibility', 'Distributed Training and Multi-GPU', 'Invalid', 'Miscellaneous']
+    keys = [f'themes_count_{x.lower().replace(" ", "_").replace("-", "_")}' for x in levels]
+    data = pd.DataFrame({'Theme': levels, 'Count': [d.get(k, 0) for k in keys]})
+    plt.figure(figsize=(10, 6))
+    plt.barh(data['Theme'], data['Count'])
+    plt.xlabel('Count', fontsize=18)
+    plt.title('Themes', fontsize=24)
+    plt.tight_layout()
+    plt.savefig(f'{out_folder}/themes.png', dpi=120)
+    plt.close()
+  
+def issue_activity_sankey(df, out_folder):
+    
+    d = df.to_dict('records')[0]
+    label = ["New Issues", "Issues Under Discussion", "Issues Discussed and Closed", "Issues Not Responded to", "Issues Closed Without Discussion"]
+    values = [
+        d['issues_created'], 
+        d['open_discussion'] + d['closed_discussion'],  # 7
+        d['closed_discussion'], # 3
+        d['open_no_discussion'] + d['closed_no_discussion'],
+        d['closed_no_discussion'] 
+    ]
+
+    fig = go.Figure(data=[go.Sankey(
+        node = dict(
+        pad = 15,
+        thickness = 20,
+        line = dict(color = "black", width = 0.5),
+        label = [f"{l} ({values[i]})" for i, l in enumerate(label)],
+        color = ["#007bff", "#17a2b8", "#6610f2", "#dc3545", "#6c757d"]  # color scheme to highlight different flows
+        ),
+        link = dict(
+        source = [0, 1, 0, 3], # indices correspond to labels, eg A1, A2, etc
+        target = [1, 2, 3, 4],
+        value = [v if v > 0 else 1e-9 for v in values[1:]]
+    ))])
+
+    fig.update_layout(title_text="Issue Flow", font_size=16)
+    fig.update_layout(margin=dict(l=20, r=20, t=60, b=20))  # adjust margins to make text more visible
+    fig.write_image(f"{out_folder}/engagement_sankey.png")
+
+
+def draw_all_plots(repo_name, out_folder, overview):
+    func1 = [plot_views_clones, plot_high_traffic_resources, plot_high_traffic_referrers, plot_commit_activity]
+    func2 = [plot_user_expertise, plot_severity, plot_sentiment, plot_themes, issue_activity_sankey]
+    logger.info("Plotting traffic trends...")
+    for func in func1:
+        try:
+            func(repo_name, out_folder)
+        except:
+            print(f"Github fetch failed for {func}. Make sure you have push-access to {repo_name}!")
+    logger.info("Plotting issue trends...")
+    for func in func2:
+        func(overview, out_folder)
+    

+ 6 - 0
recipes/use_cases/github_triage/requirements.txt

@@ -0,0 +1,6 @@
+kaleido
+plotly
+openai
+groq
+fpdf
+plotly

+ 240 - 0
recipes/use_cases/github_triage/triage.py

@@ -0,0 +1,240 @@
+import logging
+import os
+from typing import Optional, Tuple, Dict
+import pandas as pd
+import fire
+
+from llm import run_llm_inference
+from utils import fetch_repo_issues, validate_df_values
+from plots import draw_all_plots
+from pdf_report import create_report_pdf
+
+logging.basicConfig(level=logging.INFO, filename='log.txt', format='%(asctime)s [%(levelname)-5.5s] %(message)s')
+logger = logging.getLogger(__name__)
+logger.addHandler(logging.StreamHandler())
+
+def generate_issue_annotations(
+    issues_df: pd.DataFrame
+) -> Tuple[pd.DataFrame, Dict[str, int]]:
+    """
+    Get the annotations for the given issues.
+
+    Args:
+    - issues_df (pd.DataFrame): The DataFrame containing the issues.
+
+    Returns:
+    - Tuple[pd.DataFrame, Dict[str, int]]: A tuple containing the annotated issues DataFrame and the theme counts.
+    """
+
+    # pyre-fixme[6]
+    def _categorize_issues(
+        issues_metadata_df: pd.DataFrame,
+    ) -> Tuple[pd.Series, Dict[str, int]]:
+        """
+        Categorize the issues.
+
+        Args:
+        - issues_metadata_df (pd.DataFrame): The DataFrame containing the issues metadata.
+
+        Returns:
+        - Tuple[pd.Series, Dict[str, int]]: A tuple containing the categorized issues and the theme counts.
+        """
+        minified_issues = issues_metadata_df[
+            [
+                "number",
+                "summary",
+                "possible_causes",
+                "remediations",
+                "component",
+                "issue_type",
+            ]
+        ].to_dict(orient="records")
+        themes_json = run_llm_inference(
+            "assign_category",
+            str(minified_issues),
+            generation_kwargs={"temperature": 0.45, "max_tokens": 2048},
+        )
+
+        tmp = {}
+        for t in themes_json["report"]:
+            for num in t["related_issues"]:
+                tmp[num] = tmp.get(num, []) + [t["theme"]]
+
+        themes = issues_metadata_df.number.apply(
+            lambda x: tmp.get(x, ["Miscellaneous"])
+        )
+        theme_count = {
+            k["theme"]: len(k["related_issues"]) for k in themes_json["report"]
+        }
+        return themes, theme_count
+
+    logger.info(f"Generating annotations for {len(issues_df)} issues")
+    
+    discussions = issues_df["discussion"].tolist()
+    metadata = run_llm_inference(
+        "parse_issue",
+        discussions,
+        generation_kwargs={"max_tokens": 2048, "temperature": 0.42},
+    )
+
+    # Handle the case where the LLM returns None instead of a generated response
+    metadata_index = [
+        issues_df.index[i] for i in range(len(metadata)) if metadata[i] is not None
+    ]
+    metadata = [m for m in metadata if m is not None]
+
+    issues_metadata_df = issues_df.merge(
+        pd.DataFrame(metadata, index=metadata_index), left_index=True, right_index=True
+    )
+
+    themes, theme_count = _categorize_issues(issues_metadata_df)
+    issues_metadata_df["themes"] = themes
+
+    return issues_metadata_df, theme_count
+
+
+def generate_executive_reports(
+    annotated_issues: pd.DataFrame,
+    theme_counts: Dict,
+    repo_name: str,
+    start_date: str,
+    end_date: str,
+    save_folder: Optional[str] = None,
+) -> Tuple[pd.DataFrame, pd.DataFrame]:
+    """
+    Generate executive reports for the given issues.
+
+    Args:
+    - annotated_issues (pd.DataFrame): The DataFrame containing the annotated issues.
+    - theme_counts (dict): A dictionary containing the theme counts.
+    - repo_name (str): The name of the repository. Defaults to None.
+    - start_date (str): The start date of the report. Defaults to None.
+    - end_date (str): The end date of the report. Defaults to None.
+
+    Returns:
+    - Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the challenges DataFrame and the overview DataFrame.
+    """
+    logger.info(f"Generating high-level summaries from annotations...")
+    
+    report = {
+        "repo_name": repo_name,
+        "start_date": start_date,
+        "end_date": end_date,
+        "sentiment_count": annotated_issues["sentiment"].value_counts().to_dict(),
+        "severity_count": annotated_issues["severity"].value_counts().to_dict(),
+        "op_expertise_count": annotated_issues["op_expertise"].value_counts().to_dict(),
+        "themes_count": theme_counts,
+        "issues_created": annotated_issues["number"].nunique(),
+        "open_discussion": len(
+            annotated_issues[
+                (annotated_issues.num_comments > 0) & (annotated_issues.closed == False)
+            ]
+        ),
+        "closed_discussion": len(
+            annotated_issues[
+                (annotated_issues.num_comments > 0) & (annotated_issues.closed == True)
+            ]
+        ),
+        "open_no_discussion": len(
+            annotated_issues[
+                (annotated_issues.num_comments == 0)
+                & (annotated_issues.closed == False)
+            ]
+        ),
+        "closed_no_discussion": len(
+            annotated_issues[
+                (annotated_issues.num_comments == 0) & (annotated_issues.closed == True)
+            ]
+        ),
+    }
+
+    report_input = str(
+        annotated_issues[
+            ["number", "summary", "possible_causes", "remediations"]
+        ].to_dict("records")
+    )
+    overview = run_llm_inference(
+        "get_overview", str(report_input), {"temperature": 0.45, "max_tokens": 4096}
+    )
+    report.update(overview)
+
+    overview_df = {
+        k: report[k]
+        for k in [
+            "repo_name",
+            "start_date",
+            "end_date",
+            "issues_created",
+            "open_discussion",
+            "closed_discussion",
+            "open_no_discussion",
+            "closed_no_discussion",
+        ]
+    }
+    overview_df["open_questions"] = [report["open_questions"]]
+    overview_df["executive_summary"] = [report["executive_summary"]]
+
+    for col in [
+        "sentiment_count",
+        "severity_count",
+        "op_expertise_count",
+        "themes_count",
+    ]:
+        d = report[col]
+        for k, v in d.items():
+            overview_df[f"{col}_{k}"] = v
+
+    overview_df = pd.DataFrame(overview_df)
+    
+    logger.info(f"Identifying key-challenges faced by users...")
+
+    challenges_df = {k: report[k] for k in ["repo_name", "start_date", "end_date"]}
+    challenges_df["key_challenge"] = [
+        k["key_challenge"] for k in report["issue_analysis"]
+    ]
+    challenges_df["affected_issues"] = [
+        k["affected_issues"] for k in report["issue_analysis"]
+    ]
+    challenges_df["possible_causes"] = [
+        k["possible_causes"] for k in report["issue_analysis"]
+    ]
+    challenges_df["remediations"] = [
+        k["remediations"] for k in report["issue_analysis"]
+    ]
+    challenges_df = pd.DataFrame(challenges_df)
+
+    return challenges_df, overview_df
+   
+   
+def main(repo_name, start_date, end_date):
+    out_folder = f'output/{repo_name}/{start_date}_{end_date}'
+    os.makedirs(out_folder, exist_ok=True)
+    
+    # Get issues data
+    issues_df = fetch_repo_issues(repo_name, start_date, end_date)
+    
+    # Generate annotations and metadata
+    annotated_issues, theme_counts = generate_issue_annotations(issues_df)
+    # Validate and save generated data
+    annotated_issues = validate_df_values(annotated_issues, out_folder, 'annotated_issues')
+    
+    # Generate high-level analysis
+    challenges, overview = generate_executive_reports(annotated_issues, theme_counts, repo_name, start_date, end_date)
+    # Validate and save generated data
+    challenges = validate_df_values(challenges, out_folder, 'challenges')
+    overview = validate_df_values(overview, out_folder, 'overview')
+    
+    # Create graphs and charts
+    plot_folder = out_folder + "/plots"
+    os.makedirs(plot_folder, exist_ok=True)
+    draw_all_plots(repo_name, plot_folder, overview)
+    
+    # Create PDF report
+    exec_summary = overview['executive_summary'].iloc[0]
+    open_qs = overview['open_questions'].iloc[0]
+    key_challenges_data = challenges[['key_challenge', 'possible_causes', 'remediations', 'affected_issues']].to_dict('records')
+    create_report_pdf(repo_name, start_date, end_date, key_challenges_data, exec_summary, open_qs, out_folder)
+    
+
+if __name__ == "__main__":
+    fire.Fire(main)

+ 98 - 0
recipes/use_cases/github_triage/utils.py

@@ -0,0 +1,98 @@
+import requests
+import yaml
+import pandas as pd
+import logging
+
+logger = logging.getLogger(__name__)
+logger.addHandler(logging.StreamHandler())
+
+CFG = yaml.safe_load(open("config.yaml", "r"))
+
+
+def fetch_github_endpoint(url):
+    headers = {
+        "Authorization": f"Bearer {CFG['github_token']}",
+        "Content-Type": "application/json"
+    }
+    logger.debug(f"Requesting url: {url}")
+    response = requests.get(url, headers=headers, timeout=10)
+    return response
+
+
+def fetch_repo_issues(repo, start_date=None, end_date=None):
+    time_filter = ""
+    if start_date and not end_date:
+        time_filter = f"+created:>{start_date}"
+    if end_date and not start_date:
+        time_filter = f"+created:<{end_date}"
+    if start_date and end_date:
+        time_filter = f"+created:{start_date}..{end_date}"
+    
+    url = f"https://api.github.com/search/issues?per_page=100&sort=created&order=asc&q=repo:{repo}+is:issue{time_filter}"
+
+    samples = []
+
+    while True:
+        response = fetch_github_endpoint(url)
+        
+        if response.status_code == 200:
+            issues = response.json()['items']
+            for issue in issues:
+                if issue['body'] is None:
+                    continue
+                
+                issue['discussion'] = issue['title'] + "\n" + issue['body']
+                if issue['comments'] > 0:
+                    comments_response = fetch_github_endpoint(issue['comments_url']).json()
+                    comments = "\n> ".join([x['body'] for x in comments_response])
+                    issue['discussion'] += "\n> " + comments
+                    
+                samples.append(issue)
+        
+            # Check if there are more pages
+            if "Link" in response.headers:
+                link_header = [h.split(';') for h in response.headers["Link"].split(', ')]
+                link_header = [x for x in link_header if "next" in x[1]]
+                if link_header:
+                    url = link_header[0][0].strip().replace('<', '').replace('>','')
+                else:
+                    break
+            else:
+                break
+        else:
+            raise Exception(f"Fetching issues failed with Error: {response.status_code} on url {url}")
+        
+    rows = [{
+        "repo_name": repo,
+        "number": d['number'],
+        "html_url": d['html_url'],
+        "closed": (d['state'] == 'closed'),
+        "num_comments": d['comments'],
+        "created_at": d["created_at"],
+        "discussion": d['discussion'],
+    } for d in samples]
+    
+    logger.info(f"Fetched {len(samples)} issues on {repo} from {start_date} to {end_date}")
+    
+    return pd.DataFrame(rows)
+
+
+def fetch_repo_stats(repo):
+    repo_info = fetch_github_endpoint(f"https://api.github.com/repos/{repo}").json()
+    
+    repo_stats = {
+        "Total Open Issues": repo_info['open_issues_count'],
+        "Total Stars": repo_info['stargazers_count'],
+        "Total Forks": repo_info['forks_count'],
+    }
+    
+    return repo_stats
+
+
+def validate_df_values(df, out_folder=None, name=None):
+    df.columns = df.columns.str.lower().str.replace(" ", "_").str.replace("-", "_")
+    if out_folder is not None:
+        path = f"{out_folder}/{name}.csv"
+        df.to_csv(path, index=False)
+        logger.info(f"Data saved to {path}")
+    return df

Fichier diff supprimé car celui-ci est trop grand
+ 659 - 0
recipes/use_cases/github_triage/walkthrough.ipynb


+ 2 - 6
requirements.txt

@@ -8,13 +8,12 @@ black[jupyter]
 datasets
 fire
 peft
-transformers>=4.43.1
+transformers>=4.45.1
 sentencepiece
 py7zr
 scipy
 optimum
 matplotlib
-gradio
 chardet
 openai
 typing-extensions==4.8.0
@@ -22,10 +21,7 @@ tabulate
 evaluate
 rouge_score
 pyyaml==6.0.1
-faiss-gpu
+faiss-gpu; python_version < '3.11'
 unstructured[pdf]
-langchain_openai
-langchain
-langchain_community
 sentence_transformers
 codeshield

+ 1 - 1
src/llama_recipes/configs/fsdp.py

@@ -14,7 +14,7 @@ class fsdp_config:
     hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group.
     sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
     replica_group_size: int=0 #requires hsdp to be set. This specifies the replica group size, which is world_size/sharding_group_size.
-    checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
+    checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively FULL_STATE_DICT can be used. SHARDED_STATE_DICT saves one file with sharded weights per rank while FULL_STATE_DICT will collect all weights on rank 0 and save them in a single file.
     fsdp_activation_checkpointing: bool=True
     fsdp_cpu_offload: bool=False
     pure_bf16: bool = False

+ 14 - 1
src/llama_recipes/datasets/__init__.py

@@ -1,7 +1,20 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from functools import partial
+
 from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
 from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
+from llama_recipes.datasets.custom_dataset import get_custom_dataset,get_data_collator
 from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
-from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
+from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
+DATASET_PREPROC = {
+    "alpaca_dataset": partial(get_alpaca_dataset),
+    "grammar_dataset": get_grammar_dataset,
+    "samsum_dataset": get_samsum_dataset,
+    "custom_dataset": get_custom_dataset,
+    "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
+}
+DATALOADER_COLLATE_FUNC = {
+    "custom_dataset": get_data_collator
+}

+ 57 - 0
src/llama_recipes/datasets/custom_dataset.py

@@ -0,0 +1,57 @@
+import importlib
+from pathlib import Path
+
+def load_module_from_py_file(py_file: str) -> object:
+    """
+    This method loads a module from a py file which is not in the Python path
+    """
+    module_name = Path(py_file).name
+    loader = importlib.machinery.SourceFileLoader(module_name, py_file)
+    spec = importlib.util.spec_from_loader(module_name, loader)
+    module = importlib.util.module_from_spec(spec)
+
+    loader.exec_module(module)
+
+    return module
+
+
+def get_custom_dataset(dataset_config, tokenizer, split: str):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_custom_dataset"
+
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_config, tokenizer, split)
+    except AttributeError as e:
+        print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
+        raise e
+
+def get_data_collator(dataset_processer,dataset_config):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_data_collator"
+
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_processer)
+    except AttributeError as e:
+        print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
+        print("Using the default data_collator instead.")
+        return None

+ 65 - 21
src/llama_recipes/finetuning.py

@@ -14,16 +14,18 @@ from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
 )
-
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
+    AutoConfig,
     AutoTokenizer,
     BitsAndBytesConfig,
+    AutoProcessor, 
     LlamaForCausalLM,
-    LlamaConfig,
+    MllamaForConditionalGeneration,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.mllama.modeling_mllama import  MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
 
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
@@ -37,8 +39,9 @@ from llama_recipes.utils.config_utils import (
     generate_peft_config,
     generate_dataset_config,
     get_dataloader_kwargs,
+    check_fsdp_config,
 )
-from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
+from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
 
 from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
@@ -117,19 +120,37 @@ def main(**kwargs):
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
-    model = LlamaForCausalLM.from_pretrained(
+    config = AutoConfig.from_pretrained(train_config.model_name)
+    if config.model_type == "mllama":
+        is_vision = True
+        model = MllamaForConditionalGeneration.from_pretrained(
         train_config.model_name,
         quantization_config=bnb_config,
-        use_cache=use_cache,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
     )
-
+        processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
+        processor.tokenizer.padding_side='right'
+        model.supports_gradient_checkpointing = True
+        model.language_model.supports_gradient_checkpointing = True
+    elif config.model_type == "llama":
+        is_vision = False
+        model = LlamaForCausalLM.from_pretrained(
+            train_config.model_name,
+            quantization_config=bnb_config,
+            use_cache=use_cache,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+            device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
+            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+        )
+    else:
+        raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
-    tokenizer.pad_token_id = tokenizer.eos_token_id
-
+    if not tokenizer.pad_token_id: 
+        tokenizer.pad_token_id = tokenizer.eos_token_id
+        
     # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -162,12 +183,18 @@ def main(**kwargs):
 
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
+        check_fsdp_config(fsdp_config)
+        
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-
+        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
+        if is_vision:
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
+        else:
+        # Create the FSDP wrapper for LlamaDecoderLayer in text models
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()
@@ -195,12 +222,16 @@ def main(**kwargs):
             model.to("xpu:0")
         elif torch.cuda.is_available():
             model.to("cuda")
-
     dataset_config = generate_dataset_config(train_config, kwargs)
+    if is_vision:
+        dataset_processer = processor
+    else:
+        dataset_processer = tokenizer
+
+    # Load and preprocess the dataset for training and validation
 
-     # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
-        tokenizer,
+        dataset_processer,
         dataset_config,
         split="train",
     )
@@ -208,7 +239,7 @@ def main(**kwargs):
         print(f"--> Training Set Length = {len(dataset_train)}")
 
     dataset_val = get_preprocessed_dataset(
-        tokenizer,
+        dataset_processer,
         dataset_config,
         split="test",
     )
@@ -216,10 +247,17 @@ def main(**kwargs):
         print(f"--> Validation Set Length = {len(dataset_val)}")
 
     if train_config.batching_strategy == "packing":
-        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
-
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-
+        if is_vision:
+            raise ValueError("Packing is not supported for vision datasets")
+        else:
+            dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
+    print("length of dataset_train", len(dataset_train))
+    custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
+    if custom_data_collator:
+        print("custom_data_collator is used")
+        train_dl_kwargs["collate_fn"] = custom_data_collator
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -227,13 +265,19 @@ def main(**kwargs):
         pin_memory=True,
         **train_dl_kwargs,
     )
+    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
 
     eval_dataloader = None
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
-            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+            if is_vision:
+                raise ValueError("Packing is not supported for vision datasets")
+            else:
+                dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
 
-        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
+        if custom_data_collator:
+            val_dl_kwargs["collate_fn"] = custom_data_collator
 
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
@@ -241,6 +285,7 @@ def main(**kwargs):
             pin_memory=True,
             **val_dl_kwargs,
         )
+        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
         if len(eval_dataloader) == 0:
             raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
         else:
@@ -263,7 +308,6 @@ def main(**kwargs):
             weight_decay=train_config.weight_decay,
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
-    # Start the training process
     results = train(
         model,
         train_dataloader,

+ 2 - 1
src/llama_recipes/model_checkpointing/__init__.py

@@ -3,8 +3,9 @@
 
 from llama_recipes.model_checkpointing.checkpoint_handler import (
     load_model_checkpoint,
-    save_model_checkpoint,
+    save_fsdp_model_checkpoint_full,
     save_peft_checkpoint,
+    save_model_checkpoint,
     load_optimizer_checkpoint,
     save_optimizer_checkpoint,
     save_model_and_optimizer_sharded,

+ 19 - 5
src/llama_recipes/model_checkpointing/checkpoint_handler.py

@@ -123,7 +123,7 @@ def save_model_and_optimizer_sharded(model, rank, cfg,optim=None):
         print(
             f"Checkpoint Time = {t1-t0:.4f}\n"
         )
-def save_model_checkpoint(
+def save_fsdp_model_checkpoint_full(
     model,
     optimizer,
     rank,
@@ -152,7 +152,7 @@ def save_model_checkpoint(
         )
         save_dir = Path.cwd() / folder_name
         save_dir.mkdir(parents=True, exist_ok=True)
-        save_name = cfg.model_name + "-" + str(epoch) + ".pt"
+        save_name = cfg.model_name.replace("/","--") + "-" + str(epoch) + ".pt"
         save_full_path = str(save_dir) + "/" + save_name
 
         # save model
@@ -271,6 +271,20 @@ def save_peft_checkpoint(model, model_path):
     """save_pretrained peft model"""
 
     options = StateDictOptions(full_state_dict=True, cpu_offload=True)
-
-    state_dict = get_model_state_dict(model, options=options)
-    model.save_pretrained(model_path, state_dict=state_dict)
+    
+    if isinstance(model, FSDP):
+        state_dict = get_model_state_dict(model, options=options)
+        model.save_pretrained(model_path, state_dict=state_dict)
+    else:
+        model.save_pretrained(model_path)
+    
+    
+def save_model_checkpoint(model, output_dir):
+    """save model when not peft and on single device"""
+    
+    output_file = Path(output_dir) / "model.pt"
+    
+    state_dict = model.state_dict()
+    
+    torch.save(state_dict, output_file)
+    

+ 3 - 3
src/llama_recipes/policies/wrapping.py

@@ -4,6 +4,8 @@
 import functools
 
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.mllama.modeling_mllama import   MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
+
 from torch.distributed.fsdp.wrap import (
     transformer_auto_wrap_policy,
     size_based_auto_wrap_policy,
@@ -25,9 +27,7 @@ def get_llama_wrapper():
 
     llama_auto_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
-        transformer_layer_cls={
-            LlamaDecoderLayer,
-        },
+        transformer_layer_cls=set([LlamaDecoderLayer, MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer,MllamaCrossAttentionDecoderLayer])
     )
 
     return llama_auto_wrap_policy

+ 41 - 27
src/llama_recipes/utils/config_utils.py

@@ -5,6 +5,7 @@ import inspect
 from dataclasses import asdict
 
 import torch.distributed as dist
+from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 from torch.utils.data import DistributedSampler
 from peft import (
     LoraConfig,
@@ -16,8 +17,7 @@ from transformers.data import DataCollatorForSeq2Seq
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
-from llama_recipes.utils.dataset_utils import DATASET_PREPROC
-
+from llama_recipes.datasets import DATASET_PREPROC
 
 def update_config(config, **kwargs):
     if isinstance(config, (tuple, list)):
@@ -75,34 +75,48 @@ def generate_dataset_config(train_config, kwargs):
     return  dataset_config
 
 
-def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
-        kwargs = {}
-        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
-        if train_config.batching_strategy == "padding":
-            if train_config.enable_fsdp:
-                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
-                    dataset,
-                    batch_size=batch_size,
-                    rank=dist.get_rank(),
-                    num_replicas=dist.get_world_size(),
-                    shuffle=mode=="train",
-                )
-            else:
-                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
-            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
-        elif train_config.batching_strategy == "packing":
-            if train_config.enable_fsdp:
-                kwargs["sampler"] = DistributedSampler(
+def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
+    kwargs = {}
+    batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+    if train_config.batching_strategy == "padding":
+        if train_config.enable_fsdp:
+            kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
                 dataset,
+                batch_size=batch_size,
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
-                drop_last=True,
             )
-            kwargs["batch_size"] = batch_size
-            kwargs["drop_last"] = True
-            kwargs["collate_fn"] = default_data_collator
         else:
-            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
-
-        return kwargs
+            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+        kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
+    elif train_config.batching_strategy == "packing":
+        if train_config.enable_fsdp:
+            kwargs["sampler"] = DistributedSampler(
+            dataset,
+            rank=dist.get_rank(),
+            num_replicas=dist.get_world_size(),
+            shuffle=mode=="train",
+            drop_last=True,
+        )
+        kwargs["batch_size"] = batch_size
+        kwargs["drop_last"] = True
+        kwargs["collate_fn"] = default_data_collator
+    else:
+        raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+    return kwargs
+
+
+def check_fsdp_config(fsdp_config):
+    VALID_TYPES = (StateDictType.SHARDED_STATE_DICT, StateDictType.FULL_STATE_DICT)
+    if isinstance(fsdp_config.checkpoint_type, str):
+        str_to_obj = {
+            "StateDictType.SHARDED_STATE_DICT": StateDictType.SHARDED_STATE_DICT,
+            "StateDictType.FULL_STATE_DICT": StateDictType.FULL_STATE_DICT,
+        }
+        if fsdp_config.checkpoint_type in str_to_obj:
+            fsdp_config.checkpoint_type = str_to_obj[fsdp_config.checkpoint_type]
+        
+    if not fsdp_config.checkpoint_type in VALID_TYPES:
+        raise ValueError(f"Invalid checkpoint_type {fsdp_config.checkpoint_type}")
+    

+ 31 - 55
src/llama_recipes/utils/dataset_utils.py

@@ -1,63 +1,11 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-import importlib
-from functools import partial
-from pathlib import Path
-
 import torch
 
-from llama_recipes.datasets import (
-    get_grammar_dataset,
-    get_alpaca_dataset,
-    get_samsum_dataset,
-    get_llamaguard_toxicchat_dataset,
-)
-
-
-def load_module_from_py_file(py_file: str) -> object:
-    """
-    This method loads a module from a py file which is not in the Python path
-    """
-    module_name = Path(py_file).name
-    loader = importlib.machinery.SourceFileLoader(module_name, py_file)
-    spec = importlib.util.spec_from_loader(module_name, loader)
-    module = importlib.util.module_from_spec(spec)
-
-    loader.exec_module(module)
-
-    return module
-
-
-def get_custom_dataset(dataset_config, tokenizer, split: str):
-    if ":" in dataset_config.file:
-        module_path, func_name = dataset_config.file.split(":")
-    else:
-        module_path, func_name = dataset_config.file, "get_custom_dataset"
-
-    if not module_path.endswith(".py"):
-        raise ValueError(f"Dataset file {module_path} is not a .py file.")
-
-    module_path = Path(module_path)
-    if not module_path.is_file():
-        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
-
-    module = load_module_from_py_file(module_path.as_posix())
-    try:
-        return getattr(module, func_name)(dataset_config, tokenizer, split)
-    except AttributeError as e:
-        print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
-        raise e
-
-
-DATASET_PREPROC = {
-    "alpaca_dataset": partial(get_alpaca_dataset),
-    "grammar_dataset": get_grammar_dataset,
-    "samsum_dataset": get_samsum_dataset,
-    "custom_dataset": get_custom_dataset,
-    "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
-
-}
+from llama_recipes.data.concatenator import ConcatDataset
+from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
+from llama_recipes.utils.config_utils import get_dataloader_kwargs
 
 
 def get_preprocessed_dataset(
@@ -78,3 +26,31 @@ def get_preprocessed_dataset(
         tokenizer,
         get_split(),
     )
+
+def get_custom_data_collator(
+    dataset_processer, dataset_config
+) -> torch.utils.data.Dataset:
+    if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
+        return None
+
+    return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
+        dataset_processer,
+        dataset_config
+    )
+
+def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
+    dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
+    dl_kwargs = get_dataloader_kwargs(train_config, dataset, tokenizer, split)
+    
+    if split == "train" and train_config.batching_strategy == "packing":
+        dataset = ConcatDataset(dataset, chunk_size=train_config.context_length)
+
+    # Create data loader
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        num_workers=train_config.num_workers_dataloader,
+        pin_memory=True,
+        **dl_kwargs,
+    )
+    return dataloader
+    

+ 2 - 4
src/llama_recipes/utils/fsdp_utils.py

@@ -3,7 +3,7 @@
 from torch.distributed._tensor.device_mesh import init_device_mesh
 import os 
 
-def fsdp_auto_wrap_policy(model, transformer_layer_name):
+def fsdp_auto_wrap_policy(model, transformer_layer_names):
     import functools
 
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
@@ -20,9 +20,7 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
     transformer_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
-        transformer_layer_cls=(
-            transformer_layer_name,
-        ),
+        transformer_layer_cls=set(transformer_layer_names)
     )
 
     auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])

+ 25 - 16
src/llama_recipes/utils/train_utils.py

@@ -20,7 +20,7 @@ from transformers import LlamaTokenizer
 import json
 
 
-from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint
+from llama_recipes.model_checkpointing import save_fsdp_model_checkpoint_full, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint, save_model_checkpoint
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from accelerate.utils import is_xpu_available, is_ccl_available
@@ -118,6 +118,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     max_steps_reached = False  # Flag to indicate max training steps reached
     # Start the training loop
     for epoch in range(train_config.num_epochs):
+        print(f"Starting epoch {epoch}/{train_config.num_epochs}")
+        print(f"train_config.max_train_step: {train_config.max_train_step}")
         # stop when the maximum number of training steps is reached
         if max_steps_reached:
             break
@@ -143,10 +145,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             else:
                                 batch[key] = batch[key].to(local_rank)
                         else:
-
                             if is_xpu_available():
                                 batch[key] = batch[key].to('xpu:0')
-                            else:
+                            elif torch.cuda.is_available():
                                 batch[key] = batch[key].to('cuda:0')
                     with autocast():
                         loss = model(**batch).loss
@@ -243,27 +244,35 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
 
                 else:
-                    if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-
-                        save_model_checkpoint(
+                    if not train_config.enable_fsdp:
+                        save_model_checkpoint(model, train_config.output_dir)
+                        
+                    elif fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
+                        print(" Saving the FSDP model checkpoint using FULL_STATE_DICT")
+                        print("=====================================================")
+                        save_fsdp_model_checkpoint_full(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
-                    elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
-                        print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
-                        print("=====================================================")
+                        
+                        if train_config.save_optimizer:
+                            print(" Saving the FSDP optimizer using FULL_STATE_DICT")
+                            print("=====================================================")
+                            save_optimizer_checkpoint(
+                                model, optimizer, rank, train_config, epoch=epoch
+                            )
+                        
+                    elif fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
 
-                        save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
+                            print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
+                            print("=====================================================")
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
+                        else:
                             print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
                             print("=====================================================")
+                            save_model_and_optimizer_sharded(model, rank, train_config)
 
-                    if not train_config.use_peft and  train_config.save_optimizer:
-                        save_optimizer_checkpoint(
-                            model, optimizer, rank, train_config, epoch=epoch
-                        )
-                        print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")
+                        
                 if train_config.enable_fsdp:
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time

+ 1 - 1
src/tests/conftest.py

@@ -6,7 +6,7 @@ import pytest
 from transformers import AutoTokenizer
 
 ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
-LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B"]
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
 
 @pytest.fixture(params=LLAMA_VERSIONS)
 def llama_version(request):

+ 28 - 1
src/tests/datasets/test_custom_dataset.py

@@ -11,7 +11,7 @@ EXPECTED_RESULTS={
         "example_1": "[INST] Who made Berlin [/INST] dunno",
         "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
     },
-    "meta-llama/Meta-Llama-3.1-8B":{
+    "meta-llama/Meta-Llama-3.1-8B-Instruct":{
         "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
         "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
     },
@@ -114,3 +114,30 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train,
         }
     with pytest.raises(AttributeError):
         main(**kwargs)
+
+@pytest.mark.skip_missing_tokenizer
+@patch('llama_recipes.finetuning.AutoTokenizer')
+def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version):
+    monkeypatch.syspath_prepend("recipes/quickstart/finetuning/datasets/")
+    from custom_dataset import tokenize_dialog
+
+    setup_tokenizer(tokenizer)
+    tokenizer = tokenizer.from_pretrained()
+
+    dialog = [
+        {"role":"user", "content":"Who made Berlin?"},
+        {"role":"assistant", "content":"dunno"},
+        {"role":"user", "content":"And Rome?"},
+        {"role":"assistant", "content":"Romans"},
+    ]
+
+    result = tokenize_dialog(dialog, tokenizer)
+    
+    if "Llama-2" in llama_version:
+        assert result["labels"][:12] == [-100] * 12
+        assert result["labels"][17:28] == [-100] * 11
+        assert result["labels"].count(-100) == 11 + 12
+    else:
+        assert result["labels"][:38] == [-100] * 38
+        assert result["labels"][43:54] == [-100] * 11
+        assert result["labels"].count(-100) == 38 + 11