utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. """
  2. Utility functions for structured data extraction.
  3. This module provides helper functions for working with JSON schemas, encoding images,
  4. extracting structured data from LLM responses, and logging.
  5. """
  6. import ast
  7. import base64
  8. import json
  9. import logging
  10. import os
  11. import re
  12. from pathlib import Path
  13. from typing import Any, Dict, List, Optional, Union
  14. import pandas as pd
  15. import pymupdf
  16. import yaml
  17. from openai import OpenAI
  18. from typedicts import InferenceRequest, VLLMInferenceRequest
  19. from vllm import LLM, SamplingParams
  20. from vllm.sampling_params import GuidedDecodingParams
  21. def setup_logger(logfile, verbose=False):
  22. # Create a logger
  23. logger = logging.getLogger(__name__)
  24. logger.setLevel(logging.DEBUG)
  25. # Create a file handler
  26. file_handler = logging.FileHandler(logfile)
  27. file_handler.setLevel(logging.DEBUG)
  28. # Create a formatter and set it for the file handler
  29. formatter = logging.Formatter(
  30. "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  31. )
  32. file_handler.setFormatter(formatter)
  33. # Add the file handler to the logger
  34. logger.addHandler(file_handler)
  35. # If verbose, also add a console handler
  36. if verbose:
  37. console_handler = logging.StreamHandler()
  38. console_handler.setLevel(logging.DEBUG)
  39. console_handler.setFormatter(formatter)
  40. logger.addHandler(console_handler)
  41. return logger
  42. logger = logging.getLogger(__name__)
  43. # Compile regex patterns once for better performance
  44. JSON_BLOCK_OPEN = re.compile(r"```json")
  45. JSON_BLOCK_CLOSE = re.compile(r"}\s+```")
  46. # Configuration management
  47. def load_config(config_path: Optional[str] = None) -> Dict[str, Any]:
  48. """
  49. Load configuration from YAML file.
  50. Args:
  51. config_path: Path to the configuration file. If None, uses default path.
  52. Returns:
  53. Dict containing configuration values
  54. Raises:
  55. FileNotFoundError: If the configuration file doesn't exist
  56. yaml.YAMLError: If the configuration file is invalid
  57. """
  58. if config_path is None:
  59. config_path = os.path.join(os.path.dirname(__file__), "config.yaml")
  60. try:
  61. with open(config_path, "r") as f:
  62. return yaml.safe_load(f)
  63. except FileNotFoundError:
  64. logger.error(f"Configuration file not found: {config_path}")
  65. raise
  66. except yaml.YAMLError as e:
  67. logger.error(f"Invalid YAML in configuration file: {e}")
  68. raise
  69. # Load configuration
  70. config = load_config()
  71. # LLM Singleton
  72. class LLMSingleton:
  73. """Singleton class for managing LLM instances."""
  74. _instance = None
  75. @classmethod
  76. def get_instance(cls) -> LLM:
  77. """
  78. Get or create the LLM instance.
  79. Returns:
  80. LLM: An initialized VLLM model instance
  81. """
  82. if cls._instance is None:
  83. try:
  84. cls._instance = LLM(
  85. config["model"]["path"],
  86. tensor_parallel_size=config["model"]["tensor_parallel_size"],
  87. max_model_len=config["model"]["max_model_len"],
  88. max_num_seqs=config["model"]["max_num_seqs"],
  89. )
  90. logger.info(f"Initialized LLM with model: {config['model']['path']}")
  91. except Exception as e:
  92. logger.error(f"Failed to initialize LLM: {e}")
  93. raise
  94. return cls._instance
  95. class ImageUtils:
  96. """Utility functions for working with images."""
  97. @staticmethod
  98. def encode_image(image_path: Union[Path, str]) -> str:
  99. """
  100. Encode an image to base64.
  101. Args:
  102. image_path: Path to the image file
  103. Returns:
  104. Base64-encoded string representation of the image
  105. Raises:
  106. FileNotFoundError: If the image file doesn't exist
  107. """
  108. if isinstance(image_path, str):
  109. image_path = Path(image_path)
  110. try:
  111. return base64.b64encode(image_path.read_bytes()).decode("utf-8")
  112. except FileNotFoundError:
  113. logger.error(f"Image file not found: {image_path}")
  114. raise
  115. class JSONUtils:
  116. """Utility functions for working with JSON data."""
  117. @staticmethod
  118. def extract_json_blocks(content: str) -> List[str]:
  119. """
  120. Extract JSON code blocks from markdown-formatted text.
  121. Parses a string containing markdown-formatted text and extracts all JSON blocks
  122. that are enclosed in ```json ... ``` code blocks. This is useful for extracting
  123. structured data from LLM responses.
  124. Args:
  125. content: The markdown-formatted text containing JSON code blocks
  126. Returns:
  127. List[str]: A list of extracted JSON strings (without the markdown delimiters)
  128. """
  129. blocs_ix = []
  130. str_ptr = 0
  131. while str_ptr < len(content):
  132. start_ix = content.find("```json", str_ptr)
  133. if start_ix == -1:
  134. break
  135. start_ix += len("```json")
  136. end_match = JSON_BLOCK_CLOSE.search(content[start_ix:])
  137. if end_match:
  138. end_ix = start_ix + end_match.start() + 1
  139. else:
  140. end_ix = len(content) # no closing tag, take the rest of the string
  141. blocs_ix.append((start_ix, end_ix))
  142. str_ptr = end_ix + 1
  143. return [content[ix[0] : ix[1]].strip() for ix in blocs_ix]
  144. @staticmethod
  145. def load_json_from_str(json_str: str) -> Dict[str, Any]:
  146. """
  147. Parse a JSON string into a Python dictionary.
  148. Attempts to parse a string as JSON using multiple methods. First tries standard
  149. json.loads(), then falls back to ast.literal_eval() if that fails. This provides
  150. more robust JSON parsing for LLM outputs that might not be perfectly formatted.
  151. Args:
  152. json_str: The JSON string to parse
  153. Returns:
  154. Dict[str, Any]: The parsed JSON as a dictionary
  155. Raises:
  156. ValueError: If parsing fails
  157. """
  158. if not isinstance(json_str, str):
  159. return json_str
  160. try:
  161. return json.loads(json_str)
  162. except json.decoder.JSONDecodeError:
  163. # Try with None replacement
  164. json_str = json_str.replace("null", "None")
  165. try:
  166. return ast.literal_eval(json_str)
  167. except:
  168. raise ValueError(f"Failed to load valid JSON from string: {json_str}")
  169. @staticmethod
  170. def extract_json_from_response(content: str) -> Dict[str, Any]:
  171. """
  172. Extract and parse JSON from an LLM response.
  173. Processes a response from an LLM that may contain JSON in a markdown code block.
  174. First checks if the response contains markdown-formatted JSON blocks and extracts them,
  175. then parses the JSON string into a Python dictionary.
  176. Args:
  177. content: The LLM response text that may contain JSON
  178. Returns:
  179. Dict[str, Any]: The parsed JSON as a dictionary
  180. Raises:
  181. ValueError: If extraction or parsing fails
  182. """
  183. try:
  184. if "```json" in content:
  185. json_blocks = JSONUtils.extract_json_blocks(content)
  186. if not json_blocks:
  187. raise ValueError("No JSON blocks found in response")
  188. content = json_blocks[-1]
  189. return JSONUtils.load_json_from_str(content)
  190. except Exception as e:
  191. raise ValueError(f"Failed to extract JSON from response: {str(e)}")
  192. @staticmethod
  193. def make_all_fields_required(schema: Dict[str, Any]) -> None:
  194. """
  195. Make all fields in a JSON schema required.
  196. Recursively modifies the JSON schema in-place, so that every property in each 'properties'
  197. is added to the 'required' list at that schema level. This ensures that the LLM will
  198. attempt to extract all fields defined in the schema.
  199. Args:
  200. schema: The JSON schema to modify
  201. """
  202. def _process_schema_node(subschema):
  203. """Process a single node in the schema."""
  204. if not isinstance(subschema, dict):
  205. return
  206. schema_type = subschema.get("type")
  207. if schema_type == "object" or (
  208. isinstance(schema_type, list) and "object" in schema_type
  209. ):
  210. props = subschema.get("properties")
  211. if isinstance(props, dict):
  212. subschema["required"] = list(props.keys())
  213. # Recurse into sub-schemas
  214. for key in ("properties", "definitions", "patternProperties"):
  215. children = subschema.get(key)
  216. if isinstance(children, dict):
  217. for v in children.values():
  218. _process_schema_node(v)
  219. # Recurse into schema arrays
  220. for key in ("allOf", "anyOf", "oneOf"):
  221. children = subschema.get(key)
  222. if isinstance(children, list):
  223. for v in children:
  224. _process_schema_node(v)
  225. # 'items' can be a schema or list of schemas
  226. items = subschema.get("items")
  227. if isinstance(items, dict):
  228. _process_schema_node(items)
  229. elif isinstance(items, list):
  230. for v in items:
  231. _process_schema_node(v)
  232. # Extras: 'not', 'if', 'then', 'else'
  233. for key in ["not", "if", "then", "else"]:
  234. sub = subschema.get(key)
  235. if isinstance(sub, dict):
  236. _process_schema_node(sub)
  237. _process_schema_node(schema)
  238. class PDFUtils:
  239. """Utility functions for working with PDF files."""
  240. @staticmethod
  241. def extract_pages(
  242. pdf_path: Union[str, Path], output_dir: Union[str, Path] = None
  243. ) -> List[Dict[str, Any]]:
  244. """
  245. Extract pages from a PDF file as images to disk.
  246. Args:
  247. pdf_path: Path to the PDF file
  248. output_dir: Directory to save extracted images (defaults to /tmp/pdf_images)
  249. Returns:
  250. List of dictionaries containing doc_path, image_path, and page_num
  251. Raises:
  252. FileNotFoundError: If the PDF file doesn't exist
  253. """
  254. if isinstance(pdf_path, str):
  255. pdf_path = Path(pdf_path)
  256. if not pdf_path.exists():
  257. logger.error(f"PDF file not found: {pdf_path}")
  258. raise FileNotFoundError(f"PDF file not found: {pdf_path}")
  259. stem = pdf_path.stem
  260. if output_dir is None:
  261. output_dir = Path("/tmp/pdf_images")
  262. elif isinstance(output_dir, str):
  263. output_dir = Path(output_dir)
  264. output_dir.mkdir(exist_ok=True, parents=True)
  265. pages = []
  266. try:
  267. pdf_document = pymupdf.open(pdf_path)
  268. for page_num, page in enumerate(pdf_document):
  269. image_path = output_dir / f"{stem}_{page_num}.png"
  270. pix = page.get_pixmap(dpi=100)
  271. pix.save(str(image_path))
  272. pages.append(
  273. {
  274. "doc_path": str(pdf_path),
  275. "image_path": str(image_path),
  276. "page_num": page_num,
  277. }
  278. )
  279. return pages
  280. except Exception as e:
  281. logger.error(f"Failed to extract pages from PDF: {e}")
  282. raise
  283. class InferenceUtils:
  284. """Utility functions for running inference with LLMs."""
  285. @staticmethod
  286. def get_offline_llm() -> LLM:
  287. """
  288. Initialize and return a local LLM instance using the singleton pattern.
  289. Returns:
  290. LLM: An initialized VLLM model instance
  291. """
  292. return LLMSingleton.get_instance()
  293. @staticmethod
  294. def make_vllm_batch(
  295. request_params_batch: Union[InferenceRequest, List[InferenceRequest]],
  296. ) -> VLLMInferenceRequest:
  297. """
  298. Convert one or more inference requests to VLLM batch format.
  299. Args:
  300. request_params_batch: Single request parameters or a list of request parameters
  301. Returns:
  302. VLLMInferenceRequest: Formatted request for VLLM
  303. """
  304. if isinstance(request_params_batch, dict):
  305. request_params_batch = [request_params_batch]
  306. sampling_params = []
  307. messages = []
  308. for req in request_params_batch:
  309. params = {
  310. "top_p": req["top_p"],
  311. "temperature": req["temperature"],
  312. "max_tokens": req["max_completion_tokens"],
  313. "seed": req["seed"],
  314. }
  315. if "response_format" in req:
  316. gd_params = GuidedDecodingParams(
  317. json=req["response_format"]["json_schema"]["schema"]
  318. )
  319. sampling_params.append(
  320. SamplingParams(guided_decoding=gd_params, **params)
  321. )
  322. else:
  323. sampling_params.append(SamplingParams(**params))
  324. messages.append(req["messages"])
  325. return {"messages": messages, "sampling_params": sampling_params}
  326. @staticmethod
  327. def run_vllm_inference(
  328. vllm_request: VLLMInferenceRequest,
  329. ) -> List[str]:
  330. """
  331. Run inference on a batch of requests using the local LLM.
  332. This function processes one or more requests through the local LLM,
  333. handling the conversion to VLLM format and extracting the raw text
  334. responses.
  335. Args:
  336. vllm_request: Formatted request for VLLM
  337. Returns:
  338. List[str]: Raw text responses from the LLM for each request in the batch
  339. """
  340. try:
  341. local_llm = InferenceUtils.get_offline_llm()
  342. out = local_llm.chat(
  343. vllm_request["messages"], vllm_request["sampling_params"], use_tqdm=True
  344. )
  345. raw_responses = [r.outputs[0].text for r in out]
  346. return raw_responses
  347. except Exception as e:
  348. logger.error(f"VLLM inference failed: {e}")
  349. raise
  350. @staticmethod
  351. def run_openai_inference(request: InferenceRequest) -> str:
  352. """
  353. Run inference using OpenAI-compatible API.
  354. Args:
  355. request: Inference request parameters
  356. Returns:
  357. str: Model response text
  358. """
  359. try:
  360. client = OpenAI(
  361. base_url=config["model"]["base_url"], api_key=config["model"]["api_key"]
  362. )
  363. model_id = config["model"]["model_id"] or client.models.list().data[0].id
  364. r = client.chat.completions.create(model=model_id, **request)
  365. return r.choices[0].message.content
  366. except Exception as e:
  367. logger.error(f"OpenAI inference failed: {e}")
  368. raise
  369. @staticmethod
  370. def request_builder(
  371. user_prompt: str,
  372. system_prompt: str = None,
  373. img_path: str = None,
  374. use_json_decoding: bool = False,
  375. output_schema: Dict[str, Any] = None,
  376. **kwargs,
  377. ) -> InferenceRequest:
  378. request = kwargs
  379. msgs = []
  380. if system_prompt:
  381. msgs.append({"role": "system", "content": system_prompt})
  382. user_content = []
  383. if img_path:
  384. if not os.path.exists(img_path):
  385. raise FileNotFoundError(f"Image file not found: {img_path}")
  386. img_b64 = ImageUtils.encode_image(img_path)
  387. user_content.append(
  388. {
  389. "type": "image_url",
  390. "image_url": {"url": f"data:image/png;base64,{img_b64}"},
  391. }
  392. )
  393. user_content.append({"type": "text", "text": user_prompt})
  394. msgs.append({"role": "user", "content": user_content})
  395. request["messages"] = msgs
  396. if use_json_decoding:
  397. request["response_format"] = {
  398. "type": "json_schema",
  399. "json_schema": {"name": "OutputSchema", "schema": output_schema},
  400. }
  401. return request
  402. def export_csvs_to_excel_tabs(csv_folder_path, output_excel_path):
  403. """
  404. Exports multiple CSV files from a specified folder into a single Excel
  405. workbook, with each CSV appearing as a separate tab (sheet).
  406. Args:
  407. csv_folder_path (str): The path to the folder containing the CSV files.
  408. output_excel_path (str): The desired path for the output Excel file.
  409. """
  410. try:
  411. # Create an ExcelWriter object
  412. with pd.ExcelWriter(output_excel_path, engine="xlsxwriter") as writer:
  413. # Iterate through all files in the specified folder
  414. for filename in os.listdir(csv_folder_path):
  415. if filename.endswith(".csv"):
  416. csv_file_path = os.path.join(csv_folder_path, filename)
  417. sheet_name = os.path.splitext(filename)[0][:31]
  418. # Read the CSV file into a pandas DataFrame
  419. df = pd.read_csv(csv_file_path)
  420. # Write the DataFrame to a new sheet in the Excel file
  421. df.to_excel(writer, sheet_name=sheet_name, index=False)
  422. print(f"Successfully exported CSV files to '{output_excel_path}'")
  423. except Exception as e:
  424. print(f"An error occurred: {e}")