Paddle ocr gpu support. #4
2
.gitignore
vendored
2
.gitignore
vendored
@@ -8,3 +8,5 @@ results
|
|||||||
node_modules
|
node_modules
|
||||||
src/paddle_ocr/wheels
|
src/paddle_ocr/wheels
|
||||||
src/*.log
|
src/*.log
|
||||||
|
src/output_*.ipynb
|
||||||
|
debugset/
|
||||||
|
|||||||
@@ -42,4 +42,33 @@ class ImageTextDataset:
|
|||||||
with open(txt_path, "r", encoding="utf-8") as f:
|
with open(txt_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
return image, text
|
return image, text
|
||||||
|
|
||||||
|
def get_output_path(self, idx, output_subdir, debugset_root="/app/debugset"):
|
||||||
|
"""Get output path for saving OCR result to debugset folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Sample index
|
||||||
|
output_subdir: Subdirectory name (e.g., 'paddle_text', 'doctr_text')
|
||||||
|
debugset_root: Root folder for debug output (default: /app/debugset)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path like /app/debugset/doc1/{output_subdir}/page_001.txt
|
||||||
|
"""
|
||||||
|
img_path, _ = self.samples[idx]
|
||||||
|
# img_path: /app/dataset/doc1/img/page_001.png
|
||||||
|
# Extract relative path: doc1/img/page_001.png
|
||||||
|
parts = img_path.split("/dataset/", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
rel_path = parts[1] # doc1/img/page_001.png
|
||||||
|
else:
|
||||||
|
rel_path = os.path.basename(img_path)
|
||||||
|
|
||||||
|
# Replace /img/ with /{output_subdir}/
|
||||||
|
rel_parts = rel_path.rsplit("/img/", 1)
|
||||||
|
doc_folder = rel_parts[0] # doc1
|
||||||
|
fname = os.path.splitext(rel_parts[1])[0] + ".txt" # page_001.txt
|
||||||
|
|
||||||
|
out_dir = os.path.join(debugset_root, doc_folder, output_subdir)
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
return os.path.join(out_dir, fname)
|
||||||
111
src/doctr_raytune_rest.ipynb
Normal file
111
src/doctr_raytune_rest.ipynb
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# DocTR Hyperparameter Optimization via REST API\n",
|
||||||
|
"\n",
|
||||||
|
"Uses Ray Tune + Optuna to find optimal DocTR parameters.\n",
|
||||||
|
"\n",
|
||||||
|
"## Prerequisites\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"cd src/doctr_service\n",
|
||||||
|
"docker compose up ocr-cpu # or ocr-gpu\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Service runs on port 8003."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "deps",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "setup",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from raytune_ocr import (\n",
|
||||||
|
" check_workers, create_trainable, run_tuner, analyze_results, correlation_analysis,\n",
|
||||||
|
" doctr_payload, DOCTR_SEARCH_SPACE, DOCTR_CONFIG_KEYS,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Worker ports\n",
|
||||||
|
"PORTS = [8003]\n",
|
||||||
|
"\n",
|
||||||
|
"# Check workers are running\n",
|
||||||
|
"healthy = check_workers(PORTS, \"DocTR\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "tune",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create trainable and run tuning\n",
|
||||||
|
"trainable = create_trainable(PORTS, doctr_payload)\n",
|
||||||
|
"\n",
|
||||||
|
"results = run_tuner(\n",
|
||||||
|
" trainable=trainable,\n",
|
||||||
|
" search_space=DOCTR_SEARCH_SPACE,\n",
|
||||||
|
" num_samples=64,\n",
|
||||||
|
" num_workers=len(healthy),\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "analysis",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Analyze results\n",
|
||||||
|
"df = analyze_results(\n",
|
||||||
|
" results,\n",
|
||||||
|
" prefix=\"raytune_doctr\",\n",
|
||||||
|
" config_keys=DOCTR_CONFIG_KEYS,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"df.describe()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "correlation",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Correlation analysis\n",
|
||||||
|
"correlation_analysis(df, DOCTR_CONFIG_KEYS)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python",
|
||||||
|
"version": "3.10.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -42,4 +42,33 @@ class ImageTextDataset:
|
|||||||
with open(txt_path, "r", encoding="utf-8") as f:
|
with open(txt_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
return image, text
|
return image, text
|
||||||
|
|
||||||
|
def get_output_path(self, idx, output_subdir, debugset_root="/app/debugset"):
|
||||||
|
"""Get output path for saving OCR result to debugset folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Sample index
|
||||||
|
output_subdir: Subdirectory name (e.g., 'paddle_text', 'doctr_text')
|
||||||
|
debugset_root: Root folder for debug output (default: /app/debugset)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path like /app/debugset/doc1/{output_subdir}/page_001.txt
|
||||||
|
"""
|
||||||
|
img_path, _ = self.samples[idx]
|
||||||
|
# img_path: /app/dataset/doc1/img/page_001.png
|
||||||
|
# Extract relative path: doc1/img/page_001.png
|
||||||
|
parts = img_path.split("/dataset/", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
rel_path = parts[1] # doc1/img/page_001.png
|
||||||
|
else:
|
||||||
|
rel_path = os.path.basename(img_path)
|
||||||
|
|
||||||
|
# Replace /img/ with /{output_subdir}/
|
||||||
|
rel_parts = rel_path.rsplit("/img/", 1)
|
||||||
|
doc_folder = rel_parts[0] # doc1
|
||||||
|
fname = os.path.splitext(rel_parts[1])[0] + ".txt" # page_001.txt
|
||||||
|
|
||||||
|
out_dir = os.path.join(debugset_root, doc_folder, output_subdir)
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
return os.path.join(out_dir, fname)
|
||||||
@@ -14,6 +14,7 @@ services:
|
|||||||
- "8003:8000"
|
- "8003:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- doctr-cache:/root/.cache/doctr
|
- doctr-cache:/root/.cache/doctr
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
@@ -35,6 +36,7 @@ services:
|
|||||||
- "8003:8000"
|
- "8003:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- doctr-cache:/root/.cache/doctr
|
- doctr-cache:/root/.cache/doctr
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ class EvaluateRequest(BaseModel):
|
|||||||
# Page range
|
# Page range
|
||||||
start_page: int = Field(5, ge=0, description="Start page index (inclusive)")
|
start_page: int = Field(5, ge=0, description="Start page index (inclusive)")
|
||||||
end_page: int = Field(10, ge=1, description="End page index (exclusive)")
|
end_page: int = Field(10, ge=1, description="End page index (exclusive)")
|
||||||
|
save_output: bool = Field(False, description="Save OCR predictions to debugset folder")
|
||||||
|
|
||||||
|
|
||||||
class EvaluateResponse(BaseModel):
|
class EvaluateResponse(BaseModel):
|
||||||
@@ -302,6 +303,12 @@ def evaluate(request: EvaluateRequest):
|
|||||||
)
|
)
|
||||||
time_per_page_list.append(float(time.time() - tp0))
|
time_per_page_list.append(float(time.time() - tp0))
|
||||||
|
|
||||||
|
# Save prediction to debugset if requested
|
||||||
|
if request.save_output:
|
||||||
|
out_path = state.dataset.get_output_path(idx, "doctr_text")
|
||||||
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(pred)
|
||||||
|
|
||||||
m = evaluate_text(ref, pred)
|
m = evaluate_text(ref, pred)
|
||||||
cer_list.append(m["CER"])
|
cer_list.append(m["CER"])
|
||||||
wer_list.append(m["WER"])
|
wer_list.append(m["WER"])
|
||||||
|
|||||||
111
src/easyocr_raytune_rest.ipynb
Normal file
111
src/easyocr_raytune_rest.ipynb
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "header",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# EasyOCR Hyperparameter Optimization via REST API\n",
|
||||||
|
"\n",
|
||||||
|
"Uses Ray Tune + Optuna to find optimal EasyOCR parameters.\n",
|
||||||
|
"\n",
|
||||||
|
"## Prerequisites\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"cd src/easyocr_service\n",
|
||||||
|
"docker compose up ocr-cpu # or ocr-gpu\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"Service runs on port 8002."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "deps",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "setup",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from raytune_ocr import (\n",
|
||||||
|
" check_workers, create_trainable, run_tuner, analyze_results, correlation_analysis,\n",
|
||||||
|
" easyocr_payload, EASYOCR_SEARCH_SPACE, EASYOCR_CONFIG_KEYS,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Worker ports\n",
|
||||||
|
"PORTS = [8002]\n",
|
||||||
|
"\n",
|
||||||
|
"# Check workers are running\n",
|
||||||
|
"healthy = check_workers(PORTS, \"EasyOCR\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "tune",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Create trainable and run tuning\n",
|
||||||
|
"trainable = create_trainable(PORTS, easyocr_payload)\n",
|
||||||
|
"\n",
|
||||||
|
"results = run_tuner(\n",
|
||||||
|
" trainable=trainable,\n",
|
||||||
|
" search_space=EASYOCR_SEARCH_SPACE,\n",
|
||||||
|
" num_samples=64,\n",
|
||||||
|
" num_workers=len(healthy),\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "analysis",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Analyze results\n",
|
||||||
|
"df = analyze_results(\n",
|
||||||
|
" results,\n",
|
||||||
|
" prefix=\"raytune_easyocr\",\n",
|
||||||
|
" config_keys=EASYOCR_CONFIG_KEYS,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"df.describe()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "correlation",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Correlation analysis\n",
|
||||||
|
"correlation_analysis(df, EASYOCR_CONFIG_KEYS)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python",
|
||||||
|
"version": "3.10.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
@@ -42,4 +42,33 @@ class ImageTextDataset:
|
|||||||
with open(txt_path, "r", encoding="utf-8") as f:
|
with open(txt_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
return image, text
|
return image, text
|
||||||
|
|
||||||
|
def get_output_path(self, idx, output_subdir, debugset_root="/app/debugset"):
|
||||||
|
"""Get output path for saving OCR result to debugset folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Sample index
|
||||||
|
output_subdir: Subdirectory name (e.g., 'paddle_text', 'doctr_text')
|
||||||
|
debugset_root: Root folder for debug output (default: /app/debugset)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path like /app/debugset/doc1/{output_subdir}/page_001.txt
|
||||||
|
"""
|
||||||
|
img_path, _ = self.samples[idx]
|
||||||
|
# img_path: /app/dataset/doc1/img/page_001.png
|
||||||
|
# Extract relative path: doc1/img/page_001.png
|
||||||
|
parts = img_path.split("/dataset/", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
rel_path = parts[1] # doc1/img/page_001.png
|
||||||
|
else:
|
||||||
|
rel_path = os.path.basename(img_path)
|
||||||
|
|
||||||
|
# Replace /img/ with /{output_subdir}/
|
||||||
|
rel_parts = rel_path.rsplit("/img/", 1)
|
||||||
|
doc_folder = rel_parts[0] # doc1
|
||||||
|
fname = os.path.splitext(rel_parts[1])[0] + ".txt" # page_001.txt
|
||||||
|
|
||||||
|
out_dir = os.path.join(debugset_root, doc_folder, output_subdir)
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
return os.path.join(out_dir, fname)
|
||||||
@@ -14,6 +14,7 @@ services:
|
|||||||
- "8002:8000"
|
- "8002:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- easyocr-cache:/root/.EasyOCR
|
- easyocr-cache:/root/.EasyOCR
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
@@ -34,6 +35,7 @@ services:
|
|||||||
- "8002:8000"
|
- "8002:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- easyocr-cache:/root/.EasyOCR
|
- easyocr-cache:/root/.EasyOCR
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
|
|||||||
@@ -133,6 +133,7 @@ class EvaluateRequest(BaseModel):
|
|||||||
# Page range
|
# Page range
|
||||||
start_page: int = Field(5, ge=0, description="Start page index (inclusive)")
|
start_page: int = Field(5, ge=0, description="Start page index (inclusive)")
|
||||||
end_page: int = Field(10, ge=1, description="End page index (exclusive)")
|
end_page: int = Field(10, ge=1, description="End page index (exclusive)")
|
||||||
|
save_output: bool = Field(False, description="Save OCR predictions to debugset folder")
|
||||||
|
|
||||||
|
|
||||||
class EvaluateResponse(BaseModel):
|
class EvaluateResponse(BaseModel):
|
||||||
@@ -301,6 +302,12 @@ def evaluate(request: EvaluateRequest):
|
|||||||
pred = assemble_easyocr_result(result)
|
pred = assemble_easyocr_result(result)
|
||||||
time_per_page_list.append(float(time.time() - tp0))
|
time_per_page_list.append(float(time.time() - tp0))
|
||||||
|
|
||||||
|
# Save prediction to debugset if requested
|
||||||
|
if request.save_output:
|
||||||
|
out_path = state.dataset.get_output_path(idx, "easyocr_text")
|
||||||
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(pred)
|
||||||
|
|
||||||
m = evaluate_text(ref, pred)
|
m = evaluate_text(ref, pred)
|
||||||
cer_list.append(m["CER"])
|
cer_list.append(m["CER"])
|
||||||
wer_list.append(m["WER"])
|
wer_list.append(m["WER"])
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -42,4 +42,33 @@ class ImageTextDataset:
|
|||||||
with open(txt_path, "r", encoding="utf-8") as f:
|
with open(txt_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
return image, text
|
return image, text
|
||||||
|
|
||||||
|
def get_output_path(self, idx, output_subdir, debugset_root="/app/debugset"):
|
||||||
|
"""Get output path for saving OCR result to debugset folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: Sample index
|
||||||
|
output_subdir: Subdirectory name (e.g., 'paddle_text', 'doctr_text')
|
||||||
|
debugset_root: Root folder for debug output (default: /app/debugset)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path like /app/debugset/doc1/{output_subdir}/page_001.txt
|
||||||
|
"""
|
||||||
|
img_path, _ = self.samples[idx]
|
||||||
|
# img_path: /app/dataset/doc1/img/page_001.png
|
||||||
|
# Extract relative path: doc1/img/page_001.png
|
||||||
|
parts = img_path.split("/dataset/", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
rel_path = parts[1] # doc1/img/page_001.png
|
||||||
|
else:
|
||||||
|
rel_path = os.path.basename(img_path)
|
||||||
|
|
||||||
|
# Replace /img/ with /{output_subdir}/
|
||||||
|
rel_parts = rel_path.rsplit("/img/", 1)
|
||||||
|
doc_folder = rel_parts[0] # doc1
|
||||||
|
fname = os.path.splitext(rel_parts[1])[0] + ".txt" # page_001.txt
|
||||||
|
|
||||||
|
out_dir = os.path.join(debugset_root, doc_folder, output_subdir)
|
||||||
|
os.makedirs(out_dir, exist_ok=True)
|
||||||
|
return os.path.join(out_dir, fname)
|
||||||
@@ -9,6 +9,7 @@ services:
|
|||||||
- "8001:8000"
|
- "8001:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- paddlex-cache:/root/.paddlex
|
- paddlex-cache:/root/.paddlex
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ services:
|
|||||||
- "8002:8000"
|
- "8002:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- paddlex-cache:/root/.paddlex
|
- paddlex-cache:/root/.paddlex
|
||||||
- ./scripts:/app/scripts:ro
|
- ./scripts:/app/scripts:ro
|
||||||
environment:
|
environment:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ x-ocr-gpu-common: &ocr-gpu-common
|
|||||||
image: seryus.ddns.net/unir/paddle-ocr-gpu:latest
|
image: seryus.ddns.net/unir/paddle-ocr-gpu:latest
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- paddlex-cache:/root/.paddlex
|
- paddlex-cache:/root/.paddlex
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
@@ -39,6 +40,7 @@ x-ocr-cpu-common: &ocr-cpu-common
|
|||||||
image: seryus.ddns.net/unir/paddle-ocr-cpu:latest
|
image: seryus.ddns.net/unir/paddle-ocr-cpu:latest
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- paddlex-cache:/root/.paddlex
|
- paddlex-cache:/root/.paddlex
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro # Your dataset
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw # Your dataset
|
||||||
- paddlex-cache:/root/.paddlex # For additional models at runtime
|
- paddlex-cache:/root/.paddlex # For additional models at runtime
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
@@ -74,6 +75,7 @@ services:
|
|||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
volumes:
|
volumes:
|
||||||
- ../dataset:/app/dataset:ro
|
- ../dataset:/app/dataset:ro
|
||||||
|
- ../debugset:/app/debugset:rw
|
||||||
- paddlex-cache:/root/.paddlex
|
- paddlex-cache:/root/.paddlex
|
||||||
environment:
|
environment:
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
|
|||||||
@@ -127,6 +127,7 @@ class EvaluateRequest(BaseModel):
|
|||||||
text_rec_score_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Recognition score threshold")
|
text_rec_score_thresh: float = Field(0.0, ge=0.0, le=1.0, description="Recognition score threshold")
|
||||||
start_page: int = Field(5, ge=0, description="Start page index (inclusive)")
|
start_page: int = Field(5, ge=0, description="Start page index (inclusive)")
|
||||||
end_page: int = Field(10, ge=1, description="End page index (exclusive)")
|
end_page: int = Field(10, ge=1, description="End page index (exclusive)")
|
||||||
|
save_output: bool = Field(False, description="Save OCR predictions to debugset folder")
|
||||||
|
|
||||||
|
|
||||||
class EvaluateResponse(BaseModel):
|
class EvaluateResponse(BaseModel):
|
||||||
@@ -307,6 +308,12 @@ def evaluate(request: EvaluateRequest):
|
|||||||
pred = assemble_from_paddle_result(out)
|
pred = assemble_from_paddle_result(out)
|
||||||
time_per_page_list.append(float(time.time() - tp0))
|
time_per_page_list.append(float(time.time() - tp0))
|
||||||
|
|
||||||
|
# Save prediction to debugset if requested
|
||||||
|
if request.save_output:
|
||||||
|
out_path = state.dataset.get_output_path(idx, "paddle_text")
|
||||||
|
with open(out_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(pred)
|
||||||
|
|
||||||
m = evaluate_text(ref, pred)
|
m = evaluate_text(ref, pred)
|
||||||
cer_list.append(m["CER"])
|
cer_list.append(m["CER"])
|
||||||
wer_list.append(m["WER"])
|
wer_list.append(m["WER"])
|
||||||
|
|||||||
@@ -7,263 +7,81 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"# PaddleOCR Hyperparameter Optimization via REST API\n",
|
"# PaddleOCR Hyperparameter Optimization via REST API\n",
|
||||||
"\n",
|
"\n",
|
||||||
"This notebook runs Ray Tune hyperparameter search calling the PaddleOCR REST API (Docker container).\n",
|
"Uses Ray Tune + Optuna to find optimal PaddleOCR parameters.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"**Benefits:**\n",
|
|
||||||
"- No model reload per trial - Model stays loaded in Docker container\n",
|
|
||||||
"- Faster trials - Skip ~10s model load time per trial\n",
|
|
||||||
"- Cleaner code - REST API replaces subprocess + CLI arg parsing"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "prereq",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Prerequisites\n",
|
"## Prerequisites\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Start 2 PaddleOCR workers for parallel hyperparameter tuning:\n",
|
|
||||||
"\n",
|
|
||||||
"```bash\n",
|
"```bash\n",
|
||||||
"cd src/paddle_ocr\n",
|
"cd src/paddle_ocr\n",
|
||||||
"docker compose -f docker-compose.workers.yml up\n",
|
"docker compose -f docker-compose.workers.yml up # GPU workers on 8001-8002\n",
|
||||||
"```\n",
|
"# or: docker compose -f docker-compose.workers.yml --profile cpu up\n",
|
||||||
"\n",
|
|
||||||
"This starts 2 GPU workers on ports 8001-8002, allowing 2 concurrent trials.\n",
|
|
||||||
"\n",
|
|
||||||
"For CPU-only systems:\n",
|
|
||||||
"```bash\n",
|
|
||||||
"docker compose -f docker-compose.workers.yml --profile cpu up\n",
|
|
||||||
"```"
|
"```"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "code",
|
||||||
"id": "3ob9fsoilc4",
|
"execution_count": null,
|
||||||
|
"id": "deps",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"## 0. Dependencies"
|
"%pip install -q -U \"ray[tune]\" optuna requests pandas"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "wyr2nsoj7",
|
"id": "setup",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Install dependencies (run once)\n",
|
"from raytune_ocr import (\n",
|
||||||
"%pip install -U \"ray[tune]\"\n",
|
" check_workers, create_trainable, run_tuner, analyze_results, correlation_analysis,\n",
|
||||||
"%pip install optuna\n",
|
" paddle_ocr_payload, PADDLE_OCR_SEARCH_SPACE, PADDLE_OCR_CONFIG_KEYS,\n",
|
||||||
"%pip install requests pandas"
|
")\n",
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "imports-header",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## 1. Imports & Setup"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "imports",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": "import os\nfrom datetime import datetime\n\nimport requests\nimport pandas as pd\n\nimport ray\nfrom ray import tune, train\nfrom ray.tune.search.optuna import OptunaSearch"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "config-header",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## 2. API Configuration"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "config",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# PaddleOCR REST API endpoints - 2 workers for parallel trials\n",
|
|
||||||
"# Start workers with: cd src/paddle_ocr && docker compose -f docker-compose.workers.yml up\n",
|
|
||||||
"WORKER_PORTS = [8001, 8002]\n",
|
|
||||||
"WORKER_URLS = [f\"http://localhost:{port}\" for port in WORKER_PORTS]\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Output folder for results\n",
|
"# Worker ports\n",
|
||||||
"OUTPUT_FOLDER = \"results\"\n",
|
"PORTS = [8001, 8002]\n",
|
||||||
"os.makedirs(OUTPUT_FOLDER, exist_ok=True)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"# Number of concurrent trials = number of workers\n",
|
"# Check workers are running\n",
|
||||||
"NUM_WORKERS = len(WORKER_URLS)"
|
"healthy = check_workers(PORTS, \"PaddleOCR\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "health-check",
|
"id": "tune",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Verify all workers are running\n",
|
"# Create trainable and run tuning\n",
|
||||||
"healthy_workers = []\n",
|
"trainable = create_trainable(PORTS, paddle_ocr_payload)\n",
|
||||||
"for url in WORKER_URLS:\n",
|
|
||||||
" try:\n",
|
|
||||||
" health = requests.get(f\"{url}/health\", timeout=10).json()\n",
|
|
||||||
" if health['status'] == 'ok' and health['model_loaded']:\n",
|
|
||||||
" healthy_workers.append(url)\n",
|
|
||||||
" print(f\"✓ {url}: {health['status']} (GPU: {health.get('gpu_name', 'N/A')})\")\n",
|
|
||||||
" else:\n",
|
|
||||||
" print(f\"✗ {url}: not ready yet\")\n",
|
|
||||||
" except requests.exceptions.ConnectionError:\n",
|
|
||||||
" print(f\"✗ {url}: not reachable\")\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"if not healthy_workers:\n",
|
"results = run_tuner(\n",
|
||||||
" raise RuntimeError(\n",
|
" trainable=trainable,\n",
|
||||||
" \"No healthy workers found. Start them with:\\n\"\n",
|
" search_space=PADDLE_OCR_SEARCH_SPACE,\n",
|
||||||
" \" cd src/paddle_ocr && docker compose -f docker-compose.workers.yml up\"\n",
|
" num_samples=64,\n",
|
||||||
" )\n",
|
" num_workers=len(healthy),\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "analysis",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Analyze results\n",
|
||||||
|
"df = analyze_results(\n",
|
||||||
|
" results,\n",
|
||||||
|
" prefix=\"raytune_paddle\",\n",
|
||||||
|
" config_keys=PADDLE_OCR_CONFIG_KEYS,\n",
|
||||||
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(f\"\\n{len(healthy_workers)}/{len(WORKER_URLS)} workers ready for parallel tuning\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "search-space-header",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## 3. Search Space"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "search-space",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"search_space = {\n",
|
|
||||||
" # Whether to use document image orientation classification\n",
|
|
||||||
" \"use_doc_orientation_classify\": tune.choice([True, False]),\n",
|
|
||||||
" # Whether to use text image unwarping\n",
|
|
||||||
" \"use_doc_unwarping\": tune.choice([True, False]),\n",
|
|
||||||
" # Whether to use text line orientation classification\n",
|
|
||||||
" \"textline_orientation\": tune.choice([True, False]),\n",
|
|
||||||
" # Detection pixel threshold (pixels > threshold are considered text)\n",
|
|
||||||
" \"text_det_thresh\": tune.uniform(0.0, 0.7),\n",
|
|
||||||
" # Detection box threshold (average score within border)\n",
|
|
||||||
" \"text_det_box_thresh\": tune.uniform(0.0, 0.7),\n",
|
|
||||||
" # Text detection expansion coefficient\n",
|
|
||||||
" \"text_det_unclip_ratio\": tune.choice([0.0]),\n",
|
|
||||||
" # Text recognition threshold (filter low confidence results)\n",
|
|
||||||
" \"text_rec_score_thresh\": tune.uniform(0.0, 0.7),\n",
|
|
||||||
"}"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "trainable-header",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## 4. Trainable Function"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "trainable",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": "def trainable_paddle_ocr(config):\n \"\"\"Call PaddleOCR REST API with the given hyperparameter config.\"\"\"\n import random\n import requests\n from ray import train\n\n # Worker URLs - random selection (load balances with 2 workers, 2 concurrent trials)\n WORKER_PORTS = [8001, 8002]\n api_url = f\"http://localhost:{random.choice(WORKER_PORTS)}\"\n\n payload = {\n \"pdf_folder\": \"/app/dataset\",\n \"use_doc_orientation_classify\": config.get(\"use_doc_orientation_classify\", False),\n \"use_doc_unwarping\": config.get(\"use_doc_unwarping\", False),\n \"textline_orientation\": config.get(\"textline_orientation\", True),\n \"text_det_thresh\": config.get(\"text_det_thresh\", 0.0),\n \"text_det_box_thresh\": config.get(\"text_det_box_thresh\", 0.0),\n \"text_det_unclip_ratio\": config.get(\"text_det_unclip_ratio\", 1.5),\n \"text_rec_score_thresh\": config.get(\"text_rec_score_thresh\", 0.0),\n \"start_page\": 5,\n \"end_page\": 10,\n }\n\n try:\n response = requests.post(f\"{api_url}/evaluate\", json=payload, timeout=None)\n response.raise_for_status()\n metrics = response.json()\n metrics[\"worker\"] = api_url\n train.report(metrics)\n except Exception as e:\n train.report({\n \"CER\": 1.0,\n \"WER\": 1.0,\n \"TIME\": 0.0,\n \"PAGES\": 0,\n \"TIME_PER_PAGE\": 0,\n \"worker\": api_url,\n \"ERROR\": str(e)[:500]\n })"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "tuner-header",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## 5. Run Tuner"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "ray-init",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"ray.init(ignore_reinit_error=True)\n",
|
|
||||||
"print(f\"Ray Tune ready (version: {ray.__version__})\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "tuner",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": "tuner = tune.Tuner(\n trainable_paddle_ocr,\n tune_config=tune.TuneConfig(\n metric=\"CER\",\n mode=\"min\",\n search_alg=OptunaSearch(),\n num_samples=64,\n max_concurrent_trials=NUM_WORKERS, # Run trials in parallel across workers\n ),\n param_space=search_space,\n)\n\nresults = tuner.fit()"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "analysis-header",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## 6. Results Analysis"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "results-df",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"df = results.get_dataframe()\n",
|
|
||||||
"df.describe()"
|
"df.describe()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "save-results",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Save results to CSV\n",
|
|
||||||
"timestamp = datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
|
|
||||||
"filename = f\"raytune_paddle_rest_results_{timestamp}.csv\"\n",
|
|
||||||
"filepath = os.path.join(OUTPUT_FOLDER, filename)\n",
|
|
||||||
"\n",
|
|
||||||
"df.to_csv(filepath, index=False)\n",
|
|
||||||
"print(f\"Results saved: {filepath}\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "best-config",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Best configuration\n",
|
|
||||||
"best = df.loc[df[\"CER\"].idxmin()]\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"Best CER: {best['CER']:.6f}\")\n",
|
|
||||||
"print(f\"Best WER: {best['WER']:.6f}\")\n",
|
|
||||||
"print(f\"\\nOptimal Configuration:\")\n",
|
|
||||||
"print(f\" textline_orientation: {best['config/textline_orientation']}\")\n",
|
|
||||||
"print(f\" use_doc_orientation_classify: {best['config/use_doc_orientation_classify']}\")\n",
|
|
||||||
"print(f\" use_doc_unwarping: {best['config/use_doc_unwarping']}\")\n",
|
|
||||||
"print(f\" text_det_thresh: {best['config/text_det_thresh']:.4f}\")\n",
|
|
||||||
"print(f\" text_det_box_thresh: {best['config/text_det_box_thresh']:.4f}\")\n",
|
|
||||||
"print(f\" text_det_unclip_ratio: {best['config/text_det_unclip_ratio']}\")\n",
|
|
||||||
"print(f\" text_rec_score_thresh: {best['config/text_rec_score_thresh']:.4f}\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@@ -272,42 +90,21 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Correlation analysis\n",
|
"# Correlation analysis\n",
|
||||||
"param_cols = [\n",
|
"correlation_analysis(df, PADDLE_OCR_CONFIG_KEYS)"
|
||||||
" \"config/text_det_thresh\",\n",
|
|
||||||
" \"config/text_det_box_thresh\",\n",
|
|
||||||
" \"config/text_det_unclip_ratio\",\n",
|
|
||||||
" \"config/text_rec_score_thresh\",\n",
|
|
||||||
"]\n",
|
|
||||||
"\n",
|
|
||||||
"corr_cer = df[param_cols + [\"CER\"]].corr()[\"CER\"].sort_values(ascending=False)\n",
|
|
||||||
"corr_wer = df[param_cols + [\"WER\"]].corr()[\"WER\"].sort_values(ascending=False)\n",
|
|
||||||
"\n",
|
|
||||||
"print(\"Correlation with CER:\")\n",
|
|
||||||
"print(corr_cer)\n",
|
|
||||||
"print(\"\\nCorrelation with WER:\")\n",
|
|
||||||
"print(corr_wer)"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": ".venv",
|
"display_name": "Python 3",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"version": "3.10.0"
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.12.3"
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 5
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|||||||
333
src/raytune_ocr.py
Normal file
333
src/raytune_ocr.py
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
# raytune_ocr.py
|
||||||
|
# Shared Ray Tune utilities for OCR hyperparameter optimization
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# from raytune_ocr import check_workers, create_trainable, run_tuner, analyze_results
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Callable
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
import ray
|
||||||
|
from ray import tune, train
|
||||||
|
from ray.tune.search.optuna import OptunaSearch
|
||||||
|
|
||||||
|
|
||||||
|
def check_workers(ports: List[int], service_name: str = "OCR") -> List[str]:
|
||||||
|
"""
|
||||||
|
Verify workers are running and return healthy URLs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ports: List of port numbers to check
|
||||||
|
service_name: Name for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of healthy worker URLs
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError if no healthy workers found
|
||||||
|
"""
|
||||||
|
worker_urls = [f"http://localhost:{port}" for port in ports]
|
||||||
|
healthy_workers = []
|
||||||
|
|
||||||
|
for url in worker_urls:
|
||||||
|
try:
|
||||||
|
health = requests.get(f"{url}/health", timeout=10).json()
|
||||||
|
if health.get('status') == 'ok' and health.get('model_loaded'):
|
||||||
|
healthy_workers.append(url)
|
||||||
|
gpu = health.get('gpu_name', 'CPU')
|
||||||
|
print(f"✓ {url}: {health['status']} ({gpu})")
|
||||||
|
else:
|
||||||
|
print(f"✗ {url}: not ready yet")
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
print(f"✗ {url}: not reachable")
|
||||||
|
|
||||||
|
if not healthy_workers:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"No healthy {service_name} workers found.\n"
|
||||||
|
f"Checked ports: {ports}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{len(healthy_workers)}/{len(worker_urls)} workers ready")
|
||||||
|
return healthy_workers
|
||||||
|
|
||||||
|
|
||||||
|
def create_trainable(ports: List[int], payload_fn: Callable[[Dict], Dict]) -> Callable:
|
||||||
|
"""
|
||||||
|
Factory to create a trainable function for Ray Tune.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ports: List of worker ports for load balancing
|
||||||
|
payload_fn: Function that takes config dict and returns API payload dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trainable function for Ray Tune
|
||||||
|
"""
|
||||||
|
def trainable(config):
|
||||||
|
import random
|
||||||
|
import requests
|
||||||
|
from ray import train
|
||||||
|
|
||||||
|
api_url = f"http://localhost:{random.choice(ports)}"
|
||||||
|
payload = payload_fn(config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(f"{api_url}/evaluate", json=payload, timeout=None)
|
||||||
|
response.raise_for_status()
|
||||||
|
metrics = response.json()
|
||||||
|
metrics["worker"] = api_url
|
||||||
|
train.report(metrics)
|
||||||
|
except Exception as e:
|
||||||
|
train.report({
|
||||||
|
"CER": 1.0,
|
||||||
|
"WER": 1.0,
|
||||||
|
"TIME": 0.0,
|
||||||
|
"PAGES": 0,
|
||||||
|
"TIME_PER_PAGE": 0,
|
||||||
|
"worker": api_url,
|
||||||
|
"ERROR": str(e)[:500]
|
||||||
|
})
|
||||||
|
|
||||||
|
return trainable
|
||||||
|
|
||||||
|
|
||||||
|
def run_tuner(
|
||||||
|
trainable: Callable,
|
||||||
|
search_space: Dict[str, Any],
|
||||||
|
num_samples: int = 64,
|
||||||
|
num_workers: int = 1,
|
||||||
|
metric: str = "CER",
|
||||||
|
mode: str = "min",
|
||||||
|
) -> tune.ResultGrid:
|
||||||
|
"""
|
||||||
|
Initialize Ray and run hyperparameter tuning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainable: Trainable function from create_trainable()
|
||||||
|
search_space: Dict of parameter names to tune.* search spaces
|
||||||
|
num_samples: Number of trials to run
|
||||||
|
num_workers: Max concurrent trials
|
||||||
|
metric: Metric to optimize
|
||||||
|
mode: "min" or "max"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ray Tune ResultGrid
|
||||||
|
"""
|
||||||
|
ray.init(ignore_reinit_error=True, include_dashboard=False)
|
||||||
|
print(f"Ray Tune ready (version: {ray.__version__})")
|
||||||
|
|
||||||
|
tuner = tune.Tuner(
|
||||||
|
trainable,
|
||||||
|
tune_config=tune.TuneConfig(
|
||||||
|
metric=metric,
|
||||||
|
mode=mode,
|
||||||
|
search_alg=OptunaSearch(),
|
||||||
|
num_samples=num_samples,
|
||||||
|
max_concurrent_trials=num_workers,
|
||||||
|
),
|
||||||
|
param_space=search_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
return tuner.fit()
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_results(
|
||||||
|
results: tune.ResultGrid,
|
||||||
|
output_folder: str = "results",
|
||||||
|
prefix: str = "raytune",
|
||||||
|
config_keys: List[str] = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Analyze and save tuning results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Ray Tune ResultGrid
|
||||||
|
output_folder: Directory to save CSV
|
||||||
|
prefix: Filename prefix
|
||||||
|
config_keys: List of config keys to show in best result (without 'config/' prefix)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Results DataFrame
|
||||||
|
"""
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
df = results.get_dataframe()
|
||||||
|
|
||||||
|
# Save to CSV
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"{prefix}_results_{timestamp}.csv"
|
||||||
|
filepath = os.path.join(output_folder, filename)
|
||||||
|
df.to_csv(filepath, index=False)
|
||||||
|
print(f"Results saved: {filepath}")
|
||||||
|
|
||||||
|
# Best configuration
|
||||||
|
best = df.loc[df["CER"].idxmin()]
|
||||||
|
print(f"\nBest CER: {best['CER']:.6f}")
|
||||||
|
print(f"Best WER: {best['WER']:.6f}")
|
||||||
|
|
||||||
|
if config_keys:
|
||||||
|
print(f"\nOptimal Configuration:")
|
||||||
|
for key in config_keys:
|
||||||
|
col = f"config/{key}"
|
||||||
|
if col in best:
|
||||||
|
val = best[col]
|
||||||
|
if isinstance(val, float):
|
||||||
|
print(f" {key}: {val:.4f}")
|
||||||
|
else:
|
||||||
|
print(f" {key}: {val}")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def correlation_analysis(df: pd.DataFrame, param_keys: List[str]) -> None:
|
||||||
|
"""
|
||||||
|
Print correlation of numeric parameters with CER/WER.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Results DataFrame
|
||||||
|
param_keys: List of config keys (without 'config/' prefix)
|
||||||
|
"""
|
||||||
|
param_cols = [f"config/{k}" for k in param_keys if f"config/{k}" in df.columns]
|
||||||
|
numeric_cols = [c for c in param_cols if df[c].dtype in ['float64', 'int64']]
|
||||||
|
|
||||||
|
if not numeric_cols:
|
||||||
|
print("No numeric parameters for correlation analysis")
|
||||||
|
return
|
||||||
|
|
||||||
|
corr_cer = df[numeric_cols + ["CER"]].corr()["CER"].sort_values(ascending=False)
|
||||||
|
corr_wer = df[numeric_cols + ["WER"]].corr()["WER"].sort_values(ascending=False)
|
||||||
|
|
||||||
|
print("Correlation with CER:")
|
||||||
|
print(corr_cer)
|
||||||
|
print("\nCorrelation with WER:")
|
||||||
|
print(corr_wer)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# OCR-specific payload functions
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def paddle_ocr_payload(config: Dict, start_page: int = 5, end_page: int = 10, save_output: bool = False) -> Dict:
|
||||||
|
"""Create payload for PaddleOCR API."""
|
||||||
|
return {
|
||||||
|
"pdf_folder": "/app/dataset",
|
||||||
|
"use_doc_orientation_classify": config.get("use_doc_orientation_classify", False),
|
||||||
|
"use_doc_unwarping": config.get("use_doc_unwarping", False),
|
||||||
|
"textline_orientation": config.get("textline_orientation", True),
|
||||||
|
"text_det_thresh": config.get("text_det_thresh", 0.0),
|
||||||
|
"text_det_box_thresh": config.get("text_det_box_thresh", 0.0),
|
||||||
|
"text_det_unclip_ratio": config.get("text_det_unclip_ratio", 1.5),
|
||||||
|
"text_rec_score_thresh": config.get("text_rec_score_thresh", 0.0),
|
||||||
|
"start_page": start_page,
|
||||||
|
"end_page": end_page,
|
||||||
|
"save_output": save_output,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def doctr_payload(config: Dict, start_page: int = 5, end_page: int = 10, save_output: bool = False) -> Dict:
|
||||||
|
"""Create payload for DocTR API."""
|
||||||
|
return {
|
||||||
|
"pdf_folder": "/app/dataset",
|
||||||
|
"assume_straight_pages": config.get("assume_straight_pages", True),
|
||||||
|
"straighten_pages": config.get("straighten_pages", False),
|
||||||
|
"preserve_aspect_ratio": config.get("preserve_aspect_ratio", True),
|
||||||
|
"symmetric_pad": config.get("symmetric_pad", True),
|
||||||
|
"disable_page_orientation": config.get("disable_page_orientation", False),
|
||||||
|
"disable_crop_orientation": config.get("disable_crop_orientation", False),
|
||||||
|
"resolve_lines": config.get("resolve_lines", True),
|
||||||
|
"resolve_blocks": config.get("resolve_blocks", False),
|
||||||
|
"paragraph_break": config.get("paragraph_break", 0.035),
|
||||||
|
"start_page": start_page,
|
||||||
|
"end_page": end_page,
|
||||||
|
"save_output": save_output,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def easyocr_payload(config: Dict, start_page: int = 5, end_page: int = 10, save_output: bool = False) -> Dict:
|
||||||
|
"""Create payload for EasyOCR API."""
|
||||||
|
return {
|
||||||
|
"pdf_folder": "/app/dataset",
|
||||||
|
"text_threshold": config.get("text_threshold", 0.7),
|
||||||
|
"low_text": config.get("low_text", 0.4),
|
||||||
|
"link_threshold": config.get("link_threshold", 0.4),
|
||||||
|
"slope_ths": config.get("slope_ths", 0.1),
|
||||||
|
"ycenter_ths": config.get("ycenter_ths", 0.5),
|
||||||
|
"height_ths": config.get("height_ths", 0.5),
|
||||||
|
"width_ths": config.get("width_ths", 0.5),
|
||||||
|
"add_margin": config.get("add_margin", 0.1),
|
||||||
|
"contrast_ths": config.get("contrast_ths", 0.1),
|
||||||
|
"adjust_contrast": config.get("adjust_contrast", 0.5),
|
||||||
|
"decoder": config.get("decoder", "greedy"),
|
||||||
|
"beamWidth": config.get("beamWidth", 5),
|
||||||
|
"min_size": config.get("min_size", 10),
|
||||||
|
"start_page": start_page,
|
||||||
|
"end_page": end_page,
|
||||||
|
"save_output": save_output,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Search spaces
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
PADDLE_OCR_SEARCH_SPACE = {
|
||||||
|
"use_doc_orientation_classify": tune.choice([True, False]),
|
||||||
|
"use_doc_unwarping": tune.choice([True, False]),
|
||||||
|
"textline_orientation": tune.choice([True, False]),
|
||||||
|
"text_det_thresh": tune.uniform(0.0, 0.7),
|
||||||
|
"text_det_box_thresh": tune.uniform(0.0, 0.7),
|
||||||
|
"text_det_unclip_ratio": tune.choice([0.0]),
|
||||||
|
"text_rec_score_thresh": tune.uniform(0.0, 0.7),
|
||||||
|
}
|
||||||
|
|
||||||
|
DOCTR_SEARCH_SPACE = {
|
||||||
|
"assume_straight_pages": tune.choice([True, False]),
|
||||||
|
"straighten_pages": tune.choice([True, False]),
|
||||||
|
"preserve_aspect_ratio": tune.choice([True, False]),
|
||||||
|
"symmetric_pad": tune.choice([True, False]),
|
||||||
|
"disable_page_orientation": tune.choice([True, False]),
|
||||||
|
"disable_crop_orientation": tune.choice([True, False]),
|
||||||
|
"resolve_lines": tune.choice([True, False]),
|
||||||
|
"resolve_blocks": tune.choice([True, False]),
|
||||||
|
"paragraph_break": tune.uniform(0.01, 0.1),
|
||||||
|
}
|
||||||
|
|
||||||
|
EASYOCR_SEARCH_SPACE = {
|
||||||
|
"text_threshold": tune.uniform(0.3, 0.9),
|
||||||
|
"low_text": tune.uniform(0.2, 0.6),
|
||||||
|
"link_threshold": tune.uniform(0.2, 0.6),
|
||||||
|
"slope_ths": tune.uniform(0.0, 0.3),
|
||||||
|
"ycenter_ths": tune.uniform(0.3, 1.0),
|
||||||
|
"height_ths": tune.uniform(0.3, 1.0),
|
||||||
|
"width_ths": tune.uniform(0.3, 1.0),
|
||||||
|
"add_margin": tune.uniform(0.0, 0.3),
|
||||||
|
"contrast_ths": tune.uniform(0.05, 0.3),
|
||||||
|
"adjust_contrast": tune.uniform(0.3, 0.8),
|
||||||
|
"decoder": tune.choice(["greedy", "beamsearch"]),
|
||||||
|
"beamWidth": tune.choice([3, 5, 7, 10]),
|
||||||
|
"min_size": tune.choice([5, 10, 15, 20]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Config keys for results display
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
PADDLE_OCR_CONFIG_KEYS = [
|
||||||
|
"use_doc_orientation_classify", "use_doc_unwarping", "textline_orientation",
|
||||||
|
"text_det_thresh", "text_det_box_thresh", "text_det_unclip_ratio", "text_rec_score_thresh",
|
||||||
|
]
|
||||||
|
|
||||||
|
DOCTR_CONFIG_KEYS = [
|
||||||
|
"assume_straight_pages", "straighten_pages", "preserve_aspect_ratio", "symmetric_pad",
|
||||||
|
"disable_page_orientation", "disable_crop_orientation", "resolve_lines", "resolve_blocks",
|
||||||
|
"paragraph_break",
|
||||||
|
]
|
||||||
|
|
||||||
|
EASYOCR_CONFIG_KEYS = [
|
||||||
|
"text_threshold", "low_text", "link_threshold", "slope_ths", "ycenter_ths",
|
||||||
|
"height_ths", "width_ths", "add_margin", "contrast_ths", "adjust_contrast",
|
||||||
|
"decoder", "beamWidth", "min_size",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user