modelUpgradeExample.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # Running the script without any arguments "python modelUpgradeExample.py" performs inference with the Llama 3 8B Instruct model.
  4. # Passing --model-id "meta-llama/Meta-Llama-3.1-8B-Instruct" to the script will switch it to using the Llama 3.1 version of the same model.
  5. # The script also shows the input tokens to confirm that the models are responding to the same input
  6. import fire
  7. from transformers import AutoTokenizer, AutoModelForCausalLM
  8. import torch
  9. def main(model_id = "meta-llama/Meta-Llama-3-8B-Instruct"):
  10. tokenizer = AutoTokenizer.from_pretrained(model_id)
  11. model = AutoModelForCausalLM.from_pretrained(
  12. model_id,
  13. torch_dtype=torch.bfloat16,
  14. device_map="auto",
  15. )
  16. messages = [
  17. {"role": "system", "content": "You are a helpful chatbot"},
  18. {"role": "user", "content": "Why is the sky blue?"},
  19. {"role": "assistant", "content": "Because the light is scattered"},
  20. {"role": "user", "content": "Please tell me more about that"},
  21. ]
  22. input_ids = tokenizer.apply_chat_template(
  23. messages,
  24. add_generation_prompt=True,
  25. return_tensors="pt",
  26. ).to(model.device)
  27. print("Input tokens:")
  28. print(input_ids)
  29. attention_mask = torch.ones_like(input_ids)
  30. outputs = model.generate(
  31. input_ids,
  32. max_new_tokens=400,
  33. eos_token_id=tokenizer.eos_token_id,
  34. do_sample=True,
  35. temperature=0.6,
  36. top_p=0.9,
  37. attention_mask=attention_mask,
  38. )
  39. response = outputs[0][input_ids.shape[-1]:]
  40. print("\nOutput:\n")
  41. print(tokenizer.decode(response, skip_special_tokens=True))
  42. if __name__ == "__main__":
  43. fire.Fire(main)