import os import time import torch.multiprocessing as mp import matplotlib.pyplot as plt import numpy as np from ChexnetTrainer import ChexnetTrainer from DatasetGenerator import DatasetGenerator from visual import HeatmapGenerator # 确保 images 目录存在 if not os.path.exists('images'): os.makedirs('images') # 主函数,负责启动不同的功能(训练、测试或运行演示) def main(): # runDemo() # 运行演示模式 # runTest() # 测试模式(注释掉,可以通过解除注释来运行) runTrain() # 训练模式(注释掉,可以通过解除注释来运行) # -------------------------------------------------------------------------------- # 训练函数,定义训练所需的参数并启动模型训练 def runTrain(): DENSENET121 = 'DENSE-NET-121' # 定义DenseNet121模型名称 DENSENET169 = 'DENSE-NET-169' # 定义DenseNet169模型名称 DENSENET201 = 'DENSE-NET-201' # 定义DenseNet201模型名称 Resnet50 = 'RESNET-50' # 定义Resnet50模型名称 # 获取当前的时间戳,作为训练过程的标记 timestampTime = time.strftime("%H%M%S") timestampDate = time.strftime("%d%m%Y") timestampLaunch = timestampDate + '-' + timestampTime print("Launching " + timestampLaunch) # 图像数据所在的路径 pathDirData = './chest xray14' # 训练、验证和测试数据集文件路径 # 每个文件中包含图像路径及其对应的标签 pathFileTrain = './dataset/train_2.txt' pathFileVal = './dataset/valid_2.txt' pathFileTest = './dataset/test_2.txt' # 神经网络参数:模型架构、是否加载预训练模型、分类的类别数量 nnArchitecture = DENSENET121 nnIsTrained = True # 使用预训练的权重 nnClassCount = 14 # 数据集包含14个分类 # 训练参数:批量大小和最大迭代次数(epochs) trBatchSize = 2 trMaxEpoch = 2 # 图像预处理相关参数:图像缩放的大小和裁剪后的大小 imgtransResize = 256 imgtransCrop = 224 # 保存模型的路径,包含时间戳 pathModel = 'm-' + timestampLaunch + '.pth.tar' print('Training NN architecture = ', nnArchitecture) ChexnetTrainer.train(pathDirData, pathFileTrain, pathFileVal, nnArchitecture, nnIsTrained, nnClassCount, trBatchSize, trMaxEpoch, imgtransResize, imgtransCrop, timestampLaunch, None) print('Testing the trained model') pathRfModel = os.path.join('images', 'random_forest_model.pkl') # 更新路径 labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture, nnClassCount, nnIsTrained, trBatchSize, imgtransResize, imgtransCrop, timestampLaunch) # 生成并保存热力图 # 选择一些测试集中的样本来生成热力图 transformSequence = ChexnetTrainer._get_transform(imgtransCrop) datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest, transform=transformSequence, model=None) # 确保测试集中有足够的样本 num_samples = min(8, len(datasetTest)) sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False) # 随机选择8个样本 # 定义分类名称 CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'] # 创建一个图形窗口 plt.figure(figsize=(20, 10)) n_cols = 4 n_rows = 2 for idx, sample_idx in enumerate(sample_indices): image_path = datasetTest.listImagePaths[sample_idx] true_labels = labels[sample_idx] pred_labels = rf_preds[sample_idx] # 加载图像 img = plt.imread(image_path) # 创建子图 ax = plt.subplot(n_rows, n_cols, idx + 1) ax.imshow(img) ax.axis('off') # 获取真实标签和预测标签的名称 true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1] pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1] # 设置标题 title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}" ax.set_title(title, fontsize=10) plt.tight_layout() output_plot_path = os.path.join('images', 'test_predictions.png') plt.savefig(output_plot_path) plt.show() print(f"预测结果图已保存到 {output_plot_path}") # 生成热力图(可选) for idx in sample_indices: image_path = datasetTest.listImagePaths[idx] output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png') h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence) h.generate(image_path, output_heatmap_path, imgtransCrop) print(f"热力图已保存到 {output_heatmap_path}") # -------------------------------------------------------------------------------- # 测试函数,加载预训练模型并在测试数据集上进行测试 def runTest(): pathDirData = '/chest xray14' pathFileTest = './dataset/test.txt' nnArchitecture = 'DENSE-NET-121' nnIsTrained = True nnClassCount = 14 trBatchSize = 4 imgtransResize = 256 imgtransCrop = 224 pathModel = 'm-06102024-235412BCELoss()delete.pth.tar' timestampLaunch = '' # 获取统一的 transformSequence transformSequence = ChexnetTrainer._get_transform(imgtransCrop) pathRfModel = 'images/random_forest_model.pkl' # 确保路径正确 labels, rf_preds = ChexnetTrainer.test(pathDirData, pathFileTest, pathModel, pathRfModel, nnArchitecture, nnClassCount, nnIsTrained, trBatchSize, imgtransResize, imgtransCrop, timestampLaunch) # 生成并保存热力图 datasetTest = DatasetGenerator(pathImageDirectory=pathDirData, pathDatasetFile=pathFileTest, transform=transformSequence, model=None) # 确保测试集中有足够的样本 num_samples = min(8, len(datasetTest)) sample_indices = np.random.choice(len(datasetTest), size=num_samples, replace=False) # 随机选择8个样本 # 定义分类名称 CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'] # 创建一个图形窗口 plt.figure(figsize=(20, 10)) n_cols = 4 n_rows = 2 for idx, sample_idx in enumerate(sample_indices): image_path = datasetTest.listImagePaths[sample_idx] true_labels = labels[sample_idx] pred_labels = rf_preds[sample_idx] # 加载图像 img = plt.imread(image_path) # 创建子图 ax = plt.subplot(n_rows, n_cols, idx + 1) ax.imshow(img) ax.axis('off') # 获取真实标签和预测标签的名称 true_label_names = [CLASS_NAMES[i] for i in range(len(true_labels)) if true_labels[i] == 1] pred_label_names = [CLASS_NAMES[i] for i in range(len(pred_labels)) if pred_labels[i] == 1] # 设置标题 title = f"Predicted: {', '.join(pred_label_names)}\nTrue: {', '.join(true_label_names)}" ax.set_title(title, fontsize=10) plt.tight_layout() output_plot_path = os.path.join('images', 'test_predictions.png') plt.savefig(output_plot_path) plt.show() print(f"预测结果图已保存到 {output_plot_path}") # 生成热力图(可选) for idx in sample_indices: image_path = datasetTest.listImagePaths[idx] output_heatmap_path = os.path.join('images', f'heatmap_test_{idx}.png') h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, imgtransCrop, transformSequence) h.generate(image_path, output_heatmap_path, imgtransCrop) print(f"热力图已保存到 {output_heatmap_path}") # -------------------------------------------------------------------------------- # 演示函数,展示模型在测试集上的推理过程 def runDemo(): # 原有代码保持不变 pass # 确保代码在主进程中运行 if __name__ == '__main__': mp.set_start_method('spawn', force=True) main() # 启动主函数