123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- # Running the script without any arguments "python modelUpgradeExample.py" performs inference with the Llama 3 8B Instruct model.
- # Passing --model-id "meta-llama/Meta-Llama-3.1-8B-Instruct" to the script will switch it to using the Llama 3.1 version of the same model.
- # The script also shows the input tokens to confirm that the models are responding to the same input
- import fire
- from transformers import AutoTokenizer, AutoModelForCausalLM
- import torch
- def main(model_id = "meta-llama/Meta-Llama-3-8B-Instruct"):
- tokenizer = AutoTokenizer.from_pretrained(model_id)
- model = AutoModelForCausalLM.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
- device_map="auto",
- )
- messages = [
- {"role": "system", "content": "You are a helpful chatbot"},
- {"role": "user", "content": "Why is the sky blue?"},
- {"role": "assistant", "content": "Because the light is scattered"},
- {"role": "user", "content": "Please tell me more about that"},
- ]
- input_ids = tokenizer.apply_chat_template(
- messages,
- add_generation_prompt=True,
- return_tensors="pt",
- ).to(model.device)
- print("Input tokens:")
- print(input_ids)
-
- attention_mask = torch.ones_like(input_ids)
- outputs = model.generate(
- input_ids,
- max_new_tokens=400,
- eos_token_id=tokenizer.eos_token_id,
- do_sample=True,
- temperature=0.6,
- top_p=0.9,
- attention_mask=attention_mask,
- )
- response = outputs[0][input_ids.shape[-1]:]
- print("\nOutput:\n")
- print(tokenizer.decode(response, skip_special_tokens=True))
- if __name__ == "__main__":
- fire.Fire(main)
|