raytune rest
This commit is contained in:
43
src/README.md
Normal file
43
src/README.md
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# Running Notebooks in Background
|
||||||
|
|
||||||
|
## Option 1: Papermill (Recommended)
|
||||||
|
|
||||||
|
Runs notebooks directly without conversion.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install papermill
|
||||||
|
nohup papermill <notebook>.ipynb output.ipynb > papermill.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
Monitor:
|
||||||
|
```bash
|
||||||
|
tail -f papermill.log
|
||||||
|
```
|
||||||
|
|
||||||
|
## Option 2: Convert to Python Script
|
||||||
|
|
||||||
|
```bash
|
||||||
|
jupyter nbconvert --to script <notebook>.ipynb
|
||||||
|
nohup python <notebook>.py > output.log 2>&1 &
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** `%pip install` magic commands need manual removal before running as `.py`
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- Ray Tune notebooks require the OCR service running first (Docker)
|
||||||
|
- For Ray workers, imports must be inside trainable functions
|
||||||
|
|
||||||
|
## Example: Ray Tune PaddleOCR
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Start OCR service
|
||||||
|
cd src/paddle_ocr && docker compose up -d ocr-cpu
|
||||||
|
|
||||||
|
# 2. Run notebook with papermill
|
||||||
|
cd src
|
||||||
|
nohup papermill paddle_ocr_raytune_rest.ipynb output_raytune.ipynb > papermill.log 2>&1 &
|
||||||
|
|
||||||
|
# 3. Monitor
|
||||||
|
tail -f papermill.log
|
||||||
|
```
|
||||||
2511
src/output_raytune.ipynb
Normal file
2511
src/output_raytune.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -520,6 +520,28 @@ docker load < paddle-ocr-arm64.tar.gz
|
|||||||
|
|
||||||
## Using with Ray Tune
|
## Using with Ray Tune
|
||||||
|
|
||||||
|
### Multi-Worker Setup for Parallel Trials
|
||||||
|
|
||||||
|
Run multiple workers for parallel hyperparameter tuning:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd src/paddle_ocr
|
||||||
|
|
||||||
|
# Start 2 CPU workers (ports 8001-8002)
|
||||||
|
sudo docker compose -f docker-compose.workers.yml --profile cpu up -d
|
||||||
|
|
||||||
|
# Or for GPU workers (if supported)
|
||||||
|
sudo docker compose -f docker-compose.workers.yml --profile gpu up -d
|
||||||
|
|
||||||
|
# Check workers are healthy
|
||||||
|
curl http://localhost:8001/health
|
||||||
|
curl http://localhost:8002/health
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the notebook with `max_concurrent_trials=2` to use both workers in parallel.
|
||||||
|
|
||||||
|
### Single Worker Setup
|
||||||
|
|
||||||
Update your notebook's `trainable_paddle_ocr` function:
|
Update your notebook's `trainable_paddle_ocr` function:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
90
src/paddle_ocr/docker-compose.workers.yml
Normal file
90
src/paddle_ocr/docker-compose.workers.yml
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# docker-compose.workers.yml - Multiple PaddleOCR workers for parallel Ray Tune
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# GPU (4 workers sharing GPU):
|
||||||
|
# docker compose -f docker-compose.workers.yml up
|
||||||
|
#
|
||||||
|
# CPU (4 workers):
|
||||||
|
# docker compose -f docker-compose.workers.yml --profile cpu up
|
||||||
|
#
|
||||||
|
# Scale workers (e.g., 8 workers):
|
||||||
|
# NUM_WORKERS=8 docker compose -f docker-compose.workers.yml up
|
||||||
|
#
|
||||||
|
# Each worker runs on a separate port: 8001, 8002, 8003, 8004, ...
|
||||||
|
|
||||||
|
x-ocr-gpu-common: &ocr-gpu-common
|
||||||
|
image: seryus.ddns.net/unir/paddle-ocr-gpu:latest
|
||||||
|
volumes:
|
||||||
|
- ../dataset:/app/dataset:ro
|
||||||
|
- paddlex-cache:/root/.paddlex
|
||||||
|
environment:
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
|
- CUDA_VISIBLE_DEVICES=0
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 120s
|
||||||
|
|
||||||
|
x-ocr-cpu-common: &ocr-cpu-common
|
||||||
|
image: paddle-ocr-api:cpu
|
||||||
|
volumes:
|
||||||
|
- ../dataset:/app/dataset:ro
|
||||||
|
- paddlex-cache:/root/.paddlex
|
||||||
|
environment:
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 120s
|
||||||
|
|
||||||
|
services:
|
||||||
|
# GPU Workers (gpu profile) - share single GPU
|
||||||
|
ocr-worker-1:
|
||||||
|
<<: *ocr-gpu-common
|
||||||
|
container_name: paddle-ocr-worker-1
|
||||||
|
ports:
|
||||||
|
- "8001:8000"
|
||||||
|
profiles:
|
||||||
|
- gpu
|
||||||
|
|
||||||
|
ocr-worker-2:
|
||||||
|
<<: *ocr-gpu-common
|
||||||
|
container_name: paddle-ocr-worker-2
|
||||||
|
ports:
|
||||||
|
- "8002:8000"
|
||||||
|
profiles:
|
||||||
|
- gpu
|
||||||
|
|
||||||
|
# CPU Workers (cpu profile) - for systems without GPU
|
||||||
|
ocr-cpu-worker-1:
|
||||||
|
<<: *ocr-cpu-common
|
||||||
|
container_name: paddle-ocr-cpu-worker-1
|
||||||
|
ports:
|
||||||
|
- "8001:8000"
|
||||||
|
profiles:
|
||||||
|
- cpu
|
||||||
|
|
||||||
|
ocr-cpu-worker-2:
|
||||||
|
<<: *ocr-cpu-common
|
||||||
|
container_name: paddle-ocr-cpu-worker-2
|
||||||
|
ports:
|
||||||
|
- "8002:8000"
|
||||||
|
profiles:
|
||||||
|
- cpu
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
paddlex-cache:
|
||||||
|
name: paddlex-model-cache
|
||||||
@@ -5,6 +5,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
@@ -61,6 +62,10 @@ class AppState:
|
|||||||
dataset_path: Optional[str] = None
|
dataset_path: Optional[str] = None
|
||||||
det_model: str = DEFAULT_DET_MODEL
|
det_model: str = DEFAULT_DET_MODEL
|
||||||
rec_model: str = DEFAULT_REC_MODEL
|
rec_model: str = DEFAULT_REC_MODEL
|
||||||
|
lock: threading.Lock = None # Protects OCR model from concurrent access
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
state = AppState()
|
state = AppState()
|
||||||
@@ -281,28 +286,30 @@ def evaluate(request: EvaluateRequest):
|
|||||||
time_per_page_list = []
|
time_per_page_list = []
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
for idx in range(start, end):
|
# Lock to prevent concurrent OCR access (model is not thread-safe)
|
||||||
img, ref = state.dataset[idx]
|
with state.lock:
|
||||||
arr = np.array(img)
|
for idx in range(start, end):
|
||||||
|
img, ref = state.dataset[idx]
|
||||||
|
arr = np.array(img)
|
||||||
|
|
||||||
tp0 = time.time()
|
tp0 = time.time()
|
||||||
out = state.ocr.predict(
|
out = state.ocr.predict(
|
||||||
arr,
|
arr,
|
||||||
use_doc_orientation_classify=request.use_doc_orientation_classify,
|
use_doc_orientation_classify=request.use_doc_orientation_classify,
|
||||||
use_doc_unwarping=request.use_doc_unwarping,
|
use_doc_unwarping=request.use_doc_unwarping,
|
||||||
use_textline_orientation=request.textline_orientation,
|
use_textline_orientation=request.textline_orientation,
|
||||||
text_det_thresh=request.text_det_thresh,
|
text_det_thresh=request.text_det_thresh,
|
||||||
text_det_box_thresh=request.text_det_box_thresh,
|
text_det_box_thresh=request.text_det_box_thresh,
|
||||||
text_det_unclip_ratio=request.text_det_unclip_ratio,
|
text_det_unclip_ratio=request.text_det_unclip_ratio,
|
||||||
text_rec_score_thresh=request.text_rec_score_thresh,
|
text_rec_score_thresh=request.text_rec_score_thresh,
|
||||||
)
|
)
|
||||||
|
|
||||||
pred = assemble_from_paddle_result(out)
|
pred = assemble_from_paddle_result(out)
|
||||||
time_per_page_list.append(float(time.time() - tp0))
|
time_per_page_list.append(float(time.time() - tp0))
|
||||||
|
|
||||||
m = evaluate_text(ref, pred)
|
m = evaluate_text(ref, pred)
|
||||||
cer_list.append(m["CER"])
|
cer_list.append(m["CER"])
|
||||||
wer_list.append(m["WER"])
|
wer_list.append(m["WER"])
|
||||||
|
|
||||||
return EvaluateResponse(
|
return EvaluateResponse(
|
||||||
CER=float(np.mean(cer_list)) if cer_list else 1.0,
|
CER=float(np.mean(cer_list)) if cer_list else 1.0,
|
||||||
|
|||||||
393
src/paddle_ocr_raytune_rest.ipynb
Normal file
393
src/paddle_ocr_raytune_rest.ipynb
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# PaddleOCR Hyperparameter Optimization via REST API\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook runs Ray Tune hyperparameter search calling the PaddleOCR REST API (Docker container).\n",
|
||||||
|
"\n",
|
||||||
|
"**Benefits:**\n",
|
||||||
|
"- No model reload per trial - Model stays loaded in Docker container\n",
|
||||||
|
"- Faster trials - Skip ~10s model load time per trial\n",
|
||||||
|
"- Cleaner code - REST API replaces subprocess + CLI arg parsing"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "prereq",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Prerequisites\n",
|
||||||
|
"\n",
|
||||||
|
"Start 2 PaddleOCR workers for parallel hyperparameter tuning:\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"cd src/paddle_ocr\n",
|
||||||
|
"docker compose -f docker-compose.workers.yml up\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"This starts 2 GPU workers on ports 8001-8002, allowing 2 concurrent trials.\n",
|
||||||
|
"\n",
|
||||||
|
"For CPU-only systems:\n",
|
||||||
|
"```bash\n",
|
||||||
|
"docker compose -f docker-compose.workers.yml --profile cpu up\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3ob9fsoilc4",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 0. Dependencies"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "wyr2nsoj7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Install dependencies (run once)\n",
|
||||||
|
"%pip install -U \"ray[tune]\"\n",
|
||||||
|
"%pip install optuna\n",
|
||||||
|
"%pip install requests pandas"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "imports-header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 1. Imports & Setup"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "imports",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"from datetime import datetime\n",
|
||||||
|
"\n",
|
||||||
|
"import requests\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"import ray\n",
|
||||||
|
"from ray import tune, air\n",
|
||||||
|
"from ray.tune.search.optuna import OptunaSearch"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "config-header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 2. API Configuration"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "config",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# PaddleOCR REST API endpoints - 2 workers for parallel trials\n",
|
||||||
|
"# Start workers with: cd src/paddle_ocr && docker compose -f docker-compose.workers.yml up\n",
|
||||||
|
"WORKER_PORTS = [8001, 8002]\n",
|
||||||
|
"WORKER_URLS = [f\"http://localhost:{port}\" for port in WORKER_PORTS]\n",
|
||||||
|
"\n",
|
||||||
|
"# Output folder for results\n",
|
||||||
|
"OUTPUT_FOLDER = \"results\"\n",
|
||||||
|
"os.makedirs(OUTPUT_FOLDER, exist_ok=True)\n",
|
||||||
|
"\n",
|
||||||
|
"# Number of concurrent trials = number of workers\n",
|
||||||
|
"NUM_WORKERS = len(WORKER_URLS)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "health-check",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Verify all workers are running\n",
|
||||||
|
"healthy_workers = []\n",
|
||||||
|
"for url in WORKER_URLS:\n",
|
||||||
|
" try:\n",
|
||||||
|
" health = requests.get(f\"{url}/health\", timeout=10).json()\n",
|
||||||
|
" if health['status'] == 'ok' and health['model_loaded']:\n",
|
||||||
|
" healthy_workers.append(url)\n",
|
||||||
|
" print(f\"✓ {url}: {health['status']} (GPU: {health.get('gpu_name', 'N/A')})\")\n",
|
||||||
|
" else:\n",
|
||||||
|
" print(f\"✗ {url}: not ready yet\")\n",
|
||||||
|
" except requests.exceptions.ConnectionError:\n",
|
||||||
|
" print(f\"✗ {url}: not reachable\")\n",
|
||||||
|
"\n",
|
||||||
|
"if not healthy_workers:\n",
|
||||||
|
" raise RuntimeError(\n",
|
||||||
|
" \"No healthy workers found. Start them with:\\n\"\n",
|
||||||
|
" \" cd src/paddle_ocr && docker compose -f docker-compose.workers.yml up\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"\\n{len(healthy_workers)}/{len(WORKER_URLS)} workers ready for parallel tuning\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "search-space-header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 3. Search Space"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "search-space",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"search_space = {\n",
|
||||||
|
" # Whether to use document image orientation classification\n",
|
||||||
|
" \"use_doc_orientation_classify\": tune.choice([True, False]),\n",
|
||||||
|
" # Whether to use text image unwarping\n",
|
||||||
|
" \"use_doc_unwarping\": tune.choice([True, False]),\n",
|
||||||
|
" # Whether to use text line orientation classification\n",
|
||||||
|
" \"textline_orientation\": tune.choice([True, False]),\n",
|
||||||
|
" # Detection pixel threshold (pixels > threshold are considered text)\n",
|
||||||
|
" \"text_det_thresh\": tune.uniform(0.0, 0.7),\n",
|
||||||
|
" # Detection box threshold (average score within border)\n",
|
||||||
|
" \"text_det_box_thresh\": tune.uniform(0.0, 0.7),\n",
|
||||||
|
" # Text detection expansion coefficient\n",
|
||||||
|
" \"text_det_unclip_ratio\": tune.choice([0.0]),\n",
|
||||||
|
" # Text recognition threshold (filter low confidence results)\n",
|
||||||
|
" \"text_rec_score_thresh\": tune.uniform(0.0, 0.7),\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "trainable-header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 4. Trainable Function"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "trainable",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def trainable_paddle_ocr(config):\n",
|
||||||
|
" \"\"\"Call PaddleOCR REST API with the given hyperparameter config.\n",
|
||||||
|
" \n",
|
||||||
|
" Uses trial index to deterministically assign a worker (round-robin),\n",
|
||||||
|
" ensuring only 1 request per container at a time.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" import requests # Must be inside function for Ray workers\n",
|
||||||
|
" from ray import train\n",
|
||||||
|
"\n",
|
||||||
|
" # Worker URLs - round-robin assignment based on trial index\n",
|
||||||
|
" WORKER_PORTS = [8001, 8002]\n",
|
||||||
|
" NUM_WORKERS = len(WORKER_PORTS)\n",
|
||||||
|
" \n",
|
||||||
|
" # Get trial context for deterministic worker assignment\n",
|
||||||
|
" context = train.get_context()\n",
|
||||||
|
" trial_id = context.get_trial_id() if context else \"0\"\n",
|
||||||
|
" # Extract numeric part from trial ID (e.g., \"trainable_paddle_ocr_abc123_00001\" -> 1)\n",
|
||||||
|
" try:\n",
|
||||||
|
" trial_num = int(trial_id.split(\"_\")[-1])\n",
|
||||||
|
" except (ValueError, IndexError):\n",
|
||||||
|
" trial_num = hash(trial_id)\n",
|
||||||
|
" \n",
|
||||||
|
" worker_idx = trial_num % NUM_WORKERS\n",
|
||||||
|
" api_url = f\"http://localhost:{WORKER_PORTS[worker_idx]}\"\n",
|
||||||
|
"\n",
|
||||||
|
" payload = {\n",
|
||||||
|
" \"pdf_folder\": \"/app/dataset\",\n",
|
||||||
|
" \"use_doc_orientation_classify\": config.get(\"use_doc_orientation_classify\", False),\n",
|
||||||
|
" \"use_doc_unwarping\": config.get(\"use_doc_unwarping\", False),\n",
|
||||||
|
" \"textline_orientation\": config.get(\"textline_orientation\", True),\n",
|
||||||
|
" \"text_det_thresh\": config.get(\"text_det_thresh\", 0.0),\n",
|
||||||
|
" \"text_det_box_thresh\": config.get(\"text_det_box_thresh\", 0.0),\n",
|
||||||
|
" \"text_det_unclip_ratio\": config.get(\"text_det_unclip_ratio\", 1.5),\n",
|
||||||
|
" \"text_rec_score_thresh\": config.get(\"text_rec_score_thresh\", 0.0),\n",
|
||||||
|
" \"start_page\": 5,\n",
|
||||||
|
" \"end_page\": 10,\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" try:\n",
|
||||||
|
" response = requests.post(f\"{api_url}/evaluate\", json=payload, timeout=None) # No timeout\n",
|
||||||
|
" response.raise_for_status()\n",
|
||||||
|
" metrics = response.json()\n",
|
||||||
|
" metrics[\"worker\"] = api_url\n",
|
||||||
|
" train.report(metrics)\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" train.report({\n",
|
||||||
|
" \"CER\": 1.0,\n",
|
||||||
|
" \"WER\": 1.0,\n",
|
||||||
|
" \"TIME\": 0.0,\n",
|
||||||
|
" \"PAGES\": 0,\n",
|
||||||
|
" \"TIME_PER_PAGE\": 0,\n",
|
||||||
|
" \"worker\": api_url,\n",
|
||||||
|
" \"ERROR\": str(e)[:500]\n",
|
||||||
|
" })"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "tuner-header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 5. Run Tuner"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ray-init",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"ray.init(ignore_reinit_error=True)\n",
|
||||||
|
"print(f\"Ray Tune ready (version: {ray.__version__})\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "tuner",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tuner = tune.Tuner(\n",
|
||||||
|
" trainable_paddle_ocr,\n",
|
||||||
|
" tune_config=tune.TuneConfig(\n",
|
||||||
|
" metric=\"CER\",\n",
|
||||||
|
" mode=\"min\",\n",
|
||||||
|
" search_alg=OptunaSearch(),\n",
|
||||||
|
" num_samples=64,\n",
|
||||||
|
" max_concurrent_trials=NUM_WORKERS, # Run trials in parallel across workers\n",
|
||||||
|
" ),\n",
|
||||||
|
" run_config=air.RunConfig(verbose=2, log_to_file=False),\n",
|
||||||
|
" param_space=search_space,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"results = tuner.fit()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "analysis-header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 6. Results Analysis"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "results-df",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"df = results.get_dataframe()\n",
|
||||||
|
"df.describe()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "save-results",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Save results to CSV\n",
|
||||||
|
"timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
|
||||||
|
"filename = f\"raytune_paddle_rest_results_{timestamp}.csv\"\n",
|
||||||
|
"filepath = os.path.join(OUTPUT_FOLDER, filename)\n",
|
||||||
|
"\n",
|
||||||
|
"df.to_csv(filepath, index=False)\n",
|
||||||
|
"print(f\"Results saved: {filepath}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "best-config",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Best configuration\n",
|
||||||
|
"best = df.loc[df[\"CER\"].idxmin()]\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Best CER: {best['CER']:.6f}\")\n",
|
||||||
|
"print(f\"Best WER: {best['WER']:.6f}\")\n",
|
||||||
|
"print(f\"\\nOptimal Configuration:\")\n",
|
||||||
|
"print(f\" textline_orientation: {best['config/textline_orientation']}\")\n",
|
||||||
|
"print(f\" use_doc_orientation_classify: {best['config/use_doc_orientation_classify']}\")\n",
|
||||||
|
"print(f\" use_doc_unwarping: {best['config/use_doc_unwarping']}\")\n",
|
||||||
|
"print(f\" text_det_thresh: {best['config/text_det_thresh']:.4f}\")\n",
|
||||||
|
"print(f\" text_det_box_thresh: {best['config/text_det_box_thresh']:.4f}\")\n",
|
||||||
|
"print(f\" text_det_unclip_ratio: {best['config/text_det_unclip_ratio']}\")\n",
|
||||||
|
"print(f\" text_rec_score_thresh: {best['config/text_rec_score_thresh']:.4f}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "correlation",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Correlation analysis\n",
|
||||||
|
"param_cols = [\n",
|
||||||
|
" \"config/text_det_thresh\",\n",
|
||||||
|
" \"config/text_det_box_thresh\",\n",
|
||||||
|
" \"config/text_det_unclip_ratio\",\n",
|
||||||
|
" \"config/text_rec_score_thresh\",\n",
|
||||||
|
"]\n",
|
||||||
|
"\n",
|
||||||
|
"corr_cer = df[param_cols + [\"CER\"]].corr()[\"CER\"].sort_values(ascending=False)\n",
|
||||||
|
"corr_wer = df[param_cols + [\"WER\"]].corr()[\"WER\"].sort_values(ascending=False)\n",
|
||||||
|
"\n",
|
||||||
|
"print(\"Correlation with CER:\")\n",
|
||||||
|
"print(corr_cer)\n",
|
||||||
|
"print(\"\\nCorrelation with WER:\")\n",
|
||||||
|
"print(corr_wer)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user