llm.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import logging
  2. from typing import Any, Dict, List, Optional, Union
  3. import yaml
  4. import time
  5. import json
  6. from openai import OpenAI
  7. import groq
  8. logger = logging.getLogger(__name__)
  9. logger.addHandler(logging.StreamHandler())
  10. CFG = yaml.safe_load(open("config.yaml", "r"))
  11. class LlamaVLLM():
  12. def __init__(self, endpoint, model_id):
  13. self.model_id = model_id
  14. self.client = OpenAI(base_url=endpoint, api_key='token')
  15. def chat(
  16. self,
  17. inputs: List[Dict[str, str]],
  18. generation_kwargs: Optional[Dict[str, Any]] = None,
  19. guided_decode_json_schema: Optional[str] = None
  20. ) -> List[str]:
  21. if generation_kwargs is None:
  22. generation_kwargs = {}
  23. try:
  24. response = self.client.chat.completions.create(
  25. model=self.model,
  26. messages=inputs,
  27. extra_body={
  28. "guided_json": guided_decode_json_schema
  29. },
  30. **generation_kwargs,
  31. )
  32. output = response.choices[0].message
  33. except Exception as e:
  34. log.error(
  35. f"FAILED to generate inference for input {inputs}\nError: {str(e)}"
  36. )
  37. output = None
  38. return output
  39. class LlamaGroq():
  40. def __init__(self, key, model_id):
  41. self.model_id = model_id
  42. self.client = groq.Groq(api_key=key)
  43. logger.debug(f"Using Groq:{self.model_id} for inference")
  44. def chat(
  45. self,
  46. inputs: List[Dict[str, str]],
  47. generation_kwargs: Optional[Dict[str, Any]] = None,
  48. guided_decode_json_schema: Optional[str] = None
  49. ) -> str:
  50. if generation_kwargs is None:
  51. generation_kwargs = {}
  52. # Currently Groq doesn't support guided JSON decoding. Workaround:
  53. if guided_decode_json_schema is not None:
  54. inputs[0]['content'] += f"\n\nEnsure your response aligns with the following JSON schema:\n{guided_decode_json_schema}\n\n"
  55. output = None
  56. while True:
  57. try:
  58. response = self.client.chat.completions.with_raw_response.create(
  59. model=self.model_id,
  60. messages=inputs,
  61. stream=False,
  62. **generation_kwargs,
  63. response_format={"type": 'json_object' if guided_decode_json_schema is not None else 'text'}
  64. )
  65. completion = response.parse()
  66. output = completion.choices[0].message.content
  67. break
  68. except groq.RateLimitError as e:
  69. wait = e.response.headers['X-Ratelimit-Reset']
  70. response = e.response
  71. print(e)
  72. print(f"[groq] waiting for {wait} to prevent ratelimiting")
  73. time.sleep(wait)
  74. except Exception as e:
  75. logger.error(f"INFERENCE FAILED with Error: {e.response.status_code}! for input:\n{inputs[-1]['content'][:300]}")
  76. return output
  77. def run_llm_inference(
  78. prompt_name: str,
  79. inputs: Union[str, List[str]],
  80. generation_kwargs: Optional[Dict] = None,
  81. guided_decode_json_schema=None,
  82. ) -> Union[List[str], List[Dict[str, Any]]]:
  83. """
  84. Run the LLM inference on the given inputs.
  85. Args:
  86. - prompt_name (str): The name of the prompt to use.
  87. - inputs (str or List[str]): The input(s) to the LLM.
  88. - generation_kwargs (Dict): Additional keyword arguments to pass to the LLM.
  89. - guided_decode_json_schema (str): The JSON schema to use for guided decoding.
  90. Returns:
  91. - Union[str, List[str]]: The response(s) from the LLM.
  92. """
  93. # initialize appropriate LLM accessor
  94. if CFG['model']['use'] == 'vllm':
  95. LLM = LlamaVLLM(**CFG['model']['vllm'])
  96. elif CFG['model']['use'] == 'groq':
  97. LLM = LlamaGroq(**CFG['model']['groq'])
  98. else:
  99. raise ValueError("Invalid model type in config.yaml")
  100. logger.debug(f"Running `{prompt_name}` inference with {CFG['model']['use']}")
  101. _batch = True
  102. if isinstance(inputs, str):
  103. _batch = False
  104. inputs = [inputs]
  105. inputs = [
  106. [
  107. {"role": "system", "content": CFG["prompts"][prompt_name]["system"]},
  108. {"role": "user", "content": i},
  109. ]
  110. for i in inputs
  111. ]
  112. if (
  113. guided_decode_json_schema is None
  114. and "json_schema" in CFG["prompts"][prompt_name]
  115. ):
  116. guided_decode_json_schema = " ".join(
  117. CFG["prompts"][prompt_name]["json_schema"].split()
  118. )
  119. responses = [
  120. LLM.chat(i, generation_kwargs, guided_decode_json_schema) for i in inputs
  121. ]
  122. if guided_decode_json_schema is not None:
  123. responses_json = []
  124. for r in responses:
  125. if r is not None:
  126. try:
  127. responses_json.append(json.loads(r, strict=False))
  128. continue
  129. except json.JSONDecodeError:
  130. logger.error(f"Error decoding JSON: {r}")
  131. responses_json.append(None)
  132. responses = responses_json
  133. if not _batch:
  134. responses = responses[0]
  135. return responses