图像修复、超分、ViT都离不开它:深入浅出图解PyTorch Fold/Unfold的5个实战场景
图像修复、超分、ViT都离不开它深入浅出图解PyTorch Fold/Unfold的5个实战场景在计算机视觉领域PyTorch的Fold和Unfold操作就像瑞士军刀中的万能工具虽然低调却能在关键时刻解决复杂问题。想象一下当你需要处理图像块匹配、实现自定义卷积或构建Transformer输入时这两个操作能让你从繁琐的循环代码中解放出来。本文将带您深入五个实际场景看看这些张量乐高积木如何优雅地拼接出视觉任务的解决方案。1. 非局部均值去噪块匹配的艺术非局部均值去噪的核心思想是利用图像中相似块的加权平均来消除噪声。传统实现需要嵌套循环遍历每个像素和其邻域而Unfold让这一切变得高效import torch import torch.nn as nn def non_local_denoise(image, patch_size3, search_window7): # image: (1, C, H, W) unfold nn.Unfold(kernel_sizepatch_size, paddingpatch_size//2) patches unfold(image) # (1, C*p*p, H*W) # 计算块间相似度简化版 patches_norm patches / (patches.norm(dim1, keepdimTrue) 1e-6) similarity torch.matmul(patches_norm.transpose(1,2), patches_norm) # 加权平均 denoised torch.matmul(similarity.softmax(dim-1), patches.transpose(1,2)) # 还原图像 fold nn.Fold(output_sizeimage.shape[2:], kernel_sizepatch_size, paddingpatch_size//2) return fold(denoised.transpose(1,2))关键优势并行计算所有图像块相似度避免Python循环带来的性能损失保持与卷积操作一致的接口风格2. 超分辨率重建子像素卷积的逆过程在ESPCN等超分网络中PixelShuffle通过周期洗牌操作实现上采样。而Unfold可以看作是其逆向操作将高分辨率图像分解为低分辨率块def prepare_hr_patches(hr_image, scale_factor2): 为生成对抗训练准备HR图像块 b, c, h, w hr_image.shape unfold nn.Unfold(kernel_sizescale_factor, stridescale_factor) patches unfold(hr_image) # (b, c*4, h*w/4) return patches.view(b, -1, h//scale_factor, w//scale_factor) # 与PixelShuffle的对应关系 hr_image torch.randn(1, 3, 32, 32) lr_patches prepare_hr_patches(hr_image) ps nn.PixelShuffle(2) reconstructed ps(lr_patches) # 近似原始HR图像应用场景对比操作类型输入维度输出维度典型用途Unfold(b,c,h,w)(b,ckk,hw/(kk))特征块提取PixelShuffle(b,crr,h,w)(b,c,hr,wr)亚像素上采样3. Vision Transformer图像分块的工程实现ViT将图像划分为16x16的块作为Transformer的输入序列。使用Unfold可以高效实现这一过程class PatchEmbedding(nn.Module): def __init__(self, img_size224, patch_size16, in_chans3, embed_dim768): super().__init__() self.proj nn.Sequential( nn.Unfold(kernel_sizepatch_size, stridepatch_size), nn.Linear(in_chans * patch_size**2, embed_dim) ) self.num_patches (img_size // patch_size) ** 2 def forward(self, x): x self.proj(x) # (b, embed_dim, num_patches) return x.transpose(1, 2) # (b, num_patches, embed_dim)与传统方法的对比手动切片需要复杂的reshape和permute操作Unfold实现自动处理边缘情况padding支持dilation参数与卷积参数完全兼容4. 自定义卷积操作超越标准卷积核当需要实现空洞卷积或可变形卷积时Unfold手动偏移Fold的组合提供了灵活的实现方案def deformable_conv2d(x, offset, kernel_size3): x: input tensor (b,c,h,w) offset: 偏移量 (b,2*k*k,h,w) b, c, h, w x.shape # 生成采样网格 grid create_base_grid(h, w) offset # 双线性采样 sampled F.grid_sample(x, grid) # 展开为块表示 unfold nn.Unfold(kernel_sizekernel_size) patches unfold(sampled) # 自定义卷积核处理 output apply_custom_kernel(patches) # 还原空间结构 fold nn.Fold(output_size(h,w), kernel_sizekernel_size) return fold(output)性能优化技巧使用grid_sample实现亚像素级偏移通过Unfold保持内存访问局部性自定义核函数可替换为任意逐块操作5. 数据增强网格化图像重组超越传统的裁剪翻转Fold/Unfold能实现创新的数据增强方式class GridShuffleAugment: def __init__(self, grid_size4): self.unfold nn.Unfold(kernel_sizegrid_size, stridegrid_size) self.fold nn.Fold(output_size(224,224), kernel_sizegrid_size, stridegrid_size) def __call__(self, x): # x: (c,h,w) patches self.unfold(x.unsqueeze(0)) # (1, c*g*g, n) patches patches[:, :, torch.randperm(patches.size(2))] return self.fold(patches).squeeze(0) # 效果示例 augmenter GridShuffleAugment() aug_img augmenter(original_img) # 创建拼贴画风格图像增强类型对比增强方式实现复杂度效果特点GPU友好度常规裁剪低局部视角高网格重组中结构变异极高混合块高语义混合中

相关新闻

最新新闻

日新闻

周新闻

月新闻