avatar
文章
23
标签
8
分类
0
首页
音乐
照片
友链
说说
关于
LogoLuckyMNIST手写数字识别总结(pytorch)
搜索
首页
音乐
照片
友链
说说
关于

MNIST手写数字识别总结(pytorch)

发表于2022-02-16|更新于2025-02-26
|总字数:1.3k|阅读时长:6分钟

此博客并不是教程,只是一个练习总结

代码汇总放在文末

1.首先导入所需要的库

1
2
3
4
5
6
7
8
9
10
11
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision

from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

from Util import plot_image,pd_one_hot #辅助函数,在博客末尾附上

2.数据集

此数据集总共包含70K张图片,其中60K作为训练集,10K作为测试集。
更多消息可以查看官网官网链接:官网

3.加载数据

batch_size设置一次处理多少图片,此处设置为512张图片,这样并行处理可以cpu,gpu加快处理速度

1
batch_size = 512

加载训练集,测试集图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data' #数据集文件夹名
,train=True
,download=True #当电脑没此数据的时候会自动下载数据集
,transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor() #矩阵转化为张量
,torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
)
,batch_size=batch_size
,shuffle=True # 设置随机打散
)

test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data'
,train=False
,download=True
,transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
,torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
),
batch_size=batch_size
,shuffle=False)

4.数据可视化

只查看9张图片,可在辅助函数内修改为其它值

1
2
3
4
x, y = next(iter(train_loader))
print(x.shape, y.shape) # 查看数据集大小
plot_image(x, y, 'image sample')
torch.Size([512, 1, 28, 28]) torch.Size([512])

注:
512, 1, 28, 28:四维矩阵,512张图片,1个通道,大小为28*28
1个通道的意思为单色,若改为3则是RGB彩色

null

5.定义神经网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Net(nn.Module):

def __init__(self):
super(Net , self).__init__()

self.fc1 = nn.Linear(28*28 , 256) # 输入和输出的维度,根据经验自己设置
self.fc2 = nn.Linear(256 , 64) # 输入维度要等于上层的输出维度
self.fc3 = nn.Linear(64 , 10) # 数字结果为0~9,所以最后输出值为10个维度

def forward(self , x):
x = F.relu(self.fc1(x)) # relu 激活函数
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

6.训练测试集

1.初始化网络

1
net = Net()

2.设置学习率

1
optimizer = optim.SGD(net.parameters() , lr = 0.01 , momentum= 0.9)

3.迭代

此处没有调用GPU处理数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
loss_s = [ ]   # 存储损失值
for each in range(3): # 迭代三次
for location , (x,y) in enumerate(train_loader):
x = x.view(x.size(0) , 28*28) #将图片矩阵打平
out = net(x)
y_onehot = pd_one_hot(y)

loss = F.mse_loss(out , torch.from_numpy(y_onehot).float())

# 清零梯度
optimizer.zero_grad()
# 计算梯度
loss.backward()
# 更新梯度
optimizer.step()

if(location % 5 == 0): # 每处理5*512张图片记录一次损失函数值
loss_s.append(loss.item()) # .item的意思为只输出值
print('第' , each+1 , '次迭代完成')
第 1 次迭代完成
第 2 次迭代完成
第 3 次迭代完成

4.查看损失值

1
2
plt.plot(range(len(loss_s)) , loss_s , 'y')
plt.show()

损失值递减且趋于稳定,训练过程正确
null

7.预测训练集

1
2
3
4
5
6
7
8
9
10
11
# 存储预测正确图片的数量
total_correct = 0
for x,y in test_loader:
x = x.view(x.size(0), 28*28)
out = net(x)
pred = out.argmax(dim=1) # 取最大值概率所在的位置
correct = pred.eq(y).sum().float().item()
total_correct += correct

print('正确率:' , total_correct/len(test_loader.dataset))
正确率: 0.8903

8.查看测试集图片

1
2
x , y = next(iter(test_loader))
plot_image(x , y , 'test')

null




Util.py代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

