# doctr_tuning_rest.py # FastAPI REST service for DocTR hyperparameter evaluation # Usage: uvicorn doctr_tuning_rest:app --host 0.0.0.0 --port 8000 import os import re import time from typing import Optional from contextlib import asynccontextmanager import numpy as np import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from doctr.models import ocr_predictor from jiwer import wer, cer from dataset_manager import ImageTextDataset def get_gpu_info() -> dict: """Get GPU status information from PyTorch.""" info = { "cuda_available": torch.cuda.is_available(), "device": "cuda" if torch.cuda.is_available() else "cpu", "gpu_count": 0, "gpu_name": None, "gpu_memory_total": None, "gpu_memory_used": None, } if info["cuda_available"]: try: info["gpu_count"] = torch.cuda.device_count() if info["gpu_count"] > 0: info["gpu_name"] = torch.cuda.get_device_name(0) info["gpu_memory_total"] = f"{torch.cuda.get_device_properties(0).total_memory / (1024**3):.2f} GB" info["gpu_memory_used"] = f"{torch.cuda.memory_allocated(0) / (1024**3):.2f} GB" except Exception as e: info["gpu_error"] = str(e) return info # Model configuration via environment variables DEFAULT_DET_ARCH = os.environ.get("DOCTR_DET_ARCH", "db_resnet50") DEFAULT_RECO_ARCH = os.environ.get("DOCTR_RECO_ARCH", "crnn_vgg16_bn") # Global state for model and dataset class AppState: model: Optional[object] = None dataset: Optional[ImageTextDataset] = None dataset_path: Optional[str] = None det_arch: str = DEFAULT_DET_ARCH reco_arch: str = DEFAULT_RECO_ARCH # Track current model config for cache invalidation current_config: Optional[dict] = None device: str = "cuda" if torch.cuda.is_available() else "cpu" state = AppState() def create_model( assume_straight_pages: bool = True, straighten_pages: bool = False, preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, disable_page_orientation: bool = False, disable_crop_orientation: bool = False, ) -> object: """Create DocTR model with given configuration.""" model = ocr_predictor( det_arch=state.det_arch, reco_arch=state.reco_arch, pretrained=True, assume_straight_pages=assume_straight_pages, straighten_pages=straighten_pages, preserve_aspect_ratio=preserve_aspect_ratio, symmetric_pad=symmetric_pad, ) # Apply orientation settings if supported if hasattr(model, 'disable_page_orientation'): model.disable_page_orientation = disable_page_orientation if hasattr(model, 'disable_crop_orientation'): model.disable_crop_orientation = disable_crop_orientation # Move to GPU if available if state.device == "cuda": model = model.cuda() return model @asynccontextmanager async def lifespan(app: FastAPI): """Load DocTR model at startup with default configuration.""" gpu_info = get_gpu_info() print("=" * 50) print("GPU STATUS") print("=" * 50) print(f" CUDA available: {gpu_info['cuda_available']}") print(f" Device: {gpu_info['device']}") if gpu_info['cuda_available']: print(f" GPU count: {gpu_info['gpu_count']}") print(f" GPU name: {gpu_info['gpu_name']}") print(f" GPU memory total: {gpu_info['gpu_memory_total']}") print("=" * 50) print(f"Loading DocTR models...") print(f" Detection: {state.det_arch}") print(f" Recognition: {state.reco_arch}") # Load with default config state.model = create_model() state.current_config = { "assume_straight_pages": True, "straighten_pages": False, "preserve_aspect_ratio": True, "symmetric_pad": True, "disable_page_orientation": False, "disable_crop_orientation": False, } if gpu_info['cuda_available']: gpu_after = get_gpu_info() print(f" GPU memory after load: {gpu_after.get('gpu_memory_used', 'N/A')}") print("Model loaded successfully!") yield state.model = None state.dataset = None app = FastAPI( title="DocTR Tuning API", description="REST API for DocTR hyperparameter evaluation", version="1.0.0", lifespan=lifespan, ) class EvaluateRequest(BaseModel): """Request schema with all tunable DocTR hyperparameters.""" pdf_folder: str = Field("/app/dataset", description="Path to dataset folder") # Processing flags (require model reinit) assume_straight_pages: bool = Field(True, description="Skip rotation handling for straight documents") straighten_pages: bool = Field(False, description="Pre-straighten pages before detection") preserve_aspect_ratio: bool = Field(True, description="Maintain document proportions during resize") symmetric_pad: bool = Field(True, description="Use symmetric padding when preserving aspect ratio") # Orientation flags disable_page_orientation: bool = Field(False, description="Skip page orientation classification") disable_crop_orientation: bool = Field(False, description="Skip crop orientation detection") # Output grouping resolve_lines: bool = Field(True, description="Group words into lines") resolve_blocks: bool = Field(False, description="Group lines into blocks") paragraph_break: float = Field(0.035, ge=0.0, le=1.0, description="Minimum space ratio separating paragraphs") # Page range start_page: int = Field(5, ge=0, description="Start page index (inclusive)") end_page: int = Field(10, ge=1, description="End page index (exclusive)") class EvaluateResponse(BaseModel): """Response schema matching CLI output.""" CER: float WER: float TIME: float PAGES: int TIME_PER_PAGE: float model_reinitialized: bool = False class HealthResponse(BaseModel): status: str model_loaded: bool dataset_loaded: bool dataset_size: Optional[int] = None det_arch: Optional[str] = None reco_arch: Optional[str] = None cuda_available: Optional[bool] = None device: Optional[str] = None gpu_name: Optional[str] = None gpu_memory_used: Optional[str] = None gpu_memory_total: Optional[str] = None def doctr_result_to_text(result, resolve_lines: bool = True, resolve_blocks: bool = False) -> str: """ Convert DocTR result to plain text. Structure: Document -> pages -> blocks -> lines -> words """ lines = [] for page in result.pages: for block in page.blocks: for line in block.lines: line_text = " ".join([w.value for w in line.words]) lines.append(line_text) if resolve_blocks: lines.append("") # paragraph separator text = " ".join([l for l in lines if l]).strip() text = re.sub(r"\s+", " ", text).strip() return text def evaluate_text(reference: str, prediction: str) -> dict: """Calculate WER and CER metrics.""" return {"WER": wer(reference, prediction), "CER": cer(reference, prediction)} @app.get("/health", response_model=HealthResponse) def health_check(): """Check if the service is ready.""" gpu_info = get_gpu_info() return HealthResponse( status="ok" if state.model is not None else "initializing", model_loaded=state.model is not None, dataset_loaded=state.dataset is not None, dataset_size=len(state.dataset) if state.dataset else None, det_arch=state.det_arch, reco_arch=state.reco_arch, cuda_available=gpu_info.get("cuda_available"), device=gpu_info.get("device"), gpu_name=gpu_info.get("gpu_name"), gpu_memory_used=gpu_info.get("gpu_memory_used"), gpu_memory_total=gpu_info.get("gpu_memory_total"), ) @app.post("/evaluate", response_model=EvaluateResponse) def evaluate(request: EvaluateRequest): """ Evaluate OCR with given hyperparameters. Returns CER, WER, and timing metrics. Note: Model will be reinitialized if processing flags change. """ if state.model is None: raise HTTPException(status_code=503, detail="Model not loaded yet") # Load or reload dataset if path changed if state.dataset is None or state.dataset_path != request.pdf_folder: if not os.path.isdir(request.pdf_folder): raise HTTPException(status_code=400, detail=f"Dataset folder not found: {request.pdf_folder}") state.dataset = ImageTextDataset(request.pdf_folder) state.dataset_path = request.pdf_folder if len(state.dataset) == 0: raise HTTPException(status_code=400, detail="Dataset is empty") # Check if model needs to be reinitialized new_config = { "assume_straight_pages": request.assume_straight_pages, "straighten_pages": request.straighten_pages, "preserve_aspect_ratio": request.preserve_aspect_ratio, "symmetric_pad": request.symmetric_pad, "disable_page_orientation": request.disable_page_orientation, "disable_crop_orientation": request.disable_crop_orientation, } model_reinitialized = False if state.current_config != new_config: print(f"Model config changed, reinitializing...") state.model = create_model(**new_config) state.current_config = new_config model_reinitialized = True # Validate page range start = request.start_page end = min(request.end_page, len(state.dataset)) if start >= end: raise HTTPException(status_code=400, detail=f"Invalid page range: {start}-{end}") cer_list, wer_list = [], [] time_per_page_list = [] t0 = time.time() for idx in range(start, end): img, ref = state.dataset[idx] arr = np.array(img) tp0 = time.time() # DocTR expects a list of images result = state.model([arr]) pred = doctr_result_to_text( result, resolve_lines=request.resolve_lines, resolve_blocks=request.resolve_blocks, ) time_per_page_list.append(float(time.time() - tp0)) m = evaluate_text(ref, pred) cer_list.append(m["CER"]) wer_list.append(m["WER"]) return EvaluateResponse( CER=float(np.mean(cer_list)) if cer_list else 1.0, WER=float(np.mean(wer_list)) if wer_list else 1.0, TIME=float(time.time() - t0), PAGES=len(cer_list), TIME_PER_PAGE=float(np.mean(time_per_page_list)) if time_per_page_list else 0.0, model_reinitialized=model_reinitialized, ) @app.post("/evaluate_full", response_model=EvaluateResponse) def evaluate_full(request: EvaluateRequest): """Evaluate on ALL pages (ignores start_page/end_page).""" request.start_page = 0 request.end_page = 9999 return evaluate(request) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)