Paddle ocr gpu support. #4

Merged
Seryusjj merged 40 commits from gpu_support into main 2026-01-19 17:35:25 +00:00
5 changed files with 70 additions and 40 deletions
Showing only changes of commit e2cca72cf2 - Show all commits

View File

@@ -25,9 +25,7 @@
"id": "deps", "id": "deps",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": "# Pin Ray version for API stability (tune.report takes dict, not kwargs in 2.x)\n%pip install -q \"ray[tune]==2.53.0\" optuna requests pandas"
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -108,4 +106,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }

View File

@@ -25,9 +25,7 @@
"id": "deps", "id": "deps",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": "# Pin Ray version for API stability (tune.report takes dict, not kwargs in 2.x)\n%pip install -q \"ray[tune]==2.53.0\" optuna requests pandas"
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -108,4 +106,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }

View File

@@ -70,6 +70,30 @@ services:
profiles: profiles:
- gpu - gpu
ocr-worker-3:
<<: *ocr-gpu-common
container_name: paddle-ocr-worker-3
ports:
- "8003:8000"
profiles:
- gpu
ocr-worker-4:
<<: *ocr-gpu-common
container_name: paddle-ocr-worker-4
ports:
- "8004:8000"
profiles:
- gpu
ocr-worker-5:
<<: *ocr-gpu-common
container_name: paddle-ocr-worker-5
ports:
- "8005:8000"
profiles:
- gpu
# CPU Workers (cpu profile) - for systems without GPU # CPU Workers (cpu profile) - for systems without GPU
ocr-cpu-worker-1: ocr-cpu-worker-1:
<<: *ocr-cpu-common <<: *ocr-cpu-common
@@ -87,6 +111,30 @@ services:
profiles: profiles:
- cpu - cpu
ocr-cpu-worker-3:
<<: *ocr-cpu-common
container_name: paddle-ocr-cpu-worker-3
ports:
- "8003:8000"
profiles:
- cpu
ocr-cpu-worker-4:
<<: *ocr-cpu-common
container_name: paddle-ocr-cpu-worker-4
ports:
- "8004:8000"
profiles:
- cpu
ocr-cpu-worker-5:
<<: *ocr-cpu-common
container_name: paddle-ocr-cpu-worker-5
ports:
- "8005:8000"
profiles:
- cpu
volumes: volumes:
paddlex-cache: paddlex-cache:
name: paddlex-model-cache name: paddlex-model-cache

View File

@@ -24,9 +24,7 @@
"id": "deps", "id": "deps",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": "# Pin Ray version for API stability (tune.report takes dict, not kwargs in 2.x)\n%pip install -q \"ray[tune]==2.53.0\" optuna requests pandas"
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -34,18 +32,7 @@
"id": "setup", "id": "setup",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": "from raytune_ocr import (\n check_workers, create_trainable, run_tuner, analyze_results, correlation_analysis,\n paddle_ocr_payload, PADDLE_OCR_SEARCH_SPACE, PADDLE_OCR_CONFIG_KEYS,\n)\n\n# Worker ports (3 workers to avoid OOM)\nPORTS = [8001, 8002, 8003]\n\n# Check workers are running\nhealthy = check_workers(PORTS, \"PaddleOCR\")"
"from raytune_ocr import (\n",
" check_workers, create_trainable, run_tuner, analyze_results, correlation_analysis,\n",
" paddle_ocr_payload, PADDLE_OCR_SEARCH_SPACE, PADDLE_OCR_CONFIG_KEYS,\n",
")\n",
"\n",
"# Worker ports\n",
"PORTS = [8001, 8002]\n",
"\n",
"# Check workers are running\n",
"healthy = check_workers(PORTS, \"PaddleOCR\")"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -53,17 +40,7 @@
"id": "tune", "id": "tune",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": "# Create trainable and run tuning\ntrainable = create_trainable(PORTS, paddle_ocr_payload)\n\nresults = run_tuner(\n trainable=trainable,\n search_space=PADDLE_OCR_SEARCH_SPACE,\n num_samples=128,\n num_workers=len(healthy),\n)"
"# Create trainable and run tuning\n",
"trainable = create_trainable(PORTS, paddle_ocr_payload)\n",
"\n",
"results = run_tuner(\n",
" trainable=trainable,\n",
" search_space=PADDLE_OCR_SEARCH_SPACE,\n",
" num_samples=64,\n",
" num_workers=len(healthy),\n",
")"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
@@ -107,4 +84,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5
} }

View File

@@ -12,7 +12,7 @@ import requests
import pandas as pd import pandas as pd
import ray import ray
from ray import tune, train from ray import tune
from ray.tune.search.optuna import OptunaSearch from ray.tune.search.optuna import OptunaSearch
@@ -65,11 +65,15 @@ def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Ca
Returns: Returns:
Trainable function for Ray Tune Trainable function for Ray Tune
Note:
Ray Tune 2.x API: tune.report(metrics_dict) - pass dict directly, NOT kwargs.
See: https://docs.ray.io/en/latest/tune/api/doc/ray.tune.report.html
""" """
def trainable(config): def trainable(config):
import random import random
import requests import requests
from ray import train from ray.tune import report # Ray 2.x: report(dict), not report(**kwargs)
api_url = f"http://localhost:{random.choice(ports)}" api_url = f"http://localhost:{random.choice(ports)}"
payload = payload_fn(config) payload = payload_fn(config)
@@ -79,9 +83,9 @@ def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Ca
response.raise_for_status() response.raise_for_status()
metrics = response.json() metrics = response.json()
metrics["worker"] = api_url metrics["worker"] = api_url
train.report(metrics) report(metrics) # Ray 2.x API: pass dict directly
except Exception as e: except Exception as e:
train.report({ report({ # Ray 2.x API: pass dict directly
"CER": 1.0, "CER": 1.0,
"WER": 1.0, "WER": 1.0,
"TIME": 0.0, "TIME": 0.0,
@@ -116,7 +120,12 @@ def run_tuner(
Returns: Returns:
Ray Tune ResultGrid Ray Tune ResultGrid
""" """
ray.init(ignore_reinit_error=True, include_dashboard=False) ray.init(
ignore_reinit_error=True,
include_dashboard=False,
configure_logging=False,
_metrics_export_port=0, # Disable metrics export to avoid connection warnings
)
print(f"Ray Tune ready (version: {ray.__version__})") print(f"Ray Tune ready (version: {ray.__version__})")
tuner = tune.Tuner( tuner = tune.Tuner(