| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- import inspect
 
- from dataclasses import fields
 
- from peft import (
 
-     LoraConfig,
 
-     AdaptionPromptConfig,
 
-     PrefixTuningConfig,
 
- )
 
- import configs.datasets as datasets
 
- from configs import lora_config, llama_adapter_config, prefix_config, train_config
 
- from .dataset_utils import DATASET_PREPROC
 
- def update_config(config, **kwargs):
 
-     if isinstance(config, (tuple, list)):
 
-         for c in config:
 
-             update_config(c, **kwargs)
 
-     else:
 
-         for k, v in kwargs.items():
 
-             if hasattr(config, k):
 
-                 setattr(config, k, v)
 
-             elif "." in k:
 
-                 # allow --some_config.some_param=True
 
-                 config_name, param_name = k.split(".")
 
-                 if type(config).__name__ == config_name:
 
-                     if hasattr(config, param_name):
 
-                         setattr(config, param_name, v)
 
-                     else:
 
-                         # In case of specialized config we can warm user
 
-                         print(f"Warning: {config_name} does not accept parameter: {k}")
 
-             elif isinstance(config, train_config):
 
-                 print(f"Warning: unknown parameter {k}")
 
-                         
 
-                         
 
- def generate_peft_config(train_config, kwargs):
 
-     configs = (lora_config, llama_adapter_config, prefix_config)
 
-     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
 
-     names = tuple(c.__name__.rstrip("_config") for c in configs)
 
-     
 
-     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
 
-     
 
-     config = configs[names.index(train_config.peft_method)]
 
-     update_config(config, **kwargs)
 
-     params = {k.name: getattr(config, k.name) for k in fields(config)}
 
-     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
 
-     
 
-     return peft_config
 
- def generate_dataset_config(train_config, kwargs):
 
-     names = tuple(DATASET_PREPROC.keys())
 
-     
 
-     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
 
-     
 
-     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]
 
-     update_config(dataset_config, **kwargs)
 
-     
 
-     return  dataset_config
 
 
  |