| 
					
				 | 
			
			
				@@ -65,7 +65,7 @@ def load_model_sharded(model, rank, cfg): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     reader = FileSystemReader(load_dir) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        checkpoint = model.state_dict() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        checkpoint = {"model": model.state_dict()} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if rank == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ck = checkpoint.keys() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}") 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -78,7 +78,7 @@ def load_model_sharded(model, rank, cfg): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             print(f"checkpoint after load_state_dict()") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             ck = checkpoint.keys() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             print(f" checkpoint key len = {len(ck)} and \n keys =  {ck}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        model.load_state_dict(checkpoint) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        model.load_state_dict(checkpoint["model"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     if rank == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(f"Sharded state checkpoint loaded from {load_dir}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |