This template is used  for generation task by using GAN structure.

Github Link: https://github.com/dingguanglei/jdit

Guide List:

Jdit is a research processing oriented framework based on pytorch. Only care about your ideas. You don't need to build a long boring code to run a deep learning project to verify your ideas.

You only need to implement you ideas and don't do anything with training framework, multiply-gpus, checkpoint, process visualization, performance evaluation and so on.

If you have any problems, or you find bugs you can contact the author.

E-mail: dingguanglei.bupt@qq.com

Template of GAN for generation

Because of inheriting from SupGanTrainer. This template is much more simple.
As for generation task, the most different point to others task is the input which is noise that not from dataset.

When you load data from dataset, the behavior is different. You got images as real data, and by using something like random to generate latent data (input data).
So, this is the feature of this type of tasks.

However, all of these are in general. If you need something special, just rewrite the methods without hesitate.

Class GenerateGanTrainer. The key method application like this:

class GenerateGanTrainer(SupGanTrainer):
    d_turn = 1
    """The training times of Discriminator every ones Generator training.
    """

    def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, 
        datasets, latent_shape):
        """ a gan super class

        :param logdir:Path of log
        :param nepochs:Amount of epochs.
        :param gpu_ids_abs: The id of gpus which t obe used. If use CPU, set ``[]``.
        :param netG:Generator model.
        :param netD:Discrimiator model
        :param optG:Optimizer of Generator.
        :param optD:Optimizer of Discrimiator.
        :param datasets:Datasets.
        :param latent_shape:The shape of input noise.
        """
        super(GenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, 
            netD, optG, optD, datasets)
        self.latent_shape = latent_shape
        self.fixed_input = torch.randn((self.datasets.batch_size, *self.latent_shape)).to(self.device)

    def get_data_from_batch(self, batch_data: list, device: torch.device):
        ground_truth_tensor = batch_data[0]
        input_tensor = torch.randn((len(ground_truth_tensor), *self.latent_shape))
        return input_tensor, ground_truth_tensor

    def valid_epoch(self):
        super(GenerateGanTrainer, self).valid_epoch()
        self.netG.eval()
        # watching the variation during training by a fixed input
        with torch.no_grad():
            fake = self.netG(self.fixed_input).detach()
        self.watcher.image(fake, self.current_epoch, tag="Valid/Fixed_fake", grid_size=(4, 4), shuffle=False)
        # saving training processes to build a .gif.
        self.watcher.set_training_progress_images(fake, grid_size=(4, 4))
        self.netG.train()

    @abstractmethod
    def compute_d_loss(self):
        """ Rewrite this method to compute your own loss Discriminator."""
        pass

    @abstractmethod
    def compute_g_loss(self):
        """Rewrite this method to compute your own loss of Generator."""
        pass

    @abstractmethod
    def compute_valid(self):
        _, d_var_dic = self.compute_g_loss()
        _, g_var_dic = self.compute_d_loss()
        var_dic = dict(d_var_dic, **g_var_dic)
        return var_dic

    def test(self):
        self.input = torch.randn((16, *self.latent_shape)).to(self.device)
        self.netG.eval()
        with torch.no_grad():
            fake = self.netG(self.input).detach()
        self.watcher.image(fake, 
            self.current_epoch, 
            tag="Test/fake", 
            grid_size=(4,4), 
            shuffle=False)
        self.netG.train()

There is no training logic any more. Because your training logic are accomplished in class SupGanTrainer. So, here you only need to care about the training loss, valid function and test function.

  • compute_d_loss. You need to compute loss for discriminator. Then this function should return two variable. The first one is the main loss, which you want to do backward. The second is a dict(), the key and value will be shown on tensorboard without do any more computation.
  • compute_g_loss. The same thing like compute_d_loss, and this is for generator.
  • compute_valid. This is valid function, and it will be called every epoch at the end. Because of validation, here you don't need loss for backward. So, the return value is only a dict(), which do the same thing like compute_d_loss, just showing the value on tensorboard. The valid value will be avaraged by total valid dataset.
  • test. This will be called at the end of the whole training. So, just rewrite this method and do whatever you want.

Example: Generation task

Let me give you an example to apply for this template. This example is in jdit.trainer.instances.fashingGenerateGan
Here is the code:

class FashingGenerateGenerateGanTrainer(GenerateGanTrainer):
    d_turn = 1

    def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset, latent_shape):
        super(FashingGenerateGenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD,
                                                                dataset,
                                                                latent_shape=latent_shape)

        data, label = self.datasets.samples_train
        self.watcher.embedding(data, data, label, global_step=1)

    def compute_d_loss(self):
        d_fake = self.netD(self.fake.detach())
        d_real = self.netD(self.ground_truth)
        var_dic = dict()
        var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2))
        return loss_d, var_dic

    def compute_g_loss(self):
        d_fake = self.netD(self.fake)
        var_dic = dict()
        var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2)
        return loss_g, var_dic

    def compute_valid(self):
        _, d_var_dic = self.compute_g_loss()
        _, g_var_dic = self.compute_d_loss()
        var_dic = dict(d_var_dic, **g_var_dic)
        return var_dic

This final template is clear. You can see there are two loss functions for discriminator and generator. For another function is valid function.

For loss, it uses least square loss and the value of this loss is in the var_dic. It means that it will show this value on the tensorboard.

For valid fuinction, it just recalls the loss fucntion and collects the values in var_dic . It will show the same thing in loss. But for valid, these values are computed for valid dataset and avaraged them.

Have a try

The following is the complete code. You can read and copy them to run or you can import the training function for a quick start.

Quick start

from jdit.trainer.instances.fashingClassification import start_fashingGenerateGanTrainer
start_fashingGenerateGanTrainer()

