115 lines
3.8 KiB
Python
115 lines
3.8 KiB
Python
|
|
# test.py - Simple client to test PaddleOCR REST API
|
||
|
|
# Usage: python test.py [--url URL] [--dataset PATH]
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import requests
|
||
|
|
import time
|
||
|
|
import sys
|
||
|
|
|
||
|
|
|
||
|
|
def wait_for_health(url: str, timeout: int = 120) -> bool:
|
||
|
|
"""Wait for API to be ready."""
|
||
|
|
health_url = f"{url}/health"
|
||
|
|
start = time.time()
|
||
|
|
|
||
|
|
print(f"Waiting for API at {health_url}...")
|
||
|
|
while time.time() - start < timeout:
|
||
|
|
try:
|
||
|
|
resp = requests.get(health_url, timeout=5)
|
||
|
|
if resp.status_code == 200:
|
||
|
|
data = resp.json()
|
||
|
|
if data.get("model_loaded"):
|
||
|
|
print(f"API ready! Model loaded in {time.time() - start:.1f}s")
|
||
|
|
return True
|
||
|
|
print(f" Model loading... ({time.time() - start:.0f}s)")
|
||
|
|
except requests.exceptions.ConnectionError:
|
||
|
|
print(f" Connecting... ({time.time() - start:.0f}s)")
|
||
|
|
except Exception as e:
|
||
|
|
print(f" Error: {e}")
|
||
|
|
time.sleep(2)
|
||
|
|
|
||
|
|
print("Timeout waiting for API")
|
||
|
|
return False
|
||
|
|
|
||
|
|
|
||
|
|
def test_evaluate(url: str, config: dict) -> dict:
|
||
|
|
"""Run evaluation with given config."""
|
||
|
|
eval_url = f"{url}/evaluate"
|
||
|
|
|
||
|
|
print(f"\nTesting config: {config}")
|
||
|
|
start = time.time()
|
||
|
|
|
||
|
|
resp = requests.post(eval_url, json=config, timeout=600)
|
||
|
|
resp.raise_for_status()
|
||
|
|
|
||
|
|
result = resp.json()
|
||
|
|
elapsed = time.time() - start
|
||
|
|
|
||
|
|
print(f"Results (took {elapsed:.1f}s):")
|
||
|
|
print(f" CER: {result['CER']:.4f} ({result['CER']*100:.2f}%)")
|
||
|
|
print(f" WER: {result['WER']:.4f} ({result['WER']*100:.2f}%)")
|
||
|
|
print(f" Pages: {result['PAGES']}")
|
||
|
|
print(f" Time/page: {result['TIME_PER_PAGE']:.2f}s")
|
||
|
|
|
||
|
|
return result
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Test PaddleOCR REST API")
|
||
|
|
parser.add_argument("--url", default="http://localhost:8000", help="API base URL")
|
||
|
|
parser.add_argument("--dataset", default="/app/dataset", help="Dataset path (inside container)")
|
||
|
|
parser.add_argument("--skip-health", action="store_true", help="Skip health check wait")
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Wait for API to be ready
|
||
|
|
if not args.skip_health:
|
||
|
|
if not wait_for_health(args.url):
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
# Test 1: Baseline config (default PaddleOCR)
|
||
|
|
print("\n" + "="*50)
|
||
|
|
print("TEST 1: Baseline Configuration")
|
||
|
|
print("="*50)
|
||
|
|
baseline = test_evaluate(args.url, {
|
||
|
|
"pdf_folder": args.dataset,
|
||
|
|
"use_doc_orientation_classify": False,
|
||
|
|
"use_doc_unwarping": False,
|
||
|
|
"textline_orientation": False, # Baseline: disabled
|
||
|
|
"text_det_thresh": 0.0,
|
||
|
|
"text_det_box_thresh": 0.0,
|
||
|
|
"text_det_unclip_ratio": 1.5,
|
||
|
|
"text_rec_score_thresh": 0.0,
|
||
|
|
"start_page": 5,
|
||
|
|
"end_page": 10,
|
||
|
|
})
|
||
|
|
|
||
|
|
# Test 2: Optimized config (from Ray Tune results)
|
||
|
|
print("\n" + "="*50)
|
||
|
|
print("TEST 2: Optimized Configuration")
|
||
|
|
print("="*50)
|
||
|
|
optimized = test_evaluate(args.url, {
|
||
|
|
"pdf_folder": args.dataset,
|
||
|
|
"use_doc_orientation_classify": False,
|
||
|
|
"use_doc_unwarping": False,
|
||
|
|
"textline_orientation": True, # KEY: enabled
|
||
|
|
"text_det_thresh": 0.4690,
|
||
|
|
"text_det_box_thresh": 0.5412,
|
||
|
|
"text_det_unclip_ratio": 0.0,
|
||
|
|
"text_rec_score_thresh": 0.6350,
|
||
|
|
"start_page": 5,
|
||
|
|
"end_page": 10,
|
||
|
|
})
|
||
|
|
|
||
|
|
# Summary
|
||
|
|
print("\n" + "="*50)
|
||
|
|
print("SUMMARY")
|
||
|
|
print("="*50)
|
||
|
|
cer_reduction = (1 - optimized["CER"] / baseline["CER"]) * 100 if baseline["CER"] > 0 else 0
|
||
|
|
print(f"Baseline CER: {baseline['CER']*100:.2f}%")
|
||
|
|
print(f"Optimized CER: {optimized['CER']*100:.2f}%")
|
||
|
|
print(f"Improvement: {cer_reduction:.1f}% reduction in errors")
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|