evaluate.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. #!/usr/bin/env python3
  2. """
  3. Script to evaluate a vision-language model on the W2 tax form dataset using compatible API client.
  4. Leverages the OpenAI-compatible SDK for various endpoints, like vLLM server, Llama API, or any compatible API.
  5. Support batch processing.
  6. Loads images from the provided dataset, sends them to the compatible API server,
  7. and compares with the expected output.
  8. """
  9. import argparse
  10. import base64
  11. import json
  12. import logging
  13. import os
  14. import pathlib
  15. import re
  16. import time
  17. import traceback
  18. from concurrent.futures import as_completed, ThreadPoolExecutor
  19. from datetime import datetime
  20. from pathlib import Path
  21. from typing import Any, Dict, List, Optional, Tuple
  22. from datasets import load_dataset, load_from_disk
  23. from openai import OpenAI
  24. from PIL import Image
  25. from pydantic import BaseModel
  26. from tqdm import tqdm
  27. # Set up logging
  28. logging.basicConfig(
  29. level=logging.WARNING, format="%(asctime)s - %(levelname)s - %(message)s"
  30. )
  31. logger = logging.getLogger(__name__)
  32. logger.setLevel(logging.INFO)
  33. class W2Form(BaseModel):
  34. box_b_employer_identification_number: str
  35. box_c_employer_name: str
  36. box_c_employer_street_address: str
  37. box_c_employer_city_state_zip: str
  38. box_a_employee_ssn: str
  39. box_e_employee_name: str
  40. box_e_employee_street_address: str
  41. box_e_employee_city_state_zip: str
  42. box_d_control_number: int
  43. box_1_wages: float
  44. box_2_federal_tax_withheld: float
  45. box_3_social_security_wages: float
  46. box_4_social_security_tax_withheld: float
  47. box_5_medicare_wages: float
  48. box_6_medicare_wages_tax_withheld: float
  49. box_7_social_security_tips: float
  50. box_8_allocated_tips: float
  51. box_9_advance_eic_payment: Optional[str]
  52. box_10_dependent_care_benefits: float
  53. box_11_nonqualified_plans: float
  54. box_12a_code: str
  55. box_12a_value: float
  56. box_12b_code: str
  57. box_12b_value: float
  58. box_12c_code: str
  59. box_12c_value: float
  60. box_12d_code: Optional[str]
  61. box_12d_value: float
  62. box_13_statutary_employee: Optional[str]
  63. box_13_retirement_plan: Optional[str]
  64. box_13_third_part_sick_pay: Optional[str]
  65. box_15_1_state: str
  66. box_15_1_employee_state_id: str
  67. box_16_1_state_wages: float
  68. box_17_1_state_income_tax: float
  69. box_18_1_local_wages: float
  70. box_19_1_local_income_tax: float
  71. box_20_1_locality: str
  72. box_15_2_state: str
  73. box_15_2_employee_state_id: str
  74. box_16_2_state_wages: float
  75. box_17_2_state_income_tax: float
  76. box_18_2_local_wages: float
  77. box_19_2_local_income_tax: float
  78. box_20_2_locality: str
  79. # ----------- Utilities -----------
  80. def encode_image_to_base64(image_path: str) -> str:
  81. """Encode image to base64 string."""
  82. with open(image_path, "rb") as f:
  83. return base64.b64encode(f.read()).decode()
  84. def create_messages(prompt: str, image_path: str) -> List[Dict]:
  85. """Create messages array for API client call."""
  86. content = [
  87. {"type": "text", "text": prompt},
  88. {
  89. "type": "image_url",
  90. "image_url": {
  91. "url": f"data:image/png;base64,{encode_image_to_base64(image_path)}"
  92. },
  93. },
  94. ]
  95. return [{"role": "user", "content": content}]
  96. def clean_json_string(json_str: str) -> str:
  97. """
  98. Clean common JSON formatting issues from LLM responses.
  99. Args:
  100. json_str: Raw JSON string that may contain formatting issues
  101. Returns:
  102. Cleaned JSON string
  103. """
  104. # Remove markdown code block markers
  105. json_str = re.sub(r"```(?:json)?\s*", "", json_str)
  106. json_str = re.sub(r"\s*```", "", json_str)
  107. # Fix malformed string patterns like: "field": ",\n" ,
  108. # This handles the specific error case where strings are malformed with newlines
  109. json_str = re.sub(r':\s*",\s*"\s*,', ': "",', json_str)
  110. # Fix incomplete string literals with control characters
  111. # Pattern: "field": "partial_value\nrest_of_value",
  112. json_str = re.sub(r':\s*"([^"]*)\n([^"]*)",', r': "\1\2",', json_str)
  113. # Fix the specific pattern from the error: "field": "value\n" followed by whitespace and comma
  114. json_str = re.sub(r':\s*"([^"]*)\n"\s*,', r': "\1",', json_str)
  115. # Remove trailing commas in objects and arrays
  116. json_str = re.sub(r",(\s*[}\]])", r"\1", json_str)
  117. # Fix missing quotes around keys (sometimes LLMs output unquoted keys)
  118. json_str = re.sub(r"([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:", r'\1"\2":', json_str)
  119. # Fix single quotes to double quotes (JSON requires double quotes)
  120. json_str = re.sub(r"'([^']*)'", r'"\1"', json_str)
  121. # Remove control characters that are not allowed in JSON strings
  122. # Keep only printable ASCII and basic whitespace
  123. json_str = "".join(char for char in json_str if ord(char) >= 32 or char in "\t\r ")
  124. # Fix null-like values that should be proper JSON null
  125. json_str = re.sub(r":\s*None\s*,", ": null,", json_str, flags=re.IGNORECASE)
  126. json_str = re.sub(r":\s*undefined\s*,", ": null,", json_str, flags=re.IGNORECASE)
  127. return json_str
  128. def extract_json_from_response(response: str) -> Tuple[Dict[str, Any], bool]:
  129. """
  130. Robust JSON extraction from LLM responses with comprehensive error handling.
  131. Args:
  132. response: Raw response text from LLM
  133. Returns:
  134. Tuple of (extracted_json_dict, has_error)
  135. """
  136. if not response or not response.strip():
  137. logger.warning("Empty response provided")
  138. return {}, True
  139. # Strategy 1: Look for JSON content between triple backticks
  140. json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", response, re.DOTALL)
  141. if json_match:
  142. json_str = json_match.group(1)
  143. else:
  144. # Strategy 2: Look for JSON object pattern (handle nested braces)
  145. json_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", response, re.DOTALL)
  146. if json_match:
  147. json_str = json_match.group(0)
  148. else:
  149. # Strategy 3: Find content between first { and last }
  150. start_idx = response.find("{")
  151. end_idx = response.rfind("}")
  152. if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
  153. json_str = response[start_idx : end_idx + 1]
  154. else:
  155. logger.warning("No JSON pattern found in response")
  156. logger.debug(f"Response snippet: {response[:200]}...")
  157. return {}, True
  158. # Clean the extracted JSON string
  159. original_json_str = json_str
  160. json_str = clean_json_string(json_str)
  161. # Attempt to parse with multiple strategies
  162. parsing_strategies = [
  163. ("direct", lambda s: json.loads(s)),
  164. ("strip_whitespace", lambda s: json.loads(s.strip())),
  165. (
  166. "fix_escapes",
  167. lambda s: json.loads(s.replace("\\\\", "\\").replace('\\"', '"')),
  168. ),
  169. ]
  170. for strategy_name, parse_func in parsing_strategies:
  171. try:
  172. parsed_json = parse_func(json_str)
  173. # Validate that it's a dictionary (expected for most use cases)
  174. if not isinstance(parsed_json, dict):
  175. logger.warning(
  176. f"Extracted JSON is not a dictionary: {type(parsed_json)}"
  177. )
  178. continue
  179. logger.debug(f"Successfully parsed JSON using strategy: {strategy_name}")
  180. return parsed_json, False
  181. except json.JSONDecodeError as e:
  182. logger.debug(f"Strategy '{strategy_name}' failed: {e}")
  183. continue
  184. except Exception as e:
  185. logger.debug(f"Unexpected error in strategy '{strategy_name}': {e}")
  186. continue
  187. # If all strategies fail, log details for debugging
  188. logger.error("All JSON parsing strategies failed")
  189. logger.debug(f"Original JSON string (first 500 chars): {original_json_str[:500]}")
  190. logger.debug(f"Cleaned JSON string (first 500 chars): {json_str[:500]}")
  191. return {}, True
  192. def generate_prompt(structured=True) -> str:
  193. """Generate prompt for the model."""
  194. json_schema = W2Form.model_json_schema()
  195. prompt = (
  196. "You are an expert document information extraction system. "
  197. "I will show you an image of a W-2 tax form. "
  198. "Please extract all the information from this form and return it in a JSON format. "
  199. "Include all fields such as employee details, employer details, wages, federal income tax withheld, "
  200. "social security wages, social security tax withheld, medicare wages and tips, medicare tax withheld, "
  201. "and any other information present on the form. "
  202. )
  203. if not structured:
  204. prompt += f"Return ONLY the JSON output without any additional text or explanations following this schema {json_schema}"
  205. return prompt
  206. def call_api_client(
  207. client: OpenAI,
  208. messages: List[Dict],
  209. model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
  210. temperature: float = 0.0,
  211. max_tokens: int = 8192,
  212. response_format: Optional[Dict] = None,
  213. timeout: int = 300,
  214. seed: Optional[int] = 42,
  215. ):
  216. """
  217. Call compatible API server using OpenAI-compatible client.
  218. """
  219. try:
  220. kwargs = {
  221. "model": model,
  222. "messages": messages,
  223. "temperature": temperature,
  224. "max_tokens": max_tokens,
  225. "timeout": timeout,
  226. }
  227. # Add seed if provided for reproducible generation
  228. if seed is not None:
  229. kwargs["seed"] = seed
  230. # Add response format if structured output is enabled
  231. if response_format:
  232. kwargs["response_format"] = response_format
  233. logger.debug(f"Making API client call with model: {model}")
  234. response = client.chat.completions.create(**kwargs)
  235. logger.debug(f"Received response with {len(response.choices)} choices")
  236. return response
  237. except Exception as e:
  238. logger.error(f"API client call failed: {e}")
  239. raise
  240. def process_single_sample(
  241. client: OpenAI,
  242. sample_data: Tuple[int, Dict],
  243. output_dir: str,
  244. model: str,
  245. structured: bool,
  246. timeout: int,
  247. ) -> Dict[str, Any]:
  248. """Process a single sample using OpenAI SDK."""
  249. idx, sample = sample_data
  250. try:
  251. # Get image
  252. image = sample["image"]
  253. # Save image temporarily
  254. image_path = get_image_path(image, output_dir, idx)
  255. logger.debug(f"Saved image to {image_path}")
  256. # Generate prompt and messages
  257. prompt = generate_prompt(structured)
  258. messages = create_messages(prompt, image_path)
  259. # Prepare response format for structured output
  260. response_format = None
  261. if structured:
  262. json_schema = W2Form.model_json_schema()
  263. response_format = {
  264. "type": "json_schema",
  265. "json_schema": {
  266. "name": "W2Form",
  267. "schema": json_schema,
  268. "strict": True,
  269. },
  270. }
  271. # Call API client
  272. start_time = time.time()
  273. try:
  274. response = call_api_client(
  275. client=client,
  276. messages=messages,
  277. model=model,
  278. response_format=response_format,
  279. timeout=timeout,
  280. )
  281. content = response.choices[0].message.content
  282. usage = response.usage.model_dump() if response.usage else {}
  283. except Exception as e:
  284. logger.error(f"Error calling SDK for sample {idx}: {e}")
  285. content = ""
  286. usage = {}
  287. processing_time = time.time() - start_time
  288. # Extract JSON from response
  289. extracted_json, json_parsing_error = extract_json_from_response(content)
  290. # Get ground truth
  291. ground_truth_raw = json.loads(sample["ground_truth"])
  292. # Handle the gt_parse wrapper structure if present
  293. if "gt_parse" in ground_truth_raw:
  294. ground_truth = ground_truth_raw["gt_parse"]
  295. else:
  296. ground_truth = ground_truth_raw
  297. # Normalize for comparison
  298. normalized_pred = normalize_json(extracted_json)
  299. normalized_gt = normalize_json(ground_truth)
  300. # Save results
  301. result = {
  302. "sample_id": idx,
  303. "prediction": extracted_json,
  304. "ground_truth": ground_truth,
  305. "normalized_prediction": normalized_pred,
  306. "normalized_gt": normalized_gt,
  307. "raw_response": content,
  308. "processing_time": processing_time,
  309. "json_parsing_error": json_parsing_error,
  310. "usage": usage,
  311. }
  312. return result
  313. except Exception as e:
  314. traceback_str = traceback.format_exc()
  315. logger.error(f"Error processing sample {idx}: {str(e)} at line {traceback_str}")
  316. return {
  317. "sample_id": idx,
  318. "prediction": {},
  319. "ground_truth": {},
  320. "normalized_prediction": {},
  321. "normalized_gt": {},
  322. "raw_response": "",
  323. "processing_time": 0.0,
  324. "json_parsing_error": True,
  325. "usage": {},
  326. "error": str(e),
  327. }
  328. def calculate_metrics(results: List[Dict]) -> Dict[str, Any]:
  329. """Calculate accuracy metrics for the predictions."""
  330. if not results:
  331. logger.error("No results provided")
  332. return {"accuracy": 0.0, "field_accuracy": {}}
  333. # Initialize metrics
  334. total_fields = 0
  335. correct_fields = 0
  336. parse_errors = 0
  337. total_records = len(results)
  338. logger.info(f"Total records: {total_records}")
  339. field_counts = {}
  340. field_correct = {}
  341. for result in results:
  342. pred, gt = result["prediction"], result["ground_truth"]
  343. if result["json_parsing_error"]:
  344. parse_errors += 1
  345. total_fields += len(gt)
  346. continue
  347. for field in gt.keys():
  348. # Count total occurrences of this field
  349. field_counts[field] = field_counts.get(field, 0) + 1
  350. total_fields += 1
  351. # Check if field is correct
  352. if field in pred and pred[field] == gt[field]:
  353. correct_fields += 1
  354. field_correct[field] = field_correct.get(field, 0) + 1
  355. # Calculate overall accuracy
  356. accuracy = correct_fields / total_fields if total_fields > 0 else 0.0
  357. errors = parse_errors / total_records if total_records > 0 else 0.0
  358. # Calculate per-field accuracy
  359. field_accuracy = {}
  360. for field in field_counts:
  361. field_accuracy[field] = field_correct.get(field, 0) / field_counts[field]
  362. return {
  363. "accuracy": accuracy,
  364. "field_accuracy": field_accuracy,
  365. "parse_error": errors,
  366. }
  367. def normalize_field_value(value: Any) -> str:
  368. """Normalize field values for comparison."""
  369. if value is None:
  370. return ""
  371. # Convert to string and normalize
  372. value_str = str(value).strip().lower()
  373. # Remove common separators in numbers
  374. value_str = value_str.replace(",", "").replace(" ", "")
  375. # Try to convert to float for numeric comparison
  376. try:
  377. value_float = float(value_str)
  378. return str(value_float)
  379. except ValueError:
  380. return value_str
  381. def normalize_json(json_obj: Dict) -> Dict:
  382. """Normalize JSON object for comparison."""
  383. normalized = {}
  384. for key, value in json_obj.items():
  385. # Normalize key (lowercase, remove spaces)
  386. norm_key = key.lower().replace(" ", "_")
  387. # Normalize value
  388. if isinstance(value, dict):
  389. normalized[norm_key] = normalize_json(value)
  390. elif isinstance(value, list):
  391. normalized[norm_key] = [normalize_field_value(v) for v in value]
  392. else:
  393. normalized[norm_key] = normalize_field_value(value)
  394. return normalized
  395. def get_image_path(image: Image.Image, output_dir: str, idx: int) -> str:
  396. """Get the path to save the image."""
  397. # Create a temporary file for the image
  398. temp_dir = pathlib.Path(output_dir) / "temp"
  399. os.makedirs(temp_dir, exist_ok=True)
  400. image_path = temp_dir / f"temp_{idx}.png"
  401. image_path = str(image_path.resolve())
  402. image.save(image_path)
  403. return image_path
  404. def vllm_openai_sdk_evaluation(
  405. test_set,
  406. output_dir: str,
  407. server_url: str = "http://localhost:8001",
  408. api_key: str = "default-blank-localhost",
  409. model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
  410. structured: bool = True,
  411. timeout: int = 300,
  412. max_workers: int = 10,
  413. ):
  414. """
  415. Evaluate the W2 extraction task using OpenAI SDK with batch processing.
  416. """
  417. # Initialize OpenAI client
  418. client = OpenAI(
  419. api_key=api_key, # vLLM doesn't require a real API key
  420. base_url=f"{server_url}",
  421. )
  422. # Prepare sample data for batch processing
  423. sample_data = [(idx, sample) for idx, sample in enumerate(test_set)]
  424. results = []
  425. # Use ThreadPoolExecutor for concurrent processing
  426. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  427. # Submit all tasks
  428. future_to_sample = {
  429. executor.submit(
  430. process_single_sample,
  431. client,
  432. data,
  433. output_dir,
  434. model,
  435. structured,
  436. timeout,
  437. ): data[0]
  438. for data in sample_data
  439. }
  440. # Collect results with progress bar
  441. for future in tqdm(
  442. as_completed(future_to_sample),
  443. total=len(sample_data),
  444. desc="Processing samples (batch)",
  445. ):
  446. sample_idx = future_to_sample[future]
  447. try:
  448. result = future.result()
  449. results.append(result)
  450. except Exception as e:
  451. logger.error(f"Exception in sample {sample_idx}: {e}")
  452. # Add error result
  453. results.append(
  454. {
  455. "sample_id": sample_idx,
  456. "prediction": {},
  457. "ground_truth": {},
  458. "normalized_prediction": {},
  459. "normalized_gt": {},
  460. "raw_response": "",
  461. "processing_time": 0.0,
  462. "json_parsing_error": True,
  463. "usage": {},
  464. "error": str(e),
  465. }
  466. )
  467. # Sort results by sample_id to maintain order
  468. results.sort(key=lambda x: x["sample_id"])
  469. return results
  470. def vllm_openai_sdk_sequential_evaluation(
  471. test_set,
  472. output_dir: str,
  473. server_url: str = "http://localhost:8001",
  474. api_key: str = "default-blank-localhost",
  475. model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
  476. structured: bool = True,
  477. timeout: int = 300,
  478. ):
  479. """
  480. Evaluate the W2 extraction task using OpenAI SDK sequentially (for debugging).
  481. """
  482. # Initialize OpenAI client
  483. client = OpenAI(
  484. api_key=api_key, # vLLM doesn't require a real API key
  485. base_url=f"{server_url}",
  486. )
  487. results = []
  488. for idx, sample in enumerate(
  489. tqdm(test_set, desc="Processing samples (sequential)")
  490. ):
  491. result = process_single_sample(
  492. client, (idx, sample), output_dir, model, structured, timeout
  493. )
  494. results.append(result)
  495. return results
  496. def main():
  497. parser = argparse.ArgumentParser(
  498. description="Evaluate vision-language model on W2 tax form dataset"
  499. )
  500. parser.add_argument(
  501. "--server_url",
  502. type=str,
  503. default="http://localhost:8001",
  504. help="URL of the vLLM HTTP server",
  505. )
  506. parser.add_argument(
  507. "--model",
  508. type=str,
  509. default="meta-llama/Llama-3.2-11B-Vision-Instruct",
  510. help="Model name to use for inference",
  511. )
  512. parser.add_argument(
  513. "--dataset_name",
  514. type=str,
  515. default="singhsays/fake-w2-us-tax-form-dataset",
  516. help="Name of the Huggingface dataset",
  517. )
  518. parser.add_argument(
  519. "--output_dir",
  520. type=str,
  521. default="./w2_evaluation_results",
  522. help="Directory to save evaluation results",
  523. )
  524. parser.add_argument(
  525. "--limit",
  526. type=int,
  527. default=10,
  528. help="Number of samples to evaluate (default: 10, use -1 for all)",
  529. )
  530. parser.add_argument(
  531. "--structured",
  532. action="store_true",
  533. default=False,
  534. help="Whether to use structured output (JSON schema)",
  535. )
  536. parser.add_argument(
  537. "--timeout",
  538. type=int,
  539. default=300,
  540. help="Timeout for SDK requests in seconds",
  541. )
  542. parser.add_argument(
  543. "--max_workers",
  544. type=int,
  545. default=10,
  546. help="Maximum number of concurrent workers for batch processing",
  547. )
  548. parser.add_argument(
  549. "--sequential",
  550. action="store_true",
  551. default=False,
  552. help="Process samples sequentially instead of in parallel (for debugging)",
  553. )
  554. args = parser.parse_args()
  555. # Create output directory
  556. os.makedirs(args.output_dir, exist_ok=True)
  557. # Load dataset
  558. logger.info(f"Loading dataset: {args.dataset_name}")
  559. test_set = None
  560. if Path(args.dataset_name, "state.json").exists():
  561. test_set = load_from_disk(args.dataset_name)
  562. else:
  563. dataset = load_dataset(args.dataset_name)
  564. if "test" not in dataset:
  565. logger.error("Dataset does not have a test split")
  566. return 1
  567. test_set = dataset["test"]
  568. logger.info(f"Loaded test set with {len(test_set)} samples")
  569. # Limit number of samples if specified
  570. if args.limit > 0 and args.limit < len(test_set):
  571. test_set = test_set.select(range(args.limit))
  572. logger.info(f"Limited to {args.limit} samples")
  573. # Get API key from environment variable
  574. api_key = os.getenv("TOGETHER_API_KEY") or os.getenv("OPENAI_API_KEY")
  575. if not api_key:
  576. logger.warning(
  577. "No API key found. Please set the TOGETHER_API_KEY or OPENAI_API_KEY environment variable for public APIs."
  578. )
  579. api_key = "default-blank-localhost"
  580. # Test server connection
  581. try:
  582. client = OpenAI(
  583. api_key=api_key,
  584. base_url=f"{args.server_url}",
  585. )
  586. # Test with a simple call
  587. # models = client.models.list()
  588. logger.info(f"Successfully connected to server at {args.server_url}")
  589. # logger.info(f"Available models: {[model.id for model in models.data]}")
  590. except Exception as e:
  591. logger.error(f"Failed to connect to server at {args.server_url}: {e}")
  592. logger.error("Make sure the server is running and accessible")
  593. return 1
  594. # Run evaluation
  595. if args.sequential:
  596. logger.info("Running sequential evaluation...")
  597. results = vllm_openai_sdk_sequential_evaluation(
  598. test_set=test_set,
  599. output_dir=args.output_dir,
  600. server_url=args.server_url,
  601. api_key=api_key,
  602. model=args.model,
  603. structured=args.structured,
  604. timeout=args.timeout,
  605. )
  606. else:
  607. logger.info(f"Running batch evaluation with {args.max_workers} workers...")
  608. results = vllm_openai_sdk_evaluation(
  609. test_set=test_set,
  610. output_dir=args.output_dir,
  611. server_url=args.server_url,
  612. api_key=api_key,
  613. model=args.model,
  614. structured=args.structured,
  615. timeout=args.timeout,
  616. max_workers=args.max_workers,
  617. )
  618. # Save detailed results
  619. timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
  620. try:
  621. results_file = os.path.join(args.output_dir, f"results_{timestamp}.json")
  622. with open(results_file, "w") as f:
  623. json.dump(results, f, indent=2)
  624. logger.info(f"Detailed results saved to {results_file}")
  625. except Exception as e:
  626. logger.error(f"Error saving detailed results: {str(e)}")
  627. return 1
  628. # Calculate metrics
  629. metrics = calculate_metrics(results)
  630. # Save evaluation summary
  631. output_file = os.path.join(args.output_dir, f"evaluation_results_{timestamp}.json")
  632. arguments = {
  633. "server_url": args.server_url,
  634. "model": args.model,
  635. "output_dir": args.output_dir,
  636. "dataset_name": args.dataset_name,
  637. "limit": args.limit,
  638. "structured": args.structured,
  639. "timeout": args.timeout,
  640. "max_workers": args.max_workers,
  641. "sequential": args.sequential,
  642. "prompt": generate_prompt(args.structured),
  643. }
  644. summary = {
  645. "arguments": arguments,
  646. "metrics": metrics,
  647. "timestamp": timestamp,
  648. "total_samples": len(results),
  649. }
  650. with open(output_file, "w") as f:
  651. json.dump(summary, f, indent=2)
  652. # Print summary
  653. logger.info("=" * 50)
  654. logger.info("EVALUATION SUMMARY")
  655. logger.info("=" * 50)
  656. logger.info(f"Overall accuracy: {metrics['accuracy']:.4f}")
  657. logger.info(f"Parse error rate: {metrics['parse_error']:.4f}")
  658. logger.info("Field-level accuracy:")
  659. field_accuracy = metrics["field_accuracy"]
  660. for field, acc in sorted(field_accuracy.items(), key=lambda x: x[1], reverse=True):
  661. logger.info(f" {field}: {acc:.4f}")
  662. logger.info(f"Results saved to {output_file}")
  663. # Clean up temp directory if it exists
  664. temp_dir = os.path.join(args.output_dir, "temp")
  665. if os.path.exists(temp_dir):
  666. import shutil
  667. shutil.rmtree(temp_dir)
  668. return 0
  669. if __name__ == "__main__":
  670. exit(main())