Paddle ocr gpu support. #4
@@ -25,9 +25,7 @@
|
||||
"id": "deps",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
|
||||
]
|
||||
"source": "# Pin Ray version for API stability (tune.report takes dict, not kwargs in 2.x)\n%pip install -q \"ray[tune]==2.53.0\" optuna requests pandas"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
||||
@@ -25,9 +25,7 @@
|
||||
"id": "deps",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
|
||||
]
|
||||
"source": "# Pin Ray version for API stability (tune.report takes dict, not kwargs in 2.x)\n%pip install -q \"ray[tune]==2.53.0\" optuna requests pandas"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
||||
@@ -70,6 +70,30 @@ services:
|
||||
profiles:
|
||||
- gpu
|
||||
|
||||
ocr-worker-3:
|
||||
<<: *ocr-gpu-common
|
||||
container_name: paddle-ocr-worker-3
|
||||
ports:
|
||||
- "8003:8000"
|
||||
profiles:
|
||||
- gpu
|
||||
|
||||
ocr-worker-4:
|
||||
<<: *ocr-gpu-common
|
||||
container_name: paddle-ocr-worker-4
|
||||
ports:
|
||||
- "8004:8000"
|
||||
profiles:
|
||||
- gpu
|
||||
|
||||
ocr-worker-5:
|
||||
<<: *ocr-gpu-common
|
||||
container_name: paddle-ocr-worker-5
|
||||
ports:
|
||||
- "8005:8000"
|
||||
profiles:
|
||||
- gpu
|
||||
|
||||
# CPU Workers (cpu profile) - for systems without GPU
|
||||
ocr-cpu-worker-1:
|
||||
<<: *ocr-cpu-common
|
||||
@@ -87,6 +111,30 @@ services:
|
||||
profiles:
|
||||
- cpu
|
||||
|
||||
ocr-cpu-worker-3:
|
||||
<<: *ocr-cpu-common
|
||||
container_name: paddle-ocr-cpu-worker-3
|
||||
ports:
|
||||
- "8003:8000"
|
||||
profiles:
|
||||
- cpu
|
||||
|
||||
ocr-cpu-worker-4:
|
||||
<<: *ocr-cpu-common
|
||||
container_name: paddle-ocr-cpu-worker-4
|
||||
ports:
|
||||
- "8004:8000"
|
||||
profiles:
|
||||
- cpu
|
||||
|
||||
ocr-cpu-worker-5:
|
||||
<<: *ocr-cpu-common
|
||||
container_name: paddle-ocr-cpu-worker-5
|
||||
ports:
|
||||
- "8005:8000"
|
||||
profiles:
|
||||
- cpu
|
||||
|
||||
volumes:
|
||||
paddlex-cache:
|
||||
name: paddlex-model-cache
|
||||
|
||||
@@ -24,9 +24,7 @@
|
||||
"id": "deps",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
|
||||
]
|
||||
"source": "# Pin Ray version for API stability (tune.report takes dict, not kwargs in 2.x)\n%pip install -q \"ray[tune]==2.53.0\" optuna requests pandas"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -34,18 +32,7 @@
|
||||
"id": "setup",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from raytune_ocr import (\n",
|
||||
" check_workers, create_trainable, run_tuner, analyze_results, correlation_analysis,\n",
|
||||
" paddle_ocr_payload, PADDLE_OCR_SEARCH_SPACE, PADDLE_OCR_CONFIG_KEYS,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Worker ports\n",
|
||||
"PORTS = [8001, 8002]\n",
|
||||
"\n",
|
||||
"# Check workers are running\n",
|
||||
"healthy = check_workers(PORTS, \"PaddleOCR\")"
|
||||
]
|
||||
"source": "from raytune_ocr import (\n check_workers, create_trainable, run_tuner, analyze_results, correlation_analysis,\n paddle_ocr_payload, PADDLE_OCR_SEARCH_SPACE, PADDLE_OCR_CONFIG_KEYS,\n)\n\n# Worker ports (3 workers to avoid OOM)\nPORTS = [8001, 8002, 8003]\n\n# Check workers are running\nhealthy = check_workers(PORTS, \"PaddleOCR\")"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
@@ -53,17 +40,7 @@
|
||||
"id": "tune",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create trainable and run tuning\n",
|
||||
"trainable = create_trainable(PORTS, paddle_ocr_payload)\n",
|
||||
"\n",
|
||||
"results = run_tuner(\n",
|
||||
" trainable=trainable,\n",
|
||||
" search_space=PADDLE_OCR_SEARCH_SPACE,\n",
|
||||
" num_samples=64,\n",
|
||||
" num_workers=len(healthy),\n",
|
||||
")"
|
||||
]
|
||||
"source": "# Create trainable and run tuning\ntrainable = create_trainable(PORTS, paddle_ocr_payload)\n\nresults = run_tuner(\n trainable=trainable,\n search_space=PADDLE_OCR_SEARCH_SPACE,\n num_samples=128,\n num_workers=len(healthy),\n)"
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
|
||||
@@ -12,7 +12,7 @@ import requests
|
||||
import pandas as pd
|
||||
|
||||
import ray
|
||||
from ray import tune, train
|
||||
from ray import tune
|
||||
from ray.tune.search.optuna import OptunaSearch
|
||||
|
||||
|
||||
@@ -65,11 +65,15 @@ def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Ca
|
||||
|
||||
Returns:
|
||||
Trainable function for Ray Tune
|
||||
|
||||
Note:
|
||||
Ray Tune 2.x API: tune.report(metrics_dict) - pass dict directly, NOT kwargs.
|
||||
See: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.report.html
|
||||
"""
|
||||
def trainable(config):
|
||||
import random
|
||||
import requests
|
||||
from ray import train
|
||||
from ray.tune import report # Ray 2.x: report(dict), not report(**kwargs)
|
||||
|
||||
api_url = f"http://localhost:{random.choice(ports)}"
|
||||
payload = payload_fn(config)
|
||||
@@ -79,9 +83,9 @@ def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Ca
|
||||
response.raise_for_status()
|
||||
metrics = response.json()
|
||||
metrics["worker"] = api_url
|
||||
train.report(metrics)
|
||||
report(metrics) # Ray 2.x API: pass dict directly
|
||||
except Exception as e:
|
||||
train.report({
|
||||
report({ # Ray 2.x API: pass dict directly
|
||||
"CER": 1.0,
|
||||
"WER": 1.0,
|
||||
"TIME": 0.0,
|
||||
@@ -116,7 +120,12 @@ def run_tuner(
|
||||
Returns:
|
||||
Ray Tune ResultGrid
|
||||
"""
|
||||
ray.init(ignore_reinit_error=True, include_dashboard=False)
|
||||
ray.init(
|
||||
ignore_reinit_error=True,
|
||||
include_dashboard=False,
|
||||
configure_logging=False,
|
||||
_metrics_export_port=0, # Disable metrics export to avoid connection warnings
|
||||
)
|
||||
print(f"Ray Tune ready (version: {ray.__version__})")
|
||||
|
||||
tuner = tune.Tuner(
|
||||
|
||||
Reference in New Issue
Block a user