|
@@ -21,8 +21,8 @@ from transformers import (
|
|
|
AutoTokenizer,
|
|
|
BitsAndBytesConfig,
|
|
|
AutoProcessor,
|
|
|
+ LlamaForCausalLM,
|
|
|
MllamaForConditionalGeneration,
|
|
|
- AutoModel,
|
|
|
)
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
|
|
@@ -132,9 +132,11 @@ def main(**kwargs):
|
|
|
)
|
|
|
processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
processor.tokenizer.padding_side='right'
|
|
|
+ model.supports_gradient_checkpointing = True
|
|
|
+ model.language_model.supports_gradient_checkpointing = True
|
|
|
elif config.model_type == "llama":
|
|
|
is_vision = False
|
|
|
- model = AutoModel.from_pretrained(
|
|
|
+ model = LlamaForCausalLM.from_pretrained(
|
|
|
train_config.model_name,
|
|
|
quantization_config=bnb_config,
|
|
|
use_cache=use_cache,
|