def plot_image(img, label, name):
fig = plt.figure()
for i in range(9):
plt.subplot(3, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()


def pd_one_hot(y):
y = y.reshape(-1 , 1)
y = pd.Series(y) # 使用pandas的one-hot处理
y= y.astype(str)
y = pd.get_dummies(y)
return y.values



正文代码汇总

项目github链接: github.com/2979083263/mnist

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision

from matplotlib import pyplot as plt
import pandas as pd
import numpy as np


from Util import plot_curve,plot_image,one_hot,pd_one_hot


batch_size = 512

train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data'
, train=True
,download=True
,transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor() #矩阵转化为张量
,torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
)
,batch_size=batch_size
,shuffle=True # 设置随机打散
)

test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('mnist_data'
,train=False
,download=True
,transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
,torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
),
batch_size=batch_size
,shuffle=False)


x, y = next(iter(train_loader)) #暂时看作迭代
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')


class Net(nn.Module):

def __init__(self):
super(Net , self).__init__()

self.fc1 = nn.Linear(28*28 , 256) #输入和输出的维度
self.fc2 = nn.Linear(256 , 64)
self.fc3 = nn.Linear(64 , 10)

def forward(self , x):

x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)

return x

net = Net()


optimizer = optim.SGD(net.parameters() , lr = 0.01 , momentum= 0.9)


loss_s = [ ]


for each in range(3):
for location , (x,y) in enumerate(train_loader):
x = x.view(x.size(0) , 28*28)
out = net(x)
y_onehot = pd_one_hot(y)

loss = F.mse_loss(out , torch.from_numpy(y_onehot).float())

# 清零梯度
optimizer.zero_grad()
# 计算梯度
loss.backward()
# 更新梯度
optimizer.step()

if(location % 5 == 0):
loss_s.append(loss.item())
print('第' , each+1 , '次迭代完成')

plt.plot(range(len(loss_s)) , loss_s , 'y')
plt.show()


# 存储正确的数量
total_correct = 0


for x,y in test_loader:
x = x.view(x.size(0), 28*28)
out = net(x)
# out: [b, 10] => pred: [b]
pred = out.argmax(dim=1)
correct = pred.eq(y).sum().float().item()
total_correct += correct

print('正确率:' , total_correct/len(test_loader.dataset))

x , y = next(iter(test_loader))
plot_image(x , y , 'test')
文章作者: 刘同学
文章链接: https://mouhorse.github.io/2022-02-16/MNIST%E6%89%8B%E5%86%99%E6%95%B0%E5%AD%97%E8%AF%86%E5%88%AB%E6%80%BB%E7%BB%93-pytorch/
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来源 Lucky!
Python
赞助
  • wechat
    wechat
  • alipay
    alipay
