|
@@ -26,11 +26,8 @@ from transformers import (
|
|
|
BitsAndBytesConfig,
|
|
|
LlamaForCausalLM,
|
|
|
LlamaConfig,
|
|
|
- AutoConfig,
|
|
|
- AutoModel,
|
|
|
- LlavaNextForConditionalGeneration,
|
|
|
- LlavaNextProcessor
|
|
|
-
|
|
|
+ AutoProcessor,
|
|
|
+ MllamaForConditionalGeneration
|
|
|
)
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
|
|
@@ -126,7 +123,9 @@ def main(**kwargs):
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|
|
|
- model = LlavaNextForConditionalGeneration.from_pretrained(
|
|
|
+ if "11B" in train_config.model_name or "90B" in train_config.model_name:
|
|
|
+ is_vision = True
|
|
|
+ model = MllamaForConditionalGeneration.from_pretrained(
|
|
|
train_config.model_name,
|
|
|
quantization_config=bnb_config,
|
|
|
#use_cache=use_cache,
|
|
@@ -134,12 +133,22 @@ def main(**kwargs):
|
|
|
device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
|
torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
|
)
|
|
|
+ processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
+ processor.tokenizer.padding_side='right'
|
|
|
+ else:
|
|
|
+ model = LlamaForCausalLM.from_pretrained(
|
|
|
+ train_config.model_name,
|
|
|
+ quantization_config=bnb_config,
|
|
|
+ use_cache=use_cache,
|
|
|
+ attn_implementation="sdpa" if train_config.use_fast_kernels else None,
|
|
|
+ device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
|
|
|
+ torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
|
|
|
+ )
|
|
|
|
|
|
# Load the tokenizer and add special tokens
|
|
|
tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
- processor = LlavaNextProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
- processor.tokenizer.padding_side='right'
|
|
|
+
|
|
|
# If there is a mismatch between tokenizer vocab size and embedding matrix,
|
|
|
# throw a warning and then expand the embedding matrix
|
|
|
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
|
|
@@ -183,18 +192,16 @@ def main(**kwargs):
|
|
|
device_id = torch.xpu.current_device()
|
|
|
elif torch.cuda.is_available():
|
|
|
device_id = torch.cuda.current_device()
|
|
|
- # print(dir(model))
|
|
|
- # for layer in model.named_children():
|
|
|
- # print(f"Layer: {layer}")
|
|
|
-
|
|
|
- # layernorm = model.CLIPVisionTransformer.CLIPEncoder.LayerNorm
|
|
|
- # for name, param in layernorm.named_parameters():
|
|
|
- # print(f"Parameter: {name}, Shape: {param.shape}, Dtype: {param.dtype}")
|
|
|
- # exit()
|
|
|
+ if train_config.use_peft:
|
|
|
+ wrapping_policy = my_auto_wrapping_policy
|
|
|
+ else:
|
|
|
+ if is_vision:
|
|
|
+ wrapping_policy = ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer])
|
|
|
+ else:
|
|
|
+ wrapping_policy = ModuleWrapPolicy([LlamaDecoderLayer])
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
- auto_wrap_policy= ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer]),
|
|
|
- #auto_wrap_policy= my_auto_wrapping_policy, #if train_config.use_peft else wrapping_policy,
|
|
|
+ auto_wrap_policy= wrapping_policy,
|
|
|
cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
|
|
|
mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
|
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
@@ -205,10 +212,9 @@ def main(**kwargs):
|
|
|
param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
|
|
|
if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
)
|
|
|
- #print(model)
|
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
|
model.enable_input_require_grads()
|
|
|
- model.gradient_checkpointing_enable()
|
|
|
+ #model.gradient_checkpointing_enable()
|
|
|
apply_fsdp_checkpointing(model)
|
|
|
elif not train_config.quantization and not train_config.enable_fsdp:
|
|
|
if is_xpu_available():
|
|
@@ -217,15 +223,15 @@ def main(**kwargs):
|
|
|
model.to("cuda")
|
|
|
|
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
|
+ if is_vision:
|
|
|
+ dataset_processer = processor
|
|
|
+ else:
|
|
|
+ dataset_processer = tokenizer
|
|
|
+
|
|
|
+ # Load and preprocess the dataset for training and validation
|
|
|
|
|
|
- # Load and preprocess the dataset for training and validation
|
|
|
- # dataset_train = get_preprocessed_dataset(
|
|
|
- # processor,
|
|
|
- # dataset_config,
|
|
|
- # split="train",
|
|
|
- # )
|
|
|
dataset_train = get_preprocessed_dataset(
|
|
|
- processor,
|
|
|
+ dataset_processer,
|
|
|
dataset_config,
|
|
|
split="train",
|
|
|
)
|
|
@@ -233,7 +239,7 @@ def main(**kwargs):
|
|
|
print(f"--> Training Set Length = {len(dataset_train)}")
|
|
|
|
|
|
dataset_val = get_preprocessed_dataset(
|
|
|
- processor,
|
|
|
+ dataset_processer,
|
|
|
dataset_config,
|
|
|
split="test",
|
|
|
)
|