Files
MastersThesis/src/paddle_ocr/paddle_ocr_tuning_rest.py
sergio c7ed7b2b9c
All checks were successful
build_docker / essential (push) Successful in 0s
build_docker / build_cpu (push) Successful in 5m0s
build_docker / build_gpu (push) Successful in 22m55s
build_docker / build_easyocr (push) Successful in 18m47s
build_docker / build_easyocr_gpu (push) Successful in 19m0s
build_docker / build_raytune (push) Successful in 3m27s
build_docker / build_doctr (push) Successful in 19m42s
build_docker / build_doctr_gpu (push) Successful in 14m49s
Paddle ocr, easyicr and doctr gpu support. (#4)
2026-01-19 17:35:24 +00:00

341 lines
12 KiB
Python

# paddle_ocr_tuning_rest.py
# FastAPI REST service for PaddleOCR hyperparameter evaluation
# Usage: uvicorn paddle_ocr_tuning_rest:app --host 0.0.0.0 --port 8000
import os
import re
import time
import threading
from typing import Optional
from contextlib import asynccontextmanager
import numpy as np
import paddle
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from paddleocr import PaddleOCR
from jiwer import wer, cer
from dataset_manager import ImageTextDataset
def get_gpu_info() -> dict:
"""Get GPU status information from PaddlePaddle."""
info = {
"cuda_available": paddle.device.is_compiled_with_cuda(),
"device": str(paddle.device.get_device()),
"gpu_count": 0,
"gpu_name": None,
"gpu_memory_total": None,
"gpu_memory_used": None,
}
if info["cuda_available"]:
try:
info["gpu_count"] = paddle.device.cuda.device_count()
if info["gpu_count"] > 0:
# Get GPU properties
props = paddle.device.cuda.get_device_properties(0)
info["gpu_name"] = props.name
info["gpu_memory_total"] = f"{props.total_memory / (1024**3):.2f} GB"
# Get current memory usage
mem_reserved = paddle.device.cuda.memory_reserved(0)
mem_allocated = paddle.device.cuda.memory_allocated(0)
info["gpu_memory_used"] = f"{mem_allocated / (1024**3):.2f} GB"
info["gpu_memory_reserved"] = f"{mem_reserved / (1024**3):.2f} GB"
except Exception as e:
info["gpu_error"] = str(e)
return info
# Model configuration via environment variables (with defaults)
DEFAULT_DET_MODEL = os.environ.get("PADDLE_DET_MODEL", "PP-OCRv5_server_det")
DEFAULT_REC_MODEL = os.environ.get("PADDLE_REC_MODEL", "PP-OCRv5_server_rec")
# Global state for model and dataset
class AppState:
ocr: Optional[PaddleOCR] = None
dataset: Optional[ImageTextDataset] = None
dataset_path: Optional[str] = None
det_model: str = DEFAULT_DET_MODEL
rec_model: str = DEFAULT_REC_MODEL
lock: threading.Lock = None # Protects OCR model from concurrent access
def __init__(self):
self.lock = threading.Lock()
state = AppState()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load OCR model at startup."""
# Log GPU status
gpu_info = get_gpu_info()
print("=" * 50)
print("GPU STATUS")
print("=" * 50)
print(f" CUDA available: {gpu_info['cuda_available']}")
print(f" Device: {gpu_info['device']}")
if gpu_info['cuda_available']:
print(f" GPU count: {gpu_info['gpu_count']}")
print(f" GPU name: {gpu_info['gpu_name']}")
print(f" GPU memory total: {gpu_info['gpu_memory_total']}")
print("=" * 50)
print(f"Loading PaddleOCR models...")
print(f" Detection: {state.det_model}")
print(f" Recognition: {state.rec_model}")
state.ocr = PaddleOCR(
text_detection_model_name=state.det_model,
text_recognition_model_name=state.rec_model,
)
# Log GPU memory after model load
if gpu_info['cuda_available']:
gpu_after = get_gpu_info()
print(f" GPU memory after load: {gpu_after.get('gpu_memory_used', 'N/A')}")
print("Model loaded successfully!")
yield
# Cleanup on shutdown
state.ocr = None
state.dataset = None
app = FastAPI(
title="PaddleOCR Tuning API",
description="REST API for OCR hyperparameter evaluation",
version="1.0.0",
lifespan=lifespan,
)
class EvaluateRequest(BaseModel):
"""Request schema matching CLI arguments."""
pdf_folder: str = Field("/app/dataset", description="Path to dataset folder")
use_doc_orientation_classify: bool = Field(False, description="Use document orientation classification")
use_doc_unwarping: bool = Field(False, description="Use document unwarping")
textline_orientation: bool = Field(True, description="Use textline orientation classification")
text_det_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Detection pixel threshold")
text_det_box_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Detection box threshold")
text_det_unclip_ratio: float = Field(1.5, ge=0.0, description="Text detection expansion coefficient")
text_rec_score_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Recognition score threshold")
start_page: int = Field(5, ge=0, description="Start page index (inclusive)")
end_page: int = Field(10, ge=1, description="End page index (exclusive)")
save_output: bool = Field(False, description="Save OCR predictions to debugset folder")
class EvaluateResponse(BaseModel):
"""Response schema matching CLI output."""
CER: float
WER: float
TIME: float
PAGES: int
TIME_PER_PAGE: float
class HealthResponse(BaseModel):
status: str
model_loaded: bool
dataset_loaded: bool
dataset_size: Optional[int] = None
det_model: Optional[str] = None
rec_model: Optional[str] = None
# GPU info
cuda_available: Optional[bool] = None
device: Optional[str] = None
gpu_name: Optional[str] = None
gpu_memory_used: Optional[str] = None
gpu_memory_total: Optional[str] = None
def _normalize_box_xyxy(box):
"""Normalize bounding box to (x0, y0, x1, y1) format."""
if isinstance(box, (list, tuple)) and box and isinstance(box[0], (list, tuple)):
xs = [p[0] for p in box]
ys = [p[1] for p in box]
return min(xs), min(ys), max(xs), max(ys)
if isinstance(box, (list, tuple)):
if len(box) == 4:
x0, y0, x1, y1 = box
return min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1)
if len(box) == 8:
xs = box[0::2]
ys = box[1::2]
return min(xs), min(ys), max(xs), max(ys)
raise ValueError(f"Unrecognized box format: {box!r}")
def assemble_from_paddle_result(paddleocr_predict, min_score=0.0, line_tol_factor=0.6):
"""
Robust line grouping for PaddleOCR outputs.
Normalizes boxes, groups by line, and returns assembled text.
"""
boxes_all = []
for item in paddleocr_predict:
res = item.json.get("res", {})
boxes = res.get("rec_boxes", []) or []
texts = res.get("rec_texts", []) or []
scores = res.get("rec_scores", None)
for i, (box, text) in enumerate(zip(boxes, texts)):
try:
x0, y0, x1, y1 = _normalize_box_xyxy(box)
except Exception:
continue
y_mid = 0.5 * (y0 + y1)
score = float(scores[i]) if (scores is not None and i < len(scores)) else 1.0
t = re.sub(r"\s+", " ", str(text)).strip()
if not t:
continue
boxes_all.append((x0, y0, x1, y1, y_mid, t, score))
if min_score > 0:
boxes_all = [b for b in boxes_all if b[6] >= min_score]
if not boxes_all:
return ""
# Adaptive line tolerance
heights = [b[3] - b[1] for b in boxes_all]
median_h = float(np.median(heights)) if heights else 20.0
line_tol = max(8.0, line_tol_factor * median_h)
# Sort by vertical mid, then x0
boxes_all.sort(key=lambda b: (b[4], b[0]))
# Group into lines
lines, cur, last_y = [], [], None
for x0, y0, x1, y1, y_mid, text, score in boxes_all:
if last_y is None or abs(y_mid - last_y) <= line_tol:
cur.append((x0, text))
else:
cur.sort(key=lambda t: t[0])
lines.append(" ".join(t[1] for t in cur))
cur = [(x0, text)]
last_y = y_mid
if cur:
cur.sort(key=lambda t: t[0])
lines.append(" ".join(t[1] for t in cur))
res = "\n".join(lines)
res = re.sub(r"\s+\n", "\n", res).strip()
return res
def evaluate_text(reference: str, prediction: str) -> dict:
"""Calculate WER and CER metrics."""
return {"WER": wer(reference, prediction), "CER": cer(reference, prediction)}
@app.get("/health", response_model=HealthResponse)
def health_check():
"""Check if the service is ready."""
gpu_info = get_gpu_info()
return HealthResponse(
status="ok" if state.ocr is not None else "initializing",
model_loaded=state.ocr is not None,
dataset_loaded=state.dataset is not None,
dataset_size=len(state.dataset) if state.dataset else None,
det_model=state.det_model,
rec_model=state.rec_model,
cuda_available=gpu_info.get("cuda_available"),
device=gpu_info.get("device"),
gpu_name=gpu_info.get("gpu_name"),
gpu_memory_used=gpu_info.get("gpu_memory_used"),
gpu_memory_total=gpu_info.get("gpu_memory_total"),
)
@app.post("/evaluate", response_model=EvaluateResponse)
def evaluate(request: EvaluateRequest):
"""
Evaluate OCR with given hyperparameters.
Returns CER, WER, and timing metrics.
"""
if state.ocr is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
# Load or reload dataset if path changed
if state.dataset is None or state.dataset_path != request.pdf_folder:
if not os.path.isdir(request.pdf_folder):
raise HTTPException(status_code=400, detail=f"Dataset folder not found: {request.pdf_folder}")
state.dataset = ImageTextDataset(request.pdf_folder)
state.dataset_path = request.pdf_folder
if len(state.dataset) == 0:
raise HTTPException(status_code=400, detail="Dataset is empty")
# Validate page range
start = request.start_page
end = min(request.end_page, len(state.dataset))
if start >= end:
raise HTTPException(status_code=400, detail=f"Invalid page range: {start}-{end}")
cer_list, wer_list = [], []
time_per_page_list = []
t0 = time.time()
# Lock to prevent concurrent OCR access (model is not thread-safe)
with state.lock:
for idx in range(start, end):
img, ref = state.dataset[idx]
arr = np.array(img)
tp0 = time.time()
out = state.ocr.predict(
arr,
use_doc_orientation_classify=request.use_doc_orientation_classify,
use_doc_unwarping=request.use_doc_unwarping,
use_textline_orientation=request.textline_orientation,
text_det_thresh=request.text_det_thresh,
text_det_box_thresh=request.text_det_box_thresh,
text_det_unclip_ratio=request.text_det_unclip_ratio,
text_rec_score_thresh=request.text_rec_score_thresh,
)
pred = assemble_from_paddle_result(out)
time_per_page_list.append(float(time.time() - tp0))
# Save prediction to debugset if requested
if request.save_output:
out_path = state.dataset.get_output_path(idx, "paddle_text")
with open(out_path, "w", encoding="utf-8") as f:
f.write(pred)
m = evaluate_text(ref, pred)
cer_list.append(m["CER"])
wer_list.append(m["WER"])
return EvaluateResponse(
CER=float(np.mean(cer_list)) if cer_list else 1.0,
WER=float(np.mean(wer_list)) if wer_list else 1.0,
TIME=float(time.time() - t0),
PAGES=len(cer_list),
TIME_PER_PAGE=float(np.mean(time_per_page_list)) if time_per_page_list else 0.0,
)
@app.post("/evaluate_full", response_model=EvaluateResponse)
def evaluate_full(request: EvaluateRequest):
"""Evaluate on ALL pages (ignores start_page/end_page)."""
request.start_page = 0
request.end_page = 9999 # Will be clamped to dataset size
return evaluate(request)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)