Paddle ocr gpu support. #4

Merged
Seryusjj merged 40 commits from gpu_support into main 2026-01-19 17:35:25 +00:00
3 changed files with 129 additions and 3264 deletions
Showing only changes of commit b29df98602 - Show all commits

File diff suppressed because it is too large Load Diff

View File

@@ -188,62 +188,7 @@
"id": "trainable",
"metadata": {},
"outputs": [],
"source": [
"def trainable_paddle_ocr(config):\n",
" \"\"\"Call PaddleOCR REST API with the given hyperparameter config.\n",
" \n",
" Uses trial index to deterministically assign a worker (round-robin),\n",
" ensuring only 1 request per container at a time.\n",
" \"\"\"\n",
" import requests # Must be inside function for Ray workers\n",
" from ray import train\n",
"\n",
" # Worker URLs - round-robin assignment based on trial index\n",
" WORKER_PORTS = [8001, 8002]\n",
" NUM_WORKERS = len(WORKER_PORTS)\n",
" \n",
" # Get trial context for deterministic worker assignment\n",
" context = train.get_context()\n",
" trial_id = context.get_trial_id() if context else \"0\"\n",
" # Extract numeric part from trial ID (e.g., \"trainable_paddle_ocr_abc123_00001\" -> 1)\n",
" try:\n",
" trial_num = int(trial_id.split(\"_\")[-1])\n",
" except (ValueError, IndexError):\n",
" trial_num = hash(trial_id)\n",
" \n",
" worker_idx = trial_num % NUM_WORKERS\n",
" api_url = f\"http://localhost:{WORKER_PORTS[worker_idx]}\"\n",
"\n",
" payload = {\n",
" \"pdf_folder\": \"/app/dataset\",\n",
" \"use_doc_orientation_classify\": config.get(\"use_doc_orientation_classify\", False),\n",
" \"use_doc_unwarping\": config.get(\"use_doc_unwarping\", False),\n",
" \"textline_orientation\": config.get(\"textline_orientation\", True),\n",
" \"text_det_thresh\": config.get(\"text_det_thresh\", 0.0),\n",
" \"text_det_box_thresh\": config.get(\"text_det_box_thresh\", 0.0),\n",
" \"text_det_unclip_ratio\": config.get(\"text_det_unclip_ratio\", 1.5),\n",
" \"text_rec_score_thresh\": config.get(\"text_rec_score_thresh\", 0.0),\n",
" \"start_page\": 5,\n",
" \"end_page\": 10,\n",
" }\n",
"\n",
" try:\n",
" response = requests.post(f\"{api_url}/evaluate\", json=payload, timeout=None) # No timeout\n",
" response.raise_for_status()\n",
" metrics = response.json()\n",
" metrics[\"worker\"] = api_url\n",
" train.report(metrics)\n",
" except Exception as e:\n",
" train.report({\n",
" \"CER\": 1.0,\n",
" \"WER\": 1.0,\n",
" \"TIME\": 0.0,\n",
" \"PAGES\": 0,\n",
" \"TIME_PER_PAGE\": 0,\n",
" \"worker\": api_url,\n",
" \"ERROR\": str(e)[:500]\n",
" })"
]
"source": "def trainable_paddle_ocr(config):\n \"\"\"Call PaddleOCR REST API with the given hyperparameter config.\"\"\"\n import random\n import requests\n from ray import tune\n\n # Worker URLs - random selection (load balances with 2 workers, 2 concurrent trials)\n WORKER_PORTS = [8001, 8002]\n api_url = f\"http://localhost:{random.choice(WORKER_PORTS)}\"\n\n payload = {\n \"pdf_folder\": \"/app/dataset\",\n \"use_doc_orientation_classify\": config.get(\"use_doc_orientation_classify\", False),\n \"use_doc_unwarping\": config.get(\"use_doc_unwarping\", False),\n \"textline_orientation\": config.get(\"textline_orientation\", True),\n \"text_det_thresh\": config.get(\"text_det_thresh\", 0.0),\n \"text_det_box_thresh\": config.get(\"text_det_box_thresh\", 0.0),\n \"text_det_unclip_ratio\": config.get(\"text_det_unclip_ratio\", 1.5),\n \"text_rec_score_thresh\": config.get(\"text_rec_score_thresh\", 0.0),\n \"start_page\": 5,\n \"end_page\": 10,\n }\n\n try:\n response = requests.post(f\"{api_url}/evaluate\", json=payload, timeout=None)\n response.raise_for_status()\n metrics = response.json()\n metrics[\"worker\"] = api_url\n tune.report(**metrics)\n except Exception as e:\n tune.report(\n CER=1.0,\n WER=1.0,\n TIME=0.0,\n PAGES=0,\n TIME_PER_PAGE=0,\n worker=api_url,\n ERROR=str(e)[:500]\n )"
},
{
"cell_type": "markdown",
@@ -390,4 +335,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -1,124 +0,0 @@
#!/usr/bin/env python3
"""Ray Tune hyperparameter search for PaddleOCR via REST API."""
import os
from datetime import datetime
import requests
import pandas as pd
import ray
from ray import tune, train, air
from ray.tune.search.optuna import OptunaSearch
# Configuration
WORKER_PORTS = [8001, 8002]
OUTPUT_FOLDER = "results"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
# Search space
search_space = {
"use_doc_orientation_classify": tune.choice([True, False]),
"use_doc_unwarping": tune.choice([True, False]),
"textline_orientation": tune.choice([True, False]),
"text_det_thresh": tune.uniform(0.0, 0.7),
"text_det_box_thresh": tune.uniform(0.0, 0.7),
"text_det_unclip_ratio": tune.choice([0.0]),
"text_rec_score_thresh": tune.uniform(0.0, 0.7),
}
def trainable_paddle_ocr(config):
"""Call PaddleOCR REST API with the given hyperparameter config."""
import requests
from ray import train
NUM_WORKERS = len(WORKER_PORTS)
context = train.get_context()
trial_id = context.get_trial_id() if context else "0"
try:
trial_num = int(trial_id.split("_")[-1])
except (ValueError, IndexError):
trial_num = hash(trial_id)
worker_idx = trial_num % NUM_WORKERS
api_url = f"http://localhost:{WORKER_PORTS[worker_idx]}"
payload = {
"pdf_folder": "/app/dataset",
"use_doc_orientation_classify": config.get("use_doc_orientation_classify", False),
"use_doc_unwarping": config.get("use_doc_unwarping", False),
"textline_orientation": config.get("textline_orientation", True),
"text_det_thresh": config.get("text_det_thresh", 0.0),
"text_det_box_thresh": config.get("text_det_box_thresh", 0.0),
"text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5),
"text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0),
"start_page": 5,
"end_page": 10,
}
try:
response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None)
response.raise_for_status()
metrics = response.json()
metrics["worker"] = api_url
train.report(metrics)
except Exception as e:
train.report({
"CER": 1.0,
"WER": 1.0,
"TIME": 0.0,
"PAGES": 0,
"TIME_PER_PAGE": 0,
"worker": api_url,
"ERROR": str(e)[:500]
})
def main():
# Check workers
print("Checking workers...")
for port in WORKER_PORTS:
try:
r = requests.get(f"http://localhost:{port}/health", timeout=10)
print(f" Worker {port}: {r.json().get('status', 'unknown')}")
except Exception as e:
print(f" Worker {port}: ERROR - {e}")
print("\nStarting Ray Tune...")
ray.init(ignore_reinit_error=True)
tuner = tune.Tuner(
trainable_paddle_ocr,
tune_config=tune.TuneConfig(
metric="CER",
mode="min",
search_alg=OptunaSearch(),
num_samples=64,
max_concurrent_trials=len(WORKER_PORTS),
),
run_config=air.RunConfig(verbose=2, log_to_file=True),
param_space=search_space,
)
results = tuner.fit()
# Save results
df = results.get_dataframe()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filepath = os.path.join(OUTPUT_FOLDER, f"raytune_paddle_results_{timestamp}.csv")
df.to_csv(filepath, index=False)
print(f"\nResults saved: {filepath}")
# Best config
if len(df) > 0 and "CER" in df.columns:
best = df.loc[df["CER"].idxmin()]
print(f"\nBest CER: {best['CER']:.6f}")
print(f"Best WER: {best['WER']:.6f}")
ray.shutdown()
if __name__ == "__main__":
main()