目录
一、相关库的引入与解析
二、数据预处理-----1.图像格式转换 2.特征缩放,参数归一化
三、1、加载出MNIST数据集 2、创建数据加载器(包括训练集和测试集)
四、神经网络模型创建(类)----包括全连接层(属性特征)和向前传播(行为)
附:softmax()与log_softmax()
五、调用神经网络,选择损失函数,优化器(min损失函数)
六、模型训练与保存
附:MNIST数据集的形成以及前面的处理操作小总结:
七、模型的评估测试
八、预测自己提供的数字(使用模型)
九、结构拼接
#import
#import···as···更名操作
#from··(大库)··import··(子模块)··(from PIL import Image)
Image.open (打开图像文件)
Image.resize(调整图像大小)
Image.convert (图像格式转换(灰度图像))
(让各个参数的取值范围差不多,便于选择合适的步长,学习率)(对象MNIST)
#transforms.Compose() ------ 可以组合几个变换一起操作
#transforms.ToTensor() ------ 将原始的PILImage格式或者numpy.array格式转为张量,并把灰度范围从0-255变换到(0,1)之间
参数是无吗
#transform.Normalize(mean所处理数据的原始平均值, std所处理数据的原始标准偏差) -----Z-score归一化,把(0,1)变换到(-1,1),经该处理后数据会变为均值为 0,标准差为 1的标准化像素值,每个值x被转换为(x - mean) / std,接下来将用于处理MNIST数据集,使数据在0附近居中,使其具有指定的平均值和标准偏差,确保数据分布会更加符合标准正态分布,可改善训练过程和收敛性,更适合于神经网络的训练。
·均值为 0,标准差为 1意义: 零中心化: 将数据的均值移动到零,这样可以更好地匹配神经网络的激活函数(如 ReLU)的工作区间。 单位标准差: 将数据缩放到标准差为 1,使得数据具有一致的尺度,防止某些特征对训练过程的影响过大。
用于在训练神经网络时迭代数据。
结构:MNIST(参数一:指定存储路径,参数二:加载的是否是训练集,参数三:对图像的转换操作,参数四:若数据集不存在是否从互联网下载) 参数三:应用于每个图像的转换函数。常用的转换包括将图像转换为 PyTorch 张量 ()和标准化图像()
结构:参数一:参数二:提高模型的泛化能力参数三:
DataLoader
创建一个数据加载器
(函数应该返回的)
二维图像张量
(矩阵)
x=x.view(-1,28*28)
二维:
将输入张量 重新变形为大小为 的张量。 表示自动计算维度大小, 表示将 28x28 的图像扁平化为 784 维。
调整张量的形状,不改变其数据
。 表示自动计算维度的大小,使总元素数量保持不变。这里 会自动计算为 。
(batch_size, 1
(通道数)
, 28, 28)(batch_size, 784)
!!!!!!!在模型的前向传播过程中,当输入的特征向量经过 self.layer
(20,10)后,
输出是一个二维张量,其形状为 :
实例
基本的优化算法,每次迭代时使用当前批次数据的梯度更新模型参数。计算相对简单,适合大规模模型。可以实现更精细的控制。易陷入局部最优点,特别是对于高度非凸的损失函数。需要
手动调节学习率和学习率衰减策略。
Adam 是一种
自适应学习率的优化算法,结合了动量梯度下降和 RMSProp 算法。可能需要调整默认参数(如学习率)以获得最佳性能。
enumerate: 内置函数
使用enumerate 遍历train_loader
在遍历迭代对象时,
同时获取
当前元素的索引
和值
。返回的结果是一个包含索引和数据的元组
。
。
创建数据集:
创建数据加载器:
遍历数据加载器:
这里的output是十个值中的一个。
probability,predict=torch.max(output.data,dim=1)
: 返回一个
元组
,第一个为最大概率值,第二个为最大值的下标。
((
))
将图像调整到 28x28 像素,与 MNIST 数据集的标准图像尺寸匹配。image = image.view(1, 1, 28, 28),图像张量调整为适合输入到 CNN 模型中的形状。 表示批量大小(batch size),即一次输入一张图片。 表示通道数(channel),因为手写数字图像是灰度图像,所以通道数为 1。 表示图像的宽和高。
probability, predict = torch.max(output.data, dim=1) #同前
plt.title() 用于设置 Matplotlib 图的标题。标题将显示手写数字图像的预测值。标题使用的字体将为“SimHei”,显示中文字符。.format是 Python 字符串格式化的方法之一,用于将变量的值插入到字符串中的指定位置
( image.view() image.squeeze())!!!!!!!!!squeeze()删除一个维度
if __name__ == '__main__':的作用
一个python文件有两种使用方法,第一作为脚本直接执行,第二是 import 到其他的 python 脚本中被调用。 if __name__ == 'main': 下的代码在第一种情况下会被执行,而 import 到其他脚本中是不会被执行。
版权声明:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若内容造成侵权、违法违规、事实不符,请将相关资料发送至xkadmin@xkablog.com进行投诉反馈,一经查实,立即处理!
转载请注明出处,原文链接:https://www.xkablog.com/cjjbc/24045.html