structured_extraction.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. """
  2. Structured data extraction module for processing images with LLMs.
  3. This module provides functionality to extract structured data from images using
  4. local or API-based LLMs. It handles the preparation of requests, batching for
  5. efficient inference, and parsing of responses into structured formats.
  6. """
  7. import json
  8. import logging
  9. import os
  10. from datetime import datetime
  11. from pathlib import Path
  12. from typing import Any, Dict, List, Optional, Tuple, Union
  13. import fire
  14. from json_to_table import flatten_json_to_sql, json_to_csv
  15. from tqdm import tqdm
  16. from typedicts import ArtifactCollection, ExtractedPage, InferenceRequest
  17. from utils import (
  18. config,
  19. export_csvs_to_excel_tabs,
  20. ImageUtils,
  21. InferenceUtils,
  22. JSONUtils,
  23. PDFUtils,
  24. )
  25. # Constants
  26. EXTRACTED_DATA_KEY = "extracted_data"
  27. SUPPORTED_BACKENDS = ["offline-vllm", "openai-compat"]
  28. SUPPORTED_FILE_TYPES = [".pdf"]
  29. def setup_logger(logfile: str, verbose: bool = False) -> logging.Logger:
  30. """
  31. Set up a logger for the application with file and optional console output.
  32. Args:
  33. logfile: Path to the log file
  34. verbose: If True, also log to console
  35. Returns:
  36. Configured logger instance
  37. """
  38. # Create a logger
  39. logger = logging.getLogger(__name__)
  40. logger.setLevel(logging.DEBUG)
  41. # Create a file handler
  42. file_handler = logging.FileHandler(logfile)
  43. file_handler.setLevel(logging.DEBUG)
  44. # Create a formatter and set it for the file handler
  45. formatter = logging.Formatter(
  46. "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
  47. )
  48. file_handler.setFormatter(formatter)
  49. # Add the file handler to the logger
  50. logger.addHandler(file_handler)
  51. # If verbose, also add a console handler
  52. if verbose:
  53. console_handler = logging.StreamHandler()
  54. console_handler.setLevel(logging.DEBUG)
  55. console_handler.setFormatter(formatter)
  56. logger.addHandler(console_handler)
  57. return logger
  58. logger = setup_logger("app.log", verbose=False)
  59. class RequestBuilder:
  60. """Builder for LLM inference requests."""
  61. @staticmethod
  62. def build(
  63. img_path: str,
  64. system_prompt: str,
  65. user_prompt: str,
  66. output_schema: Dict[str, Any],
  67. use_json_decoding: bool = False,
  68. model: Optional[str] = None,
  69. ) -> InferenceRequest:
  70. """
  71. Build an inference request for an image.
  72. Args:
  73. img_path: Path to the image file
  74. system_prompt: System prompt for the LLM
  75. user_prompt: User prompt for the LLM
  76. output_schema: JSON schema for the output
  77. use_json_decoding: Whether to use JSON-guided decoding
  78. model: Optional model override
  79. Returns:
  80. InferenceRequest: Formatted request for the LLM
  81. Raises:
  82. FileNotFoundError: If the image file doesn't exist
  83. """
  84. if not os.path.exists(img_path):
  85. raise FileNotFoundError(f"Image file not found: {img_path}")
  86. img_b64 = ImageUtils.encode_image(img_path)
  87. # Create a copy of the inference config to avoid modifying the original
  88. request_params = dict(config["extraction_inference"])
  89. request_params["messages"] = [
  90. {"role": "system", "content": system_prompt},
  91. {
  92. "role": "user",
  93. "content": [
  94. {
  95. "type": "image_url",
  96. "image_url": {"url": f"data:image/png;base64,{img_b64}"},
  97. },
  98. {"type": "text", "text": user_prompt},
  99. ],
  100. },
  101. ]
  102. if use_json_decoding:
  103. request_params["response_format"] = {
  104. "type": "json_schema",
  105. "json_schema": {"name": "OutputSchema", "schema": output_schema},
  106. }
  107. if model:
  108. request_params["model"] = model
  109. return request_params
  110. class ArtifactExtractor:
  111. """Extractor for document artifacts."""
  112. @staticmethod
  113. def _prepare_inference_requests(
  114. img_path: str, artifact_types: List[str]
  115. ) -> List[Tuple[str, InferenceRequest]]:
  116. """
  117. Prepare inference requests for each artifact type.
  118. Args:
  119. img_path: Path to the image file
  120. artifact_types: Types of artifacts to extract
  121. Returns:
  122. List of tuples containing (artifact_type, inference_request)
  123. """
  124. requests = []
  125. for artifact in artifact_types:
  126. art_config = config["artifacts"].get(artifact)
  127. if not art_config:
  128. logger.warning(f"No configuration found for artifact type: {artifact}")
  129. continue
  130. system_prompt = art_config["prompts"].get("system", "")
  131. user_prompt = art_config["prompts"].get("user", "")
  132. output_schema = art_config.get("output_schema", None)
  133. use_json_decoding = art_config.get("use_json_decoding", False)
  134. if user_prompt and output_schema is not None:
  135. user_prompt = user_prompt.format(schema=json.dumps(output_schema))
  136. request = RequestBuilder.build(
  137. img_path,
  138. system_prompt,
  139. user_prompt,
  140. output_schema,
  141. use_json_decoding,
  142. )
  143. requests.append((artifact, request))
  144. return requests
  145. @staticmethod
  146. def _run_inference(
  147. requests: List[Tuple[str, InferenceRequest]],
  148. ) -> List[Tuple[str, str]]:
  149. """
  150. Run inference for all requests.
  151. Args:
  152. requests: List of tuples containing (artifact_type, inference_request)
  153. Returns:
  154. List of tuples containing (artifact_type, response)
  155. Raises:
  156. ValueError: If the backend is not supported
  157. """
  158. backend = config["model"].get("backend")
  159. if backend not in SUPPORTED_BACKENDS:
  160. raise ValueError(
  161. f"Allowed config.model.backend: {SUPPORTED_BACKENDS}, got unknown value: {backend}"
  162. )
  163. artifact_types = [r[0] for r in requests]
  164. inference_requests = [r[1] for r in requests]
  165. response_batch = []
  166. if backend == "offline-vllm":
  167. request_batch = InferenceUtils.make_vllm_batch(inference_requests)
  168. response_batch = InferenceUtils.run_vllm_inference(request_batch)
  169. elif backend == "openai-compat":
  170. response_batch = [
  171. InferenceUtils.run_openai_inference(request)
  172. for request in inference_requests
  173. ]
  174. return list(zip(artifact_types, response_batch))
  175. @staticmethod
  176. def _process_responses(responses: List[Tuple[str, str]]) -> ArtifactCollection:
  177. """
  178. Process responses into a structured artifact collection.
  179. Args:
  180. responses: List of tuples containing (artifact_type, response)
  181. Returns:
  182. ArtifactCollection: Extracted artifacts
  183. """
  184. extracted = {}
  185. for artifact_type, raw_response in responses:
  186. try:
  187. json_response = JSONUtils.extract_json_from_response(raw_response)
  188. if EXTRACTED_DATA_KEY in json_response:
  189. json_response = json_response[EXTRACTED_DATA_KEY]
  190. extracted.update(json_response)
  191. except Exception as e:
  192. logger.error(f"Failed to process response for {artifact_type}: {e}")
  193. extracted.update({artifact_type: {"error": str(e)}})
  194. return extracted
  195. @staticmethod
  196. def from_image(
  197. img_path: str,
  198. artifact_types: Union[List[str], str],
  199. ) -> ArtifactCollection:
  200. """
  201. Extract artifacts from an image.
  202. Args:
  203. img_path: Path to the image file
  204. artifact_types: Type(s) of artifacts to extract
  205. Returns:
  206. ArtifactCollection: Extracted artifacts
  207. Raises:
  208. ValueError: If the backend is not supported
  209. FileNotFoundError: If the image file doesn't exist
  210. """
  211. if not os.path.exists(img_path):
  212. raise FileNotFoundError(f"Image file not found: {img_path}")
  213. if isinstance(artifact_types, str):
  214. artifact_types = [artifact_types]
  215. # Prepare inference requests
  216. requests = ArtifactExtractor._prepare_inference_requests(
  217. img_path, artifact_types
  218. )
  219. # Run inference
  220. responses = ArtifactExtractor._run_inference(requests)
  221. # Process responses
  222. return ArtifactExtractor._process_responses(responses)
  223. @staticmethod
  224. def from_pdf(pdf_path: str, artifact_types: List[str]) -> List[ExtractedPage]:
  225. """
  226. Extract artifacts from all pages in a PDF.
  227. Args:
  228. pdf_path: Path to the PDF file
  229. artifact_types: Types of artifacts to extract
  230. Returns:
  231. List[ExtractedPage]: Extracted pages with artifacts
  232. Raises:
  233. FileNotFoundError: If the PDF file doesn't exist
  234. """
  235. if not os.path.exists(pdf_path):
  236. raise FileNotFoundError(f"PDF file not found: {pdf_path}")
  237. pdf_pages = PDFUtils.extract_pages(pdf_path)
  238. logger.info(f"Processing {len(pdf_pages)} pages from {pdf_path}")
  239. for page in tqdm(pdf_pages, desc="Processing PDF pages"):
  240. try:
  241. page_artifacts = ArtifactExtractor.from_image(
  242. page["image_path"], artifact_types
  243. )
  244. page_artifacts = json.loads(json.dumps(page_artifacts))
  245. page["artifacts"] = page_artifacts
  246. except Exception as e:
  247. logger.error(
  248. f"Error processing page {page['page_num']} in {pdf_path}: {e}"
  249. )
  250. page["artifacts"] = {"error": f"Error {e} in artifact extraction"}
  251. return pdf_pages
  252. def get_target_files(target_path: str) -> List[Path]:
  253. """
  254. Get list of files to process.
  255. Args:
  256. target_path: Path to a file or directory
  257. Returns:
  258. List of Path objects to process
  259. Raises:
  260. FileNotFoundError: If the target path doesn't exist
  261. ValueError: If the file type is unsupported
  262. """
  263. if not os.path.exists(target_path):
  264. raise FileNotFoundError(f"Target path not found: {target_path}")
  265. target_path_obj = Path(target_path)
  266. if target_path_obj.is_file() and target_path_obj.suffix not in SUPPORTED_FILE_TYPES:
  267. raise ValueError(
  268. f"Unsupported file type: {target_path_obj.suffix}. Supported types: {SUPPORTED_FILE_TYPES}"
  269. )
  270. targets = (
  271. [target_path_obj]
  272. if target_path_obj.is_file()
  273. else [f for f in target_path_obj.iterdir() if f.suffix in SUPPORTED_FILE_TYPES]
  274. )
  275. logger.debug(f"Processing {len(targets)} files")
  276. if not targets:
  277. logger.warning(f"No supported files found in {target_path}")
  278. return targets
  279. def process_files(
  280. targets: List[Path], artifact_types: List[str]
  281. ) -> List[Dict[str, Any]]:
  282. """
  283. Process files and extract artifacts.
  284. Args:
  285. targets: List of files to process
  286. artifact_types: Types of artifacts to extract
  287. Returns:
  288. List of extracted artifacts
  289. """
  290. out_json = []
  291. for target in targets:
  292. try:
  293. artifacts = ArtifactExtractor.from_pdf(str(target), artifact_types)
  294. out_json.extend(artifacts)
  295. except Exception as e:
  296. logger.error(f"Failed to process {target}: {e}")
  297. return out_json
  298. def save_results(
  299. output_dir: Path,
  300. data: List[Dict[str, Any]],
  301. save_to_db: bool = False,
  302. save_tables_as_csv: bool = False,
  303. export_excel: bool = False,
  304. ) -> None:
  305. """
  306. Save extraction results to a file and optionally to SQL and vector databases.
  307. Args:
  308. output_path: Path to save the JSON results
  309. data: Data to save
  310. save_to_sql: Whether to save to SQL database
  311. sql_db_path: Path to the SQLite database file
  312. save_to_vector: Whether to save to vector database
  313. vector_db_path: Path to the vector database
  314. """
  315. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  316. output_dir.mkdir(parents=True, exist_ok=True)
  317. # Save to JSON file
  318. output_path = None
  319. try:
  320. output_path = output_dir / f"artifacts_{timestamp}.json"
  321. json_content = json.dumps(data, indent=2)
  322. output_path.write_text(json_content)
  323. logger.info(f"Extracted artifacts written to {output_path}")
  324. except Exception as e:
  325. logger.error(f"Failed to write output file: {e}")
  326. if save_tables_as_csv or export_excel:
  327. tables_charts = sum([x["artifacts"]["tables"] for x in data], []) + sum(
  328. [x["artifacts"]["charts"] for x in data], []
  329. )
  330. for tab in tables_charts:
  331. # llm: convert each table to a csv string
  332. csv_string, filename = json_to_csv(tab)
  333. outfile = output_dir / f"tables_{timestamp}" / filename
  334. outfile.parent.mkdir(parents=True, exist_ok=True)
  335. outfile.write_text(csv_string)
  336. logger.info(f"Extracted table written to {outfile}")
  337. if export_excel:
  338. output_path = output_dir / f"tables_{timestamp}.xlsx"
  339. export_csvs_to_excel_tabs(output_dir / f"tables_{timestamp}", output_path)
  340. # Save to SQL and vector databases
  341. if save_to_db:
  342. # Get database paths from config
  343. sql_db_path = config.get("database", {}).get("sql_db_path", None)
  344. vector_db_path = config.get("database", {}).get("vector_db_path", None)
  345. assert (
  346. sql_db_path is not None
  347. ), "Save to SQL failed; SQL database path not found in config"
  348. # Save to SQL and optionally to vector database
  349. counts = flatten_json_to_sql(str(output_path), sql_db_path, vector_db_path)
  350. logger.info(
  351. f"Extracted {counts.get('text', 0)} text artifacts, {counts.get('image', 0)} image artifacts, and {counts.get('table', 0)} table artifacts from {len(data)} pages."
  352. )
  353. logger.info(f"Extracted artifacts saved to SQL database: {sql_db_path}")
  354. logger.info(f"Extracted artifacts indexed in vector database: {vector_db_path}")
  355. def main(
  356. target_path: str,
  357. artifacts: str,
  358. save_to_db: bool = False,
  359. save_tables_as_csv: bool = True,
  360. export_excel: bool = False,
  361. ) -> None:
  362. """
  363. Extract structured data from PDF documents using LLM-powered extraction.
  364. Processes PDFs to extract text, tables, images, and charts as structured JSON.
  365. Outputs are saved to timestamped files and optionally to databases.
  366. Args:
  367. target_path: PDF file or directory path to process
  368. artifacts: Comma-separated artifact types (e.g. "text,tables,images,charts")
  369. save_to_db: Save to SQL/vector databases if True
  370. save_tables_as_csv: Export tables as individual CSV files if True
  371. export_excel: Combine all tables into single Excel workbook if True
  372. Output:
  373. - JSON file with all extracted artifacts
  374. - CSV files for each table (if save_tables_as_csv=True)
  375. - Excel workbook with all tables (if export_excel=True)
  376. - Database records (if save_to_db=True)
  377. Raises:
  378. ValueError: Invalid artifact types or unsupported file format
  379. FileNotFoundError: Target path does not exist
  380. """
  381. ALLOWED_ARTIFACTS = list(config["artifacts"].keys())
  382. artifact_types = [x for x in artifacts if x in ALLOWED_ARTIFACTS]
  383. print("Extracting artifacts: ", artifact_types, "\n")
  384. # Get files to process
  385. targets = get_target_files(target_path)
  386. if not targets:
  387. return
  388. # Process files
  389. results = process_files(targets, artifact_types)
  390. # Save results
  391. target_path_obj = Path(target_path)
  392. output_dir = target_path_obj.parent / "extracted"
  393. save_results(
  394. output_dir,
  395. results,
  396. save_to_db=save_to_db,
  397. save_tables_as_csv=save_tables_as_csv,
  398. export_excel=export_excel,
  399. )
  400. if __name__ == "__main__":
  401. fire.Fire(main)