فهرست منبع

Fix test on non cuda machine

Matthias Reso 9 ماه پیش
والد
کامیت
448af9d7c1
1فایلهای تغییر یافته به همراه3 افزوده شده و 0 حذف شده
  1. 3 0
      src/tests/test_batching.py

+ 3 - 0
src/tests/test_batching.py

@@ -92,6 +92,7 @@ def test_packing(
 
 
 @pytest.mark.skip_missing_tokenizer
+@patch("llama_recipes.utils.train_utils.torch.cuda.is_bf16_supported")
 @patch("llama_recipes.finetuning.torch.cuda.is_available")
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
@@ -119,6 +120,7 @@ def test_distributed_packing(
     tokenizer,
     train,
     cuda_is_available,
+    cuda_is_bf16_supported,
     setup_tokenizer,
     setup_processor,
     llama_version,
@@ -133,6 +135,7 @@ def test_distributed_packing(
     get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
     get_config.return_value = Config(model_type=model_type)
     cuda_is_available.return_value = False
+    cuda_is_bf16_supported.return_value = False
 
     rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'