فهرست منبع

Notebook showing how to fine tune llama guard with torchtune

Ankith Gunapal 10 ماه پیش
والد
کامیت
5d87da36f7

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 1790 - 0
getting-started/responsible_ai/llama_guard/llama_guard_finetuning_multiple_violations_with_torchtune.ipynb


+ 120 - 0
getting-started/responsible_ai/llama_guard/torchtune_configs/8B_guard_full.yaml

@@ -0,0 +1,120 @@
+# Config for multi-device full finetuning in full_finetune_distributed.py
+# using a Llama3.1 8B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+#   tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 4 devices, run the following command from root:
+#   tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+#   tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
+#
+# This config works best when the model is being fine-tuned on 2+ GPUs.
+# Single device full finetuning requires more memory optimizations. It's
+# best to use 8B_full_single_device.yaml for those cases
+
+
+output_dir: /tmp/torchtune/llama_guard_3_8B/full # /tmp may be deleted by your system. Change it to your preference.
+
+# Tokenizer
+tokenizer:
+  _component_: torchtune.models.llama3.llama3_tokenizer
+  path: /tmp/Meta-Llama-Guard-3-8B/original/tokenizer.model
+  max_seq_len: null
+  prompt_template: torchtune_configs.custom_template.llama_guard_template
+
+# Dataset
+dataset:
+  _component_: torchtune.datasets.instruct_dataset
+  packed: False  # True increases speed
+  source: json
+  data_files: torchtune_configs/toxic_chat.json
+  column_map:
+    input: prompt
+    output: output
+seed: null
+shuffle: True
+
+# Model Arguments
+model:
+  _component_: torchtune.models.llama3_1.llama3_1_8b
+
+checkpointer:
+  _component_: torchtune.training.FullModelHFCheckpointer
+  checkpoint_dir: /tmp/Meta-Llama-Guard-3-8B/
+  #checkpoint_dir: /home/agunapal/fork/torchtune/ckpt/llama_guard/
+  checkpoint_files: [
+    model-00001-of-00004.safetensors,
+    model-00002-of-00004.safetensors,
+    model-00003-of-00004.safetensors,
+    model-00004-of-00004.safetensors
+  ]
+  recipe_checkpoint: null
+  #recipe_checkpoint: /tmp/torchtune/llama_guard_3_8B/full/recipe_state/recipe_state.pt
+  output_dir: ${output_dir}
+  model_type: LLAMA3
+resume_from_checkpoint: False
+
+# Fine-tuning arguments
+batch_size: 2
+epochs: 1
+
+optimizer:
+  _component_: torch.optim.AdamW
+  lr: 2e-5
+  fused: True
+loss:
+  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
+max_steps_per_epoch: null
+clip_grad_norm: null
+compile: False  # torch.compile the model + loss, True increases speed + decreases memory
+optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
+gradient_accumulation_steps: 1  # Use to increase effective batch size
+
+# Training env
+device: cuda
+
+# Memory management
+enable_activation_checkpointing: True  # True reduces memory
+enable_activation_offloading: False  # True reduces memory
+custom_sharded_layers: ['tok_embeddings', 'output']  # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+
+# Reduced precision
+dtype: bf16
+
+# Logging
+metric_logger:
+  _component_: torchtune.training.metric_logging.DiskLogger
+  log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+
+# Profiler (disabled)
+profiler:
+  _component_: torchtune.training.setup_torch_profiler
+  enabled: False
+
+  #Output directory of trace artifacts
+  output_dir: ${output_dir}/profiling_outputs
+
+  #`torch.profiler.ProfilerActivity` types to trace
+  cpu: True
+  cuda: True
+
+  #trace options passed to `torch.profiler.profile`
+  profile_memory: False
+  with_stack: False
+  record_shapes: True
+  with_flops: False
+
+  # `torch.profiler.schedule` options:
+  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+  wait_steps: 5
+  warmup_steps: 3
+  active_steps: 2
+  num_cycles: 1

+ 46 - 0
getting-started/responsible_ai/llama_guard/torchtune_configs/custom_template.py

@@ -0,0 +1,46 @@
+from typing import List
+from pathlib import Path
+from torchtune.data import Message
+from torchtune.data import PromptTemplate
+from torchtune.models.llama3 import llama3_tokenizer
+from torchtune.datasets import instruct_dataset
+
+class MyPromptTemplate(PromptTemplate):
+    def __call__(
+        self, messages: List[Message], inference: bool = False
+    ) -> List[Message]:
+        messages = super().__call__(messages, inference)
+        return messages
+
+def llama_guard_template() -> MyPromptTemplate:
+    return MyPromptTemplate(
+        template={
+            "user": (
+                """Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+S1: Violent Crimes.
+S2: Non-Violent Crimes.
+S3: Sex Crimes.
+S4: Child Exploitation.
+S5: Defamation.
+S6: Specialized Advice.
+S7: Privacy.
+S8: Intellectual Property.
+S9: Indiscriminate Weapons.
+S10: Hate.
+S11: Self-Harm.
+S12: Sexual Content.
+S13: Elections.
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+User: """,
+                """\n\n<END CONVERSATION>
+
+Provide your safety assessment for ONLY THE LAST User message in the above conversation:
+ - First line must read 'safe' or 'unsafe'.
+ - If unsafe, a second line must include a comma-separated list of violated categories. """),
+        },
+    )