lock model
This commit is contained in:
@@ -72,17 +72,7 @@
|
||||
"id": "imports",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"import requests\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"import ray\n",
|
||||
"from ray import tune, air\n",
|
||||
"from ray.tune.search.optuna import OptunaSearch"
|
||||
]
|
||||
"source": "import os\nfrom datetime import datetime\n\nimport requests\nimport pandas as pd\n\nimport ray\nfrom ray import tune, train\nfrom ray.tune.search.optuna import OptunaSearch"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -188,7 +178,7 @@
|
||||
"id": "trainable",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": "def trainable_paddle_ocr(config):\n \"\"\"Call PaddleOCR REST API with the given hyperparameter config.\"\"\"\n import random\n import requests\n from ray import tune\n\n # Worker URLs - random selection (load balances with 2 workers, 2 concurrent trials)\n WORKER_PORTS = [8001, 8002]\n api_url = f\"http://localhost:{random.choice(WORKER_PORTS)}\"\n\n payload = {\n \"pdf_folder\": \"/app/dataset\",\n \"use_doc_orientation_classify\": config.get(\"use_doc_orientation_classify\", False),\n \"use_doc_unwarping\": config.get(\"use_doc_unwarping\", False),\n \"textline_orientation\": config.get(\"textline_orientation\", True),\n \"text_det_thresh\": config.get(\"text_det_thresh\", 0.0),\n \"text_det_box_thresh\": config.get(\"text_det_box_thresh\", 0.0),\n \"text_det_unclip_ratio\": config.get(\"text_det_unclip_ratio\", 1.5),\n \"text_rec_score_thresh\": config.get(\"text_rec_score_thresh\", 0.0),\n \"start_page\": 5,\n \"end_page\": 10,\n }\n\n try:\n response = requests.post(f\"{api_url}/evaluate\", json=payload, timeout=None)\n response.raise_for_status()\n metrics = response.json()\n metrics[\"worker\"] = api_url\n tune.report(**metrics)\n except Exception as e:\n tune.report(\n CER=1.0,\n WER=1.0,\n TIME=0.0,\n PAGES=0,\n TIME_PER_PAGE=0,\n worker=api_url,\n ERROR=str(e)[:500]\n )"
|
||||
"source": "def trainable_paddle_ocr(config):\n \"\"\"Call PaddleOCR REST API with the given hyperparameter config.\"\"\"\n import random\n import requests\n from ray import train\n\n # Worker URLs - random selection (load balances with 2 workers, 2 concurrent trials)\n WORKER_PORTS = [8001, 8002]\n api_url = f\"http://localhost:{random.choice(WORKER_PORTS)}\"\n\n payload = {\n \"pdf_folder\": \"/app/dataset\",\n \"use_doc_orientation_classify\": config.get(\"use_doc_orientation_classify\", False),\n \"use_doc_unwarping\": config.get(\"use_doc_unwarping\", False),\n \"textline_orientation\": config.get(\"textline_orientation\", True),\n \"text_det_thresh\": config.get(\"text_det_thresh\", 0.0),\n \"text_det_box_thresh\": config.get(\"text_det_box_thresh\", 0.0),\n \"text_det_unclip_ratio\": config.get(\"text_det_unclip_ratio\", 1.5),\n \"text_rec_score_thresh\": config.get(\"text_rec_score_thresh\", 0.0),\n \"start_page\": 5,\n \"end_page\": 10,\n }\n\n try:\n response = requests.post(f\"{api_url}/evaluate\", json=payload, timeout=None)\n response.raise_for_status()\n metrics = response.json()\n metrics[\"worker\"] = api_url\n train.report(metrics)\n except Exception as e:\n train.report({\n \"CER\": 1.0,\n \"WER\": 1.0,\n \"TIME\": 0.0,\n \"PAGES\": 0,\n \"TIME_PER_PAGE\": 0,\n \"worker\": api_url,\n \"ERROR\": str(e)[:500]\n })"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@@ -215,22 +205,7 @@
|
||||
"id": "tuner",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tuner = tune.Tuner(\n",
|
||||
" trainable_paddle_ocr,\n",
|
||||
" tune_config=tune.TuneConfig(\n",
|
||||
" metric=\"CER\",\n",
|
||||
" mode=\"min\",\n",
|
||||
" search_alg=OptunaSearch(),\n",
|
||||
" num_samples=64,\n",
|
||||
" max_concurrent_trials=NUM_WORKERS, # Run trials in parallel across workers\n",
|
||||
" ),\n",
|
||||
" run_config=air.RunConfig(verbose=2, log_to_file=False),\n",
|
||||
" param_space=search_space,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"results = tuner.fit()"
|
||||
]
|
||||
"source": "tuner = tune.Tuner(\n trainable_paddle_ocr,\n tune_config=tune.TuneConfig(\n metric=\"CER\",\n mode=\"min\",\n search_alg=OptunaSearch(),\n num_samples=64,\n max_concurrent_trials=NUM_WORKERS, # Run trials in parallel across workers\n ),\n param_space=search_space,\n)\n\nresults = tuner.fit()"
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
|
||||
Reference in New Issue
Block a user