cover of previous post
上一篇
pytorch基础
持续补充 1import torch 1.随机相关 123456789a = torch.randn(2,3) 正态分布b = torch.rand(2,3) 0~1范围内随机c = torch.rand_like(a) 模仿a的形状生成随机矩阵d = torch.randint(1,10,[2,3,4]) 在1~10内生成形状为(2,3,4)的矩阵a = a.cuda() 将数据加载到gpu内 2.查看数据形状 12345a.type() 输出数据类型a.size() == a.shape 输出张量形状a.dim() 输出维度 3.生成矩阵 1234567891011121314151617181920212223242526272829303132a = torch.full([3,4,5],6) (3,4,5)形状矩阵,用6填满tensor([[[6., 6., 6., 6., 6.], [6., 6., 6., 6., 6.], [6., 6., 6., 6., 6.], ...
cover of next post
下一篇
[本科毕设]pytorch-人脸表情识别
两年前的存货 ==不提供源码,以后某天可能会开源到github上,本文只是向你提供我的思路,自己动手丰衣足食==此作品诞生于公元2022年,天临四年,卢雷元年😅😅 基于卷积神经网络的人脸表情识别概述功能实现了对图片、视频和摄像头三种情况下的人脸表情进行检测。可以检测出七种表情:[‘生气’, ‘厌恶’, ‘害怕’, ‘开心’, ‘自然’, ‘伤心’,...
相关推荐
cover
2021-02-26
pytorch基础
持续补充 1import torch 1.随机相关 123456789a = torch.randn(2,3) 正态分布b = torch.rand(2,3) 0~1范围内随机c = torch.rand_like(a) 模仿a的形状生成随机矩阵d = torch.randint(1,10,[2,3,4]) 在1~10内生成形状为(2,3,4)的矩阵a = a.cuda() 将数据加载到gpu内 2.查看数据形状 12345a.type() 输出数据类型a.size() == a.shape 输出张量形状a.dim() 输出维度 3.生成矩阵 1234567891011121314151617181920212223242526272829303132a = torch.full([3,4,5],6) (3,4,5)形状矩阵,用6填满tensor([[[6., 6., 6., 6., 6.], [6., 6., 6., 6., 6.], [6., 6., 6., 6., 6.], ...
cover
2018-08-21
3D图像颜色
viridis 翡翠色 #中文会慢慢给出,也会逐渐整理 只有[ ‘ ‘]内的代码才是颜色。# # cmaps = [('Perceptually Uniform Sequential', # # ['viridis', 'inferno', 'plasma', 'magma']), # # ('Sequential', ['Blues', 'BuGn', 'BuPu', # # 'GnBu', 'Greens', 'Greys', 'Oranges', 'OrRd', # # 'PuBu', 'PuBuGn', 'PuRd', 'Purples',...
cover
2018-12-28
Python函数
函数 之前由于水平不足博客写的很不好,现于2019/12/17进行大更改。 1.内置函数 ①python与其他编程语言一样有许多内置函数,我们可以直接调用使用,只需要知道参数和函数名。内置函数有很多,需要的话自行百度了解 ②利用内置函数进行数据类型转换,int,float… ③可以把函数名复制给一个变量,然后用这个变量也可以调用该函数。 a=abs    #用a代替abs函数a(-1)    #调用取绝对值函数 1     #产生结果与直接调用abs函数相同 ==2.用户自定函数== ①用def语句定义函数,后面跟函数名,括号,括号里写参数名,还要加冒号。用return语句返回。没有return语句函数也可以,只不过返回值为none。 ②python与c一样,也可以设置空函数,等到以后再补充内容,需要用pass语句。 def nop()...
cover
2018-12-15
Python基础
之前由于水平不足博客写的很不好,现于2019/11/30进行大更改。 python1.基础①代码中单引号与双引号效果一样,不像其他语言一样有所区别,但是要成对使用,不能落单。  ②命名可以用数字,字母,下划线,汉字,不能以数字开头。注意不要和关键字重名。  ③字符串:每个元素的位置可以正着数,也可以倒着数。正着数从0开始,倒着从-1开始。 字符串访问区间方式:[m,n],访问从m开始到n的字符串内元素(不包括n,左闭右开) m,n可以是正着数的,也可以是倒着数的,还可以混合用。  ④交换x,y的值:x,y=y,x(比c简单多了)   ⑤#后面跟的是注释,大程序中注释很重要,便于自己和别人的理解。  ⑥为了简化避免造成机器误解,Python还允许用 r’ ‘表示 ‘ ‘ 内部的字符串默认不转义,可以自己试试:  (1)输入:>>> print( ‘ \\\t\\ ‘)     输出:\    \  (2)输入:>>> print ( r ‘ \\\t\\‘)     输出:\\\t\\ 2.输出...
avatar
刘同学
欢迎光临我的博客
文章
23
标签
8
分类
0
Follow Me
公告
欢迎来到我的博客!
可以交换友链
联系方式:485182274@qq.com
目录
  1. 1. 1.首先导入所需要的库
  2. 2. 2.数据集
  3. 3. 3.加载数据
  4. 4. 4.数据可视化
  5. 5. 5.定义神经网络
  6. 6. 6.训练测试集
    1. 6.1. 1.初始化网络
    2. 6.2. 2.设置学习率
    3. 6.3. 3.迭代
    4. 6.4. 4.查看损失值
  7. 7. 7.预测训练集
  8. 8. 8.查看测试集图片
  9. 9. Util.py代码
  10. 10. 正文代码汇总
最新文章
Hexo本地与云端布局不同处理办法
Hexo本地与云端布局不同处理办法2025-02-25
解决 Hexo 部署到 GitHub Pages 自定义域名失效
解决 Hexo 部署到 GitHub Pages 自定义域名失效2025-02-24
Butterfly 个性化配置教程
Butterfly 个性化配置教程2025-02-23
Hexo安装并修改主题
Hexo安装并修改主题2025-02-23
MNE脑电预处理
MNE脑电预处理2024-10-07
©2018 - 2025 By 刘同学
框架 Hexo 7.3.0|主题 Butterfly 5.3.3
活出个样子给自己看
搜索
数据加载中