llm.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import logging
  2. from typing import Any, Dict, List, Optional, Union
  3. import yaml
  4. import time
  5. import json
  6. from tqdm import tqdm
  7. from openai import OpenAI
  8. import groq
  9. logger = logging.getLogger(__name__)
  10. logger.addHandler(logging.StreamHandler())
  11. CFG = yaml.safe_load(open("config.yaml", "r"))
  12. class LlamaVLLM():
  13. def __init__(self, endpoint, model_id):
  14. self.model_id = model_id
  15. self.client = OpenAI(base_url=endpoint, api_key='token')
  16. def chat(
  17. self,
  18. inputs: List[Dict[str, str]],
  19. generation_kwargs: Optional[Dict[str, Any]] = None,
  20. guided_decode_json_schema: Optional[str] = None
  21. ) -> List[str]:
  22. if generation_kwargs is None:
  23. generation_kwargs = {}
  24. try:
  25. response = self.client.chat.completions.create(
  26. model=self.model,
  27. messages=inputs,
  28. extra_body={
  29. "guided_json": guided_decode_json_schema
  30. },
  31. **generation_kwargs,
  32. )
  33. output = response.choices[0].message
  34. except Exception as e:
  35. logger.error(
  36. f"FAILED to generate inference for input {inputs}\nError: {str(e)}"
  37. )
  38. output = None
  39. return output
  40. class LlamaGroq():
  41. def __init__(self, key, model_id):
  42. self.model_id = model_id
  43. self.client = groq.Groq(api_key=key)
  44. logger.debug(f"Using Groq:{self.model_id} for inference")
  45. def chat(
  46. self,
  47. inputs: List[Dict[str, str]],
  48. generation_kwargs: Optional[Dict[str, Any]] = None,
  49. guided_decode_json_schema: Optional[str] = None
  50. ) -> str:
  51. if generation_kwargs is None:
  52. generation_kwargs = {}
  53. # Currently Groq doesn't support guided JSON decoding. Workaround:
  54. if guided_decode_json_schema is not None:
  55. inputs[0]['content'] += f"\n\nEnsure your response aligns with the following JSON schema:\n{guided_decode_json_schema}\n\n"
  56. output = None
  57. while True:
  58. try:
  59. response = self.client.chat.completions.with_raw_response.create(
  60. model=self.model_id,
  61. messages=inputs,
  62. stream=False,
  63. **generation_kwargs,
  64. response_format={"type": 'json_object' if guided_decode_json_schema is not None else 'text'}
  65. )
  66. completion = response.parse()
  67. output = completion.choices[0].message.content
  68. break
  69. except groq.RateLimitError as e:
  70. wait = e.response.headers['X-Ratelimit-Reset']
  71. response = e.response
  72. print(e)
  73. print(f"[groq] waiting for {wait} to prevent ratelimiting")
  74. time.sleep(wait)
  75. except Exception as e:
  76. logger.error(f"INFERENCE FAILED with Error: {e.response.status_code} for input:\n{inputs[-1]['content'][:300]}")
  77. break
  78. return output
  79. def run_llm_inference(
  80. prompt_name: str,
  81. inputs: Union[str, List[str]],
  82. generation_kwargs: Optional[Dict] = None,
  83. guided_decode_json_schema=None,
  84. ) -> Union[List[str], List[Dict[str, Any]]]:
  85. """
  86. Run the LLM inference on the given inputs.
  87. Args:
  88. - prompt_name (str): The name of the prompt to use.
  89. - inputs (str or List[str]): The input(s) to the LLM.
  90. - generation_kwargs (Dict): Additional keyword arguments to pass to the LLM.
  91. - guided_decode_json_schema (str): The JSON schema to use for guided decoding.
  92. Returns:
  93. - Union[str, List[str]]: The response(s) from the LLM.
  94. """
  95. # initialize appropriate LLM accessor
  96. if CFG['model']['use'] == 'vllm':
  97. LLM = LlamaVLLM(**CFG['model']['vllm'])
  98. elif CFG['model']['use'] == 'groq':
  99. LLM = LlamaGroq(**CFG['model']['groq'])
  100. else:
  101. raise ValueError("Invalid model type in config.yaml")
  102. logger.debug(f"Running `{prompt_name}` inference with {CFG['model']['use']}")
  103. _batch = True
  104. if isinstance(inputs, str):
  105. _batch = False
  106. inputs = [inputs]
  107. inputs = [
  108. [
  109. {"role": "system", "content": CFG["prompts"][prompt_name]["system"]},
  110. {"role": "user", "content": i},
  111. ]
  112. for i in inputs
  113. ]
  114. if (
  115. guided_decode_json_schema is None
  116. and "json_schema" in CFG["prompts"][prompt_name]
  117. ):
  118. guided_decode_json_schema = " ".join(
  119. CFG["prompts"][prompt_name]["json_schema"].split()
  120. )
  121. responses = [
  122. LLM.chat(i, generation_kwargs, guided_decode_json_schema)
  123. for i in tqdm(inputs, desc=f"Inference[{prompt_name}]")
  124. ]
  125. if guided_decode_json_schema is not None:
  126. responses_json = []
  127. for r in responses:
  128. if r is not None:
  129. try:
  130. responses_json.append(json.loads(r, strict=False))
  131. continue
  132. except json.JSONDecodeError:
  133. logger.error(f"Error decoding JSON: {r}")
  134. responses_json.append(None)
  135. responses = responses_json
  136. if not _batch:
  137. responses = responses[0]
  138. return responses