1 BasicSR训练配置文件各配置项详细说明
本文内容来源于BasicSR-docs ,这里只做摘抄,这个仓库对BasicSR目录层次,代码结构、文件配置的详细说明文档,如果你正在学习和了解BasicSR,我推荐可以先详细阅读这个文档。
1.1 通用配置
# general settings - 这块为通用设置
name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb # 实验名称, 若实验名字中有debug字样, 则会进入debug模式
model_type: SRModel # 使用的 model 类型,模型定义可为你自定义的模型,或者在BasicSR的models目录中
scale: 4 # 输出比输入的倍数, 在SR中是放大倍数; 若有些任务没有这个配置, 则写1
num_gpu: 1 # 指定使用的 GPU 卡数
manual_seed: 0 # 指定随机种子
注意
- num_gpu:0 表示使用CPU,auto表示自动从可用GPU块数推断
1.2 dataset
# dataset and data loader settings
datasets: # 这块是 dataset 的配置
train: # 训练 dataset 的配置
name: DIV2K # 自定义的数据集名称
type: PairedImageDataset # 读取数据的 Dataset 类
dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub # GT (Ground-Truth) 图像的文件夹路径
dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub # LQ (Low-Quality) 输入图像的文件夹路径
meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt # 预先生成的 meta_info 文件
# (for lmdb)
# dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
# dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
filename_tmpl: '{}' # 文件名字模板, 一般LQ文件会有类似 '_x4' 这样的文件后缀, 这个就是来处理GT和LQ文件后缀不匹配的问题的
io_backend: # IO 读取的 backend
type: disk # disk 表示直接从硬盘读取
# (for lmdb)
# type: lmdb
gt_size: 128 # 训练阶段裁剪 (crop) 的GT图像的尺寸大小,即训练的 label 大小
use_hflip: true # 是否开启水平方向图像增强 (随机水平翻转图像)
use_rot: true # 是否开启旋转图像增强 (随机旋转图像)
# data loader - 下面这块是 data loader 的设置
num_worker_per_gpu: 6 # 每一个 GPU 的 data loader 读取进程数目
batch_size_per_gpu: 16 # 每块 GPU 上的 batch size
dataset_enlarge_ratio: 100 # # 放大 dataset 的长度倍数 (默认为1)。可以扩大一个 epoch 所需 iterations
prefetch_mode: ~ # 预先读取数据的方式
val: # 第一个validation 数据集的设置
name: Set5 # 数据集名称
type: PairedImageDataset # 数据集的类型
dataroot_gt: datasets/Set5/GTmod12
dataroot_lq: datasets/Set5/LRbicx4
io_backend: # IO 读取的 backend
type: disk # disk 表示直接从硬盘读取
val_2: # 第二个validation 数据集的设置
name: Set14 # 数据集名称
type: PairedImageDataset # 数据集的类型
dataroot_gt: datasets/Set14/GTmod12
dataroot_lq: datasets/Set14/LRbicx4
io_backend: # IO 读取的 backend
type: disk # disk 表示直接从硬盘读取
注意的配置项:
- dataset_enlarge_ratio:它代表了手工扩大 dataset 的倍率。例如,如果训练数据集有15张图,设置 dataset_enlarge_ratio 为100,那么程序会重复读取这些图片100次,这样一个 epoch 下来,便会读取1500张图 (事实上是重复读的)。这个方法经常用来加速 data loader, 因为在有的机器上,一个 epoch 结束,会重启进程,导致拖慢训练
- prefetch_mode:默认为None,即~。cpu表示使用CPU prefetcher。cuda表示使用cuda prefetcher,它会占用GPU显存,在gpu模式下,一定需要设置pin_memory=True
- 上述配置文件中,使用了两个验证数据集validation sets,通过关键字val,val_2进行区分,如果需要更多的validation sets,可以继续添加val_3、val_4 ...
1.3 网络结构配置
# network structures - 网络结构的设置
network_g:
type: RRDBNet # 网络结构 (Architecture) 的类型
num_in_ch: 3 # 模型输入的图像通道数
num_out_ch: 3 # 模型输出的图像通道数
num_feat: 64 # 模型内部的 feature map 通道数
num_block: 23 # 模型内部基础模块的堆叠数
num_grow_ch: 32
network_d:
type: UNetDiscriminatorSN # 网络结构 (Architecture) 的类型
num_in_ch: 3 # 模型输入的图像通道数
num_feat: 64 # 模型内部的 feature map 通道数
skip_connection: True
1.4 模型路径配置
# path 配置
path: # 以下为路径和与训练模型、重启训练的设置
pretrain_network_g: ~ # 预训练模型的路径, 需要以 pth 结尾的模型
param_key_g: params # 读取的预训练的参数 key。若需要使用 EMA 模型,需要改成params_ema
strict_load_g: true # 是否严格地根据参数名称一一对应 load 模型参数。如果选择false,那么模型对于找不到的参数,会随机初始化;如果选择 true,假如存在不对应的参数,会报错提示
resume_state: ~ # # 重启训练的 state 路径, 在experiments/exp_name/training_states 目录下
注意配置信息:
- resume_state设置后, 会覆盖 pretrain_network_g 的设定
1.5 训练策略相关配置
# training settings
train: # 这块是训练策略相关的配置
ema_decay: 0.999 # EMA 更新权重
optim_g: # 这块是优化器的配置
type: Adam # 选择优化器类型,例如 Adam
# 以下属性是灵活的, 根据不同优化器有不同的设置
lr: !!float 2e-4 # 初始学习率
weight_decay: 0 # 权重衰退参数
betas: [0.9, 0.99] # Adam 优化器的 beta1 和 beta2
scheduler: # 这块是学习率调度器的配置
type: CosineAnnealingRestartLR # 选择学习率更新策略
# 以下属性是灵活的, 根据学习率 Scheduler 的不同有不同的设置
periods: [250000, 250000, 250000, 250000] # Cosine Annealing 的更新周期
restart_weights: [1, 1, 1, 1] # Cosine Annealing 每次 Restart 的权重
eta_min: !!float 1e-7 # 学习率衰退到的最小值
total_iter: 1000000 # 总共进行的训练迭代次数
warmup_iter: -1 # warm up 的迭代次数, 如是-1, 表示没有 warm up
# losses - 这块是损失函数的设置
pixel_opt: # loss 名字,这里表示 pixel-wise loss 的 options
type: L1Loss # 选择 loss 函数,例如 L1Loss
# 以下属性是灵活的, 根据不同损失函数有不同的设置
loss_weight: 1.0 # 指定 loss 的权重
reduction: mean # loss reduction 方式
注意配置信息:
- 常见的优化器(optimizer)的定义可以在 models/base_model.py 文件中的get_optimizer函数中找到。
- 常见的损失函数可以在losses目录中定义
1.6 验证相关配置
# validation settings
val: # 这块是 validation 的配置
val_freq: !!float 5e3 # validation 频率, 每隔 5000 iterations 做一次validation
save_img: false # 否需要在 validation 的时候保存图片
metrics: # 这块是 validation 中使用的指标的配置
psnr: # metric 名字, 这个名字可以是任意的
type: calculate_psnr # 选择指标类型
# 以下属性是灵活的, 根据不同 metric 有不同的设置
crop_border: 4 # 计算指标时 crop 图像边界像素范围 (不纳入计算范围)
test_y_channel: false # 是否转成在 Y(CbCr) 空间上计算
better: higher # 该指标是越高越好,还是越低越好。选择 higher或者lower,默认为 higher
niqe: # 这是在 validation 中使用的另外一个指标
type: calculate_niqe
crop_border: 4
better: lower # the lower, the better
注意配置信息:
- 验证指标可在 basicsr/metrics 目录中进行定义
- 支持在 validation 时使用多个指标,只需要在配置文件中添加配置,比如上面的 psnr 和 niqe
1.7 训练日志相关配置
# logging settings
logger: # 这块是 logging 的配置
print_freq: 100 # 多少次迭代打印一次训练信息
save_checkpoint_freq: !!float 5e3 # 多少次迭代保存一次模型权重和训练状态
use_tb_logger: true # 是否使用 tensorboard logger
wandb: # 是否使用 wandb logger
project: ~ # wandb 的 project名字。 默认是 None, 即不使用 wandb
resume_id: ~ # 如果是 resume, 可以输入上次的 wandb id, 则 log 可以接起来
1.8 分布式训练配置
# dist training settings
dist_params: # distributed training 的设置, 目前只在 Slurm 训练下才需要
backend: nccl
port: 29500
2 BasicSR测试配置文件各配置项详细说明
# general settings
name: 001_MSRResNet_x4_f64b16_DIV2K_1000k_B16G1_wandb # 实验名称
model_type: SRModel # 使用的 model 类型
scale: 4 # 输出比输入的倍数, 在SR中是放大倍数; 若有些任务没有这个配置, 则写1
num_gpu: 1 # 测试卡数
manual_seed: 0 # 指定随机种子
# test dataset settings
datasets:
test_1: # 测试数据集的设置, 后缀1表示第一个测试集
name: Set5 # 数据集的名称
type: PairedImageDataset # 读取数据的 Dataset 类
# GT 和输入 LQ 的根目录
dataroot_gt: datasets/Set5/GTmod12
dataroot_lq: datasets/Set5/LRbicx4
io_backend: # IO 读取的 backend
type: disk # disk 表示直接从硬盘读取
test_2: # 测试数据集的设置, 后缀2表示第二个测试集
name: Set14
type: PairedImageDataset
dataroot_gt: datasets/Set14/GTmod12
dataroot_lq: datasets/Set14/LRbicx4
io_backend:
type: disk
# network structures - 网络结构的设置
network_g: # 网络 g 的设置
type: MSRResNet # 网络结构 (Architecture) 的类型
num_in_ch: 3
num_out_ch: 3
num_feat: 64
num_block: 16
upscale: 4
# path
path:
pretrain_network_g: experiments/pretrained_models/net_g_1000000.pth # 预训练模型的路径, 需要以 pth 结尾的模型
param_key_g: params # 读取的预训练的参数 key。若需要使用 EMA 模型,需要改成params_ema
strict_load_g: true # 加载预训练模型时, 是否需要网络参数的名称严格对应
# validation settings - 以下为Validation (也是测试)的设置
val:
save_img: true # 是否需要在测试的时候保存图片
suffix: ~ # 对保存的图片添加后缀,如果是 None, 则使用 exp name
metrics: # 测试时候使用的 metric
psnr: # metric 名字, 这个名字可以是任意的
type: calculate_psnr # 选择指标类型
# 以下属性是灵活的, 根据不同 metric 有不同的设置
crop_border: 4 # 计算指标时 crop 图像边界像素范围 (不纳入计算范围)
test_y_channel: false # 是否转成在 Y(CbCr) 空间上计算
better: higher # the higher, the better. Default: higher
ssim:
type: calculate_ssim
crop_border: 4
test_y_channel: false
better: higher
参考链接
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:BasicSR 模型训练和模型测试配置文件配置信息详细说明
原文链接:https://www.stubbornhuang.com/3145/
发布于:2025年03月25日 10:43:55
修改于:2025年03月25日 10:44:39
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
56