Переглянути джерело

Fix numpy seed in finetuning.py (#728)

Kai Wu 6 місяців тому
батько
коміт
a8e9f4eced
1 змінених файлів з 2 додано та 0 видалено
  1. 2 0
      src/llama_recipes/finetuning.py

+ 2 - 0
src/llama_recipes/finetuning.py

@@ -9,6 +9,7 @@ import fire
 import random
 import random
 import torch
 import torch
 import torch.optim as optim
 import torch.optim as optim
+import numpy as np
 from peft import get_peft_model, PeftModel
 from peft import get_peft_model, PeftModel
 from torch.distributed.fsdp import (
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     FullyShardedDataParallel as FSDP,
@@ -82,6 +83,7 @@ def main(**kwargs):
         torch.xpu.manual_seed(train_config.seed)
         torch.xpu.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
     random.seed(train_config.seed)
     random.seed(train_config.seed)
+    np.random.seed(train_config.seed)
 
 
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         setup()
         setup()