1 LeNet5的MindSpore实现
MindSpore技术白皮书中LeNet5网络的MindSpore版本实现,与Pytorch和Tensorflow的版本相比可以让人更快的熟悉MindSpore的使用方式。
以下代码定义以及训练LeNet神经网络的过程。
# -*- coding: utf-8 -*-
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.network.optim import Momentum
from mindspore.train import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset as de
class LeNet5(nn.Cell):
"""
Lenet网络结构
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
if __name__ == '__main__':
ds = de.MnistDataset(dataset_dir="./MNIST_Data")
ds = ds.batch(batch_size=64)
network = LeNet5()
loss = SoftmaxCrossEntropyWithLogits()
optimizer = nn.Momentum(network.trainable_params(),learning_rate=0.1, momentum=0.9)
model = Model(network, loss, optimizer)
model.train(epoch=10, train_dataset=ds)
在上述代码的
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.network.optim import Momentum
from mindspore.train import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
import mindspore.dataset as de
的部分,导入了MindSpore的相关库和模块。
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2)
self.flatten = P.Flatten()
def construct(self,x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
以上代码部分定义了LeNet5的网络结果。__init__
函数实例化了LeNet所用到的所有算子,construct
函数定义了LeNet的计算逻辑。
ds = de.MnistDataset(dataset_dir="./MNIST_Data")
ds = ds.batch(batch_size=64)
以上代码部分从Mnist数据集中读取数据,并生成一个迭代器ds用作训练的输入。
network = LeNet5()
将LeNet5类实例化为network。
loss = SoftmaxCrossEntropyWithLogits()
optimizer = nn.Momentum(network.trainable_params(),learning_rate=0.1, momentum=0.9)
model = Model(network, loss, optimizer)
使用SoftmaxCrossEntropyWithLogits
计算损失loss,并使用momentum优化参数,最后使用定义的损失函数loss和优化器optimizer创建模型。
model.train(epoch=10, train_dataset=ds)
最后使用epoch控制训练迭代次数,调用模型的训练方法,并在每个eval_step对模型进行评估。
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:MindSpore – LeNet5的MindSpore实现
原文链接:https://www.stubbornhuang.com/1979/
发布于:2022年02月23日 15:40:47
修改于:2023年06月26日 20:37:00
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
52