|
@@ -0,0 +1,96 @@
|
|
|
+DEFAULT_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
|
|
|
+
|
|
|
+import argparse
|
|
|
+
|
|
|
+import torch
|
|
|
+from accelerate import Accelerator
|
|
|
+from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
+
|
|
|
+accelerator = Accelerator()
|
|
|
+device = accelerator.device
|
|
|
+
|
|
|
+
|
|
|
+def load_model_and_tokenizer(model_name: str):
|
|
|
+ """
|
|
|
+ Load the model and tokenizer for LLaMA-8b.
|
|
|
+ """
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(
|
|
|
+ model_name,
|
|
|
+ torch_dtype=torch.bfloat16,
|
|
|
+ use_safetensors=True,
|
|
|
+ device_map=device,
|
|
|
+ )
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
|
|
|
+
|
|
|
+ model, tokenizer = accelerator.prepare(model, tokenizer)
|
|
|
+ return model, tokenizer
|
|
|
+
|
|
|
+
|
|
|
+def generate_text(model, tokenizer, conversation, temperature: float, top_p: float):
|
|
|
+ """
|
|
|
+ Generate text using the model and tokenizer based on a conversation.
|
|
|
+ """
|
|
|
+ prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
|
|
|
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
+ output = model.generate(
|
|
|
+ **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
|
|
|
+ )
|
|
|
+ return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
|
|
|
+
|
|
|
+
|
|
|
+def main(
|
|
|
+ system_message: str,
|
|
|
+ user_message: str,
|
|
|
+ temperature: float,
|
|
|
+ top_p: float,
|
|
|
+ model_name: str,
|
|
|
+):
|
|
|
+ """
|
|
|
+ Call all the functions.
|
|
|
+ """
|
|
|
+ model, tokenizer = load_model_and_tokenizer(model_name)
|
|
|
+ conversation = [
|
|
|
+ {"role": "system", "content": system_message},
|
|
|
+ {"role": "user", "content": user_message},
|
|
|
+ ]
|
|
|
+ result = generate_text(model, tokenizer, conversation, temperature, top_p)
|
|
|
+ print("Generated Text: " + result)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ parser = argparse.ArgumentParser(
|
|
|
+ description="Generate text using the LLaMA-8b model with system and user messages."
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--system_message",
|
|
|
+ type=str,
|
|
|
+ default="You are a helpful AI assistant.",
|
|
|
+ help="System message to set the context (default: 'You are a helpful AI assistant.')",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--user_message", type=str, required=True, help="User message for generation"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--temperature",
|
|
|
+ type=float,
|
|
|
+ default=0.7,
|
|
|
+ help="Temperature for generation (default: 0.7)",
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "--model_name",
|
|
|
+ type=str,
|
|
|
+ default=DEFAULT_MODEL,
|
|
|
+ help=f"Model name (default: '{DEFAULT_MODEL}')",
|
|
|
+ )
|
|
|
+
|
|
|
+ args = parser.parse_args()
|
|
|
+ main(
|
|
|
+ args.system_message,
|
|
|
+ args.user_message,
|
|
|
+ args.temperature,
|
|
|
+ args.top_p,
|
|
|
+ args.model_name,
|
|
|
+ )
|