api_inference.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from __future__ import annotations
  2. import argparse
  3. import logging
  4. import os
  5. import sys
  6. from typing import Optional
  7. import gradio as gr
  8. from llama_api_client import LlamaAPIClient
  9. from openai import OpenAI
  10. logging.basicConfig(
  11. level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
  12. )
  13. LOG: logging.Logger = logging.getLogger(__name__)
  14. class LlamaInference:
  15. def __init__(self, api_key: str, provider: str):
  16. self.provider = provider
  17. if self.provider == "Llama":
  18. self.client = LlamaAPIClient(
  19. api_key=api_key,
  20. base_url="https://api.llama.com/v1/",
  21. )
  22. elif self.provider == "OpenAI":
  23. self.client = OpenAI(
  24. api_key=api_key,
  25. base_url="https://api.llama.com/compat/v1/",
  26. )
  27. def infer(self, user_input: str, model_id: str):
  28. response = self.client.chat.completions.create(
  29. model=model_id, messages=[{"role": "user", "content": user_input}]
  30. )
  31. if self.provider == "Llama":
  32. return response.completion_message.content.text
  33. return response.choices[0].message.content
  34. def launch_interface(self):
  35. if self.provider == "Llama":
  36. demo = gr.Interface(
  37. fn=self.infer,
  38. inputs=[
  39. gr.Textbox(),
  40. gr.Text("Llama-4-Maverick-17B-128E-Instruct-FP8"),
  41. ],
  42. outputs=gr.Textbox(),
  43. )
  44. elif self.provider == "OpenAI":
  45. demo = gr.Interface(
  46. fn=self.infer,
  47. inputs=[gr.Textbox(), gr.Text("Llama-3.3-8B-Instruct")],
  48. outputs=gr.Textbox(),
  49. )
  50. print("launching interface")
  51. demo.launch()
  52. def main() -> None:
  53. """
  54. Main function to handle API-based LLM inference.
  55. Parses command-line arguments, sets they api key, and launches the inference UI.
  56. """
  57. print("starting the main function")
  58. parser = argparse.ArgumentParser(
  59. description="Perform inference using API-based LLAMA LLMs"
  60. )
  61. parser.add_argument(
  62. "--api-key",
  63. type=str,
  64. help="API key for authentication (if not provided, will look for PROVIDER_API_KEY environment variable)",
  65. )
  66. parser.add_argument(
  67. "--provider",
  68. type=str,
  69. default="Llama",
  70. choices=["Llama", "OpenAI"],
  71. help="API provider to use (default: Llama)",
  72. )
  73. args = parser.parse_args()
  74. api_key: Optional[str] = args.api_key
  75. env_var_name = f"LLAMA_API_KEY"
  76. if api_key is not None:
  77. os.environ[env_var_name] = api_key
  78. else:
  79. api_key = os.environ.get(env_var_name)
  80. if api_key is None:
  81. LOG.error(
  82. f"No API key provided and {env_var_name} environment variable not found"
  83. )
  84. sys.exit(1)
  85. inference = LlamaInference(api_key, args.provider)
  86. inference.launch_interface()
  87. if __name__ == "__main__":
  88. main()