llm.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. # pyre-strict
  6. from __future__ import annotations
  7. import logging
  8. import time
  9. from abc import ABC, abstractmethod
  10. from typing import Callable
  11. import openai
  12. from typing_extensions import override
  13. NUM_LLM_RETRIES = 10
  14. MAX_TOKENS = 1000
  15. TEMPERATURE = 0.1
  16. TOP_P = 0.9
  17. LOG: logging.Logger = logging.getLogger(__name__)
  18. class LLM(ABC):
  19. def __init__(self, model: str, api_key: str | None = None) -> None:
  20. if model not in self.valid_models():
  21. LOG.warning(
  22. f"{model} is not in the valid model list for {type(self).__name__}. Valid models are: {', '.join(self.valid_models())}."
  23. )
  24. self.model: str = model
  25. self.api_key: str | None = api_key
  26. @abstractmethod
  27. def query(self, prompt: str) -> str:
  28. """
  29. Abstract method to query an LLM with a given prompt and return the response.
  30. Args:
  31. prompt (str): The prompt to send to the LLM.
  32. Returns:
  33. str: The response from the LLM.
  34. """
  35. pass
  36. def query_with_system_prompt(self, system_prompt: str, prompt: str) -> str:
  37. """
  38. Abstract method to query an LLM with a given prompt and system prompt and return the response.
  39. Args:
  40. system prompt (str): The system prompt to send to the LLM.
  41. prompt (str): The prompt to send to the LLM.
  42. Returns:
  43. str: The response from the LLM.
  44. """
  45. return self.query(system_prompt + "\n" + prompt)
  46. def _query_with_retries(
  47. self,
  48. func: Callable[..., str],
  49. *args: str,
  50. retries: int = NUM_LLM_RETRIES,
  51. backoff_factor: float = 0.5,
  52. ) -> str:
  53. last_exception = None
  54. for retry in range(retries):
  55. try:
  56. return func(*args)
  57. except Exception as exception:
  58. last_exception = exception
  59. sleep_time = backoff_factor * (2**retry)
  60. time.sleep(sleep_time)
  61. LOG.debug(
  62. f"LLM Query failed with error: {exception}. Sleeping for {sleep_time} seconds..."
  63. )
  64. raise RuntimeError(
  65. f"Unable to query LLM after {retries} retries: {last_exception}"
  66. )
  67. def query_with_retries(self, prompt: str) -> str:
  68. return self._query_with_retries(self.query, prompt)
  69. def query_with_system_prompt_with_retries(
  70. self, system_prompt: str, prompt: str
  71. ) -> str:
  72. return self._query_with_retries(
  73. self.query_with_system_prompt, system_prompt, prompt
  74. )
  75. def valid_models(self) -> list[str]:
  76. """List of valid model parameters, e.g. 'gpt-3.5-turbo' for GPT"""
  77. return []
  78. class OPENAI(LLM):
  79. """Accessing OPENAI"""
  80. def __init__(self, model: str, api_key: str) -> None:
  81. super().__init__(model, api_key)
  82. self.client = openai.OpenAI(api_key=api_key) # noqa
  83. @override
  84. def query(self, prompt: str) -> str:
  85. # Best-level effort to suppress openai log-spew.
  86. # Likely not work well in multi-threaded environment.
  87. level = logging.getLogger().level
  88. logging.getLogger().setLevel(logging.WARNING)
  89. response = self.client.chat.completions.create(
  90. model=self.model,
  91. messages=[
  92. {"role": "user", "content": prompt},
  93. ],
  94. max_tokens=MAX_TOKENS,
  95. )
  96. logging.getLogger().setLevel(level)
  97. return response.choices[0].message.content
  98. @override
  99. def valid_models(self) -> list[str]:
  100. return ["gpt-3.5-turbo", "gpt-4"]
  101. class ANYSCALE(LLM):
  102. """Accessing ANYSCALE"""
  103. def __init__(self, model: str, api_key: str) -> None:
  104. super().__init__(model, api_key)
  105. self.client = openai.OpenAI(base_url="https://api.endpoints.anyscale.com/v1", api_key=api_key) # noqa
  106. @override
  107. def query(self, prompt: str) -> str:
  108. # Best-level effort to suppress openai log-spew.
  109. # Likely not work well in multi-threaded environment.
  110. level = logging.getLogger().level
  111. logging.getLogger().setLevel(logging.WARNING)
  112. response = self.client.chat.completions.create(
  113. model=self.model,
  114. messages=[
  115. {"role": "user", "content": prompt},
  116. ],
  117. max_tokens=MAX_TOKENS,
  118. )
  119. logging.getLogger().setLevel(level)
  120. return response.choices[0].message.content
  121. @override
  122. def valid_models(self) -> list[str]:
  123. return [
  124. "meta-llama/Llama-2-7b-chat-hf",
  125. "meta-llama/Llama-2-13b-chat-hf",
  126. "meta-llama/Llama-2-70b-chat-hf",
  127. "codellama/CodeLlama-34b-Instruct-hf",
  128. "mistralai/Mistral-7B-Instruct-v0.1",
  129. "HuggingFaceH4/zephyr-7b-beta",
  130. ]
  131. class OctoAI(LLM):
  132. """Accessing OctoAI"""
  133. def __init__(self, model: str, api_key: str) -> None:
  134. super().__init__(model, api_key)
  135. self.client = openai.OpenAI(base_url="https://text.octoai.run/v1", api_key=api_key) # noqa
  136. @override
  137. def query(self, prompt: str) -> str:
  138. # Best-level effort to suppress openai log-spew.
  139. # Likely not work well in multi-threaded environment.
  140. level = logging.getLogger().level
  141. logging.getLogger().setLevel(logging.WARNING)
  142. response = self.client.chat.completions.create(
  143. model=self.model,
  144. messages=[
  145. {"role": "system", "content": "You are a helpful assistant. Keep your responses limited to one short paragraph if possible."},
  146. {"role": "user", "content": prompt},
  147. ],
  148. max_tokens=MAX_TOKENS,
  149. temperature=TEMPERATURE,
  150. top_p=TOP_P,
  151. )
  152. logging.getLogger().setLevel(level)
  153. return response.choices[0].message.content
  154. @override
  155. def valid_models(self) -> list[str]:
  156. return [
  157. "llamaguard-2-8b",
  158. "meta-llama-3-8b-instruct",
  159. "meta-llama-3-70b-instruct",
  160. ]