Some checks failed
build_docker / essential (pull_request) Successful in 1s
build_docker / build_cpu (pull_request) Successful in 4m14s
build_docker / build_easyocr (pull_request) Successful in 12m19s
build_docker / build_easyocr_gpu (pull_request) Successful in 14m2s
build_docker / build_doctr (pull_request) Successful in 12m24s
build_docker / build_doctr_gpu (pull_request) Successful in 13m10s
build_docker / build_raytune (pull_request) Successful in 1m50s
build_docker / build_gpu (pull_request) Has been cancelled
372 lines
13 KiB
Python
372 lines
13 KiB
Python
# raytune_ocr.py
|
|
# Shared Ray Tune utilities for OCR hyperparameter optimization
|
|
#
|
|
# Usage:
|
|
# from raytune_ocr import check_workers, create_trainable, run_tuner, analyze_results
|
|
#
|
|
# Environment variables:
|
|
# OCR_HOST: Host for OCR services (default: localhost)
|
|
|
|
import os
|
|
from datetime import datetime
|
|
from typing import List, Dict, Any, Callable
|
|
|
|
import requests
|
|
import pandas as pd
|
|
|
|
import ray
|
|
from ray import tune
|
|
from ray.tune.search.optuna import OptunaSearch
|
|
|
|
|
|
def check_workers(
|
|
ports: List[int],
|
|
service_name: str = "OCR",
|
|
timeout: int = 180,
|
|
interval: int = 5,
|
|
) -> List[str]:
|
|
"""
|
|
Wait for workers to be fully ready (model + dataset loaded) and return healthy URLs.
|
|
|
|
Args:
|
|
ports: List of port numbers to check
|
|
service_name: Name for error messages
|
|
timeout: Max seconds to wait for each worker
|
|
interval: Seconds between retries
|
|
|
|
Returns:
|
|
List of healthy worker URLs
|
|
|
|
Raises:
|
|
RuntimeError if no healthy workers found after timeout
|
|
"""
|
|
import time
|
|
|
|
host = os.environ.get("OCR_HOST", "localhost")
|
|
worker_urls = [f"http://{host}:{port}" for port in ports]
|
|
healthy_workers = []
|
|
|
|
for url in worker_urls:
|
|
print(f"Waiting for {url}...")
|
|
start = time.time()
|
|
|
|
while time.time() - start < timeout:
|
|
try:
|
|
health = requests.get(f"{url}/health", timeout=10).json()
|
|
model_ok = health.get('model_loaded', False)
|
|
dataset_ok = health.get('dataset_loaded', False)
|
|
|
|
if health.get('status') == 'ok' and model_ok:
|
|
gpu = health.get('gpu_name', 'CPU')
|
|
print(f"✓ {url}: ready ({gpu})")
|
|
healthy_workers.append(url)
|
|
break
|
|
|
|
elapsed = int(time.time() - start)
|
|
print(f" [{elapsed}s] model={model_ok}")
|
|
except requests.exceptions.RequestException:
|
|
elapsed = int(time.time() - start)
|
|
print(f" [{elapsed}s] not reachable")
|
|
|
|
time.sleep(interval)
|
|
else:
|
|
print(f"✗ {url}: timeout after {timeout}s")
|
|
|
|
if not healthy_workers:
|
|
raise RuntimeError(
|
|
f"No healthy {service_name} workers found.\n"
|
|
f"Checked ports: {ports}"
|
|
)
|
|
|
|
print(f"\n{len(healthy_workers)}/{len(worker_urls)} workers ready\n")
|
|
return healthy_workers
|
|
|
|
|
|
def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Callable:
|
|
"""
|
|
Factory to create a trainable function for Ray Tune.
|
|
|
|
Args:
|
|
ports: List of worker ports for load balancing
|
|
payload_fn: Function that takes config dict and returns API payload dict
|
|
|
|
Returns:
|
|
Trainable function for Ray Tune
|
|
|
|
Note:
|
|
Ray Tune 2.x API: tune.report(metrics_dict) - pass dict directly, NOT kwargs.
|
|
See: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.report.html
|
|
"""
|
|
def trainable(config):
|
|
import os
|
|
import random
|
|
import requests
|
|
from ray.tune import report # Ray 2.x: report(dict), not report(**kwargs)
|
|
|
|
host = os.environ.get("OCR_HOST", "localhost")
|
|
api_url = f"http://{host}:{random.choice(ports)}"
|
|
payload = payload_fn(config)
|
|
|
|
try:
|
|
response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None)
|
|
response.raise_for_status()
|
|
metrics = response.json()
|
|
metrics["worker"] = api_url
|
|
report(metrics) # Ray 2.x API: pass dict directly
|
|
except Exception as e:
|
|
report({ # Ray 2.x API: pass dict directly
|
|
"CER": 1.0,
|
|
"WER": 1.0,
|
|
"TIME": 0.0,
|
|
"PAGES": 0,
|
|
"TIME_PER_PAGE": 0,
|
|
"worker": api_url,
|
|
"ERROR": str(e)[:500]
|
|
})
|
|
|
|
return trainable
|
|
|
|
|
|
def run_tuner(
|
|
trainable: Callable,
|
|
search_space: Dict[str, Any],
|
|
num_samples: int = 64,
|
|
num_workers: int = 1,
|
|
metric: str = "CER",
|
|
mode: str = "min",
|
|
) -> tune.ResultGrid:
|
|
"""
|
|
Initialize Ray and run hyperparameter tuning.
|
|
|
|
Args:
|
|
trainable: Trainable function from create_trainable()
|
|
search_space: Dict of parameter names to tune.* search spaces
|
|
num_samples: Number of trials to run
|
|
num_workers: Max concurrent trials
|
|
metric: Metric to optimize
|
|
mode: "min" or "max"
|
|
|
|
Returns:
|
|
Ray Tune ResultGrid
|
|
"""
|
|
ray.init(
|
|
ignore_reinit_error=True,
|
|
include_dashboard=False,
|
|
configure_logging=False,
|
|
_metrics_export_port=0, # Disable metrics export to avoid connection warnings
|
|
)
|
|
print(f"Ray Tune ready (version: {ray.__version__})")
|
|
|
|
tuner = tune.Tuner(
|
|
trainable,
|
|
tune_config=tune.TuneConfig(
|
|
metric=metric,
|
|
mode=mode,
|
|
search_alg=OptunaSearch(),
|
|
num_samples=num_samples,
|
|
max_concurrent_trials=num_workers,
|
|
),
|
|
param_space=search_space,
|
|
)
|
|
|
|
return tuner.fit()
|
|
|
|
|
|
def analyze_results(
|
|
results: tune.ResultGrid,
|
|
output_folder: str = "results",
|
|
prefix: str = "raytune",
|
|
config_keys: List[str] = None,
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Analyze and save tuning results.
|
|
|
|
Args:
|
|
results: Ray Tune ResultGrid
|
|
output_folder: Directory to save CSV
|
|
prefix: Filename prefix
|
|
config_keys: List of config keys to show in best result (without 'config/' prefix)
|
|
|
|
Returns:
|
|
Results DataFrame
|
|
"""
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
df = results.get_dataframe()
|
|
|
|
# Save to CSV
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
filename = f"{prefix}_results_{timestamp}.csv"
|
|
filepath = os.path.join(output_folder, filename)
|
|
df.to_csv(filepath, index=False)
|
|
print(f"Results saved: {filepath}")
|
|
|
|
# Best configuration
|
|
best = df.loc[df["CER"].idxmin()]
|
|
print(f"\nBest CER: {best['CER']:.6f}")
|
|
print(f"Best WER: {best['WER']:.6f}")
|
|
|
|
if config_keys:
|
|
print(f"\nOptimal Configuration:")
|
|
for key in config_keys:
|
|
col = f"config/{key}"
|
|
if col in best:
|
|
val = best[col]
|
|
if isinstance(val, float):
|
|
print(f" {key}: {val:.4f}")
|
|
else:
|
|
print(f" {key}: {val}")
|
|
|
|
return df
|
|
|
|
|
|
def correlation_analysis(df: pd.DataFrame, param_keys: List[str]) -> None:
|
|
"""
|
|
Print correlation of numeric parameters with CER/WER.
|
|
|
|
Args:
|
|
df: Results DataFrame
|
|
param_keys: List of config keys (without 'config/' prefix)
|
|
"""
|
|
param_cols = [f"config/{k}" for k in param_keys if f"config/{k}" in df.columns]
|
|
numeric_cols = [c for c in param_cols if df[c].dtype in ['float64', 'int64']]
|
|
|
|
if not numeric_cols:
|
|
print("No numeric parameters for correlation analysis")
|
|
return
|
|
|
|
corr_cer = df[numeric_cols + ["CER"]].corr()["CER"].sort_values(ascending=False)
|
|
corr_wer = df[numeric_cols + ["WER"]].corr()["WER"].sort_values(ascending=False)
|
|
|
|
print("Correlation with CER:")
|
|
print(corr_cer)
|
|
print("\nCorrelation with WER:")
|
|
print(corr_wer)
|
|
|
|
|
|
# =============================================================================
|
|
# OCR-specific payload functions
|
|
# =============================================================================
|
|
|
|
def paddle_ocr_payload(config: Dict) -> Dict:
|
|
"""Create payload for PaddleOCR API. Uses pages 5-10 (first doc) for tuning."""
|
|
return {
|
|
"pdf_folder": "/app/dataset",
|
|
"use_doc_orientation_classify": config.get("use_doc_orientation_classify", False),
|
|
"use_doc_unwarping": config.get("use_doc_unwarping", False),
|
|
"textline_orientation": config.get("textline_orientation", True),
|
|
"text_det_thresh": config.get("text_det_thresh", 0.0),
|
|
"text_det_box_thresh": config.get("text_det_box_thresh", 0.0),
|
|
"text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5),
|
|
"text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0),
|
|
"start_page": 5,
|
|
"end_page": 10,
|
|
"save_output": False,
|
|
}
|
|
|
|
|
|
def doctr_payload(config: Dict) -> Dict:
|
|
"""Create payload for DocTR API. Uses pages 5-10 (first doc) for tuning."""
|
|
return {
|
|
"pdf_folder": "/app/dataset",
|
|
"assume_straight_pages": config.get("assume_straight_pages", True),
|
|
"straighten_pages": config.get("straighten_pages", False),
|
|
"preserve_aspect_ratio": config.get("preserve_aspect_ratio", True),
|
|
"symmetric_pad": config.get("symmetric_pad", True),
|
|
"disable_page_orientation": config.get("disable_page_orientation", False),
|
|
"disable_crop_orientation": config.get("disable_crop_orientation", False),
|
|
"resolve_lines": config.get("resolve_lines", True),
|
|
"resolve_blocks": config.get("resolve_blocks", False),
|
|
"paragraph_break": config.get("paragraph_break", 0.035),
|
|
"start_page": 5,
|
|
"end_page": 10,
|
|
"save_output": False,
|
|
}
|
|
|
|
|
|
def easyocr_payload(config: Dict) -> Dict:
|
|
"""Create payload for EasyOCR API. Uses pages 5-10 (first doc) for tuning."""
|
|
return {
|
|
"pdf_folder": "/app/dataset",
|
|
"text_threshold": config.get("text_threshold", 0.7),
|
|
"low_text": config.get("low_text", 0.4),
|
|
"link_threshold": config.get("link_threshold", 0.4),
|
|
"slope_ths": config.get("slope_ths", 0.1),
|
|
"ycenter_ths": config.get("ycenter_ths", 0.5),
|
|
"height_ths": config.get("height_ths", 0.5),
|
|
"width_ths": config.get("width_ths", 0.5),
|
|
"add_margin": config.get("add_margin", 0.1),
|
|
"contrast_ths": config.get("contrast_ths", 0.1),
|
|
"adjust_contrast": config.get("adjust_contrast", 0.5),
|
|
"decoder": config.get("decoder", "greedy"),
|
|
"beamWidth": config.get("beamWidth", 5),
|
|
"min_size": config.get("min_size", 10),
|
|
"start_page": 5,
|
|
"end_page": 10,
|
|
"save_output": False,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Search spaces
|
|
# =============================================================================
|
|
|
|
PADDLE_OCR_SEARCH_SPACE = {
|
|
"use_doc_orientation_classify": tune.choice([True, False]),
|
|
"use_doc_unwarping": tune.choice([True, False]),
|
|
"textline_orientation": tune.choice([True, False]),
|
|
"text_det_thresh": tune.uniform(0.0, 0.7),
|
|
"text_det_box_thresh": tune.uniform(0.0, 0.7),
|
|
"text_det_unclip_ratio": tune.choice([0.0]),
|
|
"text_rec_score_thresh": tune.uniform(0.0, 0.7),
|
|
}
|
|
|
|
DOCTR_SEARCH_SPACE = {
|
|
"assume_straight_pages": tune.choice([True, False]),
|
|
"straighten_pages": tune.choice([True, False]),
|
|
"preserve_aspect_ratio": tune.choice([True, False]),
|
|
"symmetric_pad": tune.choice([True, False]),
|
|
"disable_page_orientation": tune.choice([True, False]),
|
|
"disable_crop_orientation": tune.choice([True, False]),
|
|
"resolve_lines": tune.choice([True, False]),
|
|
"resolve_blocks": tune.choice([True, False]),
|
|
"paragraph_break": tune.uniform(0.01, 0.1),
|
|
}
|
|
|
|
EASYOCR_SEARCH_SPACE = {
|
|
"text_threshold": tune.uniform(0.3, 0.9),
|
|
"low_text": tune.uniform(0.2, 0.6),
|
|
"link_threshold": tune.uniform(0.2, 0.6),
|
|
"slope_ths": tune.uniform(0.0, 0.3),
|
|
"ycenter_ths": tune.uniform(0.3, 1.0),
|
|
"height_ths": tune.uniform(0.3, 1.0),
|
|
"width_ths": tune.uniform(0.3, 1.0),
|
|
"add_margin": tune.uniform(0.0, 0.3),
|
|
"contrast_ths": tune.uniform(0.05, 0.3),
|
|
"adjust_contrast": tune.uniform(0.3, 0.8),
|
|
"decoder": tune.choice(["greedy", "beamsearch"]),
|
|
"beamWidth": tune.choice([3, 5, 7, 10]),
|
|
"min_size": tune.choice([5, 10, 15, 20]),
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Config keys for results display
|
|
# =============================================================================
|
|
|
|
PADDLE_OCR_CONFIG_KEYS = [
|
|
"use_doc_orientation_classify", "use_doc_unwarping", "textline_orientation",
|
|
"text_det_thresh", "text_det_box_thresh", "text_det_unclip_ratio", "text_rec_score_thresh",
|
|
]
|
|
|
|
DOCTR_CONFIG_KEYS = [
|
|
"assume_straight_pages", "straighten_pages", "preserve_aspect_ratio", "symmetric_pad",
|
|
"disable_page_orientation", "disable_crop_orientation", "resolve_lines", "resolve_blocks",
|
|
"paragraph_break",
|
|
]
|
|
|
|
EASYOCR_CONFIG_KEYS = [
|
|
"text_threshold", "low_text", "link_threshold", "slope_ths", "ycenter_ths",
|
|
"height_ths", "width_ths", "add_margin", "contrast_ths", "adjust_contrast",
|
|
"decoder", "beamWidth", "min_size",
|
|
]
|