| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 | 
							- import logging 
 
- from typing import Any, Dict, List, Optional, Union
 
- import yaml
 
- import time
 
- import json
 
- from tqdm import tqdm
 
- from openai import OpenAI
 
- import groq
 
- logger = logging.getLogger(__name__)
 
- logger.addHandler(logging.StreamHandler())
 
- CFG = yaml.safe_load(open("config.yaml", "r"))
 
- class LlamaVLLM():
 
-     def __init__(self, endpoint, model_id):
 
-         self.model_id = model_id
 
-         self.client = OpenAI(base_url=endpoint, api_key='token')
 
-     def chat(
 
-         self,
 
-         inputs: List[Dict[str, str]],
 
-         generation_kwargs: Optional[Dict[str, Any]] = None,
 
-         guided_decode_json_schema: Optional[str] = None
 
-     ) -> List[str]:
 
-         if generation_kwargs is None:
 
-             generation_kwargs = {}
 
-             
 
-         try:
 
-             response = self.client.chat.completions.create(
 
-                 model=self.model,
 
-                 messages=inputs,
 
-                 extra_body={
 
-                     "guided_json": guided_decode_json_schema
 
-                 },
 
-                 **generation_kwargs,
 
-             )
 
-             output = response.choices[0].message
 
-         except Exception as e:
 
-             logger.error(
 
-                 f"FAILED to generate inference for input {inputs}\nError: {str(e)}"
 
-             )
 
-             output = None
 
-         return output
 
-     
 
- class LlamaGroq():
 
-     def __init__(self, key, model_id):
 
-         self.model_id = model_id
 
-         self.client = groq.Groq(api_key=key)
 
-         logger.debug(f"Using Groq:{self.model_id} for inference")
 
-     def chat(
 
-         self, 
 
-         inputs: List[Dict[str, str]], 
 
-         generation_kwargs: Optional[Dict[str, Any]] = None,
 
-         guided_decode_json_schema: Optional[str] = None
 
-     ) -> str:
 
-         
 
-         if generation_kwargs is None:
 
-             generation_kwargs = {}
 
-             
 
-         # Currently Groq doesn't support guided JSON decoding. Workaround:
 
-         if guided_decode_json_schema is not None:
 
-             inputs[0]['content'] += f"\n\nEnsure your response aligns with the following JSON schema:\n{guided_decode_json_schema}\n\n"
 
-         
 
-         output = None
 
-         
 
-         while True:
 
-             try:
 
-                 response = self.client.chat.completions.with_raw_response.create(
 
-                     model=self.model_id,
 
-                     messages=inputs,
 
-                     stream=False,
 
-                     **generation_kwargs,
 
-                     response_format={"type": 'json_object' if guided_decode_json_schema is not None else 'text'}
 
-                 )
 
-                 completion = response.parse()
 
-                 output = completion.choices[0].message.content
 
-                 break
 
-             except groq.RateLimitError as e:
 
-                 wait = e.response.headers['X-Ratelimit-Reset']
 
-                 response = e.response
 
-                 print(e)
 
-                 print(f"[groq] waiting for {wait} to prevent ratelimiting")
 
-                 time.sleep(wait)
 
-             except Exception as e:
 
-                 logger.error(f"INFERENCE FAILED with Error: {e.response.status_code} for input:\n{inputs[-1]['content'][:300]}")
 
-                 break
 
-         return output
 
- def run_llm_inference(
 
-     prompt_name: str,
 
-     inputs: Union[str, List[str]],
 
-     generation_kwargs: Optional[Dict] = None,
 
-     guided_decode_json_schema=None,
 
- ) -> Union[List[str], List[Dict[str, Any]]]:
 
-     """
 
-     Run the LLM inference on the given inputs.
 
-     Args:
 
-     - prompt_name (str): The name of the prompt to use.
 
-     - inputs (str or List[str]): The input(s) to the LLM.
 
-     - generation_kwargs (Dict): Additional keyword arguments to pass to the LLM.
 
-     - guided_decode_json_schema (str): The JSON schema to use for guided decoding.
 
-     Returns:
 
-     - Union[str, List[str]]: The response(s) from the LLM.
 
-     """
 
-     
 
-     # initialize appropriate LLM accessor
 
-     if CFG['model']['use'] == 'vllm':
 
-         LLM = LlamaVLLM(**CFG['model']['vllm'])
 
-     elif CFG['model']['use'] == 'groq':
 
-         LLM = LlamaGroq(**CFG['model']['groq'])
 
-     else:
 
-         raise ValueError("Invalid model type in config.yaml")
 
-     
 
-     logger.debug(f"Running `{prompt_name}` inference with {CFG['model']['use']}")
 
-     
 
-     _batch = True
 
-     if isinstance(inputs, str):
 
-         _batch = False
 
-         inputs = [inputs]
 
-     inputs = [
 
-         [
 
-             {"role": "system", "content": CFG["prompts"][prompt_name]["system"]},
 
-             {"role": "user", "content": i},
 
-         ]
 
-         for i in inputs
 
-     ]
 
-     if (
 
-         guided_decode_json_schema is None
 
-         and "json_schema" in CFG["prompts"][prompt_name]
 
-     ):
 
-         guided_decode_json_schema = " ".join(
 
-             CFG["prompts"][prompt_name]["json_schema"].split()
 
-         )
 
-     responses = [
 
-         LLM.chat(i, generation_kwargs, guided_decode_json_schema) 
 
-         for i in tqdm(inputs, desc=f"Inference[{prompt_name}]")
 
-     ]
 
-     if guided_decode_json_schema is not None:
 
-         responses_json = []
 
-         for r in responses:
 
-             if r is not None:
 
-                 try:
 
-                     responses_json.append(json.loads(r, strict=False))
 
-                     continue
 
-                 except json.JSONDecodeError:
 
-                     logger.error(f"Error decoding JSON: {r}")
 
-             responses_json.append(None)
 
-         responses = responses_json
 
-     if not _batch:
 
-         responses = responses[0]
 
-     return responses
 
 
  |