lock model

This commit is contained in:
2026-01-18 17:38:42 +01:00
parent b29df98602
commit 15bfba79a7
6 changed files with 295 additions and 217 deletions

View File

@@ -5,6 +5,7 @@
import os
import re
import time
import threading
from typing import Optional
from contextlib import asynccontextmanager
@@ -57,6 +58,10 @@ class AppState:
# Track current model config for cache invalidation
current_config: Optional[dict] = None
device: str = "cuda" if torch.cuda.is_available() else "cpu"
lock: threading.Lock = None # Protects OCR model from concurrent access
def __init__(self):
self.lock = threading.Lock()
state = AppState()
@@ -253,23 +258,6 @@ def evaluate(request: EvaluateRequest):
if len(state.dataset) == 0:
raise HTTPException(status_code=400, detail="Dataset is empty")
# Check if model needs to be reinitialized
new_config = {
"assume_straight_pages": request.assume_straight_pages,
"straighten_pages": request.straighten_pages,
"preserve_aspect_ratio": request.preserve_aspect_ratio,
"symmetric_pad": request.symmetric_pad,
"disable_page_orientation": request.disable_page_orientation,
"disable_crop_orientation": request.disable_crop_orientation,
}
model_reinitialized = False
if state.current_config != new_config:
print(f"Model config changed, reinitializing...")
state.model = create_model(**new_config)
state.current_config = new_config
model_reinitialized = True
# Validate page range
start = request.start_page
end = min(request.end_page, len(state.dataset))
@@ -280,24 +268,43 @@ def evaluate(request: EvaluateRequest):
time_per_page_list = []
t0 = time.time()
for idx in range(start, end):
img, ref = state.dataset[idx]
arr = np.array(img)
# Lock to prevent concurrent OCR access (model is not thread-safe)
with state.lock:
# Check if model needs to be reinitialized
new_config = {
"assume_straight_pages": request.assume_straight_pages,
"straighten_pages": request.straighten_pages,
"preserve_aspect_ratio": request.preserve_aspect_ratio,
"symmetric_pad": request.symmetric_pad,
"disable_page_orientation": request.disable_page_orientation,
"disable_crop_orientation": request.disable_crop_orientation,
}
tp0 = time.time()
# DocTR expects a list of images
result = state.model([arr])
model_reinitialized = False
if state.current_config != new_config:
print(f"Model config changed, reinitializing...")
state.model = create_model(**new_config)
state.current_config = new_config
model_reinitialized = True
pred = doctr_result_to_text(
result,
resolve_lines=request.resolve_lines,
resolve_blocks=request.resolve_blocks,
)
time_per_page_list.append(float(time.time() - tp0))
for idx in range(start, end):
img, ref = state.dataset[idx]
arr = np.array(img)
m = evaluate_text(ref, pred)
cer_list.append(m["CER"])
wer_list.append(m["WER"])
tp0 = time.time()
# DocTR expects a list of images
result = state.model([arr])
pred = doctr_result_to_text(
result,
resolve_lines=request.resolve_lines,
resolve_blocks=request.resolve_blocks,
)
time_per_page_list.append(float(time.time() - tp0))
m = evaluate_text(ref, pred)
cer_list.append(m["CER"])
wer_list.append(m["WER"])
return EvaluateResponse(
CER=float(np.mean(cer_list)) if cer_list else 1.0,