|
@@ -328,8 +328,6 @@ def load_sharded_model_single_gpu(model,model_path):
|
|
no_dist=True,
|
|
no_dist=True,
|
|
)
|
|
)
|
|
|
|
|
|
- ck = state_dict["model"].keys()
|
|
|
|
- print(f" checkpoint key len = {len(ck)} and \n keys = {state_dict.keys()}")
|
|
|
|
model.load_state_dict(state_dict["model"])
|
|
model.load_state_dict(state_dict["model"])
|
|
|
|
|
|
print(f"Sharded state checkpoint loaded from {model_path}")
|
|
print(f"Sharded state checkpoint loaded from {model_path}")
|