|
@@ -45,7 +45,17 @@ def generate_peft_config(train_config, kwargs):
|
|
|
peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
|
|
|
names = tuple(c.__name__.rstrip("_config") for c in configs)
|
|
|
|
|
|
- assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
|
|
|
+ assert (
|
|
|
+ train_config.peft_method in names
|
|
|
+ ), f"Peft config not found: {train_config.peft_method}"
|
|
|
+
|
|
|
+ assert (
|
|
|
+ train_config.peft_method != "prefix"
|
|
|
+ ), "PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)"
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ assert (
|
|
|
+ train_config.peft_method != "llama_adapter"
|
|
|
+ ), "Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)"
|
|
|
|
|
|
config = configs[names.index(train_config.peft_method)]()
|
|
|
|