瀏覽代碼

adding cuda:0 for non-fsdp situations

Hamid Shojanazeri 1 年之前
父節點
當前提交
707af7ea24
共有 1 個文件被更改,包括 1 次插入1 次删除
  1. 1 1
      utils/train_utils.py

+ 1 - 1
utils/train_utils.py

@@ -84,7 +84,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda')       
+                        batch[key] = batch[key].to('cuda:0')       
                 outputs = model(**batch)
                 loss = outputs.loss
                 loss = loss / gradient_accumulation_steps