|
@@ -20,9 +20,9 @@ from transformers import (
|
|
|
AutoConfig,
|
|
|
AutoTokenizer,
|
|
|
BitsAndBytesConfig,
|
|
|
- LlamaForCausalLM,
|
|
|
AutoProcessor,
|
|
|
- MllamaForConditionalGeneration
|
|
|
+ MllamaForConditionalGeneration,
|
|
|
+ AutoModel,
|
|
|
)
|
|
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
from transformers.models.mllama.modeling_mllama import MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
|
|
@@ -134,7 +134,7 @@ def main(**kwargs):
|
|
|
processor.tokenizer.padding_side='right'
|
|
|
elif config.model_type == "llama":
|
|
|
is_vision = False
|
|
|
- model = LlamaForCausalLM.from_pretrained(
|
|
|
+ model = AutoModel.from_pretrained(
|
|
|
train_config.model_name,
|
|
|
quantization_config=bnb_config,
|
|
|
use_cache=use_cache,
|