1 BasicSR训练配置文件各配置项详细说明

本文内容来源于BasicSR-docs ,这里只做摘抄,这个仓库对BasicSR目录层次,代码结构、文件配置的详细说明文档,如果你正在学习和了解BasicSR,我推荐可以先详细阅读这个文档。

1.1 通用配置

# general settings - 这块为通用设置
name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb # 实验名称, 若实验名字中有debug字样, 则会进入debug模式
model_type: SRModel # 使用的 model 类型,模型定义可为你自定义的模型,或者在BasicSR的models目录中
scale: 4 # 输出比输入的倍数, 在SR中是放大倍数; 若有些任务没有这个配置, 则写1
num_gpu: 1 # 指定使用的 GPU 卡数
manual_seed: 0 # 指定随机种子

注意

  • num_gpu:0 表示使用CPU,auto表示自动从可用GPU块数推断

1.2 dataset

# dataset and data loader settings
datasets: # 这块是 dataset 的配置
    train: # 训练 dataset 的配置
        name: DIV2K # 自定义的数据集名称
        type: PairedImageDataset # 读取数据的 Dataset 类
        dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub # GT (Ground-Truth) 图像的文件夹路径
        dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub # LQ (Low-Quality) 输入图像的文件夹路径
        meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt # 预先生成的 meta_info 文件
        # (for lmdb)
        # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
        # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
        filename_tmpl: '{}' # 文件名字模板, 一般LQ文件会有类似 '_x4' 这样的文件后缀, 这个就是来处理GT和LQ文件后缀不匹配的问题的
        io_backend: # IO 读取的 backend
            type: disk # disk 表示直接从硬盘读取
            # (for lmdb)
            # type: lmdb

        gt_size: 128 # 训练阶段裁剪 (crop) 的GT图像的尺寸大小,即训练的 label 大小
        use_hflip: true # 是否开启水平方向图像增强 (随机水平翻转图像)
        use_rot: true # 是否开启旋转图像增强 (随机旋转图像)

        # data loader - 下面这块是 data loader 的设置
        num_worker_per_gpu: 6 # 每一个 GPU 的 data loader 读取进程数目
        batch_size_per_gpu: 16 # 每块 GPU 上的 batch size
        dataset_enlarge_ratio: 100 # # 放大 dataset 的长度倍数 (默认为1)。可以扩大一个 epoch 所需 iterations
        prefetch_mode: ~ # 预先读取数据的方式

    val: # 第一个validation 数据集的设置
        name: Set5 # 数据集名称
        type: PairedImageDataset # 数据集的类型
        dataroot_gt: datasets/Set5/GTmod12
        dataroot_lq: datasets/Set5/LRbicx4
        io_backend: # IO 读取的 backend
            type: disk # disk 表示直接从硬盘读取

    val_2: # 第二个validation 数据集的设置
        name: Set14 # 数据集名称
        type: PairedImageDataset # 数据集的类型
        dataroot_gt: datasets/Set14/GTmod12
        dataroot_lq: datasets/Set14/LRbicx4
        io_backend: # IO 读取的 backend
            type: disk # disk 表示直接从硬盘读取

注意的配置项:

  • dataset_enlarge_ratio:它代表了手工扩大 dataset 的倍率。例如,如果训练数据集有15张图,设置 dataset_enlarge_ratio 为100,那么程序会重复读取这些图片100次,这样一个 epoch 下来,便会读取1500张图 (事实上是重复读的)。这个方法经常用来加速 data loader, 因为在有的机器上,一个 epoch 结束,会重启进程,导致拖慢训练
  • prefetch_mode:默认为None,即~。cpu表示使用CPU prefetcher。cuda表示使用cuda prefetcher,它会占用GPU显存,在gpu模式下,一定需要设置pin_memory=True
  • 上述配置文件中,使用了两个验证数据集validation sets,通过关键字val,val_2进行区分,如果需要更多的validation sets,可以继续添加val_3、val_4 ...

1.3 网络结构配置

# network structures - 网络结构的设置
network_g:
  type: RRDBNet # 网络结构 (Architecture) 的类型
  num_in_ch: 3 # 模型输入的图像通道数
  num_out_ch: 3 # 模型输出的图像通道数
  num_feat: 64 # 模型内部的 feature map 通道数
  num_block: 23 # 模型内部基础模块的堆叠数
  num_grow_ch: 32

network_d:
  type: UNetDiscriminatorSN # 网络结构 (Architecture) 的类型
  num_in_ch: 3 # 模型输入的图像通道数
  num_feat: 64 # 模型内部的 feature map 通道数
  skip_connection: True

1.4 模型路径配置

# path 配置
path: # 以下为路径和与训练模型、重启训练的设置
    pretrain_network_g: ~ # 预训练模型的路径, 需要以 pth 结尾的模型
    param_key_g: params # 读取的预训练的参数 key。若需要使用 EMA 模型,需要改成params_ema
    strict_load_g: true # 是否严格地根据参数名称一一对应 load 模型参数。如果选择false,那么模型对于找不到的参数,会随机初始化;如果选择 true,假如存在不对应的参数,会报错提示
    resume_state: ~ # # 重启训练的 state 路径, 在experiments/exp_name/training_states 目录下

注意配置信息:

  • resume_state设置后, 会覆盖 pretrain_network_g 的设定

1.5 训练策略相关配置

