|
@@ -9,6 +9,7 @@ import fire
|
|
|
import random
|
|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
+import numpy as np
|
|
|
from peft import get_peft_model, PeftModel
|
|
|
from torch.distributed.fsdp import (
|
|
|
FullyShardedDataParallel as FSDP,
|
|
@@ -82,6 +83,7 @@ def main(**kwargs):
|
|
|
torch.xpu.manual_seed(train_config.seed)
|
|
|
torch.manual_seed(train_config.seed)
|
|
|
random.seed(train_config.seed)
|
|
|
+ np.random.seed(train_config.seed)
|
|
|
|
|
|
if train_config.enable_fsdp:
|
|
|
setup()
|