模型保存

模型的保存非常简单。

    # 保存模型的结构(不含参数),module
    torch.save(network, "model_path_framework/")
    # 保存模型的参数(不含结构)
    torch.save(network.state_dict(), "model_weights_path/")
    
    # 保存的模型类型为DataParallel
    network_parallel = DataParallel(network, gpu_ids)
    torch.save(network_parallel, "modelpara_weights_path/")

通常network的类型为pytorch中的标准的module类型。
但是,如果传入的module是DataParallel,那么在读取模型的时候得到的也是DataParallel类型。

模型加载

如果模型加载时和保存的所处的环境是一致的,同为GPUs或同为CPU。那么在读取时没有什么问题。如下:

    torch.load("network_pat")

但是,如果模型保存和加载时的环境不同,例如使用GPU训练的模型保存后要在只有CPU的机器上运行,那么就需要使用参数map_location

    # 把所有的张量加载到CPU中     GPU ==> CPU
    torch.load("network_path", map_location=lambda storage, loc: storage)
    # 把所有的张量加载到GPU 1中   
    torch.load("network_path", map_location=lambda storage, loc: storage.cuda(1))
    # 把张量从GPU 1 移动到 GPU 0
    torch.load("network_path", map_location={'cuda:1':'cuda:0'})

如果保存的模型是在多个GPU上并行计算的DataParallel类型。
在读取模型时,需要先取出DataParallel的一般module模型

    # 保存的模型类型为DataParallel
    network_parallel = DataParallel(network, gpu_ids)
    torch.save(network_parallel, "modelpara_weights_path")
    # 读取的模型类型为DataParallel, 现将其读取到CPU上
    model = torch.load("modelpara_weights_path", map_location=lambda storage, loc: storage)
    # 取出DataParallel中的一般module类型
    model_module = model.module

如图

DataParallel
module

到此模型结构就加载完了。但是,模型参数加载时依然会出现问题。

GPU并行时保存的模型参数如下图。

    weight = torch.load(model_weights_path, map_location=lambda storage, loc: storage)
DataParallel Weights

参数字典中多了一个"module"标识。这是因为该参数是DataParalle.module下的参数,相比直接保存的module多了一级标签,因此需要去掉“module.”(不要忘记还有点“.” 一共7个字符)

    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in weight.items():
        name = k[7:]  # 正好去除头7个字符 "module."
        new_state_dict[name] = v
    # 加载参数
    model.load_state_dict(new_state_dict)

这样就可以将在多GPU上并行训练并保存的模型,在CPU上加载并使用了。

总结

一般来说,不建议保存模型结构。

模型结构保存后的文件几乎和参数保存文件一样大。模型结构在用的时候使用代码进行定义,随后直接使用参数初始化即可。这样可以节省很多存储空间。