在 pytorch 中使用鱼眼图像增强数据

在 pytorch 中使用鱼眼图像增强数据 不会编程的老王 2023-12-20 17:38:49 247

引言

近年来,深度学习在各个方面取得了巨大的成就,但是深度学习模型非常依赖数据,并且需要大量标签数据进行训练。除了手动标记更多的数据,还可以使用数据增强来自动为模型生成新的标记数据,并扩展现有的训练数据集。一些流行的图像数据增强是翻转、裁剪、旋转、缩放、剪切、颜色通道失真、模糊等。在本文中,我们将介绍两种在深度学习中不太流行,但对扩展数据集非常有用的数据增强方法。

旋转、剪切、缩放等增量变换都是对图像进行线性变换的仿射变换。与线性变换相比,我在这篇文章中介绍的以下两个变换是非线性的。

  • 鱼眼变换
  • 水平波变换

鱼眼变换

鱼眼变换是一种非线性变换,对于给定的中心像素,基于与给定中心像素的距离,使图像中的像素发生畸变。实际上,靠近中心的像素比远离中心的像素受到的失真要少得多。鱼眼变换采用中心和畸变因子两个参数。中心定义了转换的中心,畸变因子控制用于转换的畸变量。

上图显示了鱼眼变换对棋盘图像的影响,中心聚焦在图像的中点附近。

数学上,给定像素(x,y)的鱼眼变换函数由以下公式给出。

其中  <c_x,c_y> 代表转换的中心,“d”代表失真因子。<t(x),t(y)> 是像素 <x,y> 的转换值
请注意,对于所有的输入图像像素位置都被标准化为一个网格,左上角的像素代表位置 <-1,-1> ,右下角的像素代表位置 <1,1> 。<0,0> 表示图像的精确中心像素。所以 x 轴和 y 轴的范围是从 -11。对于上面的图像,失真因子(d)设置为0.25,中心是随机采样的间隔[-0.5,0.5]

水平波变换

水平波变换是另一种非线性变换,它使像素变形为给定幅度和频率的水平余弦波形状。它需要两个参数,振幅和频率。


上图显示了水平波变换对棋盘图像的影响。

在数学上,给定像素(x,y)的水平波变换函数由下面的公式给出。

其中“a”是给定的余弦波振幅,“f”是预先指定的频率。<t(x) ,t(y)> 是像素 <x,y> 的变换值。注意,水平波不会对像素的 x 坐标产生任何失真。与鱼眼转换示例类似,x 和 y 的范围是从 -11。在上面的例子中,“a”的值是0.2,“f”的值是20

实例

在本节中,我将介绍 PyTorch 中两种转换的矢量化实现。我更喜欢矢量化,因为它比耗时的 for 循环计算转换的速度要快得多。让我们先从鱼眼变换开始。

def get_of_fisheye(H, W, center, magnitude):  
  xx, yy = torch.linspace(-1, 1, W), torch.linspace(-1, 1, H)  
  gridy, gridx  = torch.meshgrid(yy, xx). //create identity grid
  grid = torch.stack([gridx, gridy], dim=-1)  
  d = center - grid      //calculate the distance(cx - x, cy - y)
  d_sum = torch.sqrt((d**2).sum(axis=-1)) //sqrt((cx-x)^2+(cy-y)^2)
  grid += d * d_sum.unsqueeze(-1) * magnitude 
  return grid.unsqueeze(0)
fisheye_grid = get_of_fisheye(H, W, torch.tensor([0,0]), 0.4)
fisheye_output = F.grid_sample(imgs, fisheye_grid)

上面的代码通过以下 4 个步骤来转换图像。

1.创建一个(H, W, 2 )大小的标识网格,其中 x 和 y 的范围从-1 到 1。
2.计算网格中每个像素到给定中心像素的距离
3.计算每个像素与中心像素的欧几里德距离。
4.计算 dist d 幅度并添加到原始网格。
5.使用 PyTorch 的 grid _ sample 函数对图像进行变换。

类似地,下面的代码使用水平波变换转换图像。

def get_of_horizontalwave(H, W, center, magnitude):  
  xx, yy = torch.linspace(-1, 1, W), torch.linspace(-1, 1, H)  
  gridy, gridx  = torch.meshgrid(yy, xx). //create identity grid
  grid = torch.stack([gridx, gridy], dim=-1)  
  dy = amplitude * torch.cos(freq * grid[:,:,0]) //calculate dy
  grid[:,:,1] += dy
  return grid.unsqueeze(0)
hwave_grid = get_of_horizontalwave(H, W, 10, 0.1)
hwave_output = F.grid_sample(imgs, hwave_grid)

下面是所有的代码,并对棋盘图像进行了处理。

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

