1. 加载数据

import numpy as np
import random
import struct

import torch
import torch.functional as F
import torch.nn as nn
from torch import optim

# 1. 加载数据
code2type = {0x08: 'B', 0x09: 'b', 0x0B: 'h', 0x0c: 'i', 0x0D: 'f', 0x0E: 'd'}


def readMatrix(filename):
    with open(filename, 'rb') as f:
        data_buff = f.read()
        off_set = 0
        file_head_fmt = '>HBB'
        _, elem_code, dimlen = struct.unpack_from(file_head_fmt, data_buff, off_set)
        off_set += struct.calcsize(file_head_fmt)

        # {}是字符串的占位符
        file_head_fmt = '>{}I'.format(dimlen)
        shapes = struct.unpack_from(file_head_fmt, data_buff, off_set)
        off_set += struct.calcsize(file_head_fmt)

        # 矩阵的维度的连乘,代表一共有几个元素值
        data_fmt = '>' + str(np.prod(shapes)) + code2type[elem_code]
        matrix = struct.unpack_from(data_fmt, data_buff, off_set)
        matrix = np.reshape(matrix, shapes).astype(code2type[elem_code])
    return matrix


def dataReader(img_file, label_file, batch_size=24, drop_last=False):
    # (60000, 28, 28)
    mnist_matrix = readMatrix(img_file)
    # (60000)
    mnist_label = readMatrix(label_file)

    # 将图片和标签并起来
    # [(图片1, 标签1), (图片2, 标签2), ...]
    buff = []
    # for i in range(mnist_label.shape[0]):
    #     buff.append((mnist_matrix[i,:], int(label_file[i])))
    buff = list(zip(mnist_matrix, mnist_label))

    def batch_reader():
        # 打乱数据
        random.shuffle(buff)
        b = []
        for sample in buff:
            b.append(sample)
            if len(b) == batch_size:
                yield b
                b = []
        if drop_last and len(b) != 0:
            yield b

    return batch_reader

2.1 定义全连接神经网络

class MnistLinearNet(nn.Module):
    def __init__(self):
        super(MnistLinearNet, self).__init__()
        # Defining the layers, 128, 64, 10 units each
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        # Output layer, 10 units - one for each digit
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        # Forward pass through the network, returns the output logits
        x = x.reshape((-1, 28 * 28))
        print('x.shape', x.shape)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.softmax(x, dim=1)
        return x

2.2 定义卷积神经网络

class MnistConvNet(nn.Module):
    def __init__(self):
        super(MnistConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)  # n c h w = [n 16 4 4]
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# lr = 0.1, momentum = 0.9
def LeNet2():
    net = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
        nn.AvgPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=2, stride=2), nn.Sigmoid(),
        nn.AvgPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
        nn.Linear(120, 84), nn.Sigmoid(),
        nn.Linear(84, 10)
    )
    return net

3. 训练网络

def train(loader, num_epoch=2, net_cls='LeNet'):
    lr = 0.001
    momentum = 0.9
    if net_cls == 'LeNet':
        model = LeNet()
    elif net_cls == 'LinearNet':
        model = MnistLinearNet()
    elif net_cls == 'LeNet2':
        model = LeNet2()
        lr = 0.1
    else:
        model = MnistConvNet()
    # 让我们使用分类交叉熵损失和带有动量的SGD
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    for epoch in range(num_epoch):
        running_loss = 0.0
        for i, data in enumerate(loader()):
            inputs, labels = zip(*data)
            inputs = np.array(inputs).astype('float32')
            labels = np.array(labels).astype('int64')
            inputs = torch.from_numpy(inputs).unsqueeze(1)
            labels = torch.from_numpy(labels)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # print statistics
            running_loss += loss.item()
            if i % 100 == 99:
                last_loss = running_loss / 100  # loss per batch
                print(' batch {} loss: {}'.format(i + 1), last_loss)
                running_loss = 0
    print('Finished Training')
    return model

4. 测试网络

def test(PATH, loader, net_cls='LeNet'):
    if net_cls == 'LeNet':
        model = LeNet()
    elif net_cls == 'LinearNet':
        model = MnistLinearNet()
    elif net_cls == 'LeNet2':
        model = LeNet2()
    else:
        model = MnistConvNet()
    model.load_state_dict(torch.load(PATH))
    correct = 0
    total = 0
    with torch.no_grad():
        for data in loader():
            images, labels = zip(*data)
            images = np.array(images).astype('float32')
            labels = np.array(labels).astype('int64')
            images = torch.from_numpy(images).unsqueeze(1)
            labels = torch.from_numpy(labels)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the {:d} test images: {:f}%'.format(total, 100 * correct / total))
    print('test over')
    return model

5. 训练

if __name__ == 'main':
    BATCH_SIZE = 16
    train_loader = dataReader('../data/mnist/train-images-idx3-ubyte',
                              '../data/mnist/train-labels-idx1-ubyte',
                              BATCH_SIZE, True
                              )
    test_loader = dataReader('../data/mnist/t10k-images-idx3-ubyte',
                             '../data/mnist/t10k-labels-idx1-ubyte',
                             BATCH_SIZE, True
                             )
    # 'LeNet', 'LeNet2', 'LinearNet', 'ConvNet'
    net_cls = 'LeNet2'
    PATH = './mnist_pytorch.' + net_cls + '.pth'
    model = train(train_loader, 5, net_cls)
    # 快速训练我们的模型
    torch.save(model.state_dict(), PATH)
    test(PATH, test_loader, net_cls)