框架介绍

如果你想验证一个新的Idea,例如一个新模型,你会怎么做?

  • 编写code实现你的模型
  • 建立一个简单的深度学习项目?
  • 观察训练、验证曲线?
  • 使用多任务,多GPUs?
  • 模型检查点保存?
  • 超参数调节?
  • 。。。。

核心任务只有模型,然而却要付出更多的工作量来完成验证工作。

Jdit可以让研究人员从繁杂的研究外围工作中解脱出来,使其专注于核心任务。

反例: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/train.py


if __name__ == '__main__':
    opt = TrainOptions().parse()   # get training options
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)    # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
    total_iters = 0                # the total number of training iterations

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        # epoch_start_time = time.time()  # timer for entire epoch
        # iter_data_time = time.time()    # timer for data loading per iteration
        # epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch

        for i, data in enumerate(dataset):  # inner loop within one epoch
            # iter_start_time = time.time()  # timer for computation per iteration
            # if total_iters % opt.print_freq == 0:
            #    t_data = iter_start_time - iter_data_time
            # visualizer.reset()
            # total_iters += opt.batch_size
            # epoch_iter += opt.batch_size
            model.set_input(data)         # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights

            # if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
            #    save_result = total_iters % opt.update_html_freq == 0
            #    model.compute_visuals()
                # visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            #if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
            #    losses = model.get_current_losses()
            #    t_comp = (time.time() - iter_start_time) / opt.batch_size
            #    visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
            #    if opt.display_id > 0:
            #        visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            #if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                #print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                #save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                #model.save_networks(save_suffix)

            #iter_data_time = time.time()
        #if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            #print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            #model.save_networks('latest')
            #model.save_networks(epoch)

        # print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()

Jdit基于Pytorch。

模块介绍

主要模块为Model,Trainer,Optimizer,Dataset和Loss。

以上构成了一个深度学习任务的最简必要功能,其中Loss完全由使用者来实现。

Model, Optimizer, Dataset使用Python的委托机制,使得新模块行为和pytorch原生行为保持一致。

Trainer中定义了基本的工具,用于实现数据可视化,保存,模型检查点保存等。

Github: https://github.com/dingguanglei/jdit

快速上手:

安装 (MAC, Linux, Windows):

pip install jdit

其他依赖:

  • pytorch
  • tensorflow
  • tensorboard
  • tensorboardX

# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
from jdit.trainer.single.classification import ClassificationTrainer
from jdit import Model
from jdit.optimizer import Optimizer
from jdit.dataset import FashionMNIST


class SimpleModel(nn.Module):
    def __init__(self, depth=64, num_class=10):
        super(SimpleModel, self).__init__()
        self.num_class = num_class
        self.layer1 = nn.Conv2d(1, depth, 3, 1, 1)
        self.layer2 = nn.Conv2d(depth, depth * 2, 4, 2, 1)
        self.layer3 = nn.Conv2d(depth * 2, depth * 4, 4, 2, 1)
        self.layer4 = nn.Conv2d(depth * 4, depth * 8, 4, 2, 1)
        self.layer5 = nn.Conv2d(depth * 8, num_class, 4, 1, 0)

    def forward(self, input):
        out = F.relu(self.layer1(input))
        out = F.relu(self.layer2(out))
        out = F.relu(self.layer3(out))
        out = F.relu(self.layer4(out))
        out = self.layer5(out)
        out = out.view(-1, self.num_class)
        return out


class FashingClassTrainer(ClassificationTrainer):
    def __init__(self, logdir, nepochs, gpu_ids, net, opt, datasets, num_class):
        super(FashingClassTrainer, self).__init__(logdir, nepochs, gpu_ids, net, opt, datasets, num_class)
        data, label = self.datasets.samples_train
        self.watcher.embedding(data, data, label, 1)

    def compute_loss(self):
        var_dic = {}
        var_dic["CEP"] = loss = nn.CrossEntropyLoss()(self.output, self.ground_truth.squeeze().long())

        _, predict = torch.max(self.output.detach(), 1)  # 0100=>1  0010=>2
        total = predict.size(0) * 1.0
        labels = self.ground_truth.squeeze().long()
        correct = predict.eq(labels).cpu().sum().float()
        acc = correct / total
        var_dic["ACC"] = acc
        return loss, var_dic

    def compute_valid(self):
        var_dic = {}
        var_dic["CEP"] = nn.CrossEntropyLoss()(self.output, self.labels.squeeze().long())

        _, predict = torch.max(self.output.detach(), 1)  # 0100=>1  0010=>2
        total = predict.size(0) * 1.0
        labels = self.labels.squeeze().long()
        correct = predict.eq(labels).cpu().sum().float()
        acc = correct / total
        var_dic["ACC"] = acc
        return var_dic


def start_fashingClassTrainer(gpus=(), nepochs=10, run_type="train"):
    """" An example of fashing-mnist classification
    """
    num_class = 10
    depth = 32
    gpus = gpus
    batch_size = 4
    nepochs = nepochs
    opt_hpm = {"optimizer": "Adam",
               "lr_decay": 0.94,
               "decay_position": 10,
               "position_type": "epoch",
               "lr_reset": {2: 5e-4, 3: 1e-3},
               "lr": 1e-4,
               "weight_decay": 2e-5,
               "betas": (0.9, 0.99)}

    print('===> Build dataset')
    mnist = FashionMNIST(batch_size=batch_size)
    # mnist.dataset_train = mnist.dataset_test
    torch.backends.cudnn.benchmark = True
    print('===> Building model')
    net = Model(SimpleModel(depth=depth), gpu_ids_abs=gpus, init_method="kaiming", check_point_pos=1)
    print('===> Building optimizer')
    opt = Optimizer(net.parameters(), **opt_hpm)
    print('===> Training')
    print("using `tensorboard --logdir=log` to see learning curves and net structure."
          "training and valid_epoch data, configures info and checkpoint were save in `log` directory.")
    Trainer = FashingClassTrainer("log/fashion_classify", nepochs, gpus, net, opt, mnist, num_class)
    if run_type == "train":
        Trainer.train()
    elif run_type == "debug":
        Trainer.debug()



if __name__ == '__main__':
    start_fashingClassTrainer()
>>> import jdit
>>> jdit.trainer.instances.start_fashingClassTrainer()
===> Build datasetuse 8 thread!Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to datasets/fashion_data\FashionMNIST\raw\train-images-idx3-ubyte.gz0%|▏                                                                    | 49152/26421880 [00:00<03:49, 114983.05it/s]
xtracting datasets/fashion_data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz
Processing..., 19172.89it/s]
Done!
===> Building model
SimpleModel Total number of parameters: 729866
SimpleModel model use CPU!
apply kaiming weight init!
===> Building optimizer
===> Training
using `tensorboard --logdir=log` to see learning curves and net structure.training and valid_epoch data, configures info and checkpoint were save in `log` directory.
  0%|                                      | 0/10 [00:00<?, ?epoch/s]

使用tensorboard观察训练数据和曲线:

tensorboard --logdir=log