Github Link: https://github.com/dingguanglei/jdit

Guide List:

Jdit is a research processing oriented framework based on pytorch. Only care about your ideas. You don't need to build a long boring code to run a deep learning project to verify your ideas.

You only need to implement you ideas and don't do anything with training framework, multiply-gpus, checkpoint, process visualization, performance evaluation and so on.

If you have any problems, or you find bugs you can contact the author.

E-mail: dingguanglei.bupt@qq.com

Template of GAN for pix2pix

Because of inheriting from SupGanTrainer. This template is much more simple.
As for pix2pix task, the most different point to others task is that both the input and groundtruth are images.
When you load data from dataset. You got images input images and its label groundtruth images. The output of generator is the fake images which you want to make it as close as possible to groundtruth.

So, this is the feature of this type of tasks.

However, all of these are in general. If you need something special, just rewrite the methods without hesitate.

Class Pix2pixGanTrainer. The key method application likes this:

class Pix2pixGanTrainer(SupGanTrainer):
    d_turn = 1

    def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets):
        super(Pix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets)

    def get_data_from_batch(self,batch_data: list, device: torch.device):
        input_tensor, ground_truth_tensor = batch_data[0], batch_data[1]
        return input_tensor, ground_truth_tensor

    def _watch_images(self, tag, grid_size=(3, 3), shuffle=False, save_file=True):
        self.watcher.image(self.input,
                           self.current_epoch,
                           tag="%s/input" % tag,
                           grid_size=grid_size,
                           shuffle=shuffle,
                           save_file=save_file)
        self.watcher.image(self.fake,
                           self.current_epoch,
                           tag="%s/fake" % tag,
                           grid_size=grid_size,
                           shuffle=shuffle,
                           save_file=save_file)
        self.watcher.image(self.ground_truth,
                           self.current_epoch,
                           tag="%s/real" % tag,
                           grid_size=grid_size,
                           shuffle=shuffle,
                           save_file=save_file)

    @abstractmethod
    def compute_d_loss(self):
        """ Rewrite this method to compute your own loss Discriminator."""
		pass

    @abstractmethod
    def compute_g_loss(self):
        """Rewrite this method to compute your own loss of Generator."""
		pass

    @abstractmethod
    def compute_valid(self):
        """Rewrite this method to compute valid_epoch values."""
		pass

    def valid_epoch(self):
        super(Pix2pixGanTrainer, self).valid_epoch()
        self.netG.eval()
        self.netD.eval()
        if self.fixed_input is None:
            for iteration, batch in enumerate(self.datasets.loader_test, 1):
                if isinstance(batch, list):
                    self.fixed_input, fixed_ground_truth = self.get_data_from_batch(batch, self.device)
                    self.watcher.image(self.fixed_input, self.current_epoch, tag="Fixed/groundtruth",
                                       grid_size=(6, 6),
                                       shuffle=False)
                else:
                    self.fixed_input = batch.to(self.device)
                self.watcher.image(self.fixed_input, self.current_epoch, tag="Fixed/input",
                                   grid_size=(6, 6),
                                   shuffle=False)
                break
        # watching the variation during training by a fixed input
        with torch.no_grad():
            fake = self.netG(self.fixed_input).detach()
        self.watcher.image(fake, self.current_epoch, tag="Fixed/fake", grid_size=(6, 6), shuffle=False)

        # saving training processes to build a .gif.
        self.watcher.set_training_progress_images(fake, grid_size=(6, 6))

        self.netG.train()
        self.netD.train()

    def test(self):
        pass

There is no training logic any more. Because your training logic are accomplished in class SupGanTrainer. So, here you only need to care about the training loss, valid function and test function.

  • compute_d_loss. You need to compute loss for discriminator. Then this function should return two variable. The first one is the main loss, which you want to do backward. The second is a dict(), the key and value will be shown on tensorboard without do any more computation.
  • compute_g_loss. The same thing like compute_d_loss, and this is for generator.
  • compute_valid. This is valid function, and it will be called every epoch at the end. Because of validation, here you don't need loss for backward. So, the return value is only a dict(), which do the same thing like compute_d_loss, just showing the value on tensorboard. The valid value will be avaraged by total valid dataset.
  • test. This will be called at the end of the whole training. So, just rewrite this method and do whatever you want.
  • _watch_images method was defined in SupGanTrainer, but it was rewrited here, because the groundtruth is images not labels any more. So, it will be showed on the tensorboard.

