import os import numpy as np from PIL import Image import torch from torch.utils.data import Dataset class DatasetGenerator(Dataset): def __init__(self, pathImageDirectory, pathDatasetFile, transform, model=None, nnClassCount=14): self.listImagePaths = [] self.listImageLabels = [] self.transform = transform self.model = model # 检查设备 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.model: self.model.to(self.device) # 读取数据集文件 with open(pathDatasetFile, "r") as fileDescriptor: lines = fileDescriptor.readlines() # 遍历文件,筛选有效路径 for line in lines: lineItems = line.split() # 使用 os.path.join 来确保路径的正确拼接 imagePath = os.path.normpath(os.path.join(pathImageDirectory, lineItems[0])) # 检查路径是否存在,并且是一个有效文件 if not os.path.isfile(imagePath): print(f"Warning: Path {imagePath} does not exist or is not a file, skipping this file.") continue # 跳过不存在的文件 imageLabel = [int(float(i)) for i in lineItems[1:]] if np.array(imageLabel).sum() >= 1: # 确保至少有一个标签为正 self.listImagePaths.append(imagePath) self.listImageLabels.append(imageLabel) # 如果没有有效样本,抛出异常 if len(self.listImagePaths) == 0: raise ValueError("No valid samples found. Please check your dataset file and image paths.") def __getitem__(self, index): imageData = Image.open(self.listImagePaths[index]).convert('RGB') imageData = self.transform(imageData) return imageData, torch.FloatTensor(self.listImageLabels[index]) def __len__(self): return len(self.listImagePaths) def get_features_and_labels(self): if not self.features: raise ValueError("No features extracted. Ensure `model` is provided during initialization.") return self.features, self.labels