用PyTorch复现PraNet息肉分割模型:从Res2Net骨干到反向注意力模块的保姆级代码解读
用PyTorch复现PraNet息肉分割模型从Res2Net骨干到反向注意力模块的保姆级代码解读医学图像分割一直是计算机视觉领域的重要研究方向尤其在结肠镜息肉检测中准确的分割结果直接关系到早期癌症的诊断效率。传统的U-Net及其变体虽然在该任务上表现尚可但对于息肉边界的模糊区域往往力不从心。今天我们将深入剖析PraNet这一创新架构通过PyTorch代码逐行还原其核心模块的实现细节。1. 环境配置与数据准备在开始编码前我们需要搭建合适的开发环境。推荐使用Python 3.8和PyTorch 1.9版本这对后续的混合精度训练和自定义算子支持更为友好。以下是关键依赖的安装命令conda create -n pranet python3.8 conda install pytorch1.9.0 torchvision0.10.0 cudatoolkit11.1 -c pytorch pip install opencv-python nibabel tensorboardX tqdm对于数据集处理常见的息肉分割数据集如Kvasir-SEG和CVC-ClinicDB需要特殊预处理class PolypDataset(Dataset): def __init__(self, img_paths, mask_paths, transformNone): self.img_paths sorted(glob.glob(img_paths /*)) self.mask_paths sorted(glob.glob(mask_paths /*)) self.transform transform def __getitem__(self, idx): img cv2.imread(self.img_paths[idx]) img cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask cv2.imread(self.mask_paths[idx], 0) if self.transform: augmented self.transform(imageimg, maskmask) img augmented[image] mask augmented[mask] img img.transpose(2, 0, 1).astype(float32) / 255.0 mask mask.astype(float32) / 255.0 return torch.tensor(img), torch.tensor(mask)注意实际应用中建议采用多尺度随机裁剪MSRC而非简单的随机缩放这能更好地模拟临床检查时镜头与组织的距离变化。2. Res2Net骨干网络实现PraNet选择Res2Net作为特征提取主干因其独特的层级残差连接能更好捕捉多尺度特征。与普通ResNet相比Res2Net在单个残差块内引入了更细粒度的分层class Bottle2Neck(nn.Module): expansion 4 def __init__(self, inplanes, planes, stride1, scales4, base_width26): super().__init__() width int(math.floor(planes * (base_width/64.0))) self.conv1 nn.Conv2d(inplanes, width*scales, 1, biasFalse) self.bn1 nn.BatchNorm2d(width*scales) if scales 1: self.nums 1 else: self.nums scales - 1 convs [] bns [] for i in range(self.nums): convs.append(nn.Conv2d(width, width, 3, stridestride, padding1, biasFalse)) bns.append(nn.BatchNorm2d(width)) self.convs nn.ModuleList(convs) self.bns nn.ModuleList(bns) self.conv3 nn.Conv2d(width*scales, planes*self.expansion, 1, biasFalse) self.bn3 nn.BatchNorm2d(planes*self.expansion) self.relu nn.ReLU(inplaceTrue) self.stride stride self.scales scales self.width width关键参数说明参数作用推荐值scales特征分组数4-8base_width基础通道数26stride下采样步长1或2在实际训练中我们可以冻结浅层参数以加速收敛def freeze_layers(model): for name, param in model.named_parameters(): if layer in name and not layer4 in name: param.requires_grad False3. 并行部分解码器(PPD)设计PPD模块负责聚合高层特征(3-5层)其核心在于跨层特征的动态融合而非简单拼接。实现时需要特别注意特征图的尺寸对齐class PPD(nn.Module): def __init__(self, in_channels): super().__init__() self.conv1 nn.Sequential( nn.Conv2d(in_channels[0], 256, 1), nn.BatchNorm2d(256), nn.ReLU(inplaceTrue) ) self.conv2 nn.Sequential( nn.Conv2d(in_channels[1], 256, 1), nn.BatchNorm2d(256), nn.ReLU(inplaceTrue) ) self.conv3 nn.Sequential( nn.Conv2d(in_channels[2], 256, 1), nn.BatchNorm2d(256), nn.ReLU(inplaceTrue) ) self.fuse nn.Sequential( nn.Conv2d(768, 256, 3, padding1), nn.BatchNorm2d(256), nn.ReLU(inplaceTrue) ) def forward(self, f3, f4, f5): f3 self.conv1(f3) f4 self.conv2(F.interpolate(f4, scale_factor2, modebilinear)) f5 self.conv3(F.interpolate(f5, scale_factor4, modebilinear)) return self.fuse(torch.cat([f3, f4, f5], dim1))特征融合时的常见问题及解决方案边缘效应在插值上采样时添加反射填充(reflection padding)特征冲突引入SE注意力机制动态调整通道权重信息丢失保留原始特征图的跳跃连接4. 反向注意力(RA)模块实现RA模块是PraNet最具创新性的设计它通过擦除当前预测区域来强化边界学习。其数学表达可简化为$$ R_i f_i \odot (1 - \sigma(P(S_{i1}))) $$其中$\odot$表示逐元素乘法$P$为上采样操作。对应的PyTorch实现class RA(nn.Module): def __init__(self, in_channel): super().__init__() self.convert nn.Conv2d(in_channel, 256, 1) self.down nn.Sequential( nn.Conv2d(256, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplaceTrue) ) self.flow_make nn.Conv2d(64*2, 2, 3, padding1) def forward(self, x, S): x self.convert(x) S F.interpolate(S, scale_factor2, modebilinear) feat self.down(x) flow self.flow_make(torch.cat([feat, S], dim1)) S_ F.grid_sample(S, flow.permute(0,2,3,1), padding_modeborder) return x * (1 - torch.sigmoid(S_))训练技巧初始几个epoch可禁用RA模块待PPD输出稳定后再启用对RA输出添加边缘敏感约束edge_loss F.mse_loss(sobel(RA_output), sobel(gt_mask))5. 损失函数与训练策略PraNet采用加权IoU和BCE的复合损失我们需要自定义这两个损失的计算方式class WeightedBCELoss(nn.Module): def __init__(self, pos_weight1.2): super().__init__() self.pos_weight pos_weight def forward(self, pred, target): pred torch.clamp(pred, 1e-7, 1-1e-7) loss - (self.pos_weight * target * torch.log(pred) (1-target) * torch.log(1-pred)) return loss.mean() class WeightedIoULoss(nn.Module): def forward(self, pred, target): inter (pred * target).sum(dim(1,2,3)) union (pred target - pred * target).sum(dim(1,2,3)) iou (inter 1e-7) / (union 1e-7) return 1 - iou.mean()多尺度训练的实现关键点def multi_scale_transform(img, mask, scales[0.75, 1.0, 1.25]): scale random.choice(scales) h, w int(img.shape[0]*scale), int(img.shape[1]*scale) img cv2.resize(img, (w, h)) mask cv2.resize(mask, (w, h)) return img, mask优化器配置建议采用分层学习率策略param_groups [ {params: backbone.parameters(), lr: base_lr*0.1}, {params: ppd.parameters(), lr: base_lr}, {params: ra_modules.parameters(), lr: base_lr*1.2} ] optimizer torch.optim.AdamW(param_groups, weight_decay1e-4)6. 模型集成与推理优化在实际部署时我们可以通过以下技巧提升模型性能**测试时增强(TTA)**实现def tta_inference(model, img, scales[0.75, 1.0, 1.25]): preds [] for scale in scales: h, w int(img.shape[1]*scale), int(img.shape[2]*scale) scaled_img F.interpolate(img, size(h,w), modebilinear) pred model(scaled_img) pred F.interpolate(pred, sizeimg.shape[-2:], modebilinear) preds.append(pred) return torch.mean(torch.stack(preds), dim0)模型量化示例quantized_model torch.quantization.quantize_dynamic( model, {nn.Conv2d, nn.Linear}, dtypetorch.qint8 )最后给出完整的模型整合代码框架class PraNet(nn.Module): def __init__(self, backboneres2net50): super().__init__() self.backbone build_backbone(backbone) self.ppd PPD([512, 1024, 2048]) self.ra5 RA(2048) self.ra4 RA(1024) self.ra3 RA(512) self.predictor nn.Conv2d(256, 1, 1) def forward(self, x): f1, f2, f3, f4, f5 self.backbone(x) Sg self.ppd(f3, f4, f5) S5 self.ra5(f5, Sg) F.interpolate(Sg, scale_factor0.25) S4 self.ra4(f4, S5) F.interpolate(S5, scale_factor2) S3 self.ra3(f3, S4) F.interpolate(S4, scale_factor2) pred torch.sigmoid(self.predictor(S3)) return pred