Paddle ocr gpu support. #4

Merged
Seryusjj merged 40 commits from gpu_support into main 2026-01-19 17:35:25 +00:00
3 changed files with 1756 additions and 299 deletions
Showing only changes of commit 67092e4df0 - Show all commits

1
.gitignore vendored
View File

@@ -7,3 +7,4 @@ results
.claude
node_modules
src/paddle_ocr/wheels
src/*.log

File diff suppressed because it is too large Load Diff

124
src/run_raytune.py Normal file
View File

@@ -0,0 +1,124 @@
#!/usr/bin/env python3
"""Ray Tune hyperparameter search for PaddleOCR via REST API."""
import os
from datetime import datetime
import requests
import pandas as pd
import ray
from ray import tune, train, air
from ray.tune.search.optuna import OptunaSearch
# Configuration
WORKER_PORTS = [8001, 8002]
OUTPUT_FOLDER = "results"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Search space
search_space = {
"use_doc_orientation_classify": tune.choice([True, False]),
"use_doc_unwarping": tune.choice([True, False]),
"textline_orientation": tune.choice([True, False]),
"text_det_thresh": tune.uniform(0.0, 0.7),
"text_det_box_thresh": tune.uniform(0.0, 0.7),
"text_det_unclip_ratio": tune.choice([0.0]),
"text_rec_score_thresh": tune.uniform(0.0, 0.7),
}
def trainable_paddle_ocr(config):
"""Call PaddleOCR REST API with the given hyperparameter config."""
import requests
from ray import train
NUM_WORKERS = len(WORKER_PORTS)
context = train.get_context()
trial_id = context.get_trial_id() if context else "0"
try:
trial_num = int(trial_id.split("_")[-1])
except (ValueError, IndexError):
trial_num = hash(trial_id)
worker_idx = trial_num % NUM_WORKERS
api_url = f"http://localhost:{WORKER_PORTS[worker_idx]}"
payload = {
"pdf_folder": "/app/dataset",
"use_doc_orientation_classify": config.get("use_doc_orientation_classify", False),
"use_doc_unwarping": config.get("use_doc_unwarping", False),
"textline_orientation": config.get("textline_orientation", True),
"text_det_thresh": config.get("text_det_thresh", 0.0),
"text_det_box_thresh": config.get("text_det_box_thresh", 0.0),
"text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5),
"text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0),
"start_page": 5,
"end_page": 10,
}
try:
response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None)
response.raise_for_status()
metrics = response.json()
metrics["worker"] = api_url
train.report(metrics)
except Exception as e:
train.report({
"CER": 1.0,
"WER": 1.0,
"TIME": 0.0,
"PAGES": 0,
"TIME_PER_PAGE": 0,
"worker": api_url,
"ERROR": str(e)[:500]
})
def main():
# Check workers
print("Checking workers...")
for port in WORKER_PORTS:
try:
r = requests.get(f"http://localhost:{port}/health", timeout=10)
print(f" Worker {port}: {r.json().get('status', 'unknown')}")
except Exception as e:
print(f" Worker {port}: ERROR - {e}")
print("\nStarting Ray Tune...")
ray.init(ignore_reinit_error=True)
tuner = tune.Tuner(
trainable_paddle_ocr,
tune_config=tune.TuneConfig(
metric="CER",
mode="min",
search_alg=OptunaSearch(),
num_samples=64,
max_concurrent_trials=len(WORKER_PORTS),
),
run_config=air.RunConfig(verbose=2, log_to_file=True),
param_space=search_space,
)
results = tuner.fit()
# Save results
df = results.get_dataframe()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(OUTPUT_FOLDER, f"raytune_paddle_results_{timestamp}.csv")
df.to_csv(filepath, index=False)
print(f"\nResults saved: {filepath}")
# Best config
if len(df) > 0 and "CER" in df.columns:
best = df.loc[df["CER"].idxmin()]
print(f"\nBest CER: {best['CER']:.6f}")
print(f"Best WER: {best['WER']:.6f}")
ray.shutdown()
if __name__ == "__main__":
main()