# training settings
train: # 这块是训练策略相关的配置
    ema_decay: 0.999 # EMA 更新权重
    optim_g: # 这块是优化器的配置
        type: Adam # 选择优化器类型,例如 Adam
        # 以下属性是灵活的, 根据不同优化器有不同的设置
        lr: !!float 2e-4 # 初始学习率
        weight_decay: 0 # 权重衰退参数
        betas: [0.9, 0.99] # Adam 优化器的 beta1 和 beta2
    scheduler: # 这块是学习率调度器的配置
        type: CosineAnnealingRestartLR # 选择学习率更新策略
        # 以下属性是灵活的, 根据学习率 Scheduler 的不同有不同的设置
        periods: [250000, 250000, 250000, 250000] # Cosine Annealing 的更新周期
        restart_weights: [1, 1, 1, 1] # Cosine Annealing 每次 Restart 的权重
        eta_min: !!float 1e-7 # 学习率衰退到的最小值
    total_iter: 1000000 # 总共进行的训练迭代次数
    warmup_iter: -1 # warm up 的迭代次数, 如是-1, 表示没有 warm up

    # losses - 这块是损失函数的设置
    pixel_opt: # loss 名字,这里表示 pixel-wise loss 的 options
        type: L1Loss # 选择 loss 函数,例如 L1Loss
        # 以下属性是灵活的, 根据不同损失函数有不同的设置
        loss_weight: 1.0 # 指定 loss 的权重
        reduction: mean # loss reduction 方式

注意配置信息:

  • 常见的优化器(optimizer)的定义可以在 models/base_model.py 文件中的get_optimizer函数中找到。
  • 常见的损失函数可以在losses目录中定义

1.6 验证相关配置

# validation settings
val: # 这块是 validation 的配置
    val_freq: !!float 5e3 # validation 频率, 每隔 5000 iterations 做一次validation
    save_img: false # 否需要在 validation 的时候保存图片

    metrics: # 这块是 validation 中使用的指标的配置
        psnr: # metric 名字, 这个名字可以是任意的
            type: calculate_psnr # 选择指标类型
            # 以下属性是灵活的, 根据不同 metric 有不同的设置
            crop_border: 4 # 计算指标时 crop 图像边界像素范围 (不纳入计算范围)
            test_y_channel: false # 是否转成在 Y(CbCr) 空间上计算
            better: higher # 该指标是越高越好,还是越低越好。选择 higher或者lower,默认为 higher
        niqe: # 这是在 validation 中使用的另外一个指标
            type: calculate_niqe
            crop_border: 4
            better: lower # the lower, the better

注意配置信息:

  • 验证指标可在 basicsr/metrics 目录中进行定义
  • 支持在 validation 时使用多个指标,只需要在配置文件中添加配置,比如上面的 psnr 和 niqe

1.7 训练日志相关配置

# logging settings
logger: # 这块是 logging 的配置
    print_freq: 100 # 多少次迭代打印一次训练信息
    save_checkpoint_freq: !!float 5e3 # 多少次迭代保存一次模型权重和训练状态
    use_tb_logger: true # 是否使用 tensorboard logger
    wandb: # 是否使用 wandb logger
        project: ~ # wandb 的 project名字。 默认是 None, 即不使用 wandb
        resume_id: ~ # 如果是 resume, 可以输入上次的 wandb id, 则 log 可以接起来

1.8 分布式训练配置

# dist training settings
dist_params: # distributed training 的设置, 目前只在 Slurm 训练下才需要
    backend: nccl
    port: 29500

2 BasicSR测试配置文件各配置项详细说明

# general settings
name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb # 实验名称
model_type: SRModel # 使用的 model 类型
scale: 4 # 输出比输入的倍数, 在SR中是放大倍数; 若有些任务没有这个配置, 则写1
num_gpu: 1 # 测试卡数
manual_seed: 0 # 指定随机种子

# test dataset settings
datasets:
    test_1: # 测试数据集的设置, 后缀1表示第一个测试集
        name: Set5 # 数据集的名称
        type: PairedImageDataset # 读取数据的 Dataset 类
        # GT 和输入 LQ 的根目录
        dataroot_gt: datasets/Set5/GTmod12
        dataroot_lq: datasets/Set5/LRbicx4
        io_backend: # IO 读取的 backend
            type: disk # disk 表示直接从硬盘读取
    test_2: # 测试数据集的设置, 后缀2表示第二个测试集
        name: Set14
        type: PairedImageDataset
        dataroot_gt: datasets/Set14/GTmod12
        dataroot_lq: datasets/Set14/LRbicx4
        io_backend:
            type: disk

# network structures - 网络结构的设置
network_g: # 网络 g 的设置
    type: MSRResNet # 网络结构 (Architecture) 的类型
    num_in_ch: 3
    num_out_ch: 3
    num_feat: 64
    num_block: 16
    upscale: 4

# path
path:
    pretrain_network_g: experiments/pretrained_models/net_g_1000000.pth # 预训练模型的路径, 需要以 pth 结尾的模型
    param_key_g: params # 读取的预训练的参数 key。若需要使用 EMA 模型,需要改成params_ema
    strict_load_g: true # 加载预训练模型时, 是否需要网络参数的名称严格对应

# validation settings - 以下为Validation (也是测试)的设置
val:
    save_img: true # 是否需要在测试的时候保存图片
    suffix: ~ # 对保存的图片添加后缀,如果是 None, 则使用 exp name

    metrics: # 测试时候使用的 metric
        psnr: # metric 名字, 这个名字可以是任意的
            type: calculate_psnr # 选择指标类型
            # 以下属性是灵活的, 根据不同 metric 有不同的设置
            crop_border: 4 # 计算指标时 crop 图像边界像素范围 (不纳入计算范围)
            test_y_channel: false # 是否转成在 Y(CbCr) 空间上计算
            better: higher # the higher, the better. Default: higher
        ssim:
            type: calculate_ssim
            crop_border: 4
            test_y_channel: false
            better: higher

参考链接