llm.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import logging
  2. from typing import Any, Dict, List, Optional, Union
  3. import yaml
  4. import time
  5. import json
  6. from openai import OpenAI
  7. import groq
  8. log = logging.getLogger(__name__)
  9. CFG = yaml.safe_load(open("config.yaml", "r"))
  10. class LlamaVLLM():
  11. def __init__(self, endpoint, model_id):
  12. self.model_id = model_id
  13. self.client = OpenAI(base_url=endpoint, api_key='token')
  14. def chat(
  15. self,
  16. inputs: List[Dict[str, str]],
  17. generation_kwargs: Optional[Dict[str, Any]] = None,
  18. guided_decode_json_schema: Optional[str] = None
  19. ) -> List[str]:
  20. if generation_kwargs is None:
  21. generation_kwargs = {}
  22. try:
  23. response = self.client.chat.completions.create(
  24. model=self.model,
  25. messages=inputs,
  26. extra_body={
  27. "guided_json": guided_decode_json_schema
  28. },
  29. **generation_kwargs,
  30. )
  31. output = response.choices[0].message
  32. except Exception as e:
  33. log.error(
  34. f"FAILED to generate inference for input {inputs}\nError: {str(e)}"
  35. )
  36. output = None
  37. return output
  38. class LlamaGroq():
  39. def __init__(self, key, model_id):
  40. self.model_id = model_id
  41. self.client = groq.Groq(api_key=key)
  42. print(f"Using Groq:{self.model_id} for inference")
  43. def chat(
  44. self,
  45. inputs: List[Dict[str, str]],
  46. generation_kwargs: Optional[Dict[str, Any]] = None,
  47. guided_decode_json_schema: Optional[str] = None
  48. ) -> str:
  49. if generation_kwargs is None:
  50. generation_kwargs = {}
  51. # Currently Groq doesn't support guided JSON decoding. Workaround:
  52. if guided_decode_json_schema is not None:
  53. inputs[0]['content'] += f"\n\nEnsure your response aligns with the following JSON schema:\n{guided_decode_json_schema}\n\n"
  54. output = None
  55. while True:
  56. try:
  57. response = self.client.chat.completions.with_raw_response.create(
  58. model=self.model_id,
  59. messages=inputs,
  60. stream=False,
  61. **generation_kwargs,
  62. response_format={"type": 'json_object' if guided_decode_json_schema is not None else 'text'}
  63. )
  64. completion = response.parse()
  65. output = completion.choices[0].message.content
  66. break
  67. except groq.RateLimitError as e:
  68. wait = response.headers['X-Ratelimit-Reset']
  69. response = e.response
  70. print(e)
  71. print(f"waiting for {wait} to prevent ratelimiting")
  72. time.sleep(wait)
  73. except:
  74. print(f"inference failed for input: {inputs}")
  75. return output
  76. def run_llm_inference(
  77. prompt_name: str,
  78. inputs: Union[str, List[str]],
  79. generation_kwargs: Optional[Dict] = None,
  80. guided_decode_json_schema=None,
  81. ) -> Union[List[str], List[Dict[str, Any]]]:
  82. """
  83. Run the LLM inference on the given inputs.
  84. Args:
  85. - prompt_name (str): The name of the prompt to use.
  86. - inputs (str or List[str]): The input(s) to the LLM.
  87. - generation_kwargs (Dict): Additional keyword arguments to pass to the LLM.
  88. - guided_decode_json_schema (str): The JSON schema to use for guided decoding.
  89. Returns:
  90. - Union[str, List[str]]: The response(s) from the LLM.
  91. """
  92. log.info(f"[run_llm_inference] {prompt_name}")
  93. # initialize appropriate LLM accessor
  94. if CFG['model']['use'] == 'vllm':
  95. LLM = LlamaVLLM(**CFG['model']['vllm'])
  96. elif CFG['model']['use'] == 'groq':
  97. LLM = LlamaGroq(**CFG['model']['groq'])
  98. else:
  99. raise ValueError("Invalid model type in config.yaml")
  100. _batch = True
  101. if isinstance(inputs, str):
  102. _batch = False
  103. inputs = [inputs]
  104. inputs = [
  105. [
  106. {"role": "system", "content": CFG["prompts"][prompt_name]["system"]},
  107. {"role": "user", "content": i},
  108. ]
  109. for i in inputs
  110. ]
  111. if (
  112. guided_decode_json_schema is None
  113. and "json_schema" in CFG["prompts"][prompt_name]
  114. ):
  115. guided_decode_json_schema = " ".join(
  116. CFG["prompts"][prompt_name]["json_schema"].split()
  117. )
  118. responses = [
  119. LLM.chat(i, generation_kwargs, guided_decode_json_schema) for i in inputs
  120. ]
  121. if guided_decode_json_schema is not None:
  122. responses_json = []
  123. for r in responses:
  124. if r is not None:
  125. try:
  126. responses_json.append(json.loads(r, strict=False))
  127. continue
  128. except json.JSONDecodeError:
  129. log.error(f"Error decoding JSON: {r}")
  130. responses_json.append(None)
  131. responses = responses_json
  132. if not _batch:
  133. responses = responses[0]
  134. return responses