outcome
Some checks failed
build_docker / essential (pull_request) Has been cancelled
build_docker / build_cpu (linux/amd64) (pull_request) Has been cancelled
build_docker / build_cpu (linux/arm64) (pull_request) Has been cancelled
build_docker / manifest_cpu (pull_request) Has been cancelled
build_docker / manifest_gpu (pull_request) Has been cancelled
build_docker / build_easyocr (linux/amd64) (pull_request) Has been cancelled
build_docker / build_doctr (linux/amd64) (pull_request) Has been cancelled
build_docker / build_doctr (linux/arm64) (pull_request) Has been cancelled
build_docker / manifest_easyocr (pull_request) Has been cancelled
build_docker / manifest_doctr (pull_request) Has been cancelled
build_docker / build_easyocr_gpu (linux/amd64) (pull_request) Has been cancelled
build_docker / build_easyocr_gpu (linux/arm64) (pull_request) Has been cancelled
build_docker / build_gpu (linux/amd64) (pull_request) Has been cancelled
build_docker / build_gpu (linux/arm64) (pull_request) Has been cancelled
build_docker / build_easyocr (linux/arm64) (pull_request) Has been cancelled
build_docker / build_doctr_gpu (linux/amd64) (pull_request) Has been cancelled
build_docker / build_doctr_gpu (linux/arm64) (pull_request) Has been cancelled
build_docker / manifest_easyocr_gpu (pull_request) Has been cancelled
build_docker / manifest_doctr_gpu (pull_request) Has been cancelled
Some checks failed
build_docker / essential (pull_request) Has been cancelled
build_docker / build_cpu (linux/amd64) (pull_request) Has been cancelled
build_docker / build_cpu (linux/arm64) (pull_request) Has been cancelled
build_docker / manifest_cpu (pull_request) Has been cancelled
build_docker / manifest_gpu (pull_request) Has been cancelled
build_docker / build_easyocr (linux/amd64) (pull_request) Has been cancelled
build_docker / build_doctr (linux/amd64) (pull_request) Has been cancelled
build_docker / build_doctr (linux/arm64) (pull_request) Has been cancelled
build_docker / manifest_easyocr (pull_request) Has been cancelled
build_docker / manifest_doctr (pull_request) Has been cancelled
build_docker / build_easyocr_gpu (linux/amd64) (pull_request) Has been cancelled
build_docker / build_easyocr_gpu (linux/arm64) (pull_request) Has been cancelled
build_docker / build_gpu (linux/amd64) (pull_request) Has been cancelled
build_docker / build_gpu (linux/arm64) (pull_request) Has been cancelled
build_docker / build_easyocr (linux/arm64) (pull_request) Has been cancelled
build_docker / build_doctr_gpu (linux/amd64) (pull_request) Has been cancelled
build_docker / build_doctr_gpu (linux/arm64) (pull_request) Has been cancelled
build_docker / manifest_easyocr_gpu (pull_request) Has been cancelled
build_docker / manifest_doctr_gpu (pull_request) Has been cancelled
This commit is contained in:
File diff suppressed because it is too large
Load Diff
124
src/run_raytune.py
Normal file
124
src/run_raytune.py
Normal file
@@ -0,0 +1,124 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Ray Tune hyperparameter search for PaddleOCR via REST API."""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
import pandas as pd
|
||||
|
||||
import ray
|
||||
from ray import tune, train, air
|
||||
from ray.tune.search.optuna import OptunaSearch
|
||||
|
||||
# Configuration
|
||||
WORKER_PORTS = [8001, 8002]
|
||||
OUTPUT_FOLDER = "results"
|
||||
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
||||
|
||||
# Search space
|
||||
search_space = {
|
||||
"use_doc_orientation_classify": tune.choice([True, False]),
|
||||
"use_doc_unwarping": tune.choice([True, False]),
|
||||
"textline_orientation": tune.choice([True, False]),
|
||||
"text_det_thresh": tune.uniform(0.0, 0.7),
|
||||
"text_det_box_thresh": tune.uniform(0.0, 0.7),
|
||||
"text_det_unclip_ratio": tune.choice([0.0]),
|
||||
"text_rec_score_thresh": tune.uniform(0.0, 0.7),
|
||||
}
|
||||
|
||||
|
||||
def trainable_paddle_ocr(config):
|
||||
"""Call PaddleOCR REST API with the given hyperparameter config."""
|
||||
import requests
|
||||
from ray import train
|
||||
|
||||
NUM_WORKERS = len(WORKER_PORTS)
|
||||
|
||||
context = train.get_context()
|
||||
trial_id = context.get_trial_id() if context else "0"
|
||||
try:
|
||||
trial_num = int(trial_id.split("_")[-1])
|
||||
except (ValueError, IndexError):
|
||||
trial_num = hash(trial_id)
|
||||
|
||||
worker_idx = trial_num % NUM_WORKERS
|
||||
api_url = f"http://localhost:{WORKER_PORTS[worker_idx]}"
|
||||
|
||||
payload = {
|
||||
"pdf_folder": "/app/dataset",
|
||||
"use_doc_orientation_classify": config.get("use_doc_orientation_classify", False),
|
||||
"use_doc_unwarping": config.get("use_doc_unwarping", False),
|
||||
"textline_orientation": config.get("textline_orientation", True),
|
||||
"text_det_thresh": config.get("text_det_thresh", 0.0),
|
||||
"text_det_box_thresh": config.get("text_det_box_thresh", 0.0),
|
||||
"text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5),
|
||||
"text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0),
|
||||
"start_page": 5,
|
||||
"end_page": 10,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None)
|
||||
response.raise_for_status()
|
||||
metrics = response.json()
|
||||
metrics["worker"] = api_url
|
||||
train.report(metrics)
|
||||
except Exception as e:
|
||||
train.report({
|
||||
"CER": 1.0,
|
||||
"WER": 1.0,
|
||||
"TIME": 0.0,
|
||||
"PAGES": 0,
|
||||
"TIME_PER_PAGE": 0,
|
||||
"worker": api_url,
|
||||
"ERROR": str(e)[:500]
|
||||
})
|
||||
|
||||
|
||||
def main():
|
||||
# Check workers
|
||||
print("Checking workers...")
|
||||
for port in WORKER_PORTS:
|
||||
try:
|
||||
r = requests.get(f"http://localhost:{port}/health", timeout=10)
|
||||
print(f" Worker {port}: {r.json().get('status', 'unknown')}")
|
||||
except Exception as e:
|
||||
print(f" Worker {port}: ERROR - {e}")
|
||||
|
||||
print("\nStarting Ray Tune...")
|
||||
ray.init(ignore_reinit_error=True)
|
||||
|
||||
tuner = tune.Tuner(
|
||||
trainable_paddle_ocr,
|
||||
tune_config=tune.TuneConfig(
|
||||
metric="CER",
|
||||
mode="min",
|
||||
search_alg=OptunaSearch(),
|
||||
num_samples=64,
|
||||
max_concurrent_trials=len(WORKER_PORTS),
|
||||
),
|
||||
run_config=air.RunConfig(verbose=2, log_to_file=True),
|
||||
param_space=search_space,
|
||||
)
|
||||
|
||||
results = tuner.fit()
|
||||
|
||||
# Save results
|
||||
df = results.get_dataframe()
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(OUTPUT_FOLDER, f"raytune_paddle_results_{timestamp}.csv")
|
||||
df.to_csv(filepath, index=False)
|
||||
print(f"\nResults saved: {filepath}")
|
||||
|
||||
# Best config
|
||||
if len(df) > 0 and "CER" in df.columns:
|
||||
best = df.loc[df["CER"].idxmin()]
|
||||
print(f"\nBest CER: {best['CER']:.6f}")
|
||||
print(f"Best WER: {best['WER']:.6f}")
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user