12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- DEFAULT_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
- import argparse
- import torch
- from accelerate import Accelerator
- from transformers import AutoModelForCausalLM, AutoTokenizer
- accelerator = Accelerator()
- device = accelerator.device
- def load_model_and_tokenizer(model_name: str):
- """
- Load the model and tokenizer for LLaMA-8b.
- """
- model = AutoModelForCausalLM.from_pretrained(
- model_name,
- torch_dtype=torch.bfloat16,
- use_safetensors=True,
- device_map=device,
- )
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
- model, tokenizer = accelerator.prepare(model, tokenizer)
- return model, tokenizer
- def generate_text(model, tokenizer, conversation, temperature: float, top_p: float):
- """
- Generate text using the model and tokenizer based on a conversation.
- """
- prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
- output = model.generate(
- **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
- )
- return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
- def main(
- system_message: str,
- user_message: str,
- temperature: float,
- top_p: float,
- model_name: str,
- ):
- """
- Call all the functions.
- """
- model, tokenizer = load_model_and_tokenizer(model_name)
- conversation = [
- {"role": "system", "content": system_message},
- {"role": "user", "content": user_message},
- ]
- result = generate_text(model, tokenizer, conversation, temperature, top_p)
- print("Generated Text: " + result)
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(
- description="Generate text using the LLaMA-8b model with system and user messages."
- )
- parser.add_argument(
- "--system_message",
- type=str,
- default="You are a helpful AI assistant.",
- help="System message to set the context (default: 'You are a helpful AI assistant.')",
- )
- parser.add_argument(
- "--user_message", type=str, required=True, help="User message for generation"
- )
- parser.add_argument(
- "--temperature",
- type=float,
- default=0.7,
- help="Temperature for generation (default: 0.7)",
- )
- parser.add_argument(
- "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
- )
- parser.add_argument(
- "--model_name",
- type=str,
- default=DEFAULT_MODEL,
- help=f"Model name (default: '{DEFAULT_MODEL}')",
- )
- args = parser.parse_args()
- main(
- args.system_message,
- args.user_message,
- args.temperature,
- args.top_p,
- args.model_name,
- )
|