소스 검색

adding cuda:0 for non-fsdp situations

Hamid Shojanazeri 1 년 전
부모
커밋
707af7ea24
1개의 변경된 파일1개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      utils/train_utils.py

+ 1 - 1
utils/train_utils.py

@@ -84,7 +84,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda')       
+                        batch[key] = batch[key].to('cuda:0')       
                 outputs = model(**batch)
                 loss = outputs.loss
                 loss = loss / gradient_accumulation_steps