Code:

# coding=utf-8
import torch
import torch.nn as nn
from jdit.trainer import GenerateGanTrainer
from jdit.model import Model
from jdit.optimizer import Optimizer
from jdit.dataset import FashionMNIST


class Discriminator(nn.Module):
    def __init__(self, input_nc=1, output_nc=1, depth=64):
        super(Discriminator, self).__init__()
        # 32 x 32
        self.layer1 = nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(input_nc, depth * 1, kernel_size=3, stride=1, padding=1)),
                nn.LeakyReLU(0.1))
        # 32 x 32
        self.layer2 = nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(depth * 1, depth * 2, kernel_size=3, stride=1, padding=1)),
                nn.LeakyReLU(0.1),
                nn.MaxPool2d(2, 2))
        # 16 x 16
        self.layer3 = nn.Sequential(
                nn.utils.spectral_norm(nn.Conv2d(depth * 2, depth * 4, kernel_size=3, stride=1, padding=1)),
                nn.LeakyReLU(0.1),
                nn.MaxPool2d(2, 2))
        # 8 x 8
        self.layer4 = nn.Sequential(nn.Conv2d(depth * 4, output_nc, kernel_size=8, stride=1, padding=0))
        # 1 x 1

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out


class Generator(nn.Module):
    def __init__(self, input_nc=256, output_nc=1, depth=64):
        super(Generator, self).__init__()
        self.encoder = nn.Sequential(
                nn.ConvTranspose2d(input_nc, 4 * depth, 4, 1, 0),  # 256,1,1 =>  256,4,4
                nn.ReLU())
        self.decoder = nn.Sequential(
                nn.ConvTranspose2d(4 * depth, 4 * depth, 4, 2, 1),  # 256,4,4 =>  256,8,8
                nn.ReLU(),
                nn.BatchNorm2d(4 * depth),
                nn.ConvTranspose2d(4 * depth, 2 * depth, 4, 2, 1),  # 256,8,8 =>  128,16,16
                nn.ReLU(),
                nn.BatchNorm2d(2 * depth),
                nn.ConvTranspose2d(2 * depth, depth, 4, 2, 1),  # 128,16,16 =>  64,32,32
                nn.ReLU(),
                nn.BatchNorm2d(depth),
                nn.ConvTranspose2d(depth, output_nc, 3, 1, 1),  # 64,32,32 =>  1,32,32
                )

    def forward(self, input_data):
        out = self.encoder(input_data)
        out = self.decoder(out)
        return out


class FashingGenerateGenerateGanTrainer(GenerateGanTrainer):
    d_turn = 1

    def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset, latent_shape):
        super(FashingGenerateGenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD,
                                                                dataset,
                                                                latent_shape=latent_shape)

        data, label = self.datasets.samples_train
        self.watcher.embedding(data, data, label, global_step=1)

    def compute_d_loss(self):
        d_fake = self.netD(self.fake.detach())
        d_real = self.netD(self.ground_truth)
        var_dic = dict()
        var_dic["LS_LOSSD"] = loss_d = 0.5 * (torch.mean((d_real - 1) ** 2) + torch.mean(d_fake ** 2))
        return loss_d, var_dic

    def compute_g_loss(self):
        d_fake = self.netD(self.fake)
        var_dic = dict()
        var_dic["LS_LOSSG"] = loss_g = 0.5 * torch.mean((d_fake - 1) ** 2)
        return loss_g, var_dic

    def compute_valid(self):
        _, d_var_dic = self.compute_g_loss()
        _, g_var_dic = self.compute_d_loss()
        var_dic = dict(d_var_dic, **g_var_dic)
        return var_dic


def start_fashingGenerateGanTrainer(gpus=(), nepochs=50, lr=1e-3, depth_G=32, depth_D=32, latent_shape=(256, 1, 1),
                                    run_type="train"):
    gpus = gpus  # set `gpus = []` to use cpu
    batch_size = 64
    image_channel = 1
    nepochs = nepochs

    depth_G = depth_G
    depth_D = depth_D

    G_hprams = {"optimizer": "Adam", "lr_decay": 0.94,
                "decay_position": 2, "decay_type": "epoch",
                "lr": lr, "weight_decay": 2e-5,
                "betas": (0.9, 0.99)
                }
    D_hprams = {"optimizer": "RMSprop", "lr_decay": 0.94,
                "decay_position": 2, "decay_type": "epoch",
                "lr": lr, "weight_decay": 2e-5,
                "momentum": 0
                }

    # the input shape of Generator
    latent_shape = latent_shape
    print('===> Build dataset')
    mnist = FashionMNIST(batch_size=batch_size)
    torch.backends.cudnn.benchmark = True
    print('===> Building model')
    D_net = Discriminator(input_nc=image_channel, depth=depth_D)
    D = Model(D_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=10)
    # -----------------------------------
    G_net = Generator(input_nc=latent_shape[0], output_nc=image_channel, depth=depth_G)
    G = Model(G_net, gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=10)
    print('===> Building optimizer')
    opt_D = Optimizer(D.parameters(), **D_hprams)
    opt_G = Optimizer(G.parameters(), **G_hprams)
    print('===> Training')
    print("using `tensorboard --logdir=log` to see learning curves and net structure."
          "training and valid_epoch data, configures info and checkpoint were save in `log` directory.")
    Trainer = FashingGenerateGenerateGanTrainer("log/fashion_generate", nepochs, gpus, G, D, opt_G, opt_D, mnist,
                                                latent_shape)
    if run_type == "train":
        Trainer.train()
    elif run_type == "debug":
        Trainer.debug()


if __name__ == '__main__':
    start_fashingGenerateGanTrainer()

More