Example: Pixel to pixel task

Let me give you an example to apply for this template. This example is in jdit.trainer.instances.cifarPix2pixGan. This task convert 1 channel grey input images to 3 channels color images.
Here is the code:

class CifarPix2pixGanTrainer(Pix2pixGanTrainer):
    d_turn = 5

    def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets):
        super(CifarPix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD,
                                                     datasets)

    def get_data_from_batch(self, batch_data, device):
        ground_truth_cpu, label = batch_data[0], batch_data[1]
         # only use one channel [?,3,32,32] =>[?,1,32,32]
        input_cpu = ground_truth_cpu[:, 0, :, :].unsqueeze(1)  
        return input_cpu, ground_truth_cpu

    def compute_d_loss(self):
        d_fake = self.netD(self.fake.detach())
        d_real = self.netD(self.ground_truth)
        var_dic = {}
        var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2))
        return loss_d, var_dic

    def compute_g_loss(self):
        d_fake = self.netD(self.fake)
        var_dic = {}
        var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2)

        return loss_g, var_dic

    def compute_valid(self):
        g_loss, _ = self.compute_g_loss()
        d_loss, _ = self.compute_d_loss()
        mse = ((self.fake.detach() - self.ground_truth) ** 2).mean()
        var_dic = {"LOSS_D": d_loss, "LOSS_G": g_loss, "MSE": mse}
        return var_dic

This final template is clear. You can see there are two loss functions for discriminator and generator. For another function is valid function.

  • For loss, it uses least square loss and the value of this loss is in the var_dic. It means that it will show this value on the tensorboard.
  • For valid fuinction, it just recalls the loss fucntion and collects the values in var_dic . It will show the same thing in loss. But for valid, these values are computed for valid dataset and avaraged them.
  • Because of we only need one channel of input images, so we need to rewrite the get_data_from_batch method and return two batch of images for training.

Have a try

The following is the complete code. You can read and copy them to run or you can import the training function for a quick start.

Quick start

from jdit.trainer.instances.fashingClassification import start_cifarPix2pixGanTrainer
start_cifarPix2pixGanTrainer()
# coding=utf-8
import torch
import torch.nn as nn
from jdit.trainer import Pix2pixGanTrainer
from jdit.model import Model
from jdit.optimizer import Optimizer
from jdit.dataset import Cifar10


class Discriminator(nn.Module):
    def __init__(self, input_nc=3, output_nc=1, depth=64):
        super(Discriminator, self).__init__()
        # 32 x 32
        self.layer1 = nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(input_nc, depth * 1, kernel_size=3, stride=1, padding=1)),
                nn.LeakyReLU(0.1))
        # 32 x 32
        self.layer2 = nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(depth * 1, depth * 2, kernel_size=3, stride=1, padding=1)),
                nn.LeakyReLU(0.1),
                nn.MaxPool2d(2, 2))
        # 16 x 16
        self.layer3 = nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(depth * 2, depth * 4, kernel_size=3, stride=1, padding=1)),
                nn.LeakyReLU(0.1),
                nn.MaxPool2d(2, 2))
        # 8 x 8
        self.layer4 = nn.Sequential(nn.Conv2d(depth * 4, output_nc, kernel_size=8, stride=1, padding=0))
        # 1 x 1

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out


