目录
前言
识别方法
手写输入图片:
首先将图片处理成 0/1 序列:
- 图片长 28 宽 28, 组成一个二维序列
0, 0, 1, ..., 1, 0, 0 0, 1, 0, ..., 0, 1, 0 1, 0, 0, ..., 0, 0, 1 . . . . . . 1, 0, 0, ..., 0, 0, 1 0, 1, 0, ..., 0, 1, 0 0, 0, 1, ..., 1, 0, 0
- 然后处理成一个线性,由 28*28 = 784 个 0/1 元素组成的一维矩阵
[0, 0, 1, ..., 1, 0, 0, 0, 1, 0, ..., 0, 1, 0, ..., 0, 0, 1, ..., 1, 0, 0]
- 最后经过线性方程层层求解。
计算方法
设:X 为原图序列 X = [v1, v2, ..., v784]
再将 X 增加一个维度并设为 1。
X = [1, v1, v2, ..., v784] = [1, dx]
之后(推导过程没看懂,总之先记住,日后需要的话可以再看一遍)
W1 = [d1, dx], b1 = [d1], H1 = XW1 + b1 = ... = [1, d1]
W2 = [d2, d1], b2 = [d2], H2 = H1W2 + b2 = ... = [1, d2]
W3 = [10, d2], b3 = [10], H3 = H2W3 + b3 = ... = [1, d3]
经过训练得出一个模型,即为:pred = W3{W2[W1X + b1] + b2} + b3 = H3
接下来我们通过改变 W 和 b 的值得到一个模型 Y,使得 objective = Σ(pred - Y)^2
越小越好。
计算方法
- 将要识别的图片转换成序列 X
- 对于新的 X 使用构建模型时使用的
W1, W2, W3, b1, b2, b3
得到pred = W3{W2[W1X + b1] + b2} + b3
- 得到一个含有 10 个元素的列表:例
[0.1, 0.8, 0.1, ..., 0.1]
表示识别为对应数字的概率(P(0/x) = 0.1, P(1/x) = 0.8, ...
) max(pred) = 0.8, argmax(pred) = 1
手写数字识别
前面我们简单说明了分类算法的流程:
H1 = relu( XW1 + b1)
H2 = relu(H1W1 + b2)
H3 = f(H2W1 + b3)
通常最后一层函数会使用非线性函数进行收敛。
识别流程
- 下载数据
- 建立模型
- 训练模型
- 测试模型
工具函数
创建 utils.py
准备我们会用到的工具函数:
import torch
from matplotlib import pyplot as plt
# 画曲线
def plot_curve(data):
fig = plt.figure()
plt.plot(range(len(data)), data, color='blue')
plt.legend(['value'], loc='upper right')
plt.xlabel('step')
plt.ylabel('value')
plt.show()
# 画图片
def plot_image(img, label, name):
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
plt.title(f"{name}: {label[i].item()}")
plt.xticks([])
plt.yticks([])
plt.show()
# 使用 scatter_ 为 pytorch 实现 one hot 编码功能
def one_hot(label, depth=10):
out = torch.zeros(label.size(0), depth)
idx = torch.LongTensor(label).view(-1, 1)
out.scatter_(dim=1, index=idx, value=1)
return out
训练代码
导入需要的函数:
import torch
# 神经网络相关工作
from torch import nn
# 常用函数
from torch.nn import functional as F
# 优化工具包
from torch import optim
# 视觉处理
import torchvision
# 展示图片使用
from matplotlib import pyplot as plt
# 自定义的工具函数
from utils import plot_image, plot_curve, one_hot
第一步:加载数据包
# 同时处理图片数
batch_size = 512
# 加载训练模型
train_loader = torch.utils.data.DataLoader(
# 加载 mnist 数据集
torchvision.datasets.MNIST(
# 'mnist_data' 表示下载路径
# train=True 这个参数表示下载的是训练文件还是测试文件
# download=True 表示如果本地没有 'mnist_data' 则会从网上下载
'mnist_data', train=True, download=True,
transform=torchvision.transforms.Compose([
# 通常来说下载的文件是 numpy 格式,这里表示转换为 Tensor 格式,此为 torch 处理使用的载体
torchvision.transforms.ToTensor(),
# 正则化过程,像素为 0 or 1,通过 (x-0.1307)/0.3081 可以得到一个在 0 附近均匀分布的数据
# 注释掉这一行性能会变差
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
),
# batch_size 表示一次加载多少图片
# shuffle 表示是否打散
batch_size=batch_size, shuffle=True
)
# 加载测试模型
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
# train=False 表示当前下载的是测试图片
'mnist_data', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
),
# 训练时需要打散,测试时就没必要了
batch_size=batch_size, shuffle=False
)
查看获得的数据
# 首先从训练中取一个样例
sample = next(iter(train_loader))
x, y = sample
# 查看一下我们得到的数据
print(x.shape, y.shape)
# torch.Size([512, 1, 28, 28]) torch.Size([512])
# 512 张图片,1 个通道,28 行,28 列
# 512 是我们上面设置的