# CBAM 模块 import torch import torch.nn as nn # 通道注意力模块,用于学习输入特征在通道维度上的重要性 class ChannelAttention(nn.Module): def __init__(self, in_channels, reduction=16): super(ChannelAttention, self).__init__() # in_channels:输入的通道数,表示特征图有多少个通道 # reduction:缩减率,用于减少通道的维度,通常设置为16,表示将通道数缩小 16 倍后再扩展回来 self.avg_pool = nn.AdaptiveAvgPool2d(1) # 自适应平均池化到 (1, 1),用于生成全局通道信息 self.max_pool = nn.AdaptiveMaxPool2d(1) # 自适应最大池化到 (1, 1),与平均池化结合使用 # 全连接层使用1x1卷积替代,全连接层的作用是通过线性变换来学习不同通道的重要性 # Conv2d(in_channels, in_channels // reduction, 1, bias=False): # 将输入通道数减小到原来的 1/reduction(即 1/16),用来降低计算复杂度 self.fc = nn.Sequential( nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False), nn.ReLU(inplace=True), # 激活函数,用于增加网络的非线性特性 nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False) # 将通道数还原为原始大小 ) self.sigmoid = nn.Sigmoid() # Sigmoid 激活函数,用于将输出压缩到 (0, 1) 之间 def forward(self, x): # 使用平均池化和最大池化得到两个特征图 avg_out = self.fc(self.avg_pool(x)) # 对平均池化后的特征图应用全连接层 max_out = self.fc(self.max_pool(x)) # 对最大池化后的特征图应用全连接层 out = avg_out + max_out # 将两个特征图相加,融合两种不同池化方式的信息 return self.sigmoid(out) * x # 使用 sigmoid 将结果压缩到 (0, 1) 之间,并乘以输入特征图得到加权后的输出 # 空间注意力模块,用于学习输入特征在空间维度(H 和 W)上的重要性 class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__() # kernel_size:卷积核大小,通常为 3 或 7,用于控制注意力机制的感受野 assert kernel_size in (3, 7), 'kernel size must be 3 or 7' # 检查 kernel_size 的合法性 padding = (kernel_size - 1) // 2 # 计算 padding 大小,以保持卷积前后特征图的大小一致 self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) # 卷积层,输入通道为2,输出通道为1 self.sigmoid = nn.Sigmoid() # Sigmoid 激活函数,用于将输出压缩到 (0, 1) 之间 def forward(self, x): # 平均池化和最大池化在通道维度上进行,得到两个单通道特征图 avg_out = torch.mean(x, dim=1, keepdim=True) # 对输入的特征图在通道维度取平均值 max_out, _ = torch.max(x, dim=1, keepdim=True) # 对输入的特征图在通道维度取最大值 # 将平均池化和最大池化的结果拼接在一起,得到形状为 (batch_size, 2, H, W) 的张量 mask = torch.cat([avg_out, max_out], dim=1) # 通过卷积层生成空间注意力权重 mask = self.conv(mask) mask = self.sigmoid(mask) # 使用 sigmoid 将结果压缩到 (0, 1) 之间 return mask * x # 使用注意力权重与输入特征图相乘,得到加权后的输出 # CBAM 模块,结合通道注意力和空间注意力 class CBAM(nn.Module): def __init__(self, in_channels, reduction=16, kernel_size=7): super(CBAM, self).__init__() # 通道注意力模块,首先对通道维度进行加权 self.channel_attention = ChannelAttention(in_channels, reduction) # 空间注意力模块,然后对空间维度进行加权 self.spatial_attention = SpatialAttention(kernel_size) def forward(self, x): x = self.channel_attention(x) # 先通过通道注意力模块 x = self.spatial_attention(x) # 再通过空间注意力模块 return x # 返回经过 CBAM 处理的特征图