ft-config.yaml 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Config for multi-device full finetuning in full_finetune_distributed.py
  2. # using a Llama3.1 70B Instruct model
  3. #
  4. # This config assumes that you've run the following command before launching
  5. # this run:
  6. # tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated*"
  7. #
  8. # To launch on 8 devices, run the following command from root:
  9. # tune run --nproc_per_node 8 full_finetune_distributed --config llama3_1/70B_full
  10. #
  11. # You can add specific overrides through the command line. For example
  12. # to override the checkpointer directory while launching training
  13. # you can run:
  14. # tune run --nproc_per_node 8 full_finetune_distributed --config llama3_1/70B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
  15. #
  16. # This config is only tested on an 8xA100 machine.
  17. #
  18. output_dir: /tmp/torchtune/llama3_1_70B/full # /tmp may be deleted by your system. Change it to your preference.
  19. seed: 69
  20. shuffle: True
  21. # Parallelism
  22. tensor_parallel_dim: 1
  23. tensor_parallel_plan:
  24. _component_: torchtune.models.llama3.base_llama_tp_plan
  25. # Tokenizer
  26. tokenizer:
  27. _component_: torchtune.models.llama3.llama3_tokenizer
  28. path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model
  29. max_seq_len: 16384
  30. dataset:
  31. _component_: toolcall.custom_dataset
  32. #data_files: "train_data.json"
  33. #split: "train"
  34. # Model Arguments
  35. model:
  36. _component_: torchtune.models.llama3_1.llama3_1_70b
  37. checkpointer:
  38. _component_: torchtune.training.FullModelHFCheckpointer
  39. checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
  40. checkpoint_files:
  41. filename_format: model-{}-of-{}.safetensors
  42. max_filename: "00030"
  43. recipe_checkpoint: null
  44. output_dir: ${output_dir}
  45. model_type: LLAMA3
  46. resume_from_checkpoint: False
  47. # Fine-tuning arguments
  48. batch_size: 2
  49. epochs: 1
  50. optimizer:
  51. _component_: torch.optim.AdamW
  52. lr: 2e-5
  53. # Note: highly recommended to use fused=True optimizer flag
  54. # with CPU offload for faster optimizer step.
  55. fused: False
  56. loss:
  57. _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
  58. max_steps_per_epoch: null
  59. gradient_accumulation_steps: 1 # Use to increase effective batch size
  60. # Training env
  61. device: cuda
  62. # Memory management
  63. enable_activation_checkpointing: True # True reduces memory
  64. enable_activation_offloading: False # True reduces memory
  65. custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
  66. fsdp_cpu_offload: True
  67. clip_grad_norm: null
  68. compile: False # torch.compile the model + loss, True increases speed + decreases memory
  69. optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1
  70. # Reduced precision
  71. dtype: bf16
  72. # Logging
  73. metric_logger:
  74. _component_: torchtune.training.metric_logging.DiskLogger
  75. log_dir: ${output_dir}/logs
  76. log_every_n_steps: 1
  77. log_peak_memory_stats: True
  78. # Profiler (disabled)
  79. profiler:
  80. _component_: torchtune.training.setup_torch_profiler
  81. enabled: False
  82. #Output directory of trace artifacts
  83. output_dir: ${output_dir}/profiling_outputs
  84. #`torch.profiler.ProfilerActivity` types to trace
  85. cpu: True
  86. cuda: True
  87. #trace options passed to `torch.profiler.profile`
  88. profile_memory: False
  89. with_stack: False
  90. record_shapes: True
  91. with_flops: False
  92. # `torch.profiler.schedule` options:
  93. # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  94. wait_steps: 5
  95. warmup_steps: 3
  96. active_steps: 2
  97. num_cycles: 1