# raytune_ocr.py # Shared Ray Tune utilities for OCR hyperparameter optimization # # Usage: # from raytune_ocr import check_workers, create_trainable, run_tuner, analyze_results # # Environment variables: # OCR_HOST: Host for OCR services (default: localhost) import os from datetime import datetime from typing import List, Dict, Any, Callable import requests import pandas as pd import ray from ray import tune from ray.tune.search.optuna import OptunaSearch def check_workers( ports: List[int], service_name: str = "OCR", timeout: int = 180, interval: int = 5, ) -> List[str]: """ Wait for workers to be fully ready (model + dataset loaded) and return healthy URLs. Args: ports: List of port numbers to check service_name: Name for error messages timeout: Max seconds to wait for each worker interval: Seconds between retries Returns: List of healthy worker URLs Raises: RuntimeError if no healthy workers found after timeout """ import time host = os.environ.get("OCR_HOST", "localhost") worker_urls = [f"http://{host}:{port}" for port in ports] healthy_workers = [] for url in worker_urls: print(f"Waiting for {url}...") start = time.time() while time.time() - start < timeout: try: health = requests.get(f"{url}/health", timeout=10).json() model_ok = health.get('model_loaded', False) dataset_ok = health.get('dataset_loaded', False) if health.get('status') == 'ok' and model_ok: gpu = health.get('gpu_name', 'CPU') print(f"✓ {url}: ready ({gpu})") healthy_workers.append(url) break elapsed = int(time.time() - start) print(f" [{elapsed}s] model={model_ok}") except requests.exceptions.RequestException: elapsed = int(time.time() - start) print(f" [{elapsed}s] not reachable") time.sleep(interval) else: print(f"✗ {url}: timeout after {timeout}s") if not healthy_workers: raise RuntimeError( f"No healthy {service_name} workers found.\n" f"Checked ports: {ports}" ) print(f"\n{len(healthy_workers)}/{len(worker_urls)} workers ready\n") return healthy_workers def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Callable: """ Factory to create a trainable function for Ray Tune. Args: ports: List of worker ports for load balancing payload_fn: Function that takes config dict and returns API payload dict Returns: Trainable function for Ray Tune Note: Ray Tune 2.x API: tune.report(metrics_dict) - pass dict directly, NOT kwargs. See: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.report.html """ def trainable(config): import os import random import requests from ray.tune import report # Ray 2.x: report(dict), not report(**kwargs) host = os.environ.get("OCR_HOST", "localhost") api_url = f"http://{host}:{random.choice(ports)}" payload = payload_fn(config) try: response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None) response.raise_for_status() metrics = response.json() metrics["worker"] = api_url report(metrics) # Ray 2.x API: pass dict directly except Exception as e: report({ # Ray 2.x API: pass dict directly "CER": 1.0, "WER": 1.0, "TIME": 0.0, "PAGES": 0, "TIME_PER_PAGE": 0, "worker": api_url, "ERROR": str(e)[:500] }) return trainable def run_tuner( trainable: Callable, search_space: Dict[str, Any], num_samples: int = 64, num_workers: int = 1, metric: str = "CER", mode: str = "min", ) -> tune.ResultGrid: """ Initialize Ray and run hyperparameter tuning. Args: trainable: Trainable function from create_trainable() search_space: Dict of parameter names to tune.* search spaces num_samples: Number of trials to run num_workers: Max concurrent trials metric: Metric to optimize mode: "min" or "max" Returns: Ray Tune ResultGrid """ ray.init( ignore_reinit_error=True, include_dashboard=False, configure_logging=False, _metrics_export_port=0, # Disable metrics export to avoid connection warnings ) print(f"Ray Tune ready (version: {ray.__version__})") tuner = tune.Tuner( trainable, tune_config=tune.TuneConfig( metric=metric, mode=mode, search_alg=OptunaSearch(), num_samples=num_samples, max_concurrent_trials=num_workers, ), param_space=search_space, ) return tuner.fit() def analyze_results( results: tune.ResultGrid, output_folder: str = "results", prefix: str = "raytune", config_keys: List[str] = None, ) -> pd.DataFrame: """ Analyze and save tuning results. Args: results: Ray Tune ResultGrid output_folder: Directory to save CSV prefix: Filename prefix config_keys: List of config keys to show in best result (without 'config/' prefix) Returns: Results DataFrame """ os.makedirs(output_folder, exist_ok=True) df = results.get_dataframe() # Save to CSV timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{prefix}_results_{timestamp}.csv" filepath = os.path.join(output_folder, filename) df.to_csv(filepath, index=False) print(f"Results saved: {filepath}") # Best configuration best = df.loc[df["CER"].idxmin()] print(f"\nBest CER: {best['CER']:.6f}") print(f"Best WER: {best['WER']:.6f}") if config_keys: print(f"\nOptimal Configuration:") for key in config_keys: col = f"config/{key}" if col in best: val = best[col] if isinstance(val, float): print(f" {key}: {val:.4f}") else: print(f" {key}: {val}") return df def correlation_analysis(df: pd.DataFrame, param_keys: List[str]) -> None: """ Print correlation of numeric parameters with CER/WER. Args: df: Results DataFrame param_keys: List of config keys (without 'config/' prefix) """ param_cols = [f"config/{k}" for k in param_keys if f"config/{k}" in df.columns] numeric_cols = [c for c in param_cols if df[c].dtype in ['float64', 'int64']] if not numeric_cols: print("No numeric parameters for correlation analysis") return corr_cer = df[numeric_cols + ["CER"]].corr()["CER"].sort_values(ascending=False) corr_wer = df[numeric_cols + ["WER"]].corr()["WER"].sort_values(ascending=False) print("Correlation with CER:") print(corr_cer) print("\nCorrelation with WER:") print(corr_wer) # ============================================================================= # OCR-specific payload functions # ============================================================================= def paddle_ocr_payload(config: Dict) -> Dict: """Create payload for PaddleOCR API. Uses pages 5-10 (first doc) for tuning.""" return { "pdf_folder": "/app/dataset", "use_doc_orientation_classify": config.get("use_doc_orientation_classify", False), "use_doc_unwarping": config.get("use_doc_unwarping", False), "textline_orientation": config.get("textline_orientation", True), "text_det_thresh": config.get("text_det_thresh", 0.0), "text_det_box_thresh": config.get("text_det_box_thresh", 0.0), "text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5), "text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0), "start_page": 5, "end_page": 10, "save_output": False, } def doctr_payload(config: Dict) -> Dict: """Create payload for DocTR API. Uses pages 5-10 (first doc) for tuning.""" return { "pdf_folder": "/app/dataset", "assume_straight_pages": config.get("assume_straight_pages", True), "straighten_pages": config.get("straighten_pages", False), "preserve_aspect_ratio": config.get("preserve_aspect_ratio", True), "symmetric_pad": config.get("symmetric_pad", True), "disable_page_orientation": config.get("disable_page_orientation", False), "disable_crop_orientation": config.get("disable_crop_orientation", False), "resolve_lines": config.get("resolve_lines", True), "resolve_blocks": config.get("resolve_blocks", False), "paragraph_break": config.get("paragraph_break", 0.035), "start_page": 5, "end_page": 10, "save_output": False, } def easyocr_payload(config: Dict) -> Dict: """Create payload for EasyOCR API. Uses pages 5-10 (first doc) for tuning.""" return { "pdf_folder": "/app/dataset", "text_threshold": config.get("text_threshold", 0.7), "low_text": config.get("low_text", 0.4), "link_threshold": config.get("link_threshold", 0.4), "slope_ths": config.get("slope_ths", 0.1), "ycenter_ths": config.get("ycenter_ths", 0.5), "height_ths": config.get("height_ths", 0.5), "width_ths": config.get("width_ths", 0.5), "add_margin": config.get("add_margin", 0.1), "contrast_ths": config.get("contrast_ths", 0.1), "adjust_contrast": config.get("adjust_contrast", 0.5), "decoder": config.get("decoder", "greedy"), "beamWidth": config.get("beamWidth", 5), "min_size": config.get("min_size", 10), "start_page": 5, "end_page": 10, "save_output": False, } # ============================================================================= # Search spaces # ============================================================================= PADDLE_OCR_SEARCH_SPACE = { "use_doc_orientation_classify": tune.choice([True, False]), "use_doc_unwarping": tune.choice([True, False]), "textline_orientation": tune.choice([True, False]), "text_det_thresh": tune.uniform(0.0, 0.7), "text_det_box_thresh": tune.uniform(0.0, 0.7), "text_det_unclip_ratio": tune.choice([0.0]), "text_rec_score_thresh": tune.uniform(0.0, 0.7), } DOCTR_SEARCH_SPACE = { "assume_straight_pages": tune.choice([True, False]), "straighten_pages": tune.choice([True, False]), "preserve_aspect_ratio": tune.choice([True, False]), "symmetric_pad": tune.choice([True, False]), "disable_page_orientation": tune.choice([True, False]), "disable_crop_orientation": tune.choice([True, False]), "resolve_lines": tune.choice([True, False]), "resolve_blocks": tune.choice([True, False]), "paragraph_break": tune.uniform(0.01, 0.1), } EASYOCR_SEARCH_SPACE = { "text_threshold": tune.uniform(0.3, 0.9), "low_text": tune.uniform(0.2, 0.6), "link_threshold": tune.uniform(0.2, 0.6), "slope_ths": tune.uniform(0.0, 0.3), "ycenter_ths": tune.uniform(0.3, 1.0), "height_ths": tune.uniform(0.3, 1.0), "width_ths": tune.uniform(0.3, 1.0), "add_margin": tune.uniform(0.0, 0.3), "contrast_ths": tune.uniform(0.05, 0.3), "adjust_contrast": tune.uniform(0.3, 0.8), "decoder": tune.choice(["greedy", "beamsearch"]), "beamWidth": tune.choice([3, 5, 7, 10]), "min_size": tune.choice([5, 10, 15, 20]), } # ============================================================================= # Config keys for results display # ============================================================================= PADDLE_OCR_CONFIG_KEYS = [ "use_doc_orientation_classify", "use_doc_unwarping", "textline_orientation", "text_det_thresh", "text_det_box_thresh", "text_det_unclip_ratio", "text_rec_score_thresh", ] DOCTR_CONFIG_KEYS = [ "assume_straight_pages", "straighten_pages", "preserve_aspect_ratio", "symmetric_pad", "disable_page_orientation", "disable_crop_orientation", "resolve_lines", "resolve_blocks", "paragraph_break", ] EASYOCR_CONFIG_KEYS = [ "text_threshold", "low_text", "link_threshold", "slope_ths", "ycenter_ths", "height_ths", "width_ths", "add_margin", "contrast_ths", "adjust_contrast", "decoder", "beamWidth", "min_size", ]