llm.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. #
  3. # This source code is licensed under the MIT license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. # pyre-strict
  6. from __future__ import annotations
  7. import logging
  8. import time
  9. from abc import ABC, abstractmethod
  10. from typing import Callable
  11. import openai
  12. from langchain_together import Together
  13. from typing_extensions import override
  14. NUM_LLM_RETRIES = 10
  15. MAX_TOKENS = 1000
  16. LOG: logging.Logger = logging.getLogger(__name__)
  17. class LLM(ABC):
  18. def __init__(self, model: str, api_key: str | None = None) -> None:
  19. if model not in self.valid_models():
  20. LOG.warning(
  21. f"{model} is not in the valid model list for {type(self).__name__}. Valid models are: {', '.join(self.valid_models())}."
  22. )
  23. self.model: str = model
  24. self.api_key: str | None = api_key
  25. @abstractmethod
  26. def query(self, prompt: str) -> str:
  27. """
  28. Abstract method to query an LLM with a given prompt and return the response.
  29. Args:
  30. prompt (str): The prompt to send to the LLM.
  31. Returns:
  32. str: The response from the LLM.
  33. """
  34. pass
  35. def query_with_system_prompt(self, system_prompt: str, prompt: str) -> str:
  36. """
  37. Abstract method to query an LLM with a given prompt and system prompt and return the response.
  38. Args:
  39. system prompt (str): The system prompt to send to the LLM.
  40. prompt (str): The prompt to send to the LLM.
  41. Returns:
  42. str: The response from the LLM.
  43. """
  44. return self.query(system_prompt + "\n" + prompt)
  45. def _query_with_retries(
  46. self,
  47. func: Callable[..., str],
  48. *args: str,
  49. retries: int = NUM_LLM_RETRIES,
  50. backoff_factor: float = 0.5,
  51. ) -> str:
  52. last_exception = None
  53. for retry in range(retries):
  54. try:
  55. return func(*args)
  56. except Exception as exception:
  57. last_exception = exception
  58. sleep_time = backoff_factor * (2**retry)
  59. time.sleep(sleep_time)
  60. LOG.debug(
  61. f"LLM Query failed with error: {exception}. Sleeping for {sleep_time} seconds..."
  62. )
  63. raise RuntimeError(
  64. f"Unable to query LLM after {retries} retries: {last_exception}"
  65. )
  66. def query_with_retries(self, prompt: str) -> str:
  67. return self._query_with_retries(self.query, prompt)
  68. def query_with_system_prompt_with_retries(
  69. self, system_prompt: str, prompt: str
  70. ) -> str:
  71. return self._query_with_retries(
  72. self.query_with_system_prompt, system_prompt, prompt
  73. )
  74. def valid_models(self) -> list[str]:
  75. """List of valid model parameters, e.g. 'gpt-3.5-turbo' for GPT"""
  76. return []
  77. class OPENAI(LLM):
  78. """Accessing OPENAI"""
  79. def __init__(self, model: str, api_key: str) -> None:
  80. super().__init__(model, api_key)
  81. self.client = openai.OpenAI(api_key=api_key) # noqa
  82. @override
  83. def query(self, prompt: str) -> str:
  84. # Best-level effort to suppress openai log-spew.
  85. # Likely not work well in multi-threaded environment.
  86. level = logging.getLogger().level
  87. logging.getLogger().setLevel(logging.WARNING)
  88. response = self.client.chat.completions.create(
  89. model=self.model,
  90. messages=[
  91. {"role": "user", "content": prompt},
  92. ],
  93. max_tokens=MAX_TOKENS,
  94. )
  95. logging.getLogger().setLevel(level)
  96. return response.choices[0].message.content
  97. @override
  98. def valid_models(self) -> list[str]:
  99. return ["gpt-3.5-turbo", "gpt-4"]
  100. class ANYSCALE(LLM):
  101. """Accessing ANYSCALE"""
  102. def __init__(self, model: str, api_key: str) -> None:
  103. super().__init__(model, api_key)
  104. self.client = openai.OpenAI(base_url="https://api.endpoints.anyscale.com/v1", api_key=api_key) # noqa
  105. @override
  106. def query(self, prompt: str) -> str:
  107. # Best-level effort to suppress openai log-spew.
  108. # Likely not work well in multi-threaded environment.
  109. level = logging.getLogger().level
  110. logging.getLogger().setLevel(logging.WARNING)
  111. response = self.client.chat.completions.create(
  112. model=self.model,
  113. messages=[
  114. {"role": "user", "content": prompt},
  115. ],
  116. max_tokens=MAX_TOKENS,
  117. )
  118. logging.getLogger().setLevel(level)
  119. return response.choices[0].message.content
  120. @override
  121. def valid_models(self) -> list[str]:
  122. return [
  123. "meta-llama/Llama-2-7b-chat-hf",
  124. "meta-llama/Llama-2-13b-chat-hf",
  125. "meta-llama/Llama-2-70b-chat-hf",
  126. "codellama/CodeLlama-34b-Instruct-hf",
  127. "mistralai/Mistral-7B-Instruct-v0.1",
  128. "HuggingFaceH4/zephyr-7b-beta",
  129. ]
  130. class TOGETHER(LLM):
  131. """Accessing TOGETHER"""
  132. @override
  133. def query(self, prompt: str) -> str:
  134. llm = Together(
  135. model=self.model,
  136. temperature=0.75,
  137. top_p=1,
  138. max_tokens=MAX_TOKENS,
  139. together_api_key=self.api_key,
  140. )
  141. response = llm(prompt)
  142. return "".join(response)
  143. @override
  144. def valid_models(self) -> list[str]:
  145. return [
  146. "mistralai/Mistral-7B-v0.1",
  147. "lmsys/vicuna-7b-v1.5",
  148. "togethercomputer/CodeLlama-7b",
  149. "togethercomputer/CodeLlama-7b-Python",
  150. "togethercomputer/CodeLlama-7b-Instruct",
  151. "togethercomputer/CodeLlama-13b",
  152. "togethercomputer/CodeLlama-13b-Python",
  153. "togethercomputer/CodeLlama-13b-Instruct",
  154. "togethercomputer/falcon-40b",
  155. "togethercomputer/llama-2-7b",
  156. "togethercomputer/llama-2-7b-chat",
  157. "togethercomputer/llama-2-13b",
  158. "togethercomputer/llama-2-13b-chat",
  159. "togethercomputer/llama-2-70b",
  160. "togethercomputer/llama-2-70b-chat",
  161. ]