| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166 | import logging from typing import Any, Dict, List, Optional, Unionimport yamlimport timeimport jsonfrom tqdm import tqdmfrom openai import OpenAIimport groqlogger = logging.getLogger(__name__)logger.addHandler(logging.StreamHandler())CFG = yaml.safe_load(open("config.yaml", "r"))class LlamaVLLM():    def __init__(self, endpoint, model_id):        self.model_id = model_id        self.client = OpenAI(base_url=endpoint, api_key='token')    def chat(        self,        inputs: List[Dict[str, str]],        generation_kwargs: Optional[Dict[str, Any]] = None,        guided_decode_json_schema: Optional[str] = None    ) -> List[str]:        if generation_kwargs is None:            generation_kwargs = {}                    try:            response = self.client.chat.completions.create(                model=self.model,                messages=inputs,                extra_body={                    "guided_json": guided_decode_json_schema                },                **generation_kwargs,            )            output = response.choices[0].message        except Exception as e:            logger.error(                f"FAILED to generate inference for input {inputs}\nError: {str(e)}"            )            output = None        return output    class LlamaGroq():    def __init__(self, key, model_id):        self.model_id = model_id        self.client = groq.Groq(api_key=key)        logger.debug(f"Using Groq:{self.model_id} for inference")    def chat(        self,         inputs: List[Dict[str, str]],         generation_kwargs: Optional[Dict[str, Any]] = None,        guided_decode_json_schema: Optional[str] = None    ) -> str:                if generation_kwargs is None:            generation_kwargs = {}                    # Currently Groq doesn't support guided JSON decoding. Workaround:        if guided_decode_json_schema is not None:            inputs[0]['content'] += f"\n\nEnsure your response aligns with the following JSON schema:\n{guided_decode_json_schema}\n\n"                output = None                while True:            try:                response = self.client.chat.completions.with_raw_response.create(                    model=self.model_id,                    messages=inputs,                    stream=False,                    **generation_kwargs,                    response_format={"type": 'json_object' if guided_decode_json_schema is not None else 'text'}                )                completion = response.parse()                output = completion.choices[0].message.content                break            except groq.RateLimitError as e:                wait = e.response.headers['X-Ratelimit-Reset']                response = e.response                print(e)                print(f"[groq] waiting for {wait} to prevent ratelimiting")                time.sleep(wait)            except Exception as e:                logger.error(f"INFERENCE FAILED with Error: {e.response.status_code} for input:\n{inputs[-1]['content'][:300]}")                break        return outputdef run_llm_inference(    prompt_name: str,    inputs: Union[str, List[str]],    generation_kwargs: Optional[Dict] = None,    guided_decode_json_schema=None,) -> Union[List[str], List[Dict[str, Any]]]:    """    Run the LLM inference on the given inputs.    Args:    - prompt_name (str): The name of the prompt to use.    - inputs (str or List[str]): The input(s) to the LLM.    - generation_kwargs (Dict): Additional keyword arguments to pass to the LLM.    - guided_decode_json_schema (str): The JSON schema to use for guided decoding.    Returns:    - Union[str, List[str]]: The response(s) from the LLM.    """        # initialize appropriate LLM accessor    if CFG['model']['use'] == 'vllm':        LLM = LlamaVLLM(**CFG['model']['vllm'])    elif CFG['model']['use'] == 'groq':        LLM = LlamaGroq(**CFG['model']['groq'])    else:        raise ValueError("Invalid model type in config.yaml")        logger.debug(f"Running `{prompt_name}` inference with {CFG['model']['use']}")        _batch = True    if isinstance(inputs, str):        _batch = False        inputs = [inputs]    inputs = [        [            {"role": "system", "content": CFG["prompts"][prompt_name]["system"]},            {"role": "user", "content": i},        ]        for i in inputs    ]    if (        guided_decode_json_schema is None        and "json_schema" in CFG["prompts"][prompt_name]    ):        guided_decode_json_schema = " ".join(            CFG["prompts"][prompt_name]["json_schema"].split()        )    responses = [        LLM.chat(i, generation_kwargs, guided_decode_json_schema)         for i in tqdm(inputs, desc=f"Inference[{prompt_name}]")    ]    if guided_decode_json_schema is not None:        responses_json = []        for r in responses:            if r is not None:                try:                    responses_json.append(json.loads(r, strict=False))                    continue                except json.JSONDecodeError:                    logger.error(f"Error decoding JSON: {r}")            responses_json.append(None)        responses = responses_json    if not _batch:        responses = responses[0]    return responses
 |