torchvision.transform提供了很多方法用于图像增强,其中在使用RandomXXX()系列的API时,通常内部会用到random.random()等用来做随机数。这样就可以保证每一次transfrom都是随机的。
然而如果想要应用相同的transform在输入图像和输出图像上呢?

例如:一个pix2pix的任务。需要同时对输入和输出图像进行变换,比如变换相同的角度等。

解决方法:

在pytorch文档上可以看到,除了torchvision.transform 还有torchvision.transforms.functional
相比transformtransforms.functional 更加灵活,该方法只提供了图像的增强变换功能,而并没有随机部分,因此可以自己设计应用的方式。
官网链接:torchvision.transforms.functional

示例

x_file_namesy_file_names 分别为待读取数据路径。
transform 为自己定义的方法,基本上transform里有的,transforms.functional 里都会有对应的,只需要填好参数即可。

有一些transform 是有get_params()方法的,可以用于返回配置参数。如果你嫌自己手动随机一些参数很麻烦,可以直接使用transforms.XXX.get_params()拿到这些随机好的参数,直接给transforms.functional使用即可。

import random
from torch.utils.data import Dataset
rom torchvision import transforms
import torchvision.transforms.functional as tf
from PIL import Image

class TrainDataset(Dataset):
    def __init__(self, x_file_names:list, y_file_names:list):
        self.x_file_names = x_file_names
        self.y_file_names = y_file_names
        self.nums = len(x_file_names)

    def __len__(self):
        return self.nums

    def transform(self, image, mask):
        # 拿到角度的随机数。angle是一个-180到180之间的一个数
        angle = transforms.RandomRotation.get_params([-180, 180])
        # 对image和mask做相同的旋转操作,保证他们都旋转angle角度
        image = tf.rotate(image, angle, resample=Image.NEAREST)
        mask = tf.rotate(mask, angle, resample=Image.NEAREST)
        # 自己写随机部分,50%的概率应用垂直,水平翻转。
        if random.random() > 0.5:
            image = tf.hflip(image)
            mask = tf.hflip(mask)
        if random.random() > 0.5:
            image = tf.vflip(image)
            mask = tf.vflip(mask)
        # 也可以实现一些复杂的操作
        # 50%的概率对图像放大后裁剪固定大小
        # 50%的概率对图像缩小后周边补0,并维持固定大小
        if random.random() > 0.5:
            i, j, h, w = transforms.RandomResizedCrop.get_params(
                    image, scale=(0.25, 1.0), ratio=(1, 1))
            image = tf.resized_crop(image, i, j, h, w, 256)
            mask = tf.resized_crop(mask, i, j, h, w, 256)
        else:
            pad = random.randint(0, 192)
            image = tf.pad(image, pad)
            image = tf.resize(image, 256)
            mask = tf.pad(mask, pad)
            mask = tf.resize(mask, 256)
        # 转换为tensor并做归一化
        image = tf.to_tensor(image)
        image = tf.normalize(image, [0.5], [0.5])
        mask = tf.to_tensor(mask)
        mask = tf.normalize(mask, [0.5], [0.5])
        return image, mask

    def __getitem__(self, index):
        with Image.open(self.x_file_names[index]) as img:
            x_img = img.convert("L")
        with Image.open(self.y_file_names[index]) as img:
            y_img = img.convert("L")
        # 直接应用自己写的transform即可
        x, y = self.transform(x_img, y_img)
        return x, y