Files
MastersThesis/paddle_ocr_tuning.py

251 lines
8.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Imports
import argparse, json, os, sys, time
from typing import List
import numpy as np
from PIL import Image
import fitz # PyMuPDF
from paddleocr import PaddleOCR
import re
from jiwer import wer, cer
def export_config(paddleocr_model):
yaml_path = "paddleocr_pipeline_dump.yaml"
paddleocr_model.export_paddlex_config_to_yaml(yaml_path)
print("Exported:", yaml_path)
def pdf_to_images(pdf_path: str, dpi: int = 300, pages: List[int] = None) -> List[Image.Image]:
"""
Render a PDF into a list of PIL Images using PyMuPDF or pdf2image.
'pages' is 1-based (e.g., range(1, 10) -> pages 19).
"""
images = []
if fitz is not None:
doc = fitz.open(pdf_path)
total_pages = len(doc)
# Adjust page indices (PyMuPDF uses 0-based indexing)
if pages is None:
page_indices = list(range(total_pages))
else:
# Filter out invalid pages and convert to 0-based
page_indices = [p - 1 for p in pages if 1 <= p <= total_pages]
for i in page_indices:
page = doc.load_page(i)
mat = fitz.Matrix(dpi / 72.0, dpi / 72.0)
pix = page.get_pixmap(matrix=mat, alpha=False)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
doc.close()
else:
raise RuntimeError("Install PyMuPDF or pdf2image to convert PDFs.")
return images
def pdf_extract_text(pdf_path, page_num, line_tolerance=15) -> str:
"""
Extracts text from a specific PDF page in proper reading order.
Adds '\n' when blocks are vertically separated more than line_tolerance.
Removes bullet-like characters (, •, ▪, etc.).
"""
doc = fitz.open(pdf_path)
if page_num < 1 or page_num > len(doc):
return ""
page = doc[page_num - 1]
blocks = page.get_text("blocks") # (x0, y0, x1, y1, text, block_no, block_type)
# Sort blocks: top-to-bottom, left-to-right
blocks_sorted = sorted(blocks, key=lambda b: (b[1], b[0]))
text_lines = []
last_y = None
for b in blocks_sorted:
y0 = b[1]
text_block = b[4].strip()
# Remove bullet-like characters
text_block = re.sub(r"[•▪◦●❖▶■]", "", text_block)
# If new line (based on vertical gap)
if last_y is not None and abs(y0 - last_y) > line_tolerance:
text_lines.append("") # blank line for spacing
text_lines.append(text_block.strip())
last_y = y0
# Join all lines with real newlines
text = "\n".join(text_lines)
# Normalize spaces
text = re.sub(r"\s*\n\s*", "\n", text).strip() # remove spaces around newlines
text = re.sub(r" +", " ", text).strip() # collapse multiple spaces to one
text = re.sub(r"\n{3,}", "\n\n", text).strip() # avoid triple blank lines
doc.close()
return text
def evaluate_text(reference, prediction):
return {'WER': wer(reference, prediction), 'CER': cer(reference, prediction)}
def _normalize_box_xyxy(box):
"""
Accepts:
- [[x,y],[x,y],[x,y],[x,y]] (quad)
- [x0, y0, x1, y1] (flat)
- [x0, y0, x1, y1, x2, y2, x3, y3] (flat quad)
Returns (x0, y0, x1, y1)
"""
# Quad as list of points?
if isinstance(box, (list, tuple)) and box and isinstance(box[0], (list, tuple)):
xs = [p[0] for p in box]
ys = [p[1] for p in box]
return min(xs), min(ys), max(xs), max(ys)
# Flat list
if isinstance(box, (list, tuple)):
if len(box) == 4:
x0, y0, x1, y1 = box
# ensure order
return min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1)
if len(box) == 8:
xs = box[0::2]
ys = box[1::2]
return min(xs), min(ys), max(xs), max(ys)
# Fallback
raise ValueError(f"Unrecognized box format: {box!r}")
def assemble_from_paddle_result(paddleocr_predict, min_score=0.0, line_tol_factor=0.6):
"""
Robust line grouping for PaddleOCR outputs:
- normalizes boxes to (x0,y0,x1,y1)
- adaptive line tolerance based on median box height
- optional confidence filter
- inserts '\n' between lines and preserves left→right order
"""
result = paddleocr_predict
boxes_all = [] # (x0, y0, x1, y1, y_mid, text, score)
for item in result:
res = item.json.get("res", {})
boxes = res.get("rec_boxes", []) or [] # be defensive
texts = res.get("rec_texts", []) or []
scores = res.get("rec_scores", None)
for i, (box, text) in enumerate(zip(boxes, texts)):
try:
x0, y0, x1, y1 = _normalize_box_xyxy(box)
except Exception:
# Skip weird boxes gracefully
continue
y_mid = 0.5 * (y0 + y1)
score = float(scores[i]) if (scores is not None and i < len(scores)) else 1.0
t = re.sub(r"\s+", " ", str(text)).strip()
if not t:
continue
boxes_all.append((x0, y0, x1, y1, y_mid, t, score))
if min_score > 0:
boxes_all = [b for b in boxes_all if b[6] >= min_score]
if not boxes_all:
return ""
# Adaptive line tolerance
heights = [b[3] - b[1] for b in boxes_all]
median_h = float(np.median(heights)) if heights else 20.0
line_tol = max(8.0, line_tol_factor * median_h)
# Sort by vertical mid, then x0
boxes_all.sort(key=lambda b: (b[4], b[0]))
# Group into lines
lines, cur, last_y = [], [], None
for x0, y0, x1, y1, y_mid, text, score in boxes_all:
if last_y is None or abs(y_mid - last_y) <= line_tol:
cur.append((x0, text))
else:
cur.sort(key=lambda t: t[0])
lines.append(" ".join(t[1] for t in cur))
cur = [(x0, text)]
last_y = y_mid
if cur:
cur.sort(key=lambda t: t[0])
lines.append(" ".join(t[1] for t in cur))
res = "\n".join(lines)
res = re.sub(r"\s+\n", "\n", res).strip()
return res
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--pdf-folder", required=True)
parser.add_argument("--dpi", type=int, default=300)
parser.add_argument("--textline-orientation", type=lambda s: s.lower()=="true", default=True)
parser.add_argument("--text-det-box-thresh", type=float, default=0.6)
parser.add_argument("--text-det-unclip-ratio", type=float, default=1.5)
parser.add_argument("--text-rec-score-thresh", type=float, default=0.0)
parser.add_argument("--line-tolerance", type=float, default=0.6)
parser.add_argument("--min-box-score", type=float, default=0.0)
parser.add_argument("--pages-per-pdf", type=int, default=2)
parser.add_argument("--lang", default="es")
args = parser.parse_args()
ocr = PaddleOCR(
text_detection_model_name="PP-OCRv5_server_det",
text_recognition_model_name="PP-OCRv5_server_rec",
lang=args.lang,
)
cer_list, wer_list = [], []
time_per_page_list = []
t0 = time.time()
for fname in os.listdir(args.pdf_folder):
if not fname.lower().endswith(".pdf"):
continue
pdf_path = os.path.join(args.pdf_folder, fname)
images = pdf_to_images(pdf_path, dpi=args.dpi, pages=range(1, args.pages_per_pdf+1))
for i, img in enumerate(images):
ref = pdf_extract_text(pdf_path, i+1)
arr = np.array(img)
tp0 = time.time()
out = ocr.predict(
arr,
text_det_box_thresh=args.text_det_box_thresh,
text_det_unclip_ratio=args.text_det_unclip_ratio,
text_rec_score_thresh=args.text_rec_score_thresh,
use_textline_orientation=args.textline_orientation
)
pred = assemble_from_paddle_result(out, args.min_box_score, args.line_tolerance)
time_per_page_list.append(float(time.time() - tp0))
m = evaluate_text(ref, pred)
cer_list.append(m["CER"])
wer_list.append(m["WER"])
metrics = {
"CER": float(np.mean(cer_list) if cer_list else 1.0),
"WER": float(np.mean(wer_list) if wer_list else 1.0),
"TIME": float(time.time() - t0),
"PAGES": int(len(cer_list)),
"TIME_PER_PAGE": float(np.mean(time_per_page_list) if time_per_page_list else float(time.time() - t0)),
}
print(json.dumps(metrics))
if __name__ == "__main__":
main()