Browse Source

prep sagemaker settings

Hamid Shojanazeri 1 năm trước cách đây
mục cha
commit
d176550ba0
2 tập tin đã thay đổi với 31 bổ sung4 xóa
  1. 4 4
      configs/training.py
  2. 27 0
      start_job.py

+ 4 - 4
configs/training.py

@@ -6,8 +6,8 @@ from typing import ClassVar
 
 @dataclass
 class train_config:
-    model_name: str="PATH/to/LLAMA/7B"
-    enable_fsdp: bool=False
+    model_name: str=" meta-llama/Llama-2-7b-chat-hf"
+    enable_fsdp: bool=True
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
@@ -23,8 +23,8 @@ class train_config:
     val_batch_size: int=1
     dataset = "samsum_dataset"
     peft_method: str = "lora" # None , llama_adapter, prefix
-    use_peft: bool=False
-    output_dir: str = "PATH/to/save/PEFT/model"
+    use_peft: bool=True
+    output_dir: str = "PEFT-7b-model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
     quantization: bool = False

+ 27 - 0
start_job.py

@@ -0,0 +1,27 @@
+import datetime
+from sagemaker.pytorch import PyTorch
+import sagemaker
+import os
+sagemaker_session = sagemaker.Session()
+role = sagemaker.get_execution_role()
+try:
+    role = sagemaker.get_execution_role()
+except ValueError:
+    iam = boto3.client('iam')
+    role = iam.get_role(RoleName='...')['Role']['Arn']
+print(role)
+
+volume_size = 500
+pytorch_estimator = PyTorch(
+    entry_point="llama_finetuning.py", # the name of the script
+    instance_type="ml.g5.12xlarge", 
+    instance_count=2, # this determines the number of p4d instances
+    source_dir=os.getcwd(),
+    framework_version="1.11.0",
+    py_version="py38",
+    volume_size=volume_size,
+    # dependencies=[''],
+    region='us-west-2',
+)
+pytorch_estimator.fit(
+    job_name='FSDP' + '-' + datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ"))