45 lines
1.3 KiB
Python
45 lines
1.3 KiB
Python
# Imports
|
|
import os
|
|
from PIL import Image
|
|
|
|
|
|
class ImageTextDataset:
|
|
def __init__(self, root):
|
|
self.samples = []
|
|
|
|
for folder in sorted(os.listdir(root)):
|
|
sub = os.path.join(root, folder)
|
|
img_dir = os.path.join(sub, "img")
|
|
txt_dir = os.path.join(sub, "txt")
|
|
|
|
if not (os.path.isdir(img_dir) and os.path.isdir(txt_dir)):
|
|
continue
|
|
|
|
for fname in sorted(os.listdir(img_dir)):
|
|
if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
|
|
continue
|
|
|
|
img_path = os.path.join(img_dir, fname)
|
|
|
|
# text file must have same name but .txt
|
|
txt_name = os.path.splitext(fname)[0] + ".txt"
|
|
txt_path = os.path.join(txt_dir, txt_name)
|
|
|
|
if not os.path.exists(txt_path):
|
|
continue
|
|
|
|
self.samples.append((img_path, txt_path))
|
|
def __len__(self):
|
|
return len(self.samples)
|
|
|
|
def __getitem__(self, idx):
|
|
img_path, txt_path = self.samples[idx]
|
|
|
|
# Load image
|
|
image = Image.open(img_path).convert("RGB")
|
|
|
|
# Load text
|
|
with open(txt_path, "r", encoding="utf-8") as f:
|
|
text = f.read()
|
|
|
|
return image, text |