Paddle ocr, easyicr and doctr gpu support. (#4)
All checks were successful
build_docker / essential (push) Successful in 0s
build_docker / build_cpu (push) Successful in 5m0s
build_docker / build_gpu (push) Successful in 22m55s
build_docker / build_easyocr (push) Successful in 18m47s
build_docker / build_easyocr_gpu (push) Successful in 19m0s
build_docker / build_raytune (push) Successful in 3m27s
build_docker / build_doctr (push) Successful in 19m42s
build_docker / build_doctr_gpu (push) Successful in 14m49s
All checks were successful
build_docker / essential (push) Successful in 0s
build_docker / build_cpu (push) Successful in 5m0s
build_docker / build_gpu (push) Successful in 22m55s
build_docker / build_easyocr (push) Successful in 18m47s
build_docker / build_easyocr_gpu (push) Successful in 19m0s
build_docker / build_raytune (push) Successful in 3m27s
build_docker / build_doctr (push) Successful in 19m42s
build_docker / build_doctr_gpu (push) Successful in 14m49s
This commit was merged in pull request #4.
This commit is contained in:
340
src/paddle_ocr/paddle_ocr_tuning_rest.py
Normal file
340
src/paddle_ocr/paddle_ocr_tuning_rest.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# paddle_ocr_tuning_rest.py
|
||||
# FastAPI REST service for PaddleOCR hyperparameter evaluation
|
||||
# Usage: uvicorn paddle_ocr_tuning_rest:app --host 0.0.0.0 --port 8000
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
from jiwer import wer, cer
|
||||
from dataset_manager import ImageTextDataset
|
||||
|
||||
|
||||
def get_gpu_info() -> dict:
|
||||
"""Get GPU status information from PaddlePaddle."""
|
||||
info = {
|
||||
"cuda_available": paddle.device.is_compiled_with_cuda(),
|
||||
"device": str(paddle.device.get_device()),
|
||||
"gpu_count": 0,
|
||||
"gpu_name": None,
|
||||
"gpu_memory_total": None,
|
||||
"gpu_memory_used": None,
|
||||
}
|
||||
|
||||
if info["cuda_available"]:
|
||||
try:
|
||||
info["gpu_count"] = paddle.device.cuda.device_count()
|
||||
if info["gpu_count"] > 0:
|
||||
# Get GPU properties
|
||||
props = paddle.device.cuda.get_device_properties(0)
|
||||
info["gpu_name"] = props.name
|
||||
info["gpu_memory_total"] = f"{props.total_memory / (1024**3):.2f} GB"
|
||||
|
||||
# Get current memory usage
|
||||
mem_reserved = paddle.device.cuda.memory_reserved(0)
|
||||
mem_allocated = paddle.device.cuda.memory_allocated(0)
|
||||
info["gpu_memory_used"] = f"{mem_allocated / (1024**3):.2f} GB"
|
||||
info["gpu_memory_reserved"] = f"{mem_reserved / (1024**3):.2f} GB"
|
||||
except Exception as e:
|
||||
info["gpu_error"] = str(e)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
# Model configuration via environment variables (with defaults)
|
||||
DEFAULT_DET_MODEL = os.environ.get("PADDLE_DET_MODEL", "PP-OCRv5_server_det")
|
||||
DEFAULT_REC_MODEL = os.environ.get("PADDLE_REC_MODEL", "PP-OCRv5_server_rec")
|
||||
|
||||
|
||||
# Global state for model and dataset
|
||||
class AppState:
|
||||
ocr: Optional[PaddleOCR] = None
|
||||
dataset: Optional[ImageTextDataset] = None
|
||||
dataset_path: Optional[str] = None
|
||||
det_model: str = DEFAULT_DET_MODEL
|
||||
rec_model: str = DEFAULT_REC_MODEL
|
||||
lock: threading.Lock = None # Protects OCR model from concurrent access
|
||||
|
||||
def __init__(self):
|
||||
self.lock = threading.Lock()
|
||||
|
||||
|
||||
state = AppState()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load OCR model at startup."""
|
||||
# Log GPU status
|
||||
gpu_info = get_gpu_info()
|
||||
print("=" * 50)
|
||||
print("GPU STATUS")
|
||||
print("=" * 50)
|
||||
print(f" CUDA available: {gpu_info['cuda_available']}")
|
||||
print(f" Device: {gpu_info['device']}")
|
||||
if gpu_info['cuda_available']:
|
||||
print(f" GPU count: {gpu_info['gpu_count']}")
|
||||
print(f" GPU name: {gpu_info['gpu_name']}")
|
||||
print(f" GPU memory total: {gpu_info['gpu_memory_total']}")
|
||||
print("=" * 50)
|
||||
|
||||
print(f"Loading PaddleOCR models...")
|
||||
print(f" Detection: {state.det_model}")
|
||||
print(f" Recognition: {state.rec_model}")
|
||||
state.ocr = PaddleOCR(
|
||||
text_detection_model_name=state.det_model,
|
||||
text_recognition_model_name=state.rec_model,
|
||||
)
|
||||
|
||||
# Log GPU memory after model load
|
||||
if gpu_info['cuda_available']:
|
||||
gpu_after = get_gpu_info()
|
||||
print(f" GPU memory after load: {gpu_after.get('gpu_memory_used', 'N/A')}")
|
||||
|
||||
print("Model loaded successfully!")
|
||||
yield
|
||||
# Cleanup on shutdown
|
||||
state.ocr = None
|
||||
state.dataset = None
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="PaddleOCR Tuning API",
|
||||
description="REST API for OCR hyperparameter evaluation",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
||||
class EvaluateRequest(BaseModel):
|
||||
"""Request schema matching CLI arguments."""
|
||||
pdf_folder: str = Field("/app/dataset", description="Path to dataset folder")
|
||||
use_doc_orientation_classify: bool = Field(False, description="Use document orientation classification")
|
||||
use_doc_unwarping: bool = Field(False, description="Use document unwarping")
|
||||
textline_orientation: bool = Field(True, description="Use textline orientation classification")
|
||||
text_det_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Detection pixel threshold")
|
||||
text_det_box_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Detection box threshold")
|
||||
text_det_unclip_ratio: float = Field(1.5, ge=0.0, description="Text detection expansion coefficient")
|
||||
text_rec_score_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Recognition score threshold")
|
||||
start_page: int = Field(5, ge=0, description="Start page index (inclusive)")
|
||||
end_page: int = Field(10, ge=1, description="End page index (exclusive)")
|
||||
save_output: bool = Field(False, description="Save OCR predictions to debugset folder")
|
||||
|
||||
|
||||
class EvaluateResponse(BaseModel):
|
||||
"""Response schema matching CLI output."""
|
||||
CER: float
|
||||
WER: float
|
||||
TIME: float
|
||||
PAGES: int
|
||||
TIME_PER_PAGE: float
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
model_loaded: bool
|
||||
dataset_loaded: bool
|
||||
dataset_size: Optional[int] = None
|
||||
det_model: Optional[str] = None
|
||||
rec_model: Optional[str] = None
|
||||
# GPU info
|
||||
cuda_available: Optional[bool] = None
|
||||
device: Optional[str] = None
|
||||
gpu_name: Optional[str] = None
|
||||
gpu_memory_used: Optional[str] = None
|
||||
gpu_memory_total: Optional[str] = None
|
||||
|
||||
|
||||
def _normalize_box_xyxy(box):
|
||||
"""Normalize bounding box to (x0, y0, x1, y1) format."""
|
||||
if isinstance(box, (list, tuple)) and box and isinstance(box[0], (list, tuple)):
|
||||
xs = [p[0] for p in box]
|
||||
ys = [p[1] for p in box]
|
||||
return min(xs), min(ys), max(xs), max(ys)
|
||||
|
||||
if isinstance(box, (list, tuple)):
|
||||
if len(box) == 4:
|
||||
x0, y0, x1, y1 = box
|
||||
return min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1)
|
||||
if len(box) == 8:
|
||||
xs = box[0::2]
|
||||
ys = box[1::2]
|
||||
return min(xs), min(ys), max(xs), max(ys)
|
||||
|
||||
raise ValueError(f"Unrecognized box format: {box!r}")
|
||||
|
||||
|
||||
def assemble_from_paddle_result(paddleocr_predict, min_score=0.0, line_tol_factor=0.6):
|
||||
"""
|
||||
Robust line grouping for PaddleOCR outputs.
|
||||
Normalizes boxes, groups by line, and returns assembled text.
|
||||
"""
|
||||
boxes_all = []
|
||||
for item in paddleocr_predict:
|
||||
res = item.json.get("res", {})
|
||||
boxes = res.get("rec_boxes", []) or []
|
||||
texts = res.get("rec_texts", []) or []
|
||||
scores = res.get("rec_scores", None)
|
||||
|
||||
for i, (box, text) in enumerate(zip(boxes, texts)):
|
||||
try:
|
||||
x0, y0, x1, y1 = _normalize_box_xyxy(box)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
y_mid = 0.5 * (y0 + y1)
|
||||
score = float(scores[i]) if (scores is not None and i < len(scores)) else 1.0
|
||||
|
||||
t = re.sub(r"\s+", " ", str(text)).strip()
|
||||
if not t:
|
||||
continue
|
||||
|
||||
boxes_all.append((x0, y0, x1, y1, y_mid, t, score))
|
||||
|
||||
if min_score > 0:
|
||||
boxes_all = [b for b in boxes_all if b[6] >= min_score]
|
||||
|
||||
if not boxes_all:
|
||||
return ""
|
||||
|
||||
# Adaptive line tolerance
|
||||
heights = [b[3] - b[1] for b in boxes_all]
|
||||
median_h = float(np.median(heights)) if heights else 20.0
|
||||
line_tol = max(8.0, line_tol_factor * median_h)
|
||||
|
||||
# Sort by vertical mid, then x0
|
||||
boxes_all.sort(key=lambda b: (b[4], b[0]))
|
||||
|
||||
# Group into lines
|
||||
lines, cur, last_y = [], [], None
|
||||
for x0, y0, x1, y1, y_mid, text, score in boxes_all:
|
||||
if last_y is None or abs(y_mid - last_y) <= line_tol:
|
||||
cur.append((x0, text))
|
||||
else:
|
||||
cur.sort(key=lambda t: t[0])
|
||||
lines.append(" ".join(t[1] for t in cur))
|
||||
cur = [(x0, text)]
|
||||
last_y = y_mid
|
||||
|
||||
if cur:
|
||||
cur.sort(key=lambda t: t[0])
|
||||
lines.append(" ".join(t[1] for t in cur))
|
||||
|
||||
res = "\n".join(lines)
|
||||
res = re.sub(r"\s+\n", "\n", res).strip()
|
||||
return res
|
||||
|
||||
|
||||
def evaluate_text(reference: str, prediction: str) -> dict:
|
||||
"""Calculate WER and CER metrics."""
|
||||
return {"WER": wer(reference, prediction), "CER": cer(reference, prediction)}
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
def health_check():
|
||||
"""Check if the service is ready."""
|
||||
gpu_info = get_gpu_info()
|
||||
return HealthResponse(
|
||||
status="ok" if state.ocr is not None else "initializing",
|
||||
model_loaded=state.ocr is not None,
|
||||
dataset_loaded=state.dataset is not None,
|
||||
dataset_size=len(state.dataset) if state.dataset else None,
|
||||
det_model=state.det_model,
|
||||
rec_model=state.rec_model,
|
||||
cuda_available=gpu_info.get("cuda_available"),
|
||||
device=gpu_info.get("device"),
|
||||
gpu_name=gpu_info.get("gpu_name"),
|
||||
gpu_memory_used=gpu_info.get("gpu_memory_used"),
|
||||
gpu_memory_total=gpu_info.get("gpu_memory_total"),
|
||||
)
|
||||
|
||||
|
||||
@app.post("/evaluate", response_model=EvaluateResponse)
|
||||
def evaluate(request: EvaluateRequest):
|
||||
"""
|
||||
Evaluate OCR with given hyperparameters.
|
||||
Returns CER, WER, and timing metrics.
|
||||
"""
|
||||
if state.ocr is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded yet")
|
||||
|
||||
# Load or reload dataset if path changed
|
||||
if state.dataset is None or state.dataset_path != request.pdf_folder:
|
||||
if not os.path.isdir(request.pdf_folder):
|
||||
raise HTTPException(status_code=400, detail=f"Dataset folder not found: {request.pdf_folder}")
|
||||
state.dataset = ImageTextDataset(request.pdf_folder)
|
||||
state.dataset_path = request.pdf_folder
|
||||
|
||||
if len(state.dataset) == 0:
|
||||
raise HTTPException(status_code=400, detail="Dataset is empty")
|
||||
|
||||
# Validate page range
|
||||
start = request.start_page
|
||||
end = min(request.end_page, len(state.dataset))
|
||||
if start >= end:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid page range: {start}-{end}")
|
||||
|
||||
cer_list, wer_list = [], []
|
||||
time_per_page_list = []
|
||||
t0 = time.time()
|
||||
|
||||
# Lock to prevent concurrent OCR access (model is not thread-safe)
|
||||
with state.lock:
|
||||
for idx in range(start, end):
|
||||
img, ref = state.dataset[idx]
|
||||
arr = np.array(img)
|
||||
|
||||
tp0 = time.time()
|
||||
out = state.ocr.predict(
|
||||
arr,
|
||||
use_doc_orientation_classify=request.use_doc_orientation_classify,
|
||||
use_doc_unwarping=request.use_doc_unwarping,
|
||||
use_textline_orientation=request.textline_orientation,
|
||||
text_det_thresh=request.text_det_thresh,
|
||||
text_det_box_thresh=request.text_det_box_thresh,
|
||||
text_det_unclip_ratio=request.text_det_unclip_ratio,
|
||||
text_rec_score_thresh=request.text_rec_score_thresh,
|
||||
)
|
||||
|
||||
pred = assemble_from_paddle_result(out)
|
||||
time_per_page_list.append(float(time.time() - tp0))
|
||||
|
||||
# Save prediction to debugset if requested
|
||||
if request.save_output:
|
||||
out_path = state.dataset.get_output_path(idx, "paddle_text")
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
f.write(pred)
|
||||
|
||||
m = evaluate_text(ref, pred)
|
||||
cer_list.append(m["CER"])
|
||||
wer_list.append(m["WER"])
|
||||
|
||||
return EvaluateResponse(
|
||||
CER=float(np.mean(cer_list)) if cer_list else 1.0,
|
||||
WER=float(np.mean(wer_list)) if wer_list else 1.0,
|
||||
TIME=float(time.time() - t0),
|
||||
PAGES=len(cer_list),
|
||||
TIME_PER_PAGE=float(np.mean(time_per_page_list)) if time_per_page_list else 0.0,
|
||||
)
|
||||
|
||||
|
||||
@app.post("/evaluate_full", response_model=EvaluateResponse)
|
||||
def evaluate_full(request: EvaluateRequest):
|
||||
"""Evaluate on ALL pages (ignores start_page/end_page)."""
|
||||
request.start_page = 0
|
||||
request.end_page = 9999 # Will be clamped to dataset size
|
||||
return evaluate(request)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user