evaluate.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. #!/usr/bin/env python3
  2. """
  3. Script to evaluate a vision-language model on the W2 tax form dataset using compatible API client.
  4. Leverages the OpenAI-compatible SDK for various endpoints, like vLLM server, Llama API, or any compatible API.
  5. Support batch processing.
  6. Loads images from the provided dataset, sends them to the compatible API server,
  7. and compares with the expected output.
  8. """
  9. import argparse
  10. import base64
  11. import json
  12. import logging
  13. import os
  14. import pathlib
  15. import re
  16. import time
  17. import traceback
  18. from concurrent.futures import as_completed, ThreadPoolExecutor
  19. from datetime import datetime
  20. from pathlib import Path
  21. from typing import Any, Dict, List, Optional, Tuple
  22. from datasets import load_dataset, load_from_disk
  23. from openai import OpenAI
  24. from PIL import Image
  25. from pydantic import BaseModel
  26. from tqdm import tqdm
  27. # Set up logging
  28. logging.basicConfig(
  29. level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
  30. )
  31. logger = logging.getLogger(__name__)
  32. class W2Form(BaseModel):
  33. box_b_employer_identification_number: str
  34. box_c_employer_name: str
  35. box_c_employer_street_address: str
  36. box_c_employer_city_state_zip: str
  37. box_a_employee_ssn: str
  38. box_e_employee_name: str
  39. box_e_employee_street_address: str
  40. box_e_employee_city_state_zip: str
  41. box_d_control_number: int
  42. box_1_wages: float
  43. box_2_federal_tax_withheld: float
  44. box_3_social_security_wages: float
  45. box_4_social_security_tax_withheld: float
  46. box_5_medicare_wages: float
  47. box_6_medicare_wages_tax_withheld: float
  48. box_7_social_security_tips: float
  49. box_8_allocated_tips: float
  50. box_9_advance_eic_payment: Optional[str]
  51. box_10_dependent_care_benefits: float
  52. box_11_nonqualified_plans: float
  53. box_12a_code: str
  54. box_12a_value: float
  55. box_12b_code: str
  56. box_12b_value: float
  57. box_12c_code: str
  58. box_12c_value: float
  59. box_12d_code: Optional[str]
  60. box_12d_value: float
  61. box_13_statutary_employee: Optional[str]
  62. box_13_retirement_plan: Optional[str]
  63. box_13_third_part_sick_pay: Optional[str]
  64. box_15_1_state: str
  65. box_15_1_employee_state_id: str
  66. box_16_1_state_wages: float
  67. box_17_1_state_income_tax: float
  68. box_18_1_local_wages: float
  69. box_19_1_local_income_tax: float
  70. box_20_1_locality: str
  71. box_15_2_state: str
  72. box_15_2_employee_state_id: str
  73. box_16_2_state_wages: float
  74. box_17_2_state_income_tax: float
  75. box_18_2_local_wages: float
  76. box_19_2_local_income_tax: float
  77. box_20_2_locality: str
  78. # ----------- Utilities -----------
  79. def encode_image_to_base64(image_path: str) -> str:
  80. """Encode image to base64 string."""
  81. with open(image_path, "rb") as f:
  82. return base64.b64encode(f.read()).decode()
  83. def create_messages(prompt: str, image_path: str) -> List[Dict]:
  84. """Create messages array for API client call."""
  85. content = [
  86. {"type": "text", "text": prompt},
  87. {
  88. "type": "image_url",
  89. "image_url": {
  90. "url": f"data:image/png;base64,{encode_image_to_base64(image_path)}"
  91. },
  92. },
  93. ]
  94. return [{"role": "user", "content": content}]
  95. def clean_json_string(json_str: str) -> str:
  96. """
  97. Clean common JSON formatting issues from LLM responses.
  98. Args:
  99. json_str: Raw JSON string that may contain formatting issues
  100. Returns:
  101. Cleaned JSON string
  102. """
  103. # Remove markdown code block markers
  104. json_str = re.sub(r"```(?:json)?\s*", "", json_str)
  105. json_str = re.sub(r"\s*```", "", json_str)
  106. # Fix malformed string patterns like: "field": ",\n" ,
  107. # This handles the specific error case where strings are malformed with newlines
  108. json_str = re.sub(r':\s*",\s*"\s*,', ': "",', json_str)
  109. # Fix incomplete string literals with control characters
  110. # Pattern: "field": "partial_value\nrest_of_value",
  111. json_str = re.sub(r':\s*"([^"]*)\n([^"]*)",', r': "\1\2",', json_str)
  112. # Fix the specific pattern from the error: "field": "value\n" followed by whitespace and comma
  113. json_str = re.sub(r':\s*"([^"]*)\n"\s*,', r': "\1",', json_str)
  114. # Remove trailing commas in objects and arrays
  115. json_str = re.sub(r",(\s*[}\]])", r"\1", json_str)
  116. # Fix missing quotes around keys (sometimes LLMs output unquoted keys)
  117. json_str = re.sub(r"([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:", r'\1"\2":', json_str)
  118. # Fix single quotes to double quotes (JSON requires double quotes)
  119. json_str = re.sub(r"'([^']*)'", r'"\1"', json_str)
  120. # Remove control characters that are not allowed in JSON strings
  121. # Keep only printable ASCII and basic whitespace
  122. json_str = "".join(char for char in json_str if ord(char) >= 32 or char in "\t\r ")
  123. # Fix null-like values that should be proper JSON null
  124. json_str = re.sub(r":\s*None\s*,", ": null,", json_str, flags=re.IGNORECASE)
  125. json_str = re.sub(r":\s*undefined\s*,", ": null,", json_str, flags=re.IGNORECASE)
  126. return json_str
  127. def extract_json_from_response(response: str) -> Tuple[Dict[str, Any], bool]:
  128. """
  129. Robust JSON extraction from LLM responses with comprehensive error handling.
  130. Args:
  131. response: Raw response text from LLM
  132. Returns:
  133. Tuple of (extracted_json_dict, has_error)
  134. """
  135. if not response or not response.strip():
  136. logger.warning("Empty response provided")
  137. return {}, True
  138. # Strategy 1: Look for JSON content between triple backticks
  139. json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", response, re.DOTALL)
  140. if json_match:
  141. json_str = json_match.group(1)
  142. else:
  143. # Strategy 2: Look for JSON object pattern (handle nested braces)
  144. json_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", response, re.DOTALL)
  145. if json_match:
  146. json_str = json_match.group(0)
  147. else:
  148. # Strategy 3: Find content between first { and last }
  149. start_idx = response.find("{")
  150. end_idx = response.rfind("}")
  151. if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
  152. json_str = response[start_idx : end_idx + 1]
  153. else:
  154. logger.warning("No JSON pattern found in response")
  155. logger.debug(f"Response snippet: {response[:200]}...")
  156. return {}, True
  157. # Clean the extracted JSON string
  158. original_json_str = json_str
  159. json_str = clean_json_string(json_str)
  160. # Attempt to parse with multiple strategies
  161. parsing_strategies = [
  162. ("direct", lambda s: json.loads(s)),
  163. ("strip_whitespace", lambda s: json.loads(s.strip())),
  164. (
  165. "fix_escapes",
  166. lambda s: json.loads(s.replace("\\\\", "\\").replace('\\"', '"')),
  167. ),
  168. ]
  169. for strategy_name, parse_func in parsing_strategies:
  170. try:
  171. parsed_json = parse_func(json_str)
  172. # Validate that it's a dictionary (expected for most use cases)
  173. if not isinstance(parsed_json, dict):
  174. logger.warning(
  175. f"Extracted JSON is not a dictionary: {type(parsed_json)}"
  176. )
  177. continue
  178. logger.debug(f"Successfully parsed JSON using strategy: {strategy_name}")
  179. return parsed_json, False
  180. except json.JSONDecodeError as e:
  181. logger.debug(f"Strategy '{strategy_name}' failed: {e}")
  182. continue
  183. except Exception as e:
  184. logger.debug(f"Unexpected error in strategy '{strategy_name}': {e}")
  185. continue
  186. # If all strategies fail, log details for debugging
  187. logger.error("All JSON parsing strategies failed")
  188. logger.debug(f"Original JSON string (first 500 chars): {original_json_str[:500]}")
  189. logger.debug(f"Cleaned JSON string (first 500 chars): {json_str[:500]}")
  190. return {}, True
  191. def generate_prompt(structured=True) -> str:
  192. """Generate prompt for the model."""
  193. json_schema = W2Form.model_json_schema()
  194. prompt = (
  195. "You are an expert document information extraction system. "
  196. "I will show you an image of a W-2 tax form. "
  197. "Please extract all the information from this form and return it in a JSON format. "
  198. "Include all fields such as employee details, employer details, wages, federal income tax withheld, "
  199. "social security wages, social security tax withheld, medicare wages and tips, medicare tax withheld, "
  200. "and any other information present on the form. "
  201. )
  202. if not structured:
  203. prompt += f"Return ONLY the JSON output without any additional text or explanations following this schema {json_schema}"
  204. return prompt
  205. def call_api_client(
  206. client: OpenAI,
  207. messages: List[Dict],
  208. model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
  209. temperature: float = 0.0,
  210. max_tokens: int = 8192,
  211. response_format: Optional[Dict] = None,
  212. timeout: int = 300,
  213. seed: Optional[int] = 42,
  214. ):
  215. """
  216. Call compatible API server using OpenAI-compatible client.
  217. """
  218. try:
  219. kwargs = {
  220. "model": model,
  221. "messages": messages,
  222. "temperature": temperature,
  223. "max_tokens": max_tokens,
  224. "timeout": timeout,
  225. }
  226. # Add seed if provided for reproducible generation
  227. if seed is not None:
  228. kwargs["seed"] = seed
  229. # Add response format if structured output is enabled
  230. if response_format:
  231. kwargs["response_format"] = response_format
  232. logger.debug(f"Making API client call with model: {model}")
  233. response = client.chat.completions.create(**kwargs)
  234. logger.debug(f"Received response with {len(response.choices)} choices")
  235. return response
  236. except Exception as e:
  237. logger.error(f"API client call failed: {e}")
  238. raise
  239. def process_single_sample(
  240. client: OpenAI,
  241. sample_data: Tuple[int, Dict],
  242. output_dir: str,
  243. model: str,
  244. structured: bool,
  245. timeout: int,
  246. ) -> Dict[str, Any]:
  247. """Process a single sample using OpenAI SDK."""
  248. idx, sample = sample_data
  249. try:
  250. # Get image
  251. image = sample["image"]
  252. # Save image temporarily
  253. image_path = get_image_path(image, output_dir, idx)
  254. logger.debug(f"Saved image to {image_path}")
  255. # Generate prompt and messages
  256. prompt = generate_prompt(structured)
  257. messages = create_messages(prompt, image_path)
  258. # Prepare response format for structured output
  259. response_format = None
  260. if structured:
  261. json_schema = W2Form.model_json_schema()
  262. response_format = {
  263. "type": "json_schema",
  264. "json_schema": {
  265. "name": "W2Form",
  266. "schema": json_schema,
  267. "strict": True,
  268. },
  269. }
  270. # Call API client
  271. start_time = time.time()
  272. try:
  273. response = call_api_client(
  274. client=client,
  275. messages=messages,
  276. model=model,
  277. response_format=response_format,
  278. timeout=timeout,
  279. )
  280. content = response.choices[0].message.content
  281. usage = response.usage.model_dump() if response.usage else {}
  282. except Exception as e:
  283. logger.error(f"Error calling SDK for sample {idx}: {e}")
  284. content = ""
  285. usage = {}
  286. processing_time = time.time() - start_time
  287. # Extract JSON from response
  288. extracted_json, json_parsing_error = extract_json_from_response(content)
  289. # Get ground truth
  290. ground_truth_raw = json.loads(sample["ground_truth"])
  291. # Handle the gt_parse wrapper structure if present
  292. if "gt_parse" in ground_truth_raw:
  293. ground_truth = ground_truth_raw["gt_parse"]
  294. else:
  295. ground_truth = ground_truth_raw
  296. # Normalize for comparison
  297. normalized_pred = normalize_json(extracted_json)
  298. normalized_gt = normalize_json(ground_truth)
  299. # Save results
  300. result = {
  301. "sample_id": idx,
  302. "prediction": extracted_json,
  303. "ground_truth": ground_truth,
  304. "normalized_prediction": normalized_pred,
  305. "normalized_gt": normalized_gt,
  306. "raw_response": content,
  307. "processing_time": processing_time,
  308. "json_parsing_error": json_parsing_error,
  309. "usage": usage,
  310. }
  311. return result
  312. except Exception as e:
  313. traceback_str = traceback.format_exc()
  314. logger.error(f"Error processing sample {idx}: {str(e)} at line {traceback_str}")
  315. return {
  316. "sample_id": idx,
  317. "prediction": {},
  318. "ground_truth": {},
  319. "normalized_prediction": {},
  320. "normalized_gt": {},
  321. "raw_response": "",
  322. "processing_time": 0.0,
  323. "json_parsing_error": True,
  324. "usage": {},
  325. "error": str(e),
  326. }
  327. def calculate_metrics(results: List[Dict]) -> Dict[str, Any]:
  328. """Calculate accuracy metrics for the predictions."""
  329. if not results:
  330. logger.error("No results provided")
  331. return {"accuracy": 0.0, "field_accuracy": {}}
  332. # Initialize metrics
  333. total_fields = 0
  334. correct_fields = 0
  335. parse_errors = 0
  336. total_records = len(results)
  337. logger.info(f"Total records: {total_records}")
  338. field_counts = {}
  339. field_correct = {}
  340. for result in results:
  341. pred, gt = result["prediction"], result["ground_truth"]
  342. if result["json_parsing_error"]:
  343. parse_errors += 1
  344. total_fields += len(gt)
  345. continue
  346. for field in gt.keys():
  347. # Count total occurrences of this field
  348. field_counts[field] = field_counts.get(field, 0) + 1
  349. total_fields += 1
  350. # Check if field is correct
  351. if field in pred and pred[field] == gt[field]:
  352. correct_fields += 1
  353. field_correct[field] = field_correct.get(field, 0) + 1
  354. # Calculate overall accuracy
  355. accuracy = correct_fields / total_fields if total_fields > 0 else 0.0
  356. errors = parse_errors / total_records if total_records > 0 else 0.0
  357. # Calculate per-field accuracy
  358. field_accuracy = {}
  359. for field in field_counts:
  360. field_accuracy[field] = field_correct.get(field, 0) / field_counts[field]
  361. return {
  362. "accuracy": accuracy,
  363. "field_accuracy": field_accuracy,
  364. "parse_error": errors,
  365. }
  366. def normalize_field_value(value: Any) -> str:
  367. """Normalize field values for comparison."""
  368. if value is None:
  369. return ""
  370. # Convert to string and normalize
  371. value_str = str(value).strip().lower()
  372. # Remove common separators in numbers
  373. value_str = value_str.replace(",", "").replace(" ", "")
  374. # Try to convert to float for numeric comparison
  375. try:
  376. value_float = float(value_str)
  377. return str(value_float)
  378. except ValueError:
  379. return value_str
  380. def normalize_json(json_obj: Dict) -> Dict:
  381. """Normalize JSON object for comparison."""
  382. normalized = {}
  383. for key, value in json_obj.items():
  384. # Normalize key (lowercase, remove spaces)
  385. norm_key = key.lower().replace(" ", "_")
  386. # Normalize value
  387. if isinstance(value, dict):
  388. normalized[norm_key] = normalize_json(value)
  389. elif isinstance(value, list):
  390. normalized[norm_key] = [normalize_field_value(v) for v in value]
  391. else:
  392. normalized[norm_key] = normalize_field_value(value)
  393. return normalized
  394. def get_image_path(image: Image.Image, output_dir: str, idx: int) -> str:
  395. """Get the path to save the image."""
  396. # Create a temporary file for the image
  397. temp_dir = pathlib.Path(output_dir) / "temp"
  398. os.makedirs(temp_dir, exist_ok=True)
  399. image_path = temp_dir / f"temp_{idx}.png"
  400. image_path = str(image_path.resolve())
  401. image.save(image_path)
  402. return image_path
  403. def vllm_openai_sdk_evaluation(
  404. test_set,
  405. output_dir: str,
  406. server_url: str = "http://localhost:8001",
  407. api_key: str = "default-blank-localhost",
  408. model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
  409. structured: bool = True,
  410. timeout: int = 300,
  411. max_workers: int = 10,
  412. ):
  413. """
  414. Evaluate the W2 extraction task using OpenAI SDK with batch processing.
  415. """
  416. # Initialize OpenAI client
  417. client = OpenAI(
  418. api_key=api_key, # vLLM doesn't require a real API key
  419. base_url=f"{server_url}",
  420. )
  421. # Prepare sample data for batch processing
  422. sample_data = [(idx, sample) for idx, sample in enumerate(test_set)]
  423. results = []
  424. # Use ThreadPoolExecutor for concurrent processing
  425. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  426. # Submit all tasks
  427. future_to_sample = {
  428. executor.submit(
  429. process_single_sample,
  430. client,
  431. data,
  432. output_dir,
  433. model,
  434. structured,
  435. timeout,
  436. ): data[0]
  437. for data in sample_data
  438. }
  439. # Collect results with progress bar
  440. for future in tqdm(
  441. as_completed(future_to_sample),
  442. total=len(sample_data),
  443. desc="Processing samples (batch)",
  444. ):
  445. sample_idx = future_to_sample[future]
  446. try:
  447. result = future.result()
  448. results.append(result)
  449. except Exception as e:
  450. logger.error(f"Exception in sample {sample_idx}: {e}")
  451. # Add error result
  452. results.append(
  453. {
  454. "sample_id": sample_idx,
  455. "prediction": {},
  456. "ground_truth": {},
  457. "normalized_prediction": {},
  458. "normalized_gt": {},
  459. "raw_response": "",
  460. "processing_time": 0.0,
  461. "json_parsing_error": True,
  462. "usage": {},
  463. "error": str(e),
  464. }
  465. )
  466. # Sort results by sample_id to maintain order
  467. results.sort(key=lambda x: x["sample_id"])
  468. return results
  469. def vllm_openai_sdk_sequential_evaluation(
  470. test_set,
  471. output_dir: str,
  472. server_url: str = "http://localhost:8001",
  473. api_key: str = "default-blank-localhost",
  474. model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
  475. structured: bool = True,
  476. timeout: int = 300,
  477. ):
  478. """
  479. Evaluate the W2 extraction task using OpenAI SDK sequentially (for debugging).
  480. """
  481. # Initialize OpenAI client
  482. client = OpenAI(
  483. api_key=api_key, # vLLM doesn't require a real API key
  484. base_url=f"{server_url}",
  485. )
  486. results = []
  487. for idx, sample in enumerate(
  488. tqdm(test_set, desc="Processing samples (sequential)")
  489. ):
  490. result = process_single_sample(
  491. client, (idx, sample), output_dir, model, structured, timeout
  492. )
  493. results.append(result)
  494. return results
  495. def main():
  496. parser = argparse.ArgumentParser(
  497. description="Evaluate vision-language model on W2 tax form dataset"
  498. )
  499. parser.add_argument(
  500. "--server_url",
  501. type=str,
  502. default="http://localhost:8001",
  503. help="URL of the vLLM HTTP server",
  504. )
  505. parser.add_argument(
  506. "--model",
  507. type=str,
  508. default="meta-llama/Llama-3.2-11B-Vision-Instruct",
  509. help="Model name to use for inference",
  510. )
  511. parser.add_argument(
  512. "--dataset_name",
  513. type=str,
  514. default="singhsays/fake-w2-us-tax-form-dataset",
  515. help="Name of the Huggingface dataset",
  516. )
  517. parser.add_argument(
  518. "--output_dir",
  519. type=str,
  520. default="./w2_evaluation_results",
  521. help="Directory to save evaluation results",
  522. )
  523. parser.add_argument(
  524. "--limit",
  525. type=int,
  526. default=10,
  527. help="Number of samples to evaluate (default: 10, use -1 for all)",
  528. )
  529. parser.add_argument(
  530. "--structured",
  531. action="store_true",
  532. default=False,
  533. help="Whether to use structured output (JSON schema)",
  534. )
  535. parser.add_argument(
  536. "--timeout",
  537. type=int,
  538. default=300,
  539. help="Timeout for SDK requests in seconds",
  540. )
  541. parser.add_argument(
  542. "--max_workers",
  543. type=int,
  544. default=10,
  545. help="Maximum number of concurrent workers for batch processing",
  546. )
  547. parser.add_argument(
  548. "--sequential",
  549. action="store_true",
  550. default=False,
  551. help="Process samples sequentially instead of in parallel (for debugging)",
  552. )
  553. args = parser.parse_args()
  554. # Create output directory
  555. os.makedirs(args.output_dir, exist_ok=True)
  556. # Load dataset
  557. logger.info(f"Loading dataset: {args.dataset_name}")
  558. test_set = None
  559. if Path(args.dataset_name, "state.json").exists():
  560. test_set = load_from_disk(args.dataset_name)
  561. else:
  562. dataset = load_dataset(args.dataset_name)
  563. if "test" not in dataset:
  564. logger.error("Dataset does not have a test split")
  565. return 1
  566. test_set = dataset["test"]
  567. logger.info(f"Loaded test set with {len(test_set)} samples")
  568. # Limit number of samples if specified
  569. if args.limit > 0 and args.limit < len(test_set):
  570. test_set = test_set.select(range(args.limit))
  571. logger.info(f"Limited to {args.limit} samples")
  572. # Get API key from environment variable
  573. api_key = os.getenv("TOGETHER_API_KEY") or os.getenv("OPENAI_API_KEY")
  574. if not api_key:
  575. logger.warning(
  576. "No API key found. Please set the TOGETHER_API_KEY or OPENAI_API_KEY environment variable for public APIs."
  577. )
  578. api_key = "default-blank-localhost"
  579. # Test server connection
  580. try:
  581. client = OpenAI(
  582. api_key=api_key,
  583. base_url=f"{args.server_url}",
  584. )
  585. # Test with a simple call
  586. # models = client.models.list()
  587. logger.info(f"Successfully connected to server at {args.server_url}")
  588. # logger.info(f"Available models: {[model.id for model in models.data]}")
  589. except Exception as e:
  590. logger.error(f"Failed to connect to server at {args.server_url}: {e}")
  591. logger.error("Make sure the server is running and accessible")
  592. return 1
  593. # Run evaluation
  594. if args.sequential:
  595. logger.info("Running sequential evaluation...")
  596. results = vllm_openai_sdk_sequential_evaluation(
  597. test_set=test_set,
  598. output_dir=args.output_dir,
  599. server_url=args.server_url,
  600. api_key=api_key,
  601. model=args.model,
  602. structured=args.structured,
  603. timeout=args.timeout,
  604. )
  605. else:
  606. logger.info(f"Running batch evaluation with {args.max_workers} workers...")
  607. results = vllm_openai_sdk_evaluation(
  608. test_set=test_set,
  609. output_dir=args.output_dir,
  610. server_url=args.server_url,
  611. api_key=api_key,
  612. model=args.model,
  613. structured=args.structured,
  614. timeout=args.timeout,
  615. max_workers=args.max_workers,
  616. )
  617. # Save detailed results
  618. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  619. try:
  620. results_file = os.path.join(args.output_dir, f"results_{timestamp}.json")
  621. with open(results_file, "w") as f:
  622. json.dump(results, f, indent=2)
  623. logger.info(f"Detailed results saved to {results_file}")
  624. except Exception as e:
  625. logger.error(f"Error saving detailed results: {str(e)}")
  626. return 1
  627. # Calculate metrics
  628. metrics = calculate_metrics(results)
  629. # Save evaluation summary
  630. output_file = os.path.join(args.output_dir, f"evaluation_results_{timestamp}.json")
  631. arguments = {
  632. "server_url": args.server_url,
  633. "model": args.model,
  634. "output_dir": args.output_dir,
  635. "dataset_name": args.dataset_name,
  636. "limit": args.limit,
  637. "structured": args.structured,
  638. "timeout": args.timeout,
  639. "max_workers": args.max_workers,
  640. "sequential": args.sequential,
  641. "prompt": generate_prompt(args.structured),
  642. }
  643. summary = {
  644. "arguments": arguments,
  645. "metrics": metrics,
  646. "timestamp": timestamp,
  647. "total_samples": len(results),
  648. }
  649. with open(output_file, "w") as f:
  650. json.dump(summary, f, indent=2)
  651. # Print summary
  652. logger.info("=" * 50)
  653. logger.info("EVALUATION SUMMARY")
  654. logger.info("=" * 50)
  655. logger.info(f"Overall accuracy: {metrics['accuracy']:.4f}")
  656. logger.info(f"Parse error rate: {metrics['parse_error']:.4f}")
  657. logger.info("Field-level accuracy:")
  658. field_accuracy = metrics["field_accuracy"]
  659. for field, acc in sorted(field_accuracy.items(), key=lambda x: x[1], reverse=True):
  660. logger.info(f" {field}: {acc:.4f}")
  661. logger.info(f"Results saved to {output_file}")
  662. # Clean up temp directory if it exists
  663. temp_dir = os.path.join(args.output_dir, "temp")
  664. if os.path.exists(temp_dir):
  665. import shutil
  666. shutil.rmtree(temp_dir)
  667. return 0
  668. if __name__ == "__main__":
  669. exit(main())