Pytorch – torch.cat参数详解与使用
1 torch.cat参数详解与使用 1.1 torch.cat 1.函数形式 torch.cat(tensors, dim=0, *, out=None) → Tensor 2.函数功能 在指定的维度串联指定Tensor序列,所有Tensor都必须具有相同的形状(连接维度除外),或者Tensor为…
- Pytorch
- 2022-07-25
Pytorch – torch.chunk参数详解与使用
1 torch.chunk参数详解与使用 1.1 torch.chunk 1.函数形式 torch.chunk(input, chunks, dim=0) → List of Tensors 2.函数功能 将输入Tensor拆分为特定数量的块。 如果给定维度dim上的Tensor大小不能够被整除,则…
- Pytorch
- 2022-07-22
Pytorch – pad_sequence、pack_padded_sequence、pack_sequence、pad_packed_sequence参数详解与使用
当采用 RNN 训练序列样本数据时,会面临序列样本数据长短不一的情况。比如做 NLP 任务、语音处理任务时,每个句子或语音序列的长度经常是不相同。难道要一个序列一个序列的喂给网络进行训练吗?这显然是行不通的。 为了更高效的进行 batch 处理,就需要对样本序列进行填充,保证各个样本长度相同,在 P…
- Pytorch
- 2022-07-21
Pytorch – nn.Transformer、nn.TransformerEncoderLayer、nn.TransformerEncoder、nn.TransformerDecoder、nn.TransformerDecoder参数详解
1 nn.Transformer 1.1 nn.Transformer定义 1.函数形式 torch.nn.Transformer(d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=20…
- Pytorch
- 2022-07-13
Pytorch – .to()和.cuda()的区别
1 Pytorch中.to()和.cuda()的区别 如果需要指定的设备是GPU则.to()和.cuda()没有区别,如果设备是cpu,则不能使用.cuda()。也就是说.to()既可以指定CPU也可以指定GPU,而.cuda()只能指定GPU。 1.1 .cuda() 1.单GPU os.envi…
- Pytorch
- 2022-07-11
Pytorch – 模型保存与加载以及如何在已保存的模型的基础上继续训练模型
1 模型的保存和加载 1.1 保存与加载整个模型 保存网络的所有模块,代码量少。 但是这种方法缺点是保存模型的时候,序列化的数据被绑定到了特定的类和确切的目录。 这是因为pickle不保存模型类本身,而是保存这个类的路径, 并且在加载的时候会使用。因此, 当在其他项目里使用或者重构的时候,加载模型的…
- Pytorch
- 2022-07-09
深度学习 – 我的深度学习项目代码文件组织结构
1 我的深度学习项目代码文件组织结构 一般来说,深度学习项目需要包含以下内容: 数据集预处理与加载 深度学习模型定义 模型训练 模型推理 根据以上的功能描述,我的深度学习项目代码文件组织结构如下: ├─bin ├─configs ├─data_loader ├─data_preprocess ├─m…
- Pytorch
- 2022-07-02
Pytorch – 为什么要设置随机数种子?
1 Pytorch的随机种子 最近在看一些开源的Pytorch项目时,几乎每一个项目都会设置随机数种子,比如下面这种 class RandomState(object): def __init__(self, seed): torch.set_num_threads(1) torch.backend…
- Pytorch
- 2022-07-01
Pytorch – torch.nn.Conv1d参数详解与使用
1 torch.nn.Conv1d torch.nn.Conv1d主要是对一维输入Tensor应用一维卷积。 如果一维卷积输入为(N,C_{in},L),输出为(N,C_{out},L_{out}),那么这两者的关系可描述为 \operatorname{out}\left(N_{i}, C_{\te…
- Pytorch
- 2022-06-28
Pytorch – 使用torchsummary/torchsummaryX/torchinfo库打印模型结构、输出维度和参数信息
1 torchsummary/torchsummaryX torchsummary Github地址:https://github.com/sksq96/pytorch-summary torchsummaryX Github地址:https://github.com/nmhkahn/torchsu…
- Pytorch
- 2022-06-27
Pytorch – 内置的LSTM网络torch.nn.LSTM参数详解与使用示例
1 torch.nn.LSTM torch.nn.LSTM是pytorch内置的LSTM模块。 对于torch.nn.LSTM输入序列的每一个元素,都使用以下经典的LSTM计算过程: \begin{array}{c} i_{t}=\sigma\left(W_{i i} x_{t}+b_{i i}+W…
- Pytorch
- 2022-06-26
Pytorch – 内置的CTC损失函数torch.nn.CTCLoss参数详解与使用示例
CTC(Connectionist Temporal Classification)主要是处理不定长序列对齐问题,而CTCLoss主要是计算连续未分段的时间序列与目标序列之间的损失。CTCLoss对输入与目标可能对齐的概率求和,产生一个相对于每个输入节点可微分的损失值。假设输入到目标的对应关系是“多…
- Pytorch
- 2022-06-21