# paddle_ocr_tuning_rest.py # FastAPI REST service for PaddleOCR hyperparameter evaluation # Usage: uvicorn paddle_ocr_tuning_rest:app --host 0.0.0.0 --port 8000 import os import re import time import threading from typing import Optional from contextlib import asynccontextmanager import numpy as np import paddle from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from paddleocr import PaddleOCR from jiwer import wer, cer from dataset_manager import ImageTextDataset def get_gpu_info() -> dict: """Get GPU status information from PaddlePaddle.""" info = { "cuda_available": paddle.device.is_compiled_with_cuda(), "device": str(paddle.device.get_device()), "gpu_count": 0, "gpu_name": None, "gpu_memory_total": None, "gpu_memory_used": None, } if info["cuda_available"]: try: info["gpu_count"] = paddle.device.cuda.device_count() if info["gpu_count"] > 0: # Get GPU properties props = paddle.device.cuda.get_device_properties(0) info["gpu_name"] = props.name info["gpu_memory_total"] = f"{props.total_memory / (1024**3):.2f} GB" # Get current memory usage mem_reserved = paddle.device.cuda.memory_reserved(0) mem_allocated = paddle.device.cuda.memory_allocated(0) info["gpu_memory_used"] = f"{mem_allocated / (1024**3):.2f} GB" info["gpu_memory_reserved"] = f"{mem_reserved / (1024**3):.2f} GB" except Exception as e: info["gpu_error"] = str(e) return info # Model configuration via environment variables (with defaults) DEFAULT_DET_MODEL = os.environ.get("PADDLE_DET_MODEL", "PP-OCRv5_server_det") DEFAULT_REC_MODEL = os.environ.get("PADDLE_REC_MODEL", "PP-OCRv5_server_rec") # Global state for model and dataset class AppState: ocr: Optional[PaddleOCR] = None dataset: Optional[ImageTextDataset] = None dataset_path: Optional[str] = None det_model: str = DEFAULT_DET_MODEL rec_model: str = DEFAULT_REC_MODEL lock: threading.Lock = None # Protects OCR model from concurrent access def __init__(self): self.lock = threading.Lock() state = AppState() @asynccontextmanager async def lifespan(app: FastAPI): """Load OCR model at startup.""" # Log GPU status gpu_info = get_gpu_info() print("=" * 50) print("GPU STATUS") print("=" * 50) print(f" CUDA available: {gpu_info['cuda_available']}") print(f" Device: {gpu_info['device']}") if gpu_info['cuda_available']: print(f" GPU count: {gpu_info['gpu_count']}") print(f" GPU name: {gpu_info['gpu_name']}") print(f" GPU memory total: {gpu_info['gpu_memory_total']}") print("=" * 50) print(f"Loading PaddleOCR models...") print(f" Detection: {state.det_model}") print(f" Recognition: {state.rec_model}") state.ocr = PaddleOCR( text_detection_model_name=state.det_model, text_recognition_model_name=state.rec_model, ) # Log GPU memory after model load if gpu_info['cuda_available']: gpu_after = get_gpu_info() print(f" GPU memory after load: {gpu_after.get('gpu_memory_used', 'N/A')}") print("Model loaded successfully!") yield # Cleanup on shutdown state.ocr = None state.dataset = None app = FastAPI( title="PaddleOCR Tuning API", description="REST API for OCR hyperparameter evaluation", version="1.0.0", lifespan=lifespan, ) class EvaluateRequest(BaseModel): """Request schema matching CLI arguments.""" pdf_folder: str = Field("/app/dataset", description="Path to dataset folder") use_doc_orientation_classify: bool = Field(False, description="Use document orientation classification") use_doc_unwarping: bool = Field(False, description="Use document unwarping") textline_orientation: bool = Field(True, description="Use textline orientation classification") text_det_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Detection pixel threshold") text_det_box_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Detection box threshold") text_det_unclip_ratio: float = Field(1.5, ge=0.0, description="Text detection expansion coefficient") text_rec_score_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Recognition score threshold") start_page: int = Field(5, ge=0, description="Start page index (inclusive)") end_page: int = Field(10, ge=1, description="End page index (exclusive)") save_output: bool = Field(False, description="Save OCR predictions to debugset folder") class EvaluateResponse(BaseModel): """Response schema matching CLI output.""" CER: float WER: float TIME: float PAGES: int TIME_PER_PAGE: float class HealthResponse(BaseModel): status: str model_loaded: bool dataset_loaded: bool dataset_size: Optional[int] = None det_model: Optional[str] = None rec_model: Optional[str] = None # GPU info cuda_available: Optional[bool] = None device: Optional[str] = None gpu_name: Optional[str] = None gpu_memory_used: Optional[str] = None gpu_memory_total: Optional[str] = None def _normalize_box_xyxy(box): """Normalize bounding box to (x0, y0, x1, y1) format.""" if isinstance(box, (list, tuple)) and box and isinstance(box[0], (list, tuple)): xs = [p[0] for p in box] ys = [p[1] for p in box] return min(xs), min(ys), max(xs), max(ys) if isinstance(box, (list, tuple)): if len(box) == 4: x0, y0, x1, y1 = box return min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1) if len(box) == 8: xs = box[0::2] ys = box[1::2] return min(xs), min(ys), max(xs), max(ys) raise ValueError(f"Unrecognized box format: {box!r}") def assemble_from_paddle_result(paddleocr_predict, min_score=0.0, line_tol_factor=0.6): """ Robust line grouping for PaddleOCR outputs. Normalizes boxes, groups by line, and returns assembled text. """ boxes_all = [] for item in paddleocr_predict: res = item.json.get("res", {}) boxes = res.get("rec_boxes", []) or [] texts = res.get("rec_texts", []) or [] scores = res.get("rec_scores", None) for i, (box, text) in enumerate(zip(boxes, texts)): try: x0, y0, x1, y1 = _normalize_box_xyxy(box) except Exception: continue y_mid = 0.5 * (y0 + y1) score = float(scores[i]) if (scores is not None and i < len(scores)) else 1.0 t = re.sub(r"\s+", " ", str(text)).strip() if not t: continue boxes_all.append((x0, y0, x1, y1, y_mid, t, score)) if min_score > 0: boxes_all = [b for b in boxes_all if b[6] >= min_score] if not boxes_all: return "" # Adaptive line tolerance heights = [b[3] - b[1] for b in boxes_all] median_h = float(np.median(heights)) if heights else 20.0 line_tol = max(8.0, line_tol_factor * median_h) # Sort by vertical mid, then x0 boxes_all.sort(key=lambda b: (b[4], b[0])) # Group into lines lines, cur, last_y = [], [], None for x0, y0, x1, y1, y_mid, text, score in boxes_all: if last_y is None or abs(y_mid - last_y) <= line_tol: cur.append((x0, text)) else: cur.sort(key=lambda t: t[0]) lines.append(" ".join(t[1] for t in cur)) cur = [(x0, text)] last_y = y_mid if cur: cur.sort(key=lambda t: t[0]) lines.append(" ".join(t[1] for t in cur)) res = "\n".join(lines) res = re.sub(r"\s+\n", "\n", res).strip() return res def evaluate_text(reference: str, prediction: str) -> dict: """Calculate WER and CER metrics.""" return {"WER": wer(reference, prediction), "CER": cer(reference, prediction)} @app.get("/health", response_model=HealthResponse) def health_check(): """Check if the service is ready.""" gpu_info = get_gpu_info() return HealthResponse( status="ok" if state.ocr is not None else "initializing", model_loaded=state.ocr is not None, dataset_loaded=state.dataset is not None, dataset_size=len(state.dataset) if state.dataset else None, det_model=state.det_model, rec_model=state.rec_model, cuda_available=gpu_info.get("cuda_available"), device=gpu_info.get("device"), gpu_name=gpu_info.get("gpu_name"), gpu_memory_used=gpu_info.get("gpu_memory_used"), gpu_memory_total=gpu_info.get("gpu_memory_total"), ) @app.post("/evaluate", response_model=EvaluateResponse) def evaluate(request: EvaluateRequest): """ Evaluate OCR with given hyperparameters. Returns CER, WER, and timing metrics. """ if state.ocr is None: raise HTTPException(status_code=503, detail="Model not loaded yet") # Load or reload dataset if path changed if state.dataset is None or state.dataset_path != request.pdf_folder: if not os.path.isdir(request.pdf_folder): raise HTTPException(status_code=400, detail=f"Dataset folder not found: {request.pdf_folder}") state.dataset = ImageTextDataset(request.pdf_folder) state.dataset_path = request.pdf_folder if len(state.dataset) == 0: raise HTTPException(status_code=400, detail="Dataset is empty") # Validate page range start = request.start_page end = min(request.end_page, len(state.dataset)) if start >= end: raise HTTPException(status_code=400, detail=f"Invalid page range: {start}-{end}") cer_list, wer_list = [], [] time_per_page_list = [] t0 = time.time() # Lock to prevent concurrent OCR access (model is not thread-safe) with state.lock: for idx in range(start, end): img, ref = state.dataset[idx] arr = np.array(img) tp0 = time.time() out = state.ocr.predict( arr, use_doc_orientation_classify=request.use_doc_orientation_classify, use_doc_unwarping=request.use_doc_unwarping, use_textline_orientation=request.textline_orientation, text_det_thresh=request.text_det_thresh, text_det_box_thresh=request.text_det_box_thresh, text_det_unclip_ratio=request.text_det_unclip_ratio, text_rec_score_thresh=request.text_rec_score_thresh, ) pred = assemble_from_paddle_result(out) time_per_page_list.append(float(time.time() - tp0)) # Save prediction to debugset if requested if request.save_output: out_path = state.dataset.get_output_path(idx, "paddle_text") with open(out_path, "w", encoding="utf-8") as f: f.write(pred) m = evaluate_text(ref, pred) cer_list.append(m["CER"]) wer_list.append(m["WER"]) return EvaluateResponse( CER=float(np.mean(cer_list)) if cer_list else 1.0, WER=float(np.mean(wer_list)) if wer_list else 1.0, TIME=float(time.time() - t0), PAGES=len(cer_list), TIME_PER_PAGE=float(np.mean(time_per_page_list)) if time_per_page_list else 0.0, ) @app.post("/evaluate_full", response_model=EvaluateResponse) def evaluate_full(request: EvaluateRequest): """Evaluate on ALL pages (ignores start_page/end_page).""" request.start_page = 0 request.end_page = 9999 # Will be clamped to dataset size return evaluate(request) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)