浏览代码

Add Llama 3.1 example upgrade script (#5)

* Add example upgrade script

* Modify modelUpgradeExample.py as per suggestions

* Reference files in README

---------

Co-authored-by: Thomas Robinson <trobinson@meta.com>
Thomas Robinson 8 月之前
父节点
当前提交
d17e678659
共有 2 个文件被更改,包括 55 次插入3 次删除
  1. 4 3
      recipes/quickstart/inference/README.md
  2. 51 0
      recipes/quickstart/inference/modelUpgradeExample.py

+ 4 - 3
recipes/quickstart/inference/README.md

@@ -2,6 +2,7 @@
 
 This folder contains scripts to get you started with inference on Meta Llama models.
 
-* [](./code_llama/) contains scripts for tasks relating to code generation using CodeLlama
-* [](./local_inference/) contsin scripts to do memory efficient inference on servers and local machines
-* [](./mobile_inference/) has scripts using MLC to serve Llama on Android (h/t to OctoAI for the contribution!)
+* [Code Llama](./code_llama/) contains scripts for tasks relating to code generation using CodeLlama
+* [Local Inference](./local_inference/) contains scripts to do memory efficient inference on servers and local machines
+* [Mobile Inference](./mobile_inference/) has scripts using MLC to serve Llama on Android (h/t to OctoAI for the contribution!)
+* [Model Update Example](./modelUpgradeExample.py) shows an example of replacing a Llama 3 model with a Llama 3.1 model. 

+ 51 - 0
recipes/quickstart/inference/modelUpgradeExample.py

@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# Running the script without any arguments "python modelUpgradeExample.py" performs inference with the Llama 3 8B Instruct model. 
+# Passing  --model-id "meta-llama/Meta-Llama-3.1-8B-Instruct" to the script will switch it to using the Llama 3.1 version of the same model. 
+# The script also shows the input tokens to confirm that the models are responding to the same input
+
+import fire
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import torch
+
+def main(model_id = "meta-llama/Meta-Llama-3-8B-Instruct"):
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+    model = AutoModelForCausalLM.from_pretrained(
+        model_id,
+        torch_dtype=torch.bfloat16,
+        device_map="auto",
+    )
+
+    messages = [
+        {"role": "system", "content": "You are a helpful chatbot"},
+        {"role": "user", "content": "Why is the sky blue?"},
+        {"role": "assistant", "content": "Because the light is scattered"},
+        {"role": "user", "content": "Please tell me more about that"},
+    ]
+
+    input_ids = tokenizer.apply_chat_template(
+        messages,
+        add_generation_prompt=True,
+        return_tensors="pt",
+    ).to(model.device)
+
+    print("Input tokens:")
+    print(input_ids)
+    
+    attention_mask = torch.ones_like(input_ids)
+    outputs = model.generate(
+        input_ids,
+        max_new_tokens=400,
+        eos_token_id=tokenizer.eos_token_id,
+        do_sample=True,
+        temperature=0.6,
+        top_p=0.9,
+        attention_mask=attention_mask,
+    )
+    response = outputs[0][input_ids.shape[-1]:]
+    print("\nOutput:\n")
+    print(tokenizer.decode(response, skip_special_tokens=True))
+
+if __name__ == "__main__":
+  fire.Fire(main)