import numpy as np import pandas as pd import cv2 from PIL import Image from sklearn.model_selection import train_test_split from torch.utils.data import Dataset import os from torchvision import transforms # 定义图像预处理的函数,参数包括是否为训练模式以及自定义参数(args) def build_transform(train, args): if train: # 如果是训练模式,进行一系列数据增强和归一化处理 transform = transforms.Compose(( transforms.RandomResizedCrop(int(args.img_size / 0.875), scale=(0.8, 1.0)), # 随机裁剪图像 transforms.RandomRotation(7), # 随机旋转图像 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.CenterCrop(args.img_size), # 中心裁剪 transforms.ToTensor(), # 转换为张量 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 使用Imagenet的均值和方差进行归一化 )) else: # 如果是验证或测试模式,只进行裁剪和归一化处理 transform = transforms.Compose(( transforms.Resize(int(args.img_size / 0.875)), # 调整大小 transforms.CenterCrop(args.img_size), # 中心裁剪 transforms.ToTensor(), # 转换为张量 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 归一化 )) return transform # 定义数据集类 photodatatest,继承自 PyTorch 的 Dataset 类 class ChestXray14Dataset(Dataset): def __init__(self, data_root, # 数据集路径 classes=['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'], mode='train', # 模式:'train', 'valid', 或 'test' split='official', # 数据划分方式:'official' 或 'non-official' has_val_set=True, # 是否包含验证集 transform=None): # 图像预处理方法 super().__init__() self.data_root = data_root # 数据集根目录 self.classes = classes # 多标签类别 self.num_classes = len(self.classes) # 类别数 self.mode = mode # 当前数据集模式 # 根据split的类型选择加载方式 if split == 'official': # 使用官方划分方式加载数据 self.dataframe, self.num_patients = self.load_split_file( self.data_root, self.mode, has_val_set) else: # 使用非官方划分方式加载数据 self.dataframe, self.num_patients = self.load_split_file_non_official( self.data_root, self.mode) self.transform = transform # 图像预处理方法 self.num_samples = len(self.dataframe) # 样本数量 # 加载官方划分文件的方法,根据mode加载不同的数据 def load_split_file(self, folder, mode, has_val=True): df = pd.read_csv( os.path.join( 'F:\chexnet\chexnet-master\photodatatest', 'Data_Entry_2017.csv')) # 使用绝对路径 # 如果模式为训练或验证 if mode in ['train', 'valid']: # 加载训练和验证数据文件 file_name = os.path.join(folder, 'train_val_list.txt') with open(file_name, 'r') as f: lines = f.read().splitlines() # 读取所有图像文件名 df_train_val = df[df['Image Index'].isin(lines)] # 过滤出对应的图像 # 如果需要验证集,将患者ID拆分为训练和验证集 if has_val: patient_ids = df_train_val['Patient ID'].unique() # 获取所有患者ID train_ids, val_ids = train_test_split(patient_ids, test_size=1 - 0.7 / 0.8, random_state=0, shuffle=True) target_ids = train_ids if mode == 'train' else val_ids # 根据模式选择训练或验证集的患者ID df = df_train_val[ df_train_val['Patient ID'].isin(target_ids)] # 根据ID过滤数据 else: df = df_train_val elif mode == 'test': # 如果模式为测试,加载测试数据文件 file_name = os.path.join(folder, 'test_list.txt') with open(file_name, 'r') as f: target_files = f.read().splitlines() # 读取测试集文件名 df = df[df['Image Index'].isin(target_files)] # 过滤测试数据 else: raise NotImplementedError(f'Unidentified split: {mode}') # 未识别的模式报错 num_patients = len(df['Patient ID'].unique()) # 统计患者数 return df, num_patients # 非官方划分文件的加载方法,根据比例拆分数据集 def load_split_file_non_official(self, folder, mode): train_rt, val_rt, test_rt = 0.7, 0.1, 0.2 # 定义训练、验证和测试的比例 df = pd.read_csv( os.path.join(folder, 'Data_Entry_2017.csv')) # 加载数据标签文件 patient_ids = df['Patient ID'].unique() # 获取所有患者ID # 先划分出测试集,然后在剩余数据中划分出验证集和训练集 train_val_ids, test_ids = train_test_split(patient_ids, test_size=test_rt, random_state=0, shuffle=True) train_ids, val_ids = train_test_split(train_val_ids, test_size=val_rt / ( train_rt + val_rt), random_state=0, shuffle=True) # 根据模式选择目标ID target_ids = {'train': train_ids, 'valid': val_ids, 'test': test_ids}[ mode] df = df[df['Patient ID'].isin(target_ids)] # 根据ID过滤数据 num_patients = len(target_ids) # 统计患者数 return df, num_patients # 将疾病标签转换为多标签编码 def encode_label(self, label): encoded_label = np.zeros(self.num_classes, dtype=np.float32) # 初始化全0的标签数组 if label != 'No Finding': # 如果标签不为"No Finding" for l in label.split('|'): # 对每个疾病标签进行处理 encoded_label[self.classes.index(l)] = 1 # 将对应疾病的索引位置置1 return encoded_label # 图像预处理函数,调整图像尺寸 def pre_process(self, img): h, w = img.shape img = cv2.resize(img, dsize=(max(h, w), max(h, w))) # 将图像调整为正方形 return img # 统计每个类别的样本数量 def count_class_dist(self): class_counts = np.zeros(self.num_classes) # 初始化类别计数 for index, row in self.dataframe.iterrows(): # 遍历数据集的每一行 class_counts += self.encode_label( row['Finding Labels']) # 将标签编码加到计数器中 return self.num_samples, class_counts # 返回数据集的样本数量 def __len__(self): return self.num_samples # 获取指定索引的数据 def __getitem__(self, idx): row = self.dataframe.iloc[idx] # 获取对应行的数据 img_file, label = row['Image Index'], row[ 'Finding Labels'] # 获取图像文件名和标签 img = cv2.imread( os.path.join(self.data_root, 'images', img_file)) # 读取图像 img = Image.fromarray(img) # 转换为PIL图像 if self.transform is not None: img = self.transform(img) # 应用预处理 label = self.encode_label(label) # 编码标签 return img, label # 返回图像和标签