| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 | import argparseimport osimport sysimport gradio as grimport torchfrom accelerate import Acceleratorfrom huggingface_hub import HfFolderfrom peft import PeftModelfrom PIL import Image as PIL_Imagefrom transformers import MllamaForConditionalGeneration, MllamaProcessor# Initialize acceleratoraccelerator = Accelerator()device = accelerator.device# ConstantsDEFAULT_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"MAX_OUTPUT_TOKENS = 2048MAX_IMAGE_SIZE = (1120, 1120)def get_hf_token():    """Retrieve Hugging Face token from the cache or environment."""    # Check if a token is explicitly set in the environment    token = os.getenv("HUGGINGFACE_TOKEN")    if token:        return token    # Automatically retrieve the token from the Hugging Face cache (set via huggingface-cli login)    token = HfFolder.get_token()    if token:        return token    print("Hugging Face token not found. Please login using `huggingface-cli login`.")    sys.exit(1)def load_model_and_processor(model_name: str, finetuning_path: str = None):    """Load model and processor with optional LoRA adapter"""    print(f"Loading model: {model_name}")    hf_token = get_hf_token()    model = MllamaForConditionalGeneration.from_pretrained(        model_name,        torch_dtype=torch.bfloat16,        use_safetensors=True,        device_map=device,        token=hf_token,    )    processor = MllamaProcessor.from_pretrained(        model_name, token=hf_token, use_safetensors=True    )    if finetuning_path and os.path.exists(finetuning_path):        print(f"Loading LoRA adapter from '{finetuning_path}'...")        model = PeftModel.from_pretrained(            model, finetuning_path, is_adapter=True, torch_dtype=torch.bfloat16        )        print("LoRA adapter merged successfully")    model, processor = accelerator.prepare(model, processor)    return model, processordef process_image(image_path: str = None, image=None) -> PIL_Image.Image:    """Process and validate image input"""    if image is not None:        return image.convert("RGB")    if image_path and os.path.exists(image_path):        return PIL_Image.open(image_path).convert("RGB")    raise ValueError("No valid image provided")def generate_text_from_image(    model, processor, image, prompt_text: str, temperature: float, top_p: float):    """Generate text from image using model"""    conversation = [        {            "role": "user",            "content": [{"type": "image"}, {"type": "text", "text": prompt_text}],        }    ]    prompt = processor.apply_chat_template(        conversation, add_generation_prompt=True, tokenize=False    )    inputs = processor(        image, prompt, text_kwargs={"add_special_tokens": False}, return_tensors="pt"    ).to(device)    print("Input Prompt:\n", processor.tokenizer.decode(inputs.input_ids[0]))    output = model.generate(        **inputs, temperature=temperature, top_p=top_p, max_new_tokens=MAX_OUTPUT_TOKENS    )    return processor.decode(output[0])[len(prompt) :]def gradio_interface(model_name: str):    """Create Gradio UI with LoRA support"""    # Initialize model state    current_model = {"model": None, "processor": None}    def load_or_reload_model(enable_lora: bool, lora_path: str = None):        current_model["model"], current_model["processor"] = load_model_and_processor(            model_name, lora_path if enable_lora else None        )        return "Model loaded successfully" + (" with LoRA" if enable_lora else "")    def describe_image(        image, user_prompt, temperature, top_k, top_p, max_tokens, history    ):        if image is not None:            try:                processed_image = process_image(image=image)                result = generate_text_from_image(                    current_model["model"],                    current_model["processor"],                    processed_image,                    user_prompt,                    temperature,                    top_p,                )                history.append((user_prompt, result))            except Exception as e:                history.append((user_prompt, f"Error: {str(e)}"))        return history    def clear_chat():        return []    with gr.Blocks() as demo:        gr.HTML("<h1 style='text-align: center'>Llama Vision Model Interface</h1>")        with gr.Row():            with gr.Column(scale=1):                # Model loading controls                with gr.Group():                    enable_lora = gr.Checkbox(label="Enable LoRA", value=False)                    lora_path = gr.Textbox(                        label="LoRA Weights Path",                        placeholder="Path to LoRA weights folder",                        visible=False,                    )                    load_status = gr.Textbox(label="Load Status", interactive=False)                    load_button = gr.Button("Load/Reload Model")                # Image and parameter controls                image_input = gr.Image(                    label="Image", type="pil", image_mode="RGB", height=512, width=512                )                temperature = gr.Slider(                    label="Temperature", minimum=0.1, maximum=1.0, value=0.6, step=0.1                )                top_k = gr.Slider(                    label="Top-k", minimum=1, maximum=100, value=50, step=1                )                top_p = gr.Slider(                    label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.1                )                max_tokens = gr.Slider(                    label="Max Tokens",                    minimum=50,                    maximum=MAX_OUTPUT_TOKENS,                    value=100,                    step=50,                )            with gr.Column(scale=2):                chat_history = gr.Chatbot(label="Chat", height=512)                user_prompt = gr.Textbox(                    show_label=False, placeholder="Enter your prompt", lines=2                )                with gr.Row():                    generate_button = gr.Button("Generate")                    clear_button = gr.Button("Clear")        # Event handlers        enable_lora.change(            fn=lambda x: gr.update(visible=x), inputs=[enable_lora], outputs=[lora_path]        )        load_button.click(            fn=load_or_reload_model,            inputs=[enable_lora, lora_path],            outputs=[load_status],        )        generate_button.click(            fn=describe_image,            inputs=[                image_input,                user_prompt,                temperature,                top_k,                top_p,                max_tokens,                chat_history,            ],            outputs=[chat_history],        )        clear_button.click(fn=clear_chat, outputs=[chat_history])    # Initial model load    load_or_reload_model(False)    return demodef main(args):    """Main execution flow"""    if args.gradio_ui:        demo = gradio_interface(args.model_name)        demo.launch()    else:        model, processor = load_model_and_processor(            args.model_name, args.finetuning_path        )        image = process_image(image_path=args.image_path)        result = generate_text_from_image(            model, processor, image, args.prompt_text, args.temperature, args.top_p        )        print("Generated Text:", result)if __name__ == "__main__":    parser = argparse.ArgumentParser(        description="Multi-modal inference with optional Gradio UI and LoRA support"    )    parser.add_argument("--image_path", type=str, help="Path to the input image")    parser.add_argument("--prompt_text", type=str, help="Prompt text for the image")    parser.add_argument(        "--temperature", type=float, default=0.7, help="Sampling temperature"    )    parser.add_argument("--top_p", type=float, default=0.9, help="Top-p sampling")    parser.add_argument(        "--model_name", type=str, default=DEFAULT_MODEL, help="Model name"    )    parser.add_argument("--finetuning_path", type=str, help="Path to LoRA weights")    parser.add_argument("--gradio_ui", action="store_true", help="Launch Gradio UI")    args = parser.parse_args()    main(args)
 |