1B-chat-start.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. DEFAULT_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
  2. import argparse
  3. import torch
  4. from accelerate import Accelerator
  5. from transformers import AutoModelForCausalLM, AutoTokenizer
  6. accelerator = Accelerator()
  7. device = accelerator.device
  8. def load_model_and_tokenizer(model_name: str):
  9. """
  10. Load the model and tokenizer for LLaMA-8b.
  11. """
  12. model = AutoModelForCausalLM.from_pretrained(
  13. model_name,
  14. torch_dtype=torch.bfloat16,
  15. use_safetensors=True,
  16. device_map=device,
  17. )
  18. tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True)
  19. model, tokenizer = accelerator.prepare(model, tokenizer)
  20. return model, tokenizer
  21. def generate_text(model, tokenizer, conversation, temperature: float, top_p: float):
  22. """
  23. Generate text using the model and tokenizer based on a conversation.
  24. """
  25. prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
  26. inputs = tokenizer(prompt, return_tensors="pt").to(device)
  27. output = model.generate(
  28. **inputs, temperature=temperature, top_p=top_p, max_new_tokens=512
  29. )
  30. return tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
  31. def main(
  32. system_message: str,
  33. user_message: str,
  34. temperature: float,
  35. top_p: float,
  36. model_name: str,
  37. ):
  38. """
  39. Call all the functions.
  40. """
  41. model, tokenizer = load_model_and_tokenizer(model_name)
  42. conversation = [
  43. {"role": "system", "content": system_message},
  44. {"role": "user", "content": user_message},
  45. ]
  46. result = generate_text(model, tokenizer, conversation, temperature, top_p)
  47. print("Generated Text: " + result)
  48. if __name__ == "__main__":
  49. parser = argparse.ArgumentParser(
  50. description="Generate text using the LLaMA-8b model with system and user messages."
  51. )
  52. parser.add_argument(
  53. "--system_message",
  54. type=str,
  55. default="You are a helpful AI assistant.",
  56. help="System message to set the context (default: 'You are a helpful AI assistant.')",
  57. )
  58. parser.add_argument(
  59. "--user_message", type=str, required=True, help="User message for generation"
  60. )
  61. parser.add_argument(
  62. "--temperature",
  63. type=float,
  64. default=0.7,
  65. help="Temperature for generation (default: 0.7)",
  66. )
  67. parser.add_argument(
  68. "--top_p", type=float, default=0.9, help="Top p for generation (default: 0.9)"
  69. )
  70. parser.add_argument(
  71. "--model_name",
  72. type=str,
  73. default=DEFAULT_MODEL,
  74. help=f"Model name (default: '{DEFAULT_MODEL}')",
  75. )
  76. args = parser.parse_args()
  77. main(
  78. args.system_message,
  79. args.user_message,
  80. args.temperature,
  81. args.top_p,
  82. args.model_name,
  83. )