Hyper param serach results
This commit is contained in:
45
src/dataset_manager.py
Normal file
45
src/dataset_manager.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Imports
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ImageTextDataset:
|
||||
def __init__(self, root):
|
||||
self.samples = []
|
||||
|
||||
for folder in sorted(os.listdir(root)):
|
||||
sub = os.path.join(root, folder)
|
||||
img_dir = os.path.join(sub, "img")
|
||||
txt_dir = os.path.join(sub, "txt")
|
||||
|
||||
if not (os.path.isdir(img_dir) and os.path.isdir(txt_dir)):
|
||||
continue
|
||||
|
||||
for fname in sorted(os.listdir(img_dir)):
|
||||
if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
|
||||
continue
|
||||
|
||||
img_path = os.path.join(img_dir, fname)
|
||||
|
||||
# text file must have same name but .txt
|
||||
txt_name = os.path.splitext(fname)[0] + ".txt"
|
||||
txt_path = os.path.join(txt_dir, txt_name)
|
||||
|
||||
if not os.path.exists(txt_path):
|
||||
continue
|
||||
|
||||
self.samples.append((img_path, txt_path))
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path, txt_path = self.samples[idx]
|
||||
|
||||
# Load image
|
||||
image = Image.open(img_path).convert("RGB")
|
||||
|
||||
# Load text
|
||||
with open(txt_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
return image, text
|
||||
Reference in New Issue
Block a user