Paddle ocr, easyicr and doctr gpu support. (#4)
All checks were successful
build_docker / essential (push) Successful in 0s
build_docker / build_cpu (push) Successful in 5m0s
build_docker / build_gpu (push) Successful in 22m55s
build_docker / build_easyocr (push) Successful in 18m47s
build_docker / build_easyocr_gpu (push) Successful in 19m0s
build_docker / build_raytune (push) Successful in 3m27s
build_docker / build_doctr (push) Successful in 19m42s
build_docker / build_doctr_gpu (push) Successful in 14m49s
All checks were successful
build_docker / essential (push) Successful in 0s
build_docker / build_cpu (push) Successful in 5m0s
build_docker / build_gpu (push) Successful in 22m55s
build_docker / build_easyocr (push) Successful in 18m47s
build_docker / build_easyocr_gpu (push) Successful in 19m0s
build_docker / build_raytune (push) Successful in 3m27s
build_docker / build_doctr (push) Successful in 19m42s
build_docker / build_doctr_gpu (push) Successful in 14m49s
This commit was merged in pull request #4.
This commit is contained in:
371
src/raytune/raytune_ocr.py
Normal file
371
src/raytune/raytune_ocr.py
Normal file
@@ -0,0 +1,371 @@
|
||||
# raytune_ocr.py
|
||||
# Shared Ray Tune utilities for OCR hyperparameter optimization
|
||||
#
|
||||
# Usage:
|
||||
# from raytune_ocr import check_workers, create_trainable, run_tuner, analyze_results
|
||||
#
|
||||
# Environment variables:
|
||||
# OCR_HOST: Host for OCR services (default: localhost)
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Callable
|
||||
|
||||
import requests
|
||||
import pandas as pd
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune.search.optuna import OptunaSearch
|
||||
|
||||
|
||||
def check_workers(
|
||||
ports: List[int],
|
||||
service_name: str = "OCR",
|
||||
timeout: int = 180,
|
||||
interval: int = 5,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Wait for workers to be fully ready (model + dataset loaded) and return healthy URLs.
|
||||
|
||||
Args:
|
||||
ports: List of port numbers to check
|
||||
service_name: Name for error messages
|
||||
timeout: Max seconds to wait for each worker
|
||||
interval: Seconds between retries
|
||||
|
||||
Returns:
|
||||
List of healthy worker URLs
|
||||
|
||||
Raises:
|
||||
RuntimeError if no healthy workers found after timeout
|
||||
"""
|
||||
import time
|
||||
|
||||
host = os.environ.get("OCR_HOST", "localhost")
|
||||
worker_urls = [f"http://{host}:{port}" for port in ports]
|
||||
healthy_workers = []
|
||||
|
||||
for url in worker_urls:
|
||||
print(f"Waiting for {url}...")
|
||||
start = time.time()
|
||||
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
health = requests.get(f"{url}/health", timeout=10).json()
|
||||
model_ok = health.get('model_loaded', False)
|
||||
dataset_ok = health.get('dataset_loaded', False)
|
||||
|
||||
if health.get('status') == 'ok' and model_ok:
|
||||
gpu = health.get('gpu_name', 'CPU')
|
||||
print(f"✓ {url}: ready ({gpu})")
|
||||
healthy_workers.append(url)
|
||||
break
|
||||
|
||||
elapsed = int(time.time() - start)
|
||||
print(f" [{elapsed}s] model={model_ok}")
|
||||
except requests.exceptions.RequestException:
|
||||
elapsed = int(time.time() - start)
|
||||
print(f" [{elapsed}s] not reachable")
|
||||
|
||||
time.sleep(interval)
|
||||
else:
|
||||
print(f"✗ {url}: timeout after {timeout}s")
|
||||
|
||||
if not healthy_workers:
|
||||
raise RuntimeError(
|
||||
f"No healthy {service_name} workers found.\n"
|
||||
f"Checked ports: {ports}"
|
||||
)
|
||||
|
||||
print(f"\n{len(healthy_workers)}/{len(worker_urls)} workers ready\n")
|
||||
return healthy_workers
|
||||
|
||||
|
||||
def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Callable:
|
||||
"""
|
||||
Factory to create a trainable function for Ray Tune.
|
||||
|
||||
Args:
|
||||
ports: List of worker ports for load balancing
|
||||
payload_fn: Function that takes config dict and returns API payload dict
|
||||
|
||||
Returns:
|
||||
Trainable function for Ray Tune
|
||||
|
||||
Note:
|
||||
Ray Tune 2.x API: tune.report(metrics_dict) - pass dict directly, NOT kwargs.
|
||||
See: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.report.html
|
||||
"""
|
||||
def trainable(config):
|
||||
import os
|
||||
import random
|
||||
import requests
|
||||
from ray.tune import report # Ray 2.x: report(dict), not report(**kwargs)
|
||||
|
||||
host = os.environ.get("OCR_HOST", "localhost")
|
||||
api_url = f"http://{host}:{random.choice(ports)}"
|
||||
payload = payload_fn(config)
|
||||
|
||||
try:
|
||||
response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None)
|
||||
response.raise_for_status()
|
||||
metrics = response.json()
|
||||
metrics["worker"] = api_url
|
||||
report(metrics) # Ray 2.x API: pass dict directly
|
||||
except Exception as e:
|
||||
report({ # Ray 2.x API: pass dict directly
|
||||
"CER": 1.0,
|
||||
"WER": 1.0,
|
||||
"TIME": 0.0,
|
||||
"PAGES": 0,
|
||||
"TIME_PER_PAGE": 0,
|
||||
"worker": api_url,
|
||||
"ERROR": str(e)[:500]
|
||||
})
|
||||
|
||||
return trainable
|
||||
|
||||
|
||||
def run_tuner(
|
||||
trainable: Callable,
|
||||
search_space: Dict[str, Any],
|
||||
num_samples: int = 64,
|
||||
num_workers: int = 1,
|
||||
metric: str = "CER",
|
||||
mode: str = "min",
|
||||
) -> tune.ResultGrid:
|
||||
"""
|
||||
Initialize Ray and run hyperparameter tuning.
|
||||
|
||||
Args:
|
||||
trainable: Trainable function from create_trainable()
|
||||
search_space: Dict of parameter names to tune.* search spaces
|
||||
num_samples: Number of trials to run
|
||||
num_workers: Max concurrent trials
|
||||
metric: Metric to optimize
|
||||
mode: "min" or "max"
|
||||
|
||||
Returns:
|
||||
Ray Tune ResultGrid
|
||||
"""
|
||||
ray.init(
|
||||
ignore_reinit_error=True,
|
||||
include_dashboard=False,
|
||||
configure_logging=False,
|
||||
_metrics_export_port=0, # Disable metrics export to avoid connection warnings
|
||||
)
|
||||
print(f"Ray Tune ready (version: {ray.__version__})")
|
||||
|
||||
tuner = tune.Tuner(
|
||||
trainable,
|
||||
tune_config=tune.TuneConfig(
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
search_alg=OptunaSearch(),
|
||||
num_samples=num_samples,
|
||||
max_concurrent_trials=num_workers,
|
||||
),
|
||||
param_space=search_space,
|
||||
)
|
||||
|
||||
return tuner.fit()
|
||||
|
||||
|
||||
def analyze_results(
|
||||
results: tune.ResultGrid,
|
||||
output_folder: str = "results",
|
||||
prefix: str = "raytune",
|
||||
config_keys: List[str] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
Analyze and save tuning results.
|
||||
|
||||
Args:
|
||||
results: Ray Tune ResultGrid
|
||||
output_folder: Directory to save CSV
|
||||
prefix: Filename prefix
|
||||
config_keys: List of config keys to show in best result (without 'config/' prefix)
|
||||
|
||||
Returns:
|
||||
Results DataFrame
|
||||
"""
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
df = results.get_dataframe()
|
||||
|
||||
# Save to CSV
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{prefix}_results_{timestamp}.csv"
|
||||
filepath = os.path.join(output_folder, filename)
|
||||
df.to_csv(filepath, index=False)
|
||||
print(f"Results saved: {filepath}")
|
||||
|
||||
# Best configuration
|
||||
best = df.loc[df["CER"].idxmin()]
|
||||
print(f"\nBest CER: {best['CER']:.6f}")
|
||||
print(f"Best WER: {best['WER']:.6f}")
|
||||
|
||||
if config_keys:
|
||||
print(f"\nOptimal Configuration:")
|
||||
for key in config_keys:
|
||||
col = f"config/{key}"
|
||||
if col in best:
|
||||
val = best[col]
|
||||
if isinstance(val, float):
|
||||
print(f" {key}: {val:.4f}")
|
||||
else:
|
||||
print(f" {key}: {val}")
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def correlation_analysis(df: pd.DataFrame, param_keys: List[str]) -> None:
|
||||
"""
|
||||
Print correlation of numeric parameters with CER/WER.
|
||||
|
||||
Args:
|
||||
df: Results DataFrame
|
||||
param_keys: List of config keys (without 'config/' prefix)
|
||||
"""
|
||||
param_cols = [f"config/{k}" for k in param_keys if f"config/{k}" in df.columns]
|
||||
numeric_cols = [c for c in param_cols if df[c].dtype in ['float64', 'int64']]
|
||||
|
||||
if not numeric_cols:
|
||||
print("No numeric parameters for correlation analysis")
|
||||
return
|
||||
|
||||
corr_cer = df[numeric_cols + ["CER"]].corr()["CER"].sort_values(ascending=False)
|
||||
corr_wer = df[numeric_cols + ["WER"]].corr()["WER"].sort_values(ascending=False)
|
||||
|
||||
print("Correlation with CER:")
|
||||
print(corr_cer)
|
||||
print("\nCorrelation with WER:")
|
||||
print(corr_wer)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OCR-specific payload functions
|
||||
# =============================================================================
|
||||
|
||||
def paddle_ocr_payload(config: Dict) -> Dict:
|
||||
"""Create payload for PaddleOCR API. Uses pages 5-10 (first doc) for tuning."""
|
||||
return {
|
||||
"pdf_folder": "/app/dataset",
|
||||
"use_doc_orientation_classify": config.get("use_doc_orientation_classify", False),
|
||||
"use_doc_unwarping": config.get("use_doc_unwarping", False),
|
||||
"textline_orientation": config.get("textline_orientation", True),
|
||||
"text_det_thresh": config.get("text_det_thresh", 0.0),
|
||||
"text_det_box_thresh": config.get("text_det_box_thresh", 0.0),
|
||||
"text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5),
|
||||
"text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0),
|
||||
"start_page": 5,
|
||||
"end_page": 10,
|
||||
"save_output": False,
|
||||
}
|
||||
|
||||
|
||||
def doctr_payload(config: Dict) -> Dict:
|
||||
"""Create payload for DocTR API. Uses pages 5-10 (first doc) for tuning."""
|
||||
return {
|
||||
"pdf_folder": "/app/dataset",
|
||||
"assume_straight_pages": config.get("assume_straight_pages", True),
|
||||
"straighten_pages": config.get("straighten_pages", False),
|
||||
"preserve_aspect_ratio": config.get("preserve_aspect_ratio", True),
|
||||
"symmetric_pad": config.get("symmetric_pad", True),
|
||||
"disable_page_orientation": config.get("disable_page_orientation", False),
|
||||
"disable_crop_orientation": config.get("disable_crop_orientation", False),
|
||||
"resolve_lines": config.get("resolve_lines", True),
|
||||
"resolve_blocks": config.get("resolve_blocks", False),
|
||||
"paragraph_break": config.get("paragraph_break", 0.035),
|
||||
"start_page": 5,
|
||||
"end_page": 10,
|
||||
"save_output": False,
|
||||
}
|
||||
|
||||
|
||||
def easyocr_payload(config: Dict) -> Dict:
|
||||
"""Create payload for EasyOCR API. Uses pages 5-10 (first doc) for tuning."""
|
||||
return {
|
||||
"pdf_folder": "/app/dataset",
|
||||
"text_threshold": config.get("text_threshold", 0.7),
|
||||
"low_text": config.get("low_text", 0.4),
|
||||
"link_threshold": config.get("link_threshold", 0.4),
|
||||
"slope_ths": config.get("slope_ths", 0.1),
|
||||
"ycenter_ths": config.get("ycenter_ths", 0.5),
|
||||
"height_ths": config.get("height_ths", 0.5),
|
||||
"width_ths": config.get("width_ths", 0.5),
|
||||
"add_margin": config.get("add_margin", 0.1),
|
||||
"contrast_ths": config.get("contrast_ths", 0.1),
|
||||
"adjust_contrast": config.get("adjust_contrast", 0.5),
|
||||
"decoder": config.get("decoder", "greedy"),
|
||||
"beamWidth": config.get("beamWidth", 5),
|
||||
"min_size": config.get("min_size", 10),
|
||||
"start_page": 5,
|
||||
"end_page": 10,
|
||||
"save_output": False,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Search spaces
|
||||
# =============================================================================
|
||||
|
||||
PADDLE_OCR_SEARCH_SPACE = {
|
||||
"use_doc_orientation_classify": tune.choice([True, False]),
|
||||
"use_doc_unwarping": tune.choice([True, False]),
|
||||
"textline_orientation": tune.choice([True, False]),
|
||||
"text_det_thresh": tune.uniform(0.0, 0.7),
|
||||
"text_det_box_thresh": tune.uniform(0.0, 0.7),
|
||||
"text_det_unclip_ratio": tune.choice([0.0]),
|
||||
"text_rec_score_thresh": tune.uniform(0.0, 0.7),
|
||||
}
|
||||
|
||||
DOCTR_SEARCH_SPACE = {
|
||||
"assume_straight_pages": tune.choice([True, False]),
|
||||
"straighten_pages": tune.choice([True, False]),
|
||||
"preserve_aspect_ratio": tune.choice([True, False]),
|
||||
"symmetric_pad": tune.choice([True, False]),
|
||||
"disable_page_orientation": tune.choice([True, False]),
|
||||
"disable_crop_orientation": tune.choice([True, False]),
|
||||
"resolve_lines": tune.choice([True, False]),
|
||||
"resolve_blocks": tune.choice([True, False]),
|
||||
"paragraph_break": tune.uniform(0.01, 0.1),
|
||||
}
|
||||
|
||||
EASYOCR_SEARCH_SPACE = {
|
||||
"text_threshold": tune.uniform(0.3, 0.9),
|
||||
"low_text": tune.uniform(0.2, 0.6),
|
||||
"link_threshold": tune.uniform(0.2, 0.6),
|
||||
"slope_ths": tune.uniform(0.0, 0.3),
|
||||
"ycenter_ths": tune.uniform(0.3, 1.0),
|
||||
"height_ths": tune.uniform(0.3, 1.0),
|
||||
"width_ths": tune.uniform(0.3, 1.0),
|
||||
"add_margin": tune.uniform(0.0, 0.3),
|
||||
"contrast_ths": tune.uniform(0.05, 0.3),
|
||||
"adjust_contrast": tune.uniform(0.3, 0.8),
|
||||
"decoder": tune.choice(["greedy", "beamsearch"]),
|
||||
"beamWidth": tune.choice([3, 5, 7, 10]),
|
||||
"min_size": tune.choice([5, 10, 15, 20]),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config keys for results display
|
||||
# =============================================================================
|
||||
|
||||
PADDLE_OCR_CONFIG_KEYS = [
|
||||
"use_doc_orientation_classify", "use_doc_unwarping", "textline_orientation",
|
||||
"text_det_thresh", "text_det_box_thresh", "text_det_unclip_ratio", "text_rec_score_thresh",
|
||||
]
|
||||
|
||||
DOCTR_CONFIG_KEYS = [
|
||||
"assume_straight_pages", "straighten_pages", "preserve_aspect_ratio", "symmetric_pad",
|
||||
"disable_page_orientation", "disable_crop_orientation", "resolve_lines", "resolve_blocks",
|
||||
"paragraph_break",
|
||||
]
|
||||
|
||||
EASYOCR_CONFIG_KEYS = [
|
||||
"text_threshold", "low_text", "link_threshold", "slope_ths", "ycenter_ths",
|
||||
"height_ths", "width_ths", "add_margin", "contrast_ths", "adjust_contrast",
|
||||
"decoder", "beamWidth", "min_size",
|
||||
]
|
||||
Reference in New Issue
Block a user