Paddle ocr gpu support. #4
File diff suppressed because it is too large
Load Diff
@@ -188,62 +188,7 @@
|
||||
"id": "trainable",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def trainable_paddle_ocr(config):\n",
|
||||
" \"\"\"Call PaddleOCR REST API with the given hyperparameter config.\n",
|
||||
" \n",
|
||||
" Uses trial index to deterministically assign a worker (round-robin),\n",
|
||||
" ensuring only 1 request per container at a time.\n",
|
||||
" \"\"\"\n",
|
||||
" import requests # Must be inside function for Ray workers\n",
|
||||
" from ray import train\n",
|
||||
"\n",
|
||||
" # Worker URLs - round-robin assignment based on trial index\n",
|
||||
" WORKER_PORTS = [8001, 8002]\n",
|
||||
" NUM_WORKERS = len(WORKER_PORTS)\n",
|
||||
" \n",
|
||||
" # Get trial context for deterministic worker assignment\n",
|
||||
" context = train.get_context()\n",
|
||||
" trial_id = context.get_trial_id() if context else \"0\"\n",
|
||||
" # Extract numeric part from trial ID (e.g., \"trainable_paddle_ocr_abc123_00001\" -> 1)\n",
|
||||
" try:\n",
|
||||
" trial_num = int(trial_id.split(\"_\")[-1])\n",
|
||||
" except (ValueError, IndexError):\n",
|
||||
" trial_num = hash(trial_id)\n",
|
||||
" \n",
|
||||
" worker_idx = trial_num % NUM_WORKERS\n",
|
||||
" api_url = f\"http://localhost:{WORKER_PORTS[worker_idx]}\"\n",
|
||||
"\n",
|
||||
" payload = {\n",
|
||||
" \"pdf_folder\": \"/app/dataset\",\n",
|
||||
" \"use_doc_orientation_classify\": config.get(\"use_doc_orientation_classify\", False),\n",
|
||||
" \"use_doc_unwarping\": config.get(\"use_doc_unwarping\", False),\n",
|
||||
" \"textline_orientation\": config.get(\"textline_orientation\", True),\n",
|
||||
" \"text_det_thresh\": config.get(\"text_det_thresh\", 0.0),\n",
|
||||
" \"text_det_box_thresh\": config.get(\"text_det_box_thresh\", 0.0),\n",
|
||||
" \"text_det_unclip_ratio\": config.get(\"text_det_unclip_ratio\", 1.5),\n",
|
||||
" \"text_rec_score_thresh\": config.get(\"text_rec_score_thresh\", 0.0),\n",
|
||||
" \"start_page\": 5,\n",
|
||||
" \"end_page\": 10,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" response = requests.post(f\"{api_url}/evaluate\", json=payload, timeout=None) # No timeout\n",
|
||||
" response.raise_for_status()\n",
|
||||
" metrics = response.json()\n",
|
||||
" metrics[\"worker\"] = api_url\n",
|
||||
" train.report(metrics)\n",
|
||||
" except Exception as e:\n",
|
||||
" train.report({\n",
|
||||
" \"CER\": 1.0,\n",
|
||||
" \"WER\": 1.0,\n",
|
||||
" \"TIME\": 0.0,\n",
|
||||
" \"PAGES\": 0,\n",
|
||||
" \"TIME_PER_PAGE\": 0,\n",
|
||||
" \"worker\": api_url,\n",
|
||||
" \"ERROR\": str(e)[:500]\n",
|
||||
" })"
|
||||
]
|
||||
"source": "def trainable_paddle_ocr(config):\n \"\"\"Call PaddleOCR REST API with the given hyperparameter config.\"\"\"\n import random\n import requests\n from ray import tune\n\n # Worker URLs - random selection (load balances with 2 workers, 2 concurrent trials)\n WORKER_PORTS = [8001, 8002]\n api_url = f\"http://localhost:{random.choice(WORKER_PORTS)}\"\n\n payload = {\n \"pdf_folder\": \"/app/dataset\",\n \"use_doc_orientation_classify\": config.get(\"use_doc_orientation_classify\", False),\n \"use_doc_unwarping\": config.get(\"use_doc_unwarping\", False),\n \"textline_orientation\": config.get(\"textline_orientation\", True),\n \"text_det_thresh\": config.get(\"text_det_thresh\", 0.0),\n \"text_det_box_thresh\": config.get(\"text_det_box_thresh\", 0.0),\n \"text_det_unclip_ratio\": config.get(\"text_det_unclip_ratio\", 1.5),\n \"text_rec_score_thresh\": config.get(\"text_rec_score_thresh\", 0.0),\n \"start_page\": 5,\n \"end_page\": 10,\n }\n\n try:\n response = requests.post(f\"{api_url}/evaluate\", json=payload, timeout=None)\n response.raise_for_status()\n metrics = response.json()\n metrics[\"worker\"] = api_url\n tune.report(**metrics)\n except Exception as e:\n tune.report(\n CER=1.0,\n WER=1.0,\n TIME=0.0,\n PAGES=0,\n TIME_PER_PAGE=0,\n worker=api_url,\n ERROR=str(e)[:500]\n )"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -390,4 +335,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -1,124 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Ray Tune hyperparameter search for PaddleOCR via REST API."""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
import pandas as pd
|
||||
|
||||
import ray
|
||||
from ray import tune, train, air
|
||||
from ray.tune.search.optuna import OptunaSearch
|
||||
|
||||
# Configuration
|
||||
WORKER_PORTS = [8001, 8002]
|
||||
OUTPUT_FOLDER = "results"
|
||||
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
||||
|
||||
# Search space
|
||||
search_space = {
|
||||
"use_doc_orientation_classify": tune.choice([True, False]),
|
||||
"use_doc_unwarping": tune.choice([True, False]),
|
||||
"textline_orientation": tune.choice([True, False]),
|
||||
"text_det_thresh": tune.uniform(0.0, 0.7),
|
||||
"text_det_box_thresh": tune.uniform(0.0, 0.7),
|
||||
"text_det_unclip_ratio": tune.choice([0.0]),
|
||||
"text_rec_score_thresh": tune.uniform(0.0, 0.7),
|
||||
}
|
||||
|
||||
|
||||
def trainable_paddle_ocr(config):
|
||||
"""Call PaddleOCR REST API with the given hyperparameter config."""
|
||||
import requests
|
||||
from ray import train
|
||||
|
||||
NUM_WORKERS = len(WORKER_PORTS)
|
||||
|
||||
context = train.get_context()
|
||||
trial_id = context.get_trial_id() if context else "0"
|
||||
try:
|
||||
trial_num = int(trial_id.split("_")[-1])
|
||||
except (ValueError, IndexError):
|
||||
trial_num = hash(trial_id)
|
||||
|
||||
worker_idx = trial_num % NUM_WORKERS
|
||||
api_url = f"http://localhost:{WORKER_PORTS[worker_idx]}"
|
||||
|
||||
payload = {
|
||||
"pdf_folder": "/app/dataset",
|
||||
"use_doc_orientation_classify": config.get("use_doc_orientation_classify", False),
|
||||
"use_doc_unwarping": config.get("use_doc_unwarping", False),
|
||||
"textline_orientation": config.get("textline_orientation", True),
|
||||
"text_det_thresh": config.get("text_det_thresh", 0.0),
|
||||
"text_det_box_thresh": config.get("text_det_box_thresh", 0.0),
|
||||
"text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5),
|
||||
"text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0),
|
||||
"start_page": 5,
|
||||
"end_page": 10,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None)
|
||||
response.raise_for_status()
|
||||
metrics = response.json()
|
||||
metrics["worker"] = api_url
|
||||
train.report(metrics)
|
||||
except Exception as e:
|
||||
train.report({
|
||||
"CER": 1.0,
|
||||
"WER": 1.0,
|
||||
"TIME": 0.0,
|
||||
"PAGES": 0,
|
||||
"TIME_PER_PAGE": 0,
|
||||
"worker": api_url,
|
||||
"ERROR": str(e)[:500]
|
||||
})
|
||||
|
||||
|
||||
def main():
|
||||
# Check workers
|
||||
print("Checking workers...")
|
||||
for port in WORKER_PORTS:
|
||||
try:
|
||||
r = requests.get(f"http://localhost:{port}/health", timeout=10)
|
||||
print(f" Worker {port}: {r.json().get('status', 'unknown')}")
|
||||
except Exception as e:
|
||||
print(f" Worker {port}: ERROR - {e}")
|
||||
|
||||
print("\nStarting Ray Tune...")
|
||||
ray.init(ignore_reinit_error=True)
|
||||
|
||||
tuner = tune.Tuner(
|
||||
trainable_paddle_ocr,
|
||||
tune_config=tune.TuneConfig(
|
||||
metric="CER",
|
||||
mode="min",
|
||||
search_alg=OptunaSearch(),
|
||||
num_samples=64,
|
||||
max_concurrent_trials=len(WORKER_PORTS),
|
||||
),
|
||||
run_config=air.RunConfig(verbose=2, log_to_file=True),
|
||||
param_space=search_space,
|
||||
)
|
||||
|
||||
results = tuner.fit()
|
||||
|
||||
# Save results
|
||||
df = results.get_dataframe()
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filepath = os.path.join(OUTPUT_FOLDER, f"raytune_paddle_results_{timestamp}.csv")
|
||||
df.to_csv(filepath, index=False)
|
||||
print(f"\nResults saved: {filepath}")
|
||||
|
||||
# Best config
|
||||
if len(df) > 0 and "CER" in df.columns:
|
||||
best = df.loc[df["CER"].idxmin()]
|
||||
print(f"\nBest CER: {best['CER']:.6f}")
|
||||
print(f"Best WER: {best['WER']:.6f}")
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user