raytune as docker
Some checks failed
build_docker / essential (pull_request) Successful in 1s
build_docker / build_cpu (pull_request) Successful in 4m14s
build_docker / build_easyocr (pull_request) Successful in 12m19s
build_docker / build_easyocr_gpu (pull_request) Successful in 14m2s
build_docker / build_doctr (pull_request) Successful in 12m24s
build_docker / build_doctr_gpu (pull_request) Successful in 13m10s
build_docker / build_raytune (pull_request) Successful in 1m50s
build_docker / build_gpu (pull_request) Has been cancelled

This commit is contained in:
2026-01-19 16:32:45 +01:00
parent d67cbd4677
commit 94b25f9752
20 changed files with 7214 additions and 112 deletions

371
src/raytune/raytune_ocr.py Normal file
View File

@@ -0,0 +1,371 @@
# raytune_ocr.py
# Shared Ray Tune utilities for OCR hyperparameter optimization
#
# Usage:
# from raytune_ocr import check_workers, create_trainable, run_tuner, analyze_results
#
# Environment variables:
# OCR_HOST: Host for OCR services (default: localhost)
import os
from datetime import datetime
from typing import List, Dict, Any, Callable
import requests
import pandas as pd
import ray
from ray import tune
from ray.tune.search.optuna import OptunaSearch
def check_workers(
ports: List[int],
service_name: str = "OCR",
timeout: int = 180,
interval: int = 5,
) -> List[str]:
"""
Wait for workers to be fully ready (model + dataset loaded) and return healthy URLs.
Args:
ports: List of port numbers to check
service_name: Name for error messages
timeout: Max seconds to wait for each worker
interval: Seconds between retries
Returns:
List of healthy worker URLs
Raises:
RuntimeError if no healthy workers found after timeout
"""
import time
host = os.environ.get("OCR_HOST", "localhost")
worker_urls = [f"http://{host}:{port}" for port in ports]
healthy_workers = []
for url in worker_urls:
print(f"Waiting for {url}...")
start = time.time()
while time.time() - start < timeout:
try:
health = requests.get(f"{url}/health", timeout=10).json()
model_ok = health.get('model_loaded', False)
dataset_ok = health.get('dataset_loaded', False)
if health.get('status') == 'ok' and model_ok:
gpu = health.get('gpu_name', 'CPU')
print(f"{url}: ready ({gpu})")
healthy_workers.append(url)
break
elapsed = int(time.time() - start)
print(f" [{elapsed}s] model={model_ok}")
except requests.exceptions.RequestException:
elapsed = int(time.time() - start)
print(f" [{elapsed}s] not reachable")
time.sleep(interval)
else:
print(f"{url}: timeout after {timeout}s")
if not healthy_workers:
raise RuntimeError(
f"No healthy {service_name} workers found.\n"
f"Checked ports: {ports}"
)
print(f"\n{len(healthy_workers)}/{len(worker_urls)} workers ready\n")
return healthy_workers
def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Callable:
"""
Factory to create a trainable function for Ray Tune.
Args:
ports: List of worker ports for load balancing
payload_fn: Function that takes config dict and returns API payload dict
Returns:
Trainable function for Ray Tune
Note:
Ray Tune 2.x API: tune.report(metrics_dict) - pass dict directly, NOT kwargs.
See: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.report.html
"""
def trainable(config):
import os
import random
import requests
from ray.tune import report # Ray 2.x: report(dict), not report(**kwargs)
host = os.environ.get("OCR_HOST", "localhost")
api_url = f"http://{host}:{random.choice(ports)}"
payload = payload_fn(config)
try:
response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None)
response.raise_for_status()
metrics = response.json()
metrics["worker"] = api_url
report(metrics) # Ray 2.x API: pass dict directly
except Exception as e:
report({ # Ray 2.x API: pass dict directly
"CER": 1.0,
"WER": 1.0,
"TIME": 0.0,
"PAGES": 0,
"TIME_PER_PAGE": 0,
"worker": api_url,
"ERROR": str(e)[:500]
})
return trainable
def run_tuner(
trainable: Callable,
search_space: Dict[str, Any],
num_samples: int = 64,
num_workers: int = 1,
metric: str = "CER",
mode: str = "min",
) -> tune.ResultGrid:
"""
Initialize Ray and run hyperparameter tuning.
Args:
trainable: Trainable function from create_trainable()
search_space: Dict of parameter names to tune.* search spaces
num_samples: Number of trials to run
num_workers: Max concurrent trials
metric: Metric to optimize
mode: "min" or "max"
Returns:
Ray Tune ResultGrid
"""
ray.init(
ignore_reinit_error=True,
include_dashboard=False,
configure_logging=False,
_metrics_export_port=0, # Disable metrics export to avoid connection warnings
)
print(f"Ray Tune ready (version: {ray.__version__})")
tuner = tune.Tuner(
trainable,
tune_config=tune.TuneConfig(
metric=metric,
mode=mode,
search_alg=OptunaSearch(),
num_samples=num_samples,
max_concurrent_trials=num_workers,
),
param_space=search_space,
)
return tuner.fit()
def analyze_results(
results: tune.ResultGrid,
output_folder: str = "results",
prefix: str = "raytune",
config_keys: List[str] = None,
) -> pd.DataFrame:
"""
Analyze and save tuning results.
Args:
results: Ray Tune ResultGrid
output_folder: Directory to save CSV
prefix: Filename prefix
config_keys: List of config keys to show in best result (without 'config/' prefix)
Returns:
Results DataFrame
"""
os.makedirs(output_folder, exist_ok=True)
df = results.get_dataframe()
# Save to CSV
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{prefix}_results_{timestamp}.csv"
filepath = os.path.join(output_folder, filename)
df.to_csv(filepath, index=False)
print(f"Results saved: {filepath}")
# Best configuration
best = df.loc[df["CER"].idxmin()]
print(f"\nBest CER: {best['CER']:.6f}")
print(f"Best WER: {best['WER']:.6f}")
if config_keys:
print(f"\nOptimal Configuration:")
for key in config_keys:
col = f"config/{key}"
if col in best:
val = best[col]
if isinstance(val, float):
print(f" {key}: {val:.4f}")
else:
print(f" {key}: {val}")
return df
def correlation_analysis(df: pd.DataFrame, param_keys: List[str]) -> None:
"""
Print correlation of numeric parameters with CER/WER.
Args:
df: Results DataFrame
param_keys: List of config keys (without 'config/' prefix)
"""
param_cols = [f"config/{k}" for k in param_keys if f"config/{k}" in df.columns]
numeric_cols = [c for c in param_cols if df[c].dtype in ['float64', 'int64']]
if not numeric_cols:
print("No numeric parameters for correlation analysis")
return
corr_cer = df[numeric_cols + ["CER"]].corr()["CER"].sort_values(ascending=False)
corr_wer = df[numeric_cols + ["WER"]].corr()["WER"].sort_values(ascending=False)
print("Correlation with CER:")
print(corr_cer)
print("\nCorrelation with WER:")
print(corr_wer)
# =============================================================================
# OCR-specific payload functions
# =============================================================================
def paddle_ocr_payload(config: Dict) -> Dict:
"""Create payload for PaddleOCR API. Uses pages 5-10 (first doc) for tuning."""
return {
"pdf_folder": "/app/dataset",
"use_doc_orientation_classify": config.get("use_doc_orientation_classify", False),
"use_doc_unwarping": config.get("use_doc_unwarping", False),
"textline_orientation": config.get("textline_orientation", True),
"text_det_thresh": config.get("text_det_thresh", 0.0),
"text_det_box_thresh": config.get("text_det_box_thresh", 0.0),
"text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5),
"text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0),
"start_page": 5,
"end_page": 10,
"save_output": False,
}
def doctr_payload(config: Dict) -> Dict:
"""Create payload for DocTR API. Uses pages 5-10 (first doc) for tuning."""
return {
"pdf_folder": "/app/dataset",
"assume_straight_pages": config.get("assume_straight_pages", True),
"straighten_pages": config.get("straighten_pages", False),
"preserve_aspect_ratio": config.get("preserve_aspect_ratio", True),
"symmetric_pad": config.get("symmetric_pad", True),
"disable_page_orientation": config.get("disable_page_orientation", False),
"disable_crop_orientation": config.get("disable_crop_orientation", False),
"resolve_lines": config.get("resolve_lines", True),
"resolve_blocks": config.get("resolve_blocks", False),
"paragraph_break": config.get("paragraph_break", 0.035),
"start_page": 5,
"end_page": 10,
"save_output": False,
}
def easyocr_payload(config: Dict) -> Dict:
"""Create payload for EasyOCR API. Uses pages 5-10 (first doc) for tuning."""
return {
"pdf_folder": "/app/dataset",
"text_threshold": config.get("text_threshold", 0.7),
"low_text": config.get("low_text", 0.4),
"link_threshold": config.get("link_threshold", 0.4),
"slope_ths": config.get("slope_ths", 0.1),
"ycenter_ths": config.get("ycenter_ths", 0.5),
"height_ths": config.get("height_ths", 0.5),
"width_ths": config.get("width_ths", 0.5),
"add_margin": config.get("add_margin", 0.1),
"contrast_ths": config.get("contrast_ths", 0.1),
"adjust_contrast": config.get("adjust_contrast", 0.5),
"decoder": config.get("decoder", "greedy"),
"beamWidth": config.get("beamWidth", 5),
"min_size": config.get("min_size", 10),
"start_page": 5,
"end_page": 10,
"save_output": False,
}
# =============================================================================
# Search spaces
# =============================================================================
PADDLE_OCR_SEARCH_SPACE = {
"use_doc_orientation_classify": tune.choice([True, False]),
"use_doc_unwarping": tune.choice([True, False]),
"textline_orientation": tune.choice([True, False]),
"text_det_thresh": tune.uniform(0.0, 0.7),
"text_det_box_thresh": tune.uniform(0.0, 0.7),
"text_det_unclip_ratio": tune.choice([0.0]),
"text_rec_score_thresh": tune.uniform(0.0, 0.7),
}
DOCTR_SEARCH_SPACE = {
"assume_straight_pages": tune.choice([True, False]),
"straighten_pages": tune.choice([True, False]),
"preserve_aspect_ratio": tune.choice([True, False]),
"symmetric_pad": tune.choice([True, False]),
"disable_page_orientation": tune.choice([True, False]),
"disable_crop_orientation": tune.choice([True, False]),
"resolve_lines": tune.choice([True, False]),
"resolve_blocks": tune.choice([True, False]),
"paragraph_break": tune.uniform(0.01, 0.1),
}
EASYOCR_SEARCH_SPACE = {
"text_threshold": tune.uniform(0.3, 0.9),
"low_text": tune.uniform(0.2, 0.6),
"link_threshold": tune.uniform(0.2, 0.6),
"slope_ths": tune.uniform(0.0, 0.3),
"ycenter_ths": tune.uniform(0.3, 1.0),
"height_ths": tune.uniform(0.3, 1.0),
"width_ths": tune.uniform(0.3, 1.0),
"add_margin": tune.uniform(0.0, 0.3),
"contrast_ths": tune.uniform(0.05, 0.3),
"adjust_contrast": tune.uniform(0.3, 0.8),
"decoder": tune.choice(["greedy", "beamsearch"]),
"beamWidth": tune.choice([3, 5, 7, 10]),
"min_size": tune.choice([5, 10, 15, 20]),
}
# =============================================================================
# Config keys for results display
# =============================================================================
PADDLE_OCR_CONFIG_KEYS = [
"use_doc_orientation_classify", "use_doc_unwarping", "textline_orientation",
"text_det_thresh", "text_det_box_thresh", "text_det_unclip_ratio", "text_rec_score_thresh",
]
DOCTR_CONFIG_KEYS = [
"assume_straight_pages", "straighten_pages", "preserve_aspect_ratio", "symmetric_pad",
"disable_page_orientation", "disable_crop_orientation", "resolve_lines", "resolve_blocks",
"paragraph_break",
]
EASYOCR_CONFIG_KEYS = [
"text_threshold", "low_text", "link_threshold", "slope_ths", "ycenter_ths",
"height_ths", "width_ths", "add_margin", "contrast_ths", "adjust_contrast",
"decoder", "beamWidth", "min_size",
]