## Fisheye Transformation
def get_of_fisheye(height, width, center, magnitude):
  xx, yy = torch.linspace(-1, 1, width), torch.linspace(-1, 1, height)
  gridy, gridx  = torch.meshgrid(yy, xx)   #create identity grid
  grid = torch.stack([gridx, gridy], dim=-1)
  d = center - grid         #calculate the distance(cx - x, cy - y) 
  d_sum = torch.sqrt((d**2).sum(axis=-1)) # sqrt((cx-x)**2 + (cy-y)**2)
  grid += d * d_sum.unsqueeze(-1) * magnitude #calculate dx & dy and add to original values
  return grid.unsqueeze(0)    #unsqueeze(0) since the grid needs to be 4D.

## Horizontal Wave Transformation
def get_of_horizontalwave(height, width, freq, amplitude):
  xx, yy = torch.linspace(-1, 1, width), torch.linspace(-1, 1, height)
  gridy, gridx = torch.meshgrid(yy, xx) #create identity grid
  grid = torch.stack([gridx, gridy], dim=-1)
  dy = amplitude * torch.cos(freq * grid[:,:,0]) #calculate dy
  grid[:,:,1] += dy
  return grid.unsqueeze(0)  #unsqueeze(0) since the grid needs to be 4D.

## UTILITY FUNCTIONS 
## Create Image Batch
def get_image_batch(img):
  transform = transforms.Compose([transforms.ToTensor()])
  tfms_img = transform(img)
  imgs = torch.unsqueeze(tfms_img, dim=0)
  return imgs

def plot(img, fisheye_output, hwave_output):
  fisheye_out = fisheye_output[0].numpy()
  fisheye_out = np.moveaxis(fisheye_out, 0,-1)

  hwave_out = hwave_output[0].numpy()
  hwave_out = np.moveaxis(hwave_out, 0,-1)

  fig, ax = plt.subplots(1,3, figsize=(16,4))
  ax[0].imshow(img)
  ax[1].imshow(fisheye_out)
  ax[2].imshow(hwave_out)

  ax[0].set_title('Input Image(Checkerboard)')
  ax[1].set_title('Fisheye')
  ax[2].set_title('Horizontal Wave Tfms')
  plt.show()

img = Image.open('checkerboard.png')
imgs = get_image_batch(img)
N, C, H, W = imgs.shape
fisheye_grid = get_of_fisheye(H, W, torch.tensor([0,0]), 0.4)
hwave_grid = get_of_horizontalwave(H, W, 10, 0.1)

fisheye_output = F.grid_sample(imgs, fisheye_grid, align_corners=True)
hwave_output = F.grid_sample(imgs, hwave_grid, align_corners=True)
plot(img, fisheye_output, hwave_output)

总结

本文介绍了两种用于增强图像数据的非线性增强方法,即鱼眼和水平波变换。鱼眼是一种非线性变换,它根据与固定中心像素的欧氏距离对像素进行变换。水平波变换是另一种非线性变换,它将像素扭曲成水平余弦波的形状。

声明:本文内容由易百纳平台入驻作者撰写,文章观点仅代表作者本人,不代表易百纳立场。如有内容侵权或者其他问题,请联系本站进行删除。
红包 点赞 收藏 评论 打赏
评论
0个
内容存在敏感词
手气红包
    易百纳技术社区暂无数据
相关专栏
置顶时间设置
结束时间
删除原因
  • 广告/SPAM
  • 恶意灌水
  • 违规内容
  • 文不对题
  • 重复发帖
打赏作者
易百纳技术社区
不会编程的老王
您的支持将鼓励我继续创作!
打赏金额:
¥1易百纳技术社区
¥5易百纳技术社区
¥10易百纳技术社区
¥50易百纳技术社区
¥100易百纳技术社区
支付方式:
微信支付
支付宝支付
易百纳技术社区微信支付
易百纳技术社区
打赏成功!

感谢您的打赏,如若您也想被打赏,可前往 发表专栏 哦~

举报反馈

举报类型

  • 内容涉黄/赌/毒
  • 内容侵权/抄袭
  • 政治相关
  • 涉嫌广告
  • 侮辱谩骂
  • 其他

详细说明

审核成功

发布时间设置
发布时间:
是否关联周任务-专栏模块

审核失败

失败原因
备注
拼手气红包 红包规则
祝福语
恭喜发财,大吉大利!
红包金额
红包最小金额不能低于5元
红包数量
红包数量范围10~50个
余额支付
当前余额:
可前往问答、专栏板块获取收益 去获取
取 消 确 定

小包子的红包

恭喜发财,大吉大利

已领取20/40,共1.6元 红包规则

    易百纳技术社区