浏览代码

Fix numpy seed in finetuning.py (#728)

Kai Wu 6 月之前
父节点
当前提交
a8e9f4eced
共有 1 个文件被更改,包括 2 次插入0 次删除
  1. 2 0
      src/llama_recipes/finetuning.py

+ 2 - 0
src/llama_recipes/finetuning.py

@@ -9,6 +9,7 @@ import fire
 import random
 import torch
 import torch.optim as optim
+import numpy as np
 from peft import get_peft_model, PeftModel
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
@@ -82,6 +83,7 @@ def main(**kwargs):
         torch.xpu.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
     random.seed(train_config.seed)
+    np.random.seed(train_config.seed)
 
     if train_config.enable_fsdp:
         setup()