Prechádzať zdrojové kódy

Created an api inference script with its supporting documentation (#959)

Omar Abdelwahab 5 dní pred
rodič
commit
ddd7429c61

+ 116 - 0
getting-started/inference/api_inference/README.md

@@ -0,0 +1,116 @@
+# API Inference
+
+This module provides a command-line interface for interacting with Llama models through the Llama API.
+
+## Overview
+
+The `api_inference.py` script allows you to:
+- Connect to Llama's API using your API key
+- Launch a Gradio web interface for sending prompts to Llama models
+- Get completions from models like Llama-4-Maverick-17B
+
+## Prerequisites
+
+- Python 3.8 or higher
+- A valid Llama API key
+- Required Python packages:
+  - gradio
+  - llama_api_client
+
+## Installation
+
+Ensure you have the required packages installed:
+
+```bash
+pip install gradio llama_api_client
+```
+
+## Usage
+
+You can run the script from the command line using:
+
+```bash
+python api_inference.py [OPTIONS]
+```
+
+### Command-line Options
+
+- `--api-key`: Your API key (optional)
+  - If not provided, the script will look for the appropriate environment variable based on the provider
+- `--provider`: API provider to use (optional, default: "Llama")
+  - Available options: "Llama", "OpenAI"
+
+### Setting Up Your API Key
+
+You can provide your API key in one of two ways:
+
+1. **Command-line argument**:
+   ```bash
+   python api_inference.py --api-key YOUR_API_KEY --provider Llama
+   ```
+
+2. **Environment variable**:
+   The environment variable name depends on the provider you choose:
+   ```bash
+   # For Llama (default provider)
+   export LLAMA_API_KEY=YOUR_API_KEY
+
+   # For OpenAI
+   export OPENAI_API_KEY=YOUR_API_KEY
+   ```
+
+   For Windows:
+   ```bash
+   # Command Prompt (example for Llama)
+   set LLAMA_API_KEY=YOUR_API_KEY
+
+   # PowerShell (example for Llama)
+   $env:LLAMA_API_KEY="YOUR_API_KEY"
+   ```
+
+## Example
+
+1. Run the script:
+   ```bash
+   # Using Llama (default provider)
+   python api_inference.py --api-key YOUR_API_KEY
+
+   # Using a different provider
+   python api_inference.py --api-key YOUR_API_KEY --provider OpenAI
+   ```
+
+2. The script will launch a Gradio web interface (typically at http://127.0.0.1:7860)
+
+3. In the interface:
+   - Enter your prompt in the text box
+   - The default model is "Llama-4-Maverick-17B-128E-Instruct-FP8" but you can change it
+   - Click "Submit" to get a response from the model
+
+## Troubleshooting
+
+### API Key Issues
+
+If you see an error like:
+```
+No API key provided and *_API_KEY environment variable not found
+```
+
+Make sure you've either:
+- Passed the API key using the `--api-key` argument
+- Set the appropriate environment variable for your chosen provider (LLAMA_API_KEY)
+
+## Advanced Usage
+
+You can modify the script to use different models or customize the Gradio interface as needed.
+
+## Implementation Notes
+
+- The script uses type hints for better code readability and IDE support:
+  ```python
+  api_key: Optional[str] = args.api_key
+  ```
+  This line uses the `Optional` type from the `typing` module to indicate that `api_key` can be either a string or `None`. The `Optional` type is imported from the `typing` module at the beginning of the script.
+
+## License
+
+[Include license information here]

+ 103 - 0
getting-started/inference/api_inference/api_inference.py

@@ -0,0 +1,103 @@
+from __future__ import annotations
+
+import argparse
+import logging
+import os
+import sys
+from typing import Optional
+
+import gradio as gr
+from llama_api_client import LlamaAPIClient
+from openai import OpenAI
+
+
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+)
+LOG: logging.Logger = logging.getLogger(__name__)
+
+
+class LlamaInference:
+    def __init__(self, api_key: str, provider: str):
+        self.provider = provider
+        if self.provider == "Llama":
+            self.client = LlamaAPIClient(
+                api_key=api_key,
+                base_url="https://api.llama.com/v1/",
+            )
+        elif self.provider == "OpenAI":
+            self.client = OpenAI(
+                api_key=api_key,
+                base_url="https://api.llama.com/compat/v1/",
+            )
+
+    def infer(self, user_input: str, model_id: str):
+        response = self.client.chat.completions.create(
+            model=model_id, messages=[{"role": "user", "content": user_input}]
+        )
+        if self.provider == "Llama":
+            return response.completion_message.content.text
+        return response.choices[0].message.content
+
+    def launch_interface(self):
+        if self.provider == "Llama":
+            demo = gr.Interface(
+                fn=self.infer,
+                inputs=[
+                    gr.Textbox(),
+                    gr.Text("Llama-4-Maverick-17B-128E-Instruct-FP8"),
+                ],
+                outputs=gr.Textbox(),
+            )
+        elif self.provider == "OpenAI":
+            demo = gr.Interface(
+                fn=self.infer,
+                inputs=[gr.Textbox(), gr.Text("Llama-3.3-8B-Instruct")],
+                outputs=gr.Textbox(),
+            )
+        print("launching interface")
+        demo.launch()
+
+
+def main() -> None:
+    """
+    Main function to handle API-based LLM inference.
+    Parses command-line arguments, sets they api key, and launches the inference UI.
+    """
+    print("starting the main function")
+    parser = argparse.ArgumentParser(
+        description="Perform inference using API-based LLAMA LLMs"
+    )
+
+    parser.add_argument(
+        "--api-key",
+        type=str,
+        help="API key for authentication (if not provided, will look for PROVIDER_API_KEY environment variable)",
+    )
+    parser.add_argument(
+        "--provider",
+        type=str,
+        default="Llama",
+        choices=["Llama", "OpenAI"],
+        help="API provider to use (default: Llama)",
+    )
+    args = parser.parse_args()
+
+    api_key: Optional[str] = args.api_key
+    env_var_name = f"LLAMA_API_KEY"
+
+    if api_key is not None:
+        os.environ[env_var_name] = api_key
+    else:
+        api_key = os.environ.get(env_var_name)
+        if api_key is None:
+            LOG.error(
+                f"No API key provided and {env_var_name} environment variable not found"
+            )
+            sys.exit(1)
+    inference = LlamaInference(api_key, args.provider)
+    inference.launch_interface()
+
+
+if __name__ == "__main__":
+    main()

+ 275 - 0
getting-started/inference/api_inference/llama_inference_api.ipynb

@@ -0,0 +1,275 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "dbfff3b2-fbfd-4ba5-929c-c45d028801f3",
+   "metadata": {},
+   "source": [
+    "# Llama LLM API Client Notebook"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6fcc5513-4c08-42e5-9b96-20b37482d1ea",
+   "metadata": {},
+   "source": [
+    "## This Notebook will send a request to the Llama LLM model's API and retrieve the response."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5f98a3c4-81a6-44be-987a-cd4f5adcef77",
+   "metadata": {},
+   "source": [
+    "### First we will import os, gradio and the LlamaAPIClient"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "4c2a01b9-23d3-4f0e-9c1f-92b2b2b37231",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import gradio as gr\n",
+    "from llama_api_client import LlamaAPIClient"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e3e17aea-0422-4ad4-8ada-94e56e5b0fcf",
+   "metadata": {},
+   "source": [
+    "# Use the following [link](https://l.facebook.com/l.php?u=https%3A%2F%2Fllama.developer.meta.com%2F&h=AT1np-Q96F0Qym-qTXCKJuPUqZePWH7EWzib2mRB0XSqjajogDg91FWpn-_MRnaVS7CGCR1o4Z3wff5uQNe55poibnARsHoU_Yf-uKANWztHmiSgl56GTo5ERQR1o8GmHatZ2Eb_hdCGIO4iD3-c0QZN5Sdq) to retrieve your Llama API key. Next, configure your environment by setting the \"LLAMA_API_KEY\" variable to the obtained key using one of the following steps."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "id": "8ff2c746-44ba-4594-9293-6fdae986fa1b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# either run the following command in your cli terminal\n",
+    "# export LLAMA_API_KEY='Llama API key'\n",
+    "# Or\n",
+    "# Directly set your LLAMA_API_KEY environment variable here\n",
+    "# os.environ['LLAMA_API_KEY'] = 'Llama API key' "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e7855614-9a8b-421d-b4ec-32b278c24f90",
+   "metadata": {},
+   "source": [
+    "### Run the \"infer\" function to create an interactive UI where you can enter your input prompt for the LLM in the designated \"user_input\" field and view the generated response in the adjacent \"output\" field."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 53,
+   "id": "fd340fa0-a116-4f13-999b-60f5f99f2dc3",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "* Running on local URL:  http://127.0.0.1:7890\n",
+      "* To create a public link, set `share=True` in `launch()`.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div><iframe src=\"http://127.0.0.1:7890/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": []
+     },
+     "execution_count": 53,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# You can switch between the following models:\n",
+    "# Llama-3.3-70B-Instruct\n",
+    "# Llama-4-Maverick-17B-128E-Instruct-FP8\n",
+    "def infer(user_input:str, model_id:str):\n",
+    "    response = client.chat.completions.create(\n",
+    "    model=model_id,\n",
+    "    messages=[\n",
+    "        {\"role\": \"user\", \"content\": user_input},\n",
+    "    ])\n",
+    "    return response.completion_message.content.text\n",
+    "\n",
+    "\n",
+    "demo = gr.Interface(fn=infer,\n",
+    "             inputs=[gr.Textbox(),gr.Text(\"Llama-4-Maverick-17B-128E-Instruct-FP8\")], \n",
+    "             outputs=gr.Textbox(), #<-- Gradio TextBox component as an output component\n",
+    "             )\n",
+    "\n",
+    "demo.launch()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1ba1007d-5730-4b2e-a52e-ce4a8bb2313b",
+   "metadata": {},
+   "source": [
+    "### The model can be used on a range of use cases, including factoid-based question answering, to demonstrate its effectiveness and versatility in real-world scenarios."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 52,
+   "id": "bc3dbb92-134f-477d-a4f4-67cd565a0618",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "* Running on local URL:  http://127.0.0.1:7889\n",
+      "* To create a public link, set `share=True` in `launch()`.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div><iframe src=\"http://127.0.0.1:7889/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": []
+     },
+     "execution_count": 52,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# You can switch between the following models:\n",
+    "# Llama-3.3-70B-Instruct\n",
+    "# Llama-4-Maverick-17B-128E-Instruct-FP8\n",
+    "def infer(user_input:str, model_id:str):\n",
+    "    response = client.chat.completions.create(\n",
+    "    model=model_id,\n",
+    "    messages=[\n",
+    "        {\"role\": \"user\", \"content\": user_input},\n",
+    "    ])\n",
+    "    return response.completion_message.content.text\n",
+    "\n",
+    "\n",
+    "demo = gr.Interface(fn=infer,\n",
+    "             inputs=[gr.Textbox(),gr.Text(\"Llama-4-Maverick-17B-128E-Instruct-FP8\")], \n",
+    "             outputs=gr.Textbox(), #<-- Gradio TextBox component as an output component\n",
+    "             )\n",
+    "\n",
+    "demo.launch()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 54,
+   "id": "5a47117e-700a-4025-bdf4-1f7972b0cd39",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "* Running on local URL:  http://127.0.0.1:7891\n",
+      "* To create a public link, set `share=True` in `launch()`.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div><iframe src=\"http://127.0.0.1:7891/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": []
+     },
+     "execution_count": 54,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "# You can switch between the following models:\n",
+    "# Llama-3.3-70B-Instruct\n",
+    "# Llama-4-Maverick-17B-128E-Instruct-FP8\n",
+    "def infer(user_input:str, model_id:str):\n",
+    "    response = client.chat.completions.create(\n",
+    "    model=model_id,\n",
+    "    messages=[\n",
+    "        {\"role\": \"user\", \"content\": user_input},\n",
+    "    ])\n",
+    "    return response.completion_message.content.text\n",
+    "\n",
+    "\n",
+    "demo = gr.Interface(fn=infer,\n",
+    "             inputs=[gr.Textbox(),gr.Text(\"Llama-4-Maverick-17B-128E-Instruct-FP8\")], \n",
+    "             outputs=gr.Textbox(), #<-- Gradio TextBox component as an output component\n",
+    "             )\n",
+    "\n",
+    "demo.launch()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "77d87202-da25-41e3-b3d2-7b0aa979e739",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.11"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}