#!/usr/bin/env python3 """ Script to evaluate a vision-language model on the W2 tax form dataset using compatible API client. Leverages the OpenAI-compatible SDK for various endpoints, like vLLM server, Llama API, or any compatible API. Support batch processing. Loads images from the provided dataset, sends them to the compatible API server, and compares with the expected output. """ import argparse import base64 import json import logging import os import pathlib import re import time import traceback from concurrent.futures import as_completed, ThreadPoolExecutor from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from datasets import load_dataset, load_from_disk from openai import OpenAI from PIL import Image from pydantic import BaseModel from tqdm import tqdm # Set up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) class W2Form(BaseModel): box_b_employer_identification_number: str box_c_employer_name: str box_c_employer_street_address: str box_c_employer_city_state_zip: str box_a_employee_ssn: str box_e_employee_name: str box_e_employee_street_address: str box_e_employee_city_state_zip: str box_d_control_number: int box_1_wages: float box_2_federal_tax_withheld: float box_3_social_security_wages: float box_4_social_security_tax_withheld: float box_5_medicare_wages: float box_6_medicare_wages_tax_withheld: float box_7_social_security_tips: float box_8_allocated_tips: float box_9_advance_eic_payment: Optional[str] box_10_dependent_care_benefits: float box_11_nonqualified_plans: float box_12a_code: str box_12a_value: float box_12b_code: str box_12b_value: float box_12c_code: str box_12c_value: float box_12d_code: Optional[str] box_12d_value: float box_13_statutary_employee: Optional[str] box_13_retirement_plan: Optional[str] box_13_third_part_sick_pay: Optional[str] box_15_1_state: str box_15_1_employee_state_id: str box_16_1_state_wages: float box_17_1_state_income_tax: float box_18_1_local_wages: float box_19_1_local_income_tax: float box_20_1_locality: str box_15_2_state: str box_15_2_employee_state_id: str box_16_2_state_wages: float box_17_2_state_income_tax: float box_18_2_local_wages: float box_19_2_local_income_tax: float box_20_2_locality: str # ----------- Utilities ----------- def encode_image_to_base64(image_path: str) -> str: """Encode image to base64 string.""" with open(image_path, "rb") as f: return base64.b64encode(f.read()).decode() def create_messages(prompt: str, image_path: str) -> List[Dict]: """Create messages array for API client call.""" content = [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{encode_image_to_base64(image_path)}" }, }, ] return [{"role": "user", "content": content}] def clean_json_string(json_str: str) -> str: """ Clean common JSON formatting issues from LLM responses. Args: json_str: Raw JSON string that may contain formatting issues Returns: Cleaned JSON string """ # Remove markdown code block markers json_str = re.sub(r"```(?:json)?\s*", "", json_str) json_str = re.sub(r"\s*```", "", json_str) # Fix malformed string patterns like: "field": ",\n" , # This handles the specific error case where strings are malformed with newlines json_str = re.sub(r':\s*",\s*"\s*,', ': "",', json_str) # Fix incomplete string literals with control characters # Pattern: "field": "partial_value\nrest_of_value", json_str = re.sub(r':\s*"([^"]*)\n([^"]*)",', r': "\1\2",', json_str) # Fix the specific pattern from the error: "field": "value\n" followed by whitespace and comma json_str = re.sub(r':\s*"([^"]*)\n"\s*,', r': "\1",', json_str) # Remove trailing commas in objects and arrays json_str = re.sub(r",(\s*[}\]])", r"\1", json_str) # Fix missing quotes around keys (sometimes LLMs output unquoted keys) json_str = re.sub(r"([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:", r'\1"\2":', json_str) # Fix single quotes to double quotes (JSON requires double quotes) json_str = re.sub(r"'([^']*)'", r'"\1"', json_str) # Remove control characters that are not allowed in JSON strings # Keep only printable ASCII and basic whitespace json_str = "".join(char for char in json_str if ord(char) >= 32 or char in "\t\r ") # Fix null-like values that should be proper JSON null json_str = re.sub(r":\s*None\s*,", ": null,", json_str, flags=re.IGNORECASE) json_str = re.sub(r":\s*undefined\s*,", ": null,", json_str, flags=re.IGNORECASE) return json_str def extract_json_from_response(response: str) -> Tuple[Dict[str, Any], bool]: """ Robust JSON extraction from LLM responses with comprehensive error handling. Args: response: Raw response text from LLM Returns: Tuple of (extracted_json_dict, has_error) """ if not response or not response.strip(): logger.warning("Empty response provided") return {}, True # Strategy 1: Look for JSON content between triple backticks json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", response, re.DOTALL) if json_match: json_str = json_match.group(1) else: # Strategy 2: Look for JSON object pattern (handle nested braces) json_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", response, re.DOTALL) if json_match: json_str = json_match.group(0) else: # Strategy 3: Find content between first { and last } start_idx = response.find("{") end_idx = response.rfind("}") if start_idx != -1 and end_idx != -1 and end_idx > start_idx: json_str = response[start_idx : end_idx + 1] else: logger.warning("No JSON pattern found in response") logger.debug(f"Response snippet: {response[:200]}...") return {}, True # Clean the extracted JSON string original_json_str = json_str json_str = clean_json_string(json_str) # Attempt to parse with multiple strategies parsing_strategies = [ ("direct", lambda s: json.loads(s)), ("strip_whitespace", lambda s: json.loads(s.strip())), ( "fix_escapes", lambda s: json.loads(s.replace("\\\\", "\\").replace('\\"', '"')), ), ] for strategy_name, parse_func in parsing_strategies: try: parsed_json = parse_func(json_str) # Validate that it's a dictionary (expected for most use cases) if not isinstance(parsed_json, dict): logger.warning( f"Extracted JSON is not a dictionary: {type(parsed_json)}" ) continue logger.debug(f"Successfully parsed JSON using strategy: {strategy_name}") return parsed_json, False except json.JSONDecodeError as e: logger.debug(f"Strategy '{strategy_name}' failed: {e}") continue except Exception as e: logger.debug(f"Unexpected error in strategy '{strategy_name}': {e}") continue # If all strategies fail, log details for debugging logger.error("All JSON parsing strategies failed") logger.debug(f"Original JSON string (first 500 chars): {original_json_str[:500]}") logger.debug(f"Cleaned JSON string (first 500 chars): {json_str[:500]}") return {}, True def generate_prompt(structured=True) -> str: """Generate prompt for the model.""" json_schema = W2Form.model_json_schema() prompt = ( "You are an expert document information extraction system. " "I will show you an image of a W-2 tax form. " "Please extract all the information from this form and return it in a JSON format. " "Include all fields such as employee details, employer details, wages, federal income tax withheld, " "social security wages, social security tax withheld, medicare wages and tips, medicare tax withheld, " "and any other information present on the form. " ) if not structured: prompt += f"Return ONLY the JSON output without any additional text or explanations following this schema {json_schema}" return prompt def call_api_client( client: OpenAI, messages: List[Dict], model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct", temperature: float = 0.0, max_tokens: int = 8192, response_format: Optional[Dict] = None, timeout: int = 300, seed: Optional[int] = 42, ): """ Call compatible API server using OpenAI-compatible client. """ try: kwargs = { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "timeout": timeout, } # Add seed if provided for reproducible generation if seed is not None: kwargs["seed"] = seed # Add response format if structured output is enabled if response_format: kwargs["response_format"] = response_format logger.debug(f"Making API client call with model: {model}") response = client.chat.completions.create(**kwargs) logger.debug(f"Received response with {len(response.choices)} choices") return response except Exception as e: logger.error(f"API client call failed: {e}") raise def process_single_sample( client: OpenAI, sample_data: Tuple[int, Dict], output_dir: str, model: str, structured: bool, timeout: int, ) -> Dict[str, Any]: """Process a single sample using OpenAI SDK.""" idx, sample = sample_data try: # Get image image = sample["image"] # Save image temporarily image_path = get_image_path(image, output_dir, idx) logger.debug(f"Saved image to {image_path}") # Generate prompt and messages prompt = generate_prompt(structured) messages = create_messages(prompt, image_path) # Prepare response format for structured output response_format = None if structured: json_schema = W2Form.model_json_schema() response_format = { "type": "json_schema", "json_schema": { "name": "W2Form", "schema": json_schema, "strict": True, }, } # Call API client start_time = time.time() try: response = call_api_client( client=client, messages=messages, model=model, response_format=response_format, timeout=timeout, ) content = response.choices[0].message.content usage = response.usage.model_dump() if response.usage else {} except Exception as e: logger.error(f"Error calling OpenAI SDK for sample {idx}: {e}") content = "" usage = {} processing_time = time.time() - start_time # Extract JSON from response extracted_json, json_parsing_error = extract_json_from_response(content) # Get ground truth ground_truth_raw = json.loads(sample["ground_truth"]) # Handle the gt_parse wrapper structure if present if "gt_parse" in ground_truth_raw: ground_truth = ground_truth_raw["gt_parse"] else: ground_truth = ground_truth_raw # Normalize for comparison normalized_pred = normalize_json(extracted_json) normalized_gt = normalize_json(ground_truth) # Save results result = { "sample_id": idx, "prediction": extracted_json, "ground_truth": ground_truth, "normalized_prediction": normalized_pred, "normalized_gt": normalized_gt, "raw_response": content, "processing_time": processing_time, "json_parsing_error": json_parsing_error, "usage": usage, } return result except Exception as e: traceback_str = traceback.format_exc() logger.error(f"Error processing sample {idx}: {str(e)} at line {traceback_str}") return { "sample_id": idx, "prediction": {}, "ground_truth": {}, "normalized_prediction": {}, "normalized_gt": {}, "raw_response": "", "processing_time": 0.0, "json_parsing_error": True, "usage": {}, "error": str(e), } def calculate_metrics(results: List[Dict]) -> Dict[str, Any]: """Calculate accuracy metrics for the predictions.""" if not results: logger.error("No results provided") return {"accuracy": 0.0, "field_accuracy": {}} # Initialize metrics total_fields = 0 correct_fields = 0 parse_errors = 0 total_records = len(results) logger.info(f"Total records: {total_records}") field_counts = {} field_correct = {} for result in results: pred, gt = result["prediction"], result["ground_truth"] if result["json_parsing_error"]: parse_errors += 1 total_fields += len(gt) continue for field in gt.keys(): # Count total occurrences of this field field_counts[field] = field_counts.get(field, 0) + 1 total_fields += 1 # Check if field is correct if field in pred and pred[field] == gt[field]: correct_fields += 1 field_correct[field] = field_correct.get(field, 0) + 1 # Calculate overall accuracy accuracy = correct_fields / total_fields if total_fields > 0 else 0.0 errors = parse_errors / total_records if total_records > 0 else 0.0 # Calculate per-field accuracy field_accuracy = {} for field in field_counts: field_accuracy[field] = field_correct.get(field, 0) / field_counts[field] return { "accuracy": accuracy, "field_accuracy": field_accuracy, "parse_error": errors, } def normalize_field_value(value: Any) -> str: """Normalize field values for comparison.""" if value is None: return "" # Convert to string and normalize value_str = str(value).strip().lower() # Remove common separators in numbers value_str = value_str.replace(",", "").replace(" ", "") # Try to convert to float for numeric comparison try: value_float = float(value_str) return str(value_float) except ValueError: return value_str def normalize_json(json_obj: Dict) -> Dict: """Normalize JSON object for comparison.""" normalized = {} for key, value in json_obj.items(): # Normalize key (lowercase, remove spaces) norm_key = key.lower().replace(" ", "_") # Normalize value if isinstance(value, dict): normalized[norm_key] = normalize_json(value) elif isinstance(value, list): normalized[norm_key] = [normalize_field_value(v) for v in value] else: normalized[norm_key] = normalize_field_value(value) return normalized def get_image_path(image: Image.Image, output_dir: str, idx: int) -> str: """Get the path to save the image.""" # Create a temporary file for the image temp_dir = pathlib.Path(output_dir) / "temp" os.makedirs(temp_dir, exist_ok=True) image_path = temp_dir / f"temp_{idx}.png" image_path = str(image_path.resolve()) image.save(image_path) return image_path def vllm_openai_sdk_evaluation( test_set, output_dir: str, server_url: str = "http://localhost:8001", api_key: str = "default-blank-localhost", model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct", structured: bool = True, timeout: int = 300, max_workers: int = 10, ): """ Evaluate the W2 extraction task using OpenAI SDK with batch processing. """ # Initialize OpenAI client client = OpenAI( api_key=api_key, # vLLM doesn't require a real API key base_url=f"{server_url}/v1", ) # Prepare sample data for batch processing sample_data = [(idx, sample) for idx, sample in enumerate(test_set)] results = [] # Use ThreadPoolExecutor for concurrent processing with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all tasks future_to_sample = { executor.submit( process_single_sample, client, data, output_dir, model, structured, timeout, ): data[0] for data in sample_data } # Collect results with progress bar for future in tqdm( as_completed(future_to_sample), total=len(sample_data), desc="Processing samples with OpenAI SDK (batch)", ): sample_idx = future_to_sample[future] try: result = future.result() results.append(result) except Exception as e: logger.error(f"Exception in sample {sample_idx}: {e}") # Add error result results.append( { "sample_id": sample_idx, "prediction": {}, "ground_truth": {}, "normalized_prediction": {}, "normalized_gt": {}, "raw_response": "", "processing_time": 0.0, "json_parsing_error": True, "usage": {}, "error": str(e), } ) # Sort results by sample_id to maintain order results.sort(key=lambda x: x["sample_id"]) return results def vllm_openai_sdk_sequential_evaluation( test_set, output_dir: str, server_url: str = "http://localhost:8001", api_key: str = "default-blank-localhost", model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct", structured: bool = True, timeout: int = 300, ): """ Evaluate the W2 extraction task using OpenAI SDK sequentially (for debugging). """ # Initialize OpenAI client client = OpenAI( api_key=api_key, # vLLM doesn't require a real API key base_url=f"{server_url}/v1", ) results = [] for idx, sample in enumerate( tqdm(test_set, desc="Processing samples with OpenAI SDK (sequential)") ): result = process_single_sample( client, (idx, sample), output_dir, model, structured, timeout ) results.append(result) return results def main(): parser = argparse.ArgumentParser( description="Evaluate vision-language model on W2 tax form dataset using OpenAI SDK" ) parser.add_argument( "--server_url", type=str, default="http://localhost:8001", help="URL of the vLLM HTTP server", ) parser.add_argument( "--model", type=str, default="meta-llama/Llama-3.2-11B-Vision-Instruct", help="Model name to use for inference", ) parser.add_argument( "--dataset_name", type=str, default="singhsays/fake-w2-us-tax-form-dataset", help="Name of the Huggingface dataset", ) parser.add_argument( "--output_dir", type=str, default="./w2_evaluation_results", help="Directory to save evaluation results", ) parser.add_argument( "--limit", type=int, default=10, help="Number of samples to evaluate (default: 10, use -1 for all)", ) parser.add_argument( "--structured", action="store_true", default=False, help="Whether to use structured output (JSON schema)", ) parser.add_argument( "--timeout", type=int, default=300, help="Timeout for SDK requests in seconds", ) parser.add_argument( "--max_workers", type=int, default=10, help="Maximum number of concurrent workers for batch processing", ) parser.add_argument( "--sequential", action="store_true", default=False, help="Process samples sequentially instead of in parallel (for debugging)", ) args = parser.parse_args() # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Load dataset logger.info(f"Loading dataset: {args.dataset_name}") test_set = None if Path(args.dataset_name, "state.json").exists(): test_set = load_from_disk(args.dataset_name) else: dataset = load_dataset(args.dataset_name) if "test" not in dataset: logger.error("Dataset does not have a test split") return 1 test_set = dataset["test"] logger.info(f"Loaded test set with {len(test_set)} samples") # Limit number of samples if specified if args.limit > 0 and args.limit < len(test_set): test_set = test_set.select(range(args.limit)) logger.info(f"Limited to {args.limit} samples") # Get API key from environment variable api_key = os.getenv("LLAMA_API_KEY") or os.getenv("OPENAI_API_KEY") if not api_key: logger.warning( "No API key found. Please set the LLAMA_API_KEY or OPENAI_API_KEY environment variable for public APIs." ) api_key = "default-blank-localhost" # Test server connection try: client = OpenAI( api_key=api_key, base_url=f"{args.server_url}/v1", ) # Test with a simple call models = client.models.list() logger.info(f"Successfully connected to vLLM server at {args.server_url}") logger.info(f"Available models: {[model.id for model in models.data]}") except Exception as e: logger.error(f"Failed to connect to vLLM server at {args.server_url}: {e}") logger.error("Make sure the vLLM server is running and accessible") return 1 # Run evaluation if args.sequential: logger.info("Running sequential evaluation...") results = vllm_openai_sdk_sequential_evaluation( test_set=test_set, output_dir=args.output_dir, server_url=args.server_url, api_key=api_key, model=args.model, structured=args.structured, timeout=args.timeout, ) else: logger.info(f"Running batch evaluation with {args.max_workers} workers...") results = vllm_openai_sdk_evaluation( test_set=test_set, output_dir=args.output_dir, server_url=args.server_url, api_key=api_key, model=args.model, structured=args.structured, timeout=args.timeout, max_workers=args.max_workers, ) # Save detailed results timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") try: results_file = os.path.join(args.output_dir, f"results_{timestamp}.json") with open(results_file, "w") as f: json.dump(results, f, indent=2) logger.info(f"Detailed results saved to {results_file}") except Exception as e: logger.error(f"Error saving detailed results: {str(e)}") return 1 # Calculate metrics metrics = calculate_metrics(results) # Save evaluation summary output_file = os.path.join(args.output_dir, f"evaluation_results_{timestamp}.json") arguments = { "server_url": args.server_url, "model": args.model, "output_dir": args.output_dir, "dataset_name": args.dataset_name, "limit": args.limit, "structured": args.structured, "timeout": args.timeout, "max_workers": args.max_workers, "sequential": args.sequential, "prompt": generate_prompt(args.structured), } summary = { "arguments": arguments, "metrics": metrics, "timestamp": timestamp, "total_samples": len(results), } with open(output_file, "w") as f: json.dump(summary, f, indent=2) # Print summary logger.info("=" * 50) logger.info("EVALUATION SUMMARY") logger.info("=" * 50) logger.info(f"Overall accuracy: {metrics['accuracy']:.4f}") logger.info(f"Parse error rate: {metrics['parse_error']:.4f}") logger.info("Field-level accuracy:") field_accuracy = metrics["field_accuracy"] for field, acc in sorted(field_accuracy.items(), key=lambda x: x[1], reverse=True): logger.info(f" {field}: {acc:.4f}") logger.info(f"Results saved to {output_file}") # Clean up temp directory if it exists temp_dir = os.path.join(args.output_dir, "temp") if os.path.exists(temp_dir): import shutil shutil.rmtree(temp_dir) return 0 if __name__ == "__main__": exit(main())