import os import numpy as np from PIL import Image import torch from torch.utils.data import Dataset # -------------------------------------------------------------------------------- # 定义一个数据集类 DatasetGenerator,继承自 PyTorch 的 Dataset 类 class DatasetGenerator(Dataset): # -------------------------------------------------------------------------------- # 初始化函数,传入图像目录路径、数据集文件路径和图像预处理变换 def __init__(self, pathImageDirectory, pathDatasetFile, transform): self.listImagePaths = [] # 用于存储图像路径的列表 self.listImageLabels = [] # 用于存储标签的列表 self.transform = transform # 图像的预处理方法 # ---- 打开文件,获取图像路径和标签 with open(pathDatasetFile, "r") as fileDescriptor: for line in fileDescriptor: lineItems = line.strip().split() imagePath = os.path.join(pathImageDirectory, lineItems[0]) # 获取图像文件的完整路径 imageLabel = lineItems[1:] # 获取对应的标签(位于行的第二部分) # 将标签转换为整数列表,并确保每个标签是整数(有可能是浮点数的情况需要处理) imageLabel = [int(float(i)) for i in imageLabel] # 如果标签数组中至少有一个值为1(即图片至少有一个分类标签) if np.array(imageLabel).sum() >= 1: self.listImagePaths.append(imagePath) # 将图像路径加入列表 self.listImageLabels.append(imageLabel) # 将图像标签加入列表 # -------------------------------------------------------------------------------- # 获取数据集中特定索引的图像及其标签 def __getitem__(self, index): # 根据索引获取图像路径 imagePath = self.listImagePaths[index] # 打开图像文件,并将图像转换为 RGB 模式 imageData = Image.open(imagePath).convert('RGB') # 将对应的标签转换为 PyTorch 的 FloatTensor(浮点数张量) imageLabel = torch.FloatTensor(self.listImageLabels[index]) # 如果有图像预处理操作,应用预处理 if self.transform is not None: imageData = self.transform(imageData) # 返回图像数据和其对应的标签 return imageData, imageLabel # -------------------------------------------------------------------------------- # 返回数据集的样本数量 def __len__(self): return len(self.listImagePaths)