Browse Source

Notebook showing how to fine tune llama guard with torchtune

Ankith Gunapal 1 month ago
parent
commit
5d87da36f7

File diff suppressed because it is too large
+ 1790 - 0
getting-started/responsible_ai/llama_guard/llama_guard_finetuning_multiple_violations_with_torchtune.ipynb


+ 120 - 0
getting-started/responsible_ai/llama_guard/torchtune_configs/8B_guard_full.yaml

@@ -0,0 +1,120 @@
+# Config for multi-device full finetuning in full_finetune_distributed.py
+# using a Llama3.1 8B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+#   tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 4 devices, run the following command from root:
+#   tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+#   tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
+#
+# This config works best when the model is being fine-tuned on 2+ GPUs.
+# Single device full finetuning requires more memory optimizations. It's
+# best to use 8B_full_single_device.yaml for those cases
+
+
+output_dir: /tmp/torchtune/llama_guard_3_8B/full # /tmp may be deleted by your system. Change it to your preference.
+
+# Tokenizer
+tokenizer:
+  _component_: torchtune.models.llama3.llama3_tokenizer
+  path: /tmp/Meta-Llama-Guard-3-8B/original/tokenizer.model
+  max_seq_len: null
+  prompt_template: torchtune_configs.custom_template.llama_guard_template
+
+# Dataset
+dataset:
+  _component_: torchtune.datasets.instruct_dataset
+  packed: False  # True increases speed
+  source: json
+  data_files: torchtune_configs/toxic_chat.json
+  column_map:
+    input: prompt
+    output: output
+seed: null
+shuffle: True
+
+# Model Arguments
+model:
+  _component_: torchtune.models.llama3_1.llama3_1_8b
+
+checkpointer:
+  _component_: torchtune.training.FullModelHFCheckpointer
+  checkpoint_dir: /tmp/Meta-Llama-Guard-3-8B/
+  #checkpoint_dir: /home/agunapal/fork/torchtune/ckpt/llama_guard/
+  checkpoint_files: [
+    model-00001-of-00004.safetensors,
+    model-00002-of-00004.safetensors,
+    model-00003-of-00004.safetensors,
+    model-00004-of-00004.safetensors
+  ]
+  recipe_checkpoint: null
+  #recipe_checkpoint: /tmp/torchtune/llama_guard_3_8B/full/recipe_state/recipe_state.pt
+  output_dir: ${output_dir}
+  model_type: LLAMA3
+resume_from_checkpoint: False
+
+# Fine-tuning arguments
+batch_size: 2
+epochs: 1
+
+optimizer:
+  _component_: torch.optim.AdamW
+  lr: 2e-5
+  fused: True
+loss:
+  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+max_steps_per_epoch: null
+clip_grad_norm: null
+compile: False  # torch.compile the model + loss, True increases speed + decreases memory
+optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
+gradient_accumulation_steps: 1  # Use to increase effective batch size
+
+# Training env
+device: cuda
+
+# Memory management
+enable_activation_checkpointing: True  # True reduces memory
+enable_activation_offloading: False  # True reduces memory
+custom_sharded_layers: ['tok_embeddings', 'output']  # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+
+# Reduced precision
+dtype: bf16
+
+# Logging
+metric_logger:
+  _component_: torchtune.training.metric_logging.DiskLogger
+  log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+
+# Profiler (disabled)
+profiler:
+  _component_: torchtune.training.setup_torch_profiler
+  enabled: False
+
+  #Output directory of trace artifacts
+  output_dir: ${output_dir}/profiling_outputs
+
+  #`torch.profiler.ProfilerActivity` types to trace
+  cpu: True
+  cuda: True
+
+  #trace options passed to `torch.profiler.profile`
+  profile_memory: False
+  with_stack: False
+  record_shapes: True
+  with_flops: False
+
+  # `torch.profiler.schedule` options:
+  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+  wait_steps: 5
+  warmup_steps: 3
+  active_steps: 2
+  num_cycles: 1

+ 46 - 0
getting-started/responsible_ai/llama_guard/torchtune_configs/custom_template.py

@@ -0,0 +1,46 @@
+from typing import List
+from pathlib import Path
+from torchtune.data import Message
+from torchtune.data import PromptTemplate
+from torchtune.models.llama3 import llama3_tokenizer
+from torchtune.datasets import instruct_dataset
+
+class MyPromptTemplate(PromptTemplate):
+    def __call__(
+        self, messages: List[Message], inference: bool = False
+    ) -> List[Message]:
+        messages = super().__call__(messages, inference)
+        return messages
+
+def llama_guard_template() -> MyPromptTemplate:
+    return MyPromptTemplate(
+        template={
+            "user": (
+                """Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+S1: Violent Crimes.
+S2: Non-Violent Crimes.
+S3: Sex Crimes.
+S4: Child Exploitation.
+S5: Defamation.
+S6: Specialized Advice.
+S7: Privacy.
+S8: Intellectual Property.
+S9: Indiscriminate Weapons.
+S10: Hate.
+S11: Self-Harm.
+S12: Sexual Content.
+S13: Elections.
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+User: """,
+                """\n\n<END CONVERSATION>
+
+Provide your safety assessment for ONLY THE LAST User message in the above conversation:
+ - First line must read 'safe' or 'unsafe'.
+ - If unsafe, a second line must include a comma-separated list of violated categories. """),
+        },
+    )