class Generator(nn.Module):
    def __init__(self, input_nc=1, output_nc=3, depth=32):
        super(Generator, self).__init__()

        self.latent_to_features = nn.Sequential(
                nn.Conv2d(input_nc, 1 * depth, 3, 1, 1),  # 1,32,32 => d,32,32
                nn.ReLU(),
                nn.BatchNorm2d(1 * depth),
                nn.Conv2d(1 * depth, 2 * depth, 4, 2, 1),  # d,32,32 => 2d,16,16
                nn.ReLU(),
                nn.BatchNorm2d(2 * depth),
                nn.Conv2d(2 * depth, 4 * depth, 4, 2, 1),  # 2d,16,16 => 4d,8,8
                nn.ReLU(),
                nn.BatchNorm2d(4 * depth),
                nn.Conv2d(4 * depth, 4 * depth, 4, 2, 1),  # 4d,8,8  => 4d,4,4
                nn.ReLU(),
                nn.BatchNorm2d(4 * depth)
                )
        self.features_to_image = nn.Sequential(
                nn.ConvTranspose2d(4 * depth, 4 * depth, 4, 2, 1),  # 4d,4,4 =>  4d,8,8
                nn.ReLU(),
                nn.BatchNorm2d(4 * depth),
                nn.ConvTranspose2d(4 * depth, 2 * depth, 4, 2, 1),  # 4d,8,8 =>  2d,16,16
                nn.ReLU(),
                nn.BatchNorm2d(2 * depth),
                nn.ConvTranspose2d(2 * depth, depth, 4, 2, 1),  # 2d,16,16 =>  d,32,32
                nn.ReLU(),
                nn.BatchNorm2d(depth),
                nn.ConvTranspose2d(depth, output_nc, 3, 1, 1),  # d,32,32 =>  3,32,32
                )

    def forward(self, input_data):
        out = self.latent_to_features(input_data)
        out = self.features_to_image(out)
        return out


class CifarPix2pixGanTrainer(Pix2pixGanTrainer):
    d_turn = 5

    def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, datasets):
        super(CifarPix2pixGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD,
                                                     datasets)

    def get_data_from_batch(self, batch_data, device):
        ground_truth_cpu, label = batch_data[0], batch_data[1]
        input_cpu = ground_truth_cpu[:, 0, :, :].unsqueeze(1)  # only use one channel [?,3,32,32] =>[?,1,32,32]
        return input_cpu, ground_truth_cpu

    def compute_d_loss(self):
        d_fake = self.netD(self.fake.detach())
        d_real = self.netD(self.ground_truth)
        var_dic = {}
        var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2))
        return loss_d, var_dic

    def compute_g_loss(self):
        d_fake = self.netD(self.fake)
        var_dic = {}
        var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2)

        return loss_g, var_dic

    def compute_valid(self):
        g_loss, _ = self.compute_g_loss()
        d_loss, _ = self.compute_d_loss()
        mse = ((self.fake.detach() - self.ground_truth) ** 2).mean()
        var_dic = {"LOSS_D": d_loss, "LOSS_G": g_loss, "MSE": mse}
        return var_dic


def start_cifarPix2pixGanTrainer(gpus=(), nepochs=200, lr=1e-3, depth_G=32, depth_D=32, run_type="train"):
    gpus = gpus  # set `gpus = []` to use cpu
    batch_size = 32
    image_channel = 3
    nepochs = nepochs
    depth_G = depth_G
    depth_D = depth_D

    G_hprams = {"optimizer": "Adam", "lr_decay": 0.9,
                "decay_position": 10, "decay_type": "epoch",
                "lr": lr, "weight_decay": 2e-5,
                "betas": (0.9, 0.99)
                }
    D_hprams = {"optimizer": "RMSprop", "lr_decay": 0.9,
                "decay_position": 10, "decay_type": "epoch",
                "lr": lr, "weight_decay": 2e-5,
                "momentum": 0
                }

    print('===> Build dataset')
    cifar10 = Cifar10(root="datasets/cifar10", batch_size=batch_size)
    torch.backends.cudnn.benchmark = True
    print('===> Building model')
    D_net = Discriminator(input_nc=image_channel, depth=depth_D)
    D = Model(D_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=50)
    # -----------------------------------
    G_net = Generator(input_nc=1, output_nc=image_channel, depth=depth_G)
    G = Model(G_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=50)
    print('===> Building optimizer')
    opt_D = Optimizer(D.parameters(), **D_hprams)
    opt_G = Optimizer(G.parameters(), **G_hprams)
    print('===> Training')
    Trainer = CifarPix2pixGanTrainer("log/cifar_p2p", nepochs, gpus, G, D, opt_G, opt_D, cifar10)
    if run_type == "train":
        Trainer.train()
    elif run_type == "debug":
        Trainer.debug()


if __name__ == '__main__':
    start_cifarPix2pixGanTrainer()

More