LeNet算法复现¶

LeNet-5

卷积层输出:

$o= \lfloor \frac{n+2p-f}{s} \rfloor+ 1$

池化层输出:

$o= \frac{n+2p-f}{s}+ 1$

$n代表图片大小,p代表填充,f代表卷积核,s代表步长,o代表输出图片大小$

池化输出大小=[(输入大小-卷积核(过滤器)大小)/步长]+1

In [2]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
device = ('cuda' if torch.cuda.is_available() else 'cpu')
In [22]:
#获取数据
train_data = torchvision.datasets.MNIST(
    root='../data/', 
    train=True, 
    transform=torchvision.transforms.ToTensor(), 
    download=False)

test_data = torchvision.datasets.MNIST(
    root='../data/', 
    train=False, 
    transform=torchvision.transforms.ToTensor(), 
    download=False)

#对数据进行分批次训练
batch_size = 64
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
In [14]:
# 定义模型
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding= 2), #1*28*28-- 6*28*28
            # 设置 padding=2 使得结果为 28*28 (28+2*2-5)/1+ 1
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), # 6*14*14
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding= 0), # 16*10*10
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2), # 16*5*5
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(16, 120, kernel_size=5, stride=1, padding=0), #120*1
            nn.ReLU()
        )
        self.fc1 = nn.Sequential(
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10)
        )
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

def evaluate_accuracy(data, model):
    """
    计算测试集训练效果
    """
    acc_sum, n = 0.0, 0
    model.eval()
    with torch.no_grad():
        for x, y in data:
            x, y = x.to(device), y.to(device)
            acc_sum += (model(x).argmax(1)== y).float().sum().item() #计算正确的个数
            n += y.shape[0] #计算全部数据个数
    return acc_sum/ n
In [15]:
#定义损失函数以及优化函数
model = LeNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
In [6]:
def train(data, model, loss_fn, optimizer):
    size = len(data.dataset)
    model.train()
    for batch, (x,y) in enumerate(data):
        x, y = x.to(device), y.to(device)

        pred = model(x)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch %100 == 0:
            loss, current = loss.item(), (batch+ 1)* len(x) #loss为:tensor(127.4510, device='cuda:0', grad_fn=<DivBackward1>)所以通过item()去得到他们的具体数值
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    

def test(data, model, loss_fn):
    size = len(data.dataset)
    num_batches = len(data)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in data:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
In [8]:
epochs = 21
for t in range(epochs):
    if epochs/10 == 0:
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer)
        test(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
-------------------------------
loss: 2.272300  [   64/60000]
loss: 2.269048  [ 6464/60000]
loss: 2.278383  [12864/60000]
loss: 2.273706  [19264/60000]
loss: 2.274111  [25664/60000]
loss: 2.270103  [32064/60000]
loss: 2.254435  [38464/60000]
loss: 2.277729  [44864/60000]
loss: 2.270250  [51264/60000]
loss: 2.255527  [57664/60000]
Test Error: 
 Accuracy: 30.1%, Avg loss: 2.258554 

Epoch 2
-------------------------------
loss: 2.257055  [   64/60000]
loss: 2.248479  [ 6464/60000]
loss: 2.262434  [12864/60000]
loss: 2.252451  [19264/60000]
loss: 2.254298  [25664/60000]
loss: 2.245104  [32064/60000]
loss: 2.219618  [38464/60000]
loss: 2.255465  [44864/60000]
loss: 2.239263  [51264/60000]
loss: 2.216094  [57664/60000]
Test Error: 
 Accuracy: 38.5%, Avg loss: 2.220405 

Epoch 3
-------------------------------
loss: 2.220529  [   64/60000]
loss: 2.198015  [ 6464/60000]
loss: 2.226407  [12864/60000]
loss: 2.196418  [19264/60000]
loss: 2.201220  [25664/60000]
loss: 2.175020  [32064/60000]
loss: 2.122310  [38464/60000]
loss: 2.191850  [44864/60000]
loss: 2.142621  [51264/60000]
loss: 2.095963  [57664/60000]
Test Error: 
 Accuracy: 44.2%, Avg loss: 2.098208 

Epoch 4
-------------------------------
loss: 2.103936  [   64/60000]
loss: 2.031757  [ 6464/60000]
loss: 2.105584  [12864/60000]
loss: 2.008843  [19264/60000]
loss: 2.016967  [25664/60000]
loss: 1.924333  [32064/60000]
loss: 1.748101  [38464/60000]
loss: 1.972561  [44864/60000]
loss: 1.804982  [51264/60000]
loss: 1.649397  [57664/60000]
Test Error: 
 Accuracy: 54.8%, Avg loss: 1.677555 

Epoch 5
-------------------------------
loss: 1.742907  [   64/60000]
loss: 1.542034  [ 6464/60000]
loss: 1.671715  [12864/60000]
loss: 1.490320  [19264/60000]
loss: 1.450628  [25664/60000]
loss: 1.320244  [32064/60000]
loss: 1.038140  [38464/60000]
loss: 1.442216  [44864/60000]
loss: 1.164547  [51264/60000]
loss: 1.009380  [57664/60000]
Test Error: 
 Accuracy: 70.4%, Avg loss: 1.018695 

Epoch 6
-------------------------------
loss: 1.148516  [   64/60000]
loss: 0.897734  [ 6464/60000]
loss: 0.964151  [12864/60000]
loss: 0.832320  [19264/60000]
loss: 0.885078  [25664/60000]
loss: 0.752632  [32064/60000]
loss: 0.629204  [38464/60000]
loss: 0.919603  [44864/60000]
loss: 0.766702  [51264/60000]
loss: 0.731564  [57664/60000]
Test Error: 
 Accuracy: 79.8%, Avg loss: 0.670018 

Epoch 7
-------------------------------
loss: 0.817475  [   64/60000]
loss: 0.596555  [ 6464/60000]
loss: 0.593916  [12864/60000]
loss: 0.590386  [19264/60000]
loss: 0.677329  [25664/60000]
loss: 0.536581  [32064/60000]
loss: 0.469469  [38464/60000]
loss: 0.676546  [44864/60000]
loss: 0.615217  [51264/60000]
loss: 0.625991  [57664/60000]
Test Error: 
 Accuracy: 84.1%, Avg loss: 0.528252 

Epoch 8
-------------------------------
loss: 0.682580  [   64/60000]
loss: 0.482965  [ 6464/60000]
loss: 0.444592  [12864/60000]
loss: 0.502215  [19264/60000]
loss: 0.565583  [25664/60000]
loss: 0.454061  [32064/60000]
loss: 0.375289  [38464/60000]
loss: 0.562539  [44864/60000]
loss: 0.538687  [51264/60000]
loss: 0.565120  [57664/60000]
Test Error: 
 Accuracy: 86.5%, Avg loss: 0.452470 

Epoch 9
-------------------------------
loss: 0.595269  [   64/60000]
loss: 0.420472  [ 6464/60000]
loss: 0.369769  [12864/60000]
loss: 0.452824  [19264/60000]
loss: 0.485208  [25664/60000]
loss: 0.413044  [32064/60000]
loss: 0.316019  [38464/60000]
loss: 0.505066  [44864/60000]
loss: 0.483565  [51264/60000]
loss: 0.524046  [57664/60000]
Test Error: 
 Accuracy: 87.9%, Avg loss: 0.403727 

Epoch 10
-------------------------------
loss: 0.527021  [   64/60000]
loss: 0.376387  [ 6464/60000]
loss: 0.322572  [12864/60000]
loss: 0.420881  [19264/60000]
loss: 0.419712  [25664/60000]
loss: 0.387603  [32064/60000]
loss: 0.278192  [38464/60000]
loss: 0.469517  [44864/60000]
loss: 0.439173  [51264/60000]
loss: 0.490684  [57664/60000]
Test Error: 
 Accuracy: 89.2%, Avg loss: 0.367917 

Epoch 11
-------------------------------
loss: 0.470405  [   64/60000]
loss: 0.342959  [ 6464/60000]
loss: 0.287529  [12864/60000]
loss: 0.399545  [19264/60000]
loss: 0.365872  [25664/60000]
loss: 0.368633  [32064/60000]
loss: 0.253037  [38464/60000]
loss: 0.446631  [44864/60000]
loss: 0.399979  [51264/60000]
loss: 0.463988  [57664/60000]
Test Error: 
 Accuracy: 90.0%, Avg loss: 0.340112 

Epoch 12
-------------------------------
loss: 0.425155  [   64/60000]
loss: 0.320384  [ 6464/60000]
loss: 0.260997  [12864/60000]
loss: 0.385833  [19264/60000]
loss: 0.325431  [25664/60000]
loss: 0.353819  [32064/60000]
loss: 0.232314  [38464/60000]
loss: 0.432795  [44864/60000]
loss: 0.367541  [51264/60000]
loss: 0.440752  [57664/60000]
Test Error: 
 Accuracy: 90.5%, Avg loss: 0.317997 

Epoch 13
-------------------------------
loss: 0.384871  [   64/60000]
loss: 0.303614  [ 6464/60000]
loss: 0.238483  [12864/60000]
loss: 0.375285  [19264/60000]
loss: 0.294723  [25664/60000]
loss: 0.339454  [32064/60000]
loss: 0.216060  [38464/60000]
loss: 0.420999  [44864/60000]
loss: 0.339942  [51264/60000]
loss: 0.419056  [57664/60000]
Test Error: 
 Accuracy: 90.9%, Avg loss: 0.299733 

Epoch 14
-------------------------------
loss: 0.352212  [   64/60000]
loss: 0.293606  [ 6464/60000]
loss: 0.219635  [12864/60000]
loss: 0.366890  [19264/60000]
loss: 0.269858  [25664/60000]
loss: 0.325129  [32064/60000]
loss: 0.201280  [38464/60000]
loss: 0.413026  [44864/60000]
loss: 0.315552  [51264/60000]
loss: 0.399812  [57664/60000]
Test Error: 
 Accuracy: 91.4%, Avg loss: 0.284142 

Epoch 15
-------------------------------
loss: 0.324788  [   64/60000]
loss: 0.286184  [ 6464/60000]
loss: 0.206210  [12864/60000]
loss: 0.359954  [19264/60000]
loss: 0.252210  [25664/60000]
loss: 0.310987  [32064/60000]
loss: 0.189725  [38464/60000]
loss: 0.404570  [44864/60000]
loss: 0.295089  [51264/60000]
loss: 0.381713  [57664/60000]
Test Error: 
 Accuracy: 91.8%, Avg loss: 0.270155 

Epoch 16
-------------------------------
loss: 0.299573  [   64/60000]
loss: 0.280513  [ 6464/60000]
loss: 0.195589  [12864/60000]
loss: 0.352637  [19264/60000]
loss: 0.236667  [25664/60000]
loss: 0.297527  [32064/60000]
loss: 0.178814  [38464/60000]
loss: 0.395350  [44864/60000]
loss: 0.276259  [51264/60000]
loss: 0.364834  [57664/60000]
Test Error: 
 Accuracy: 92.2%, Avg loss: 0.257590 

Epoch 17
-------------------------------
loss: 0.276401  [   64/60000]
loss: 0.274002  [ 6464/60000]
loss: 0.187289  [12864/60000]
loss: 0.345724  [19264/60000]
loss: 0.222929  [25664/60000]
loss: 0.285335  [32064/60000]
loss: 0.167951  [38464/60000]
loss: 0.385595  [44864/60000]
loss: 0.259681  [51264/60000]
loss: 0.348546  [57664/60000]
Test Error: 
 Accuracy: 92.5%, Avg loss: 0.246200 

Epoch 18
-------------------------------
loss: 0.254310  [   64/60000]
loss: 0.267883  [ 6464/60000]
loss: 0.180744  [12864/60000]
loss: 0.337245  [19264/60000]
loss: 0.210713  [25664/60000]
loss: 0.274544  [32064/60000]
loss: 0.157922  [38464/60000]
loss: 0.376551  [44864/60000]
loss: 0.244212  [51264/60000]
loss: 0.333477  [57664/60000]
Test Error: 
 Accuracy: 92.9%, Avg loss: 0.235646 

Epoch 19
-------------------------------
loss: 0.234088  [   64/60000]
loss: 0.262315  [ 6464/60000]
loss: 0.175059  [12864/60000]
loss: 0.328604  [19264/60000]
loss: 0.199100  [25664/60000]
loss: 0.265312  [32064/60000]
loss: 0.149548  [38464/60000]
loss: 0.366871  [44864/60000]
loss: 0.231003  [51264/60000]
loss: 0.320089  [57664/60000]
Test Error: 
 Accuracy: 93.3%, Avg loss: 0.225980 

Epoch 20
-------------------------------
loss: 0.215904  [   64/60000]
loss: 0.256772  [ 6464/60000]
loss: 0.169321  [12864/60000]
loss: 0.320245  [19264/60000]
loss: 0.187060  [25664/60000]
loss: 0.256086  [32064/60000]
loss: 0.142840  [38464/60000]
loss: 0.357499  [44864/60000]
loss: 0.218770  [51264/60000]
loss: 0.307769  [57664/60000]
Test Error: 
 Accuracy: 93.6%, Avg loss: 0.216964 

Epoch 21
-------------------------------
loss: 0.200837  [   64/60000]
loss: 0.252707  [ 6464/60000]
loss: 0.163498  [12864/60000]
loss: 0.312334  [19264/60000]
loss: 0.176415  [25664/60000]
loss: 0.248623  [32064/60000]
loss: 0.136781  [38464/60000]
loss: 0.347503  [44864/60000]
loss: 0.208322  [51264/60000]
loss: 0.297342  [57664/60000]
Test Error: 
 Accuracy: 93.7%, Avg loss: 0.208597 

Epoch 22
-------------------------------
loss: 0.187000  [   64/60000]
loss: 0.248547  [ 6464/60000]
loss: 0.158774  [12864/60000]
loss: 0.304266  [19264/60000]
loss: 0.165540  [25664/60000]
loss: 0.241203  [32064/60000]
loss: 0.132309  [38464/60000]
loss: 0.337949  [44864/60000]
loss: 0.199489  [51264/60000]
loss: 0.288586  [57664/60000]
Test Error: 
 Accuracy: 94.0%, Avg loss: 0.200504 

Epoch 23
-------------------------------
loss: 0.175113  [   64/60000]
loss: 0.244488  [ 6464/60000]
loss: 0.153888  [12864/60000]
loss: 0.296739  [19264/60000]
loss: 0.156701  [25664/60000]
loss: 0.233408  [32064/60000]
loss: 0.127569  [38464/60000]
loss: 0.327803  [44864/60000]
loss: 0.192073  [51264/60000]
loss: 0.280420  [57664/60000]
Test Error: 
 Accuracy: 94.2%, Avg loss: 0.193056 

Epoch 24
-------------------------------
loss: 0.164669  [   64/60000]
loss: 0.238974  [ 6464/60000]
loss: 0.149055  [12864/60000]
loss: 0.290120  [19264/60000]
loss: 0.146750  [25664/60000]
loss: 0.226248  [32064/60000]
loss: 0.123039  [38464/60000]
loss: 0.318942  [44864/60000]
loss: 0.186323  [51264/60000]
loss: 0.274440  [57664/60000]
Test Error: 
 Accuracy: 94.4%, Avg loss: 0.186250 

Epoch 25
-------------------------------
loss: 0.155306  [   64/60000]
loss: 0.234572  [ 6464/60000]
loss: 0.143423  [12864/60000]
loss: 0.284805  [19264/60000]
loss: 0.136558  [25664/60000]
loss: 0.218994  [32064/60000]
loss: 0.119276  [38464/60000]
loss: 0.311303  [44864/60000]
loss: 0.181491  [51264/60000]
loss: 0.268996  [57664/60000]
Test Error: 
 Accuracy: 94.6%, Avg loss: 0.179679 

Epoch 26
-------------------------------
loss: 0.147111  [   64/60000]
loss: 0.229742  [ 6464/60000]
loss: 0.138638  [12864/60000]
loss: 0.279645  [19264/60000]
loss: 0.127543  [25664/60000]
loss: 0.212860  [32064/60000]
loss: 0.115741  [38464/60000]
loss: 0.303846  [44864/60000]
loss: 0.177719  [51264/60000]
loss: 0.264207  [57664/60000]
Test Error: 
 Accuracy: 94.7%, Avg loss: 0.173709 

Epoch 27
-------------------------------
loss: 0.139985  [   64/60000]
loss: 0.224278  [ 6464/60000]
loss: 0.134438  [12864/60000]
loss: 0.275243  [19264/60000]
loss: 0.119117  [25664/60000]
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_7180\2134615688.py in <module>
      3     if epochs%10 == 0:
      4         print(f"Epoch {t+1}\n-------------------------------")
----> 5         train(train_dataloader, model, loss_fn, optimizer)
      6         test(test_dataloader, model, loss_fn)
      7 print("Done!")

~\AppData\Local\Temp\ipykernel_7180\2395823617.py in train(data, model, loss_fn, optimizer)
      9         x, y = x.to(device), y.to(device)
     10 
---> 11         pred = model(x)
     12         loss = loss_fn(pred, y)
     13 

e:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~\AppData\Local\Temp\ipykernel_7180\2835041471.py in forward(self, x)
     26         x = self.layer1(x)
     27         x = self.layer2(x)
---> 28         x = self.layer3(x)
     29         x = x.view(x.size(0), -1)
     30         x = self.fc1(x)

e:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

e:\Anaconda\lib\site-packages\torch\nn\modules\container.py in forward(self, input)
    202     def forward(self, input):
    203         for module in self:
--> 204             input = module(input)
    205         return input
    206 

e:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

e:\Anaconda\lib\site-packages\torch\nn\modules\conv.py in forward(self, input)
    461 
    462     def forward(self, input: Tensor) -> Tensor:
--> 463         return self._conv_forward(input, self.weight, self.bias)
    464 
    465 class Conv3d(_ConvNd):

e:\Anaconda\lib\site-packages\torch\nn\modules\conv.py in _conv_forward(self, input, weight, bias)
    457                             weight, bias, self.stride,
    458                             _pair(0), self.dilation, self.groups)
--> 459         return F.conv2d(input, weight, bias, self.stride,
    460                         self.padding, self.dilation, self.groups)
    461 

KeyboardInterrupt: 
In [31]:
def test(data, model):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in data:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    correct /= len(data.dataset)
    return correct

epochs = 21
# 没有对数据进行优化,可以选择去对数据进行优化,比如正则化、添加droup out、添加resnet连接
for epoch in range(epochs):
    print('Epoch:{}\n'.format(epoch+1))
    for batch, (x,y) in enumerate(train_dataloader):
        size = len(train_dataloader.dataset)
        model.train()
        x, y = x.to(device), y.to(device)

        pred = model(x)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if batch %100 == 0:
            loss, current = loss.item(), (batch+ 1)* len(x) 
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    correct = test(test_dataloader, model)
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}% \n")
Epoch:1

loss: 0.098646  [   64/60000]
loss: 0.157770  [ 6464/60000]
loss: 0.106539  [12864/60000]
loss: 0.215023  [19264/60000]
loss: 0.100605  [25664/60000]
loss: 0.150523  [32064/60000]
loss: 0.102860  [38464/60000]
loss: 0.228336  [44864/60000]
loss: 0.174099  [51264/60000]
loss: 0.196144  [57664/60000]
Test Error: 
 Accuracy: 95.9% 

Epoch:2

loss: 0.107013  [   64/60000]
loss: 0.154821  [ 6464/60000]
loss: 0.105823  [12864/60000]
loss: 0.208457  [19264/60000]
loss: 0.096202  [25664/60000]
loss: 0.145667  [32064/60000]
loss: 0.101881  [38464/60000]
loss: 0.221606  [44864/60000]
loss: 0.172308  [51264/60000]
loss: 0.194549  [57664/60000]
Test Error: 
 Accuracy: 96.0% 

Epoch:3

loss: 0.102882  [   64/60000]
loss: 0.150609  [ 6464/60000]
loss: 0.103246  [12864/60000]
loss: 0.202459  [19264/60000]
loss: 0.090362  [25664/60000]
loss: 0.141385  [32064/60000]
loss: 0.101247  [38464/60000]
loss: 0.215352  [44864/60000]
loss: 0.170169  [51264/60000]
loss: 0.192590  [57664/60000]
Test Error: 
 Accuracy: 96.1% 

Epoch:4

loss: 0.099215  [   64/60000]
loss: 0.146702  [ 6464/60000]
loss: 0.100990  [12864/60000]
loss: 0.196955  [19264/60000]
loss: 0.085184  [25664/60000]
loss: 0.137365  [32064/60000]
loss: 0.100910  [38464/60000]
loss: 0.210799  [44864/60000]
loss: 0.167222  [51264/60000]
loss: 0.190774  [57664/60000]
Test Error: 
 Accuracy: 96.2% 

Epoch:5

loss: 0.095846  [   64/60000]
loss: 0.143141  [ 6464/60000]
loss: 0.098452  [12864/60000]
loss: 0.191900  [19264/60000]
loss: 0.080353  [25664/60000]
loss: 0.133666  [32064/60000]
loss: 0.100417  [38464/60000]
loss: 0.206443  [44864/60000]
loss: 0.165251  [51264/60000]
loss: 0.189182  [57664/60000]
Test Error: 
 Accuracy: 96.4% 

Epoch:6

loss: 0.092713  [   64/60000]
loss: 0.139563  [ 6464/60000]
loss: 0.096374  [12864/60000]
loss: 0.187604  [19264/60000]
loss: 0.076391  [25664/60000]
loss: 0.130577  [32064/60000]
loss: 0.100084  [38464/60000]
loss: 0.201824  [44864/60000]
loss: 0.163594  [51264/60000]
loss: 0.187552  [57664/60000]
Test Error: 
 Accuracy: 96.5% 

Epoch:7

loss: 0.090275  [   64/60000]
loss: 0.136202  [ 6464/60000]
loss: 0.094576  [12864/60000]
loss: 0.183031  [19264/60000]
loss: 0.072579  [25664/60000]
loss: 0.127552  [32064/60000]
loss: 0.100191  [38464/60000]
loss: 0.197527  [44864/60000]
loss: 0.161685  [51264/60000]
loss: 0.185743  [57664/60000]
Test Error: 
 Accuracy: 96.5% 

Epoch:8

loss: 0.087929  [   64/60000]
loss: 0.132803  [ 6464/60000]
loss: 0.093051  [12864/60000]
loss: 0.178486  [19264/60000]
loss: 0.069132  [25664/60000]
loss: 0.125270  [32064/60000]
loss: 0.100164  [38464/60000]
loss: 0.193437  [44864/60000]
loss: 0.159913  [51264/60000]
loss: 0.183776  [57664/60000]
Test Error: 
 Accuracy: 96.6% 

Epoch:9

loss: 0.085961  [   64/60000]
loss: 0.129178  [ 6464/60000]
loss: 0.091663  [12864/60000]
loss: 0.174406  [19264/60000]
loss: 0.066205  [25664/60000]
loss: 0.122413  [32064/60000]
loss: 0.100105  [38464/60000]
loss: 0.190159  [44864/60000]
loss: 0.158249  [51264/60000]
loss: 0.182137  [57664/60000]
Test Error: 
 Accuracy: 96.6% 

Epoch:10

loss: 0.083964  [   64/60000]
loss: 0.125733  [ 6464/60000]
loss: 0.090878  [12864/60000]
loss: 0.169320  [19264/60000]
loss: 0.063556  [25664/60000]
loss: 0.120058  [32064/60000]
loss: 0.099898  [38464/60000]
loss: 0.186033  [44864/60000]
loss: 0.156695  [51264/60000]
loss: 0.180757  [57664/60000]
Test Error: 
 Accuracy: 96.7% 

Epoch:11

loss: 0.082257  [   64/60000]
loss: 0.122557  [ 6464/60000]
loss: 0.090394  [12864/60000]
loss: 0.165458  [19264/60000]
loss: 0.061349  [25664/60000]
loss: 0.117119  [32064/60000]
loss: 0.100087  [38464/60000]
loss: 0.182522  [44864/60000]
loss: 0.156071  [51264/60000]
loss: 0.179523  [57664/60000]
Test Error: 
 Accuracy: 96.8% 

Epoch:12

loss: 0.080549  [   64/60000]
loss: 0.119308  [ 6464/60000]
loss: 0.089058  [12864/60000]
loss: 0.161788  [19264/60000]
loss: 0.059212  [25664/60000]
loss: 0.114778  [32064/60000]
loss: 0.099970  [38464/60000]
loss: 0.179842  [44864/60000]
loss: 0.154772  [51264/60000]
loss: 0.178012  [57664/60000]
Test Error: 
 Accuracy: 96.9% 

Epoch:13

loss: 0.078873  [   64/60000]
loss: 0.116339  [ 6464/60000]
loss: 0.088101  [12864/60000]
loss: 0.157712  [19264/60000]
loss: 0.057071  [25664/60000]
loss: 0.112539  [32064/60000]
loss: 0.099875  [38464/60000]
loss: 0.177357  [44864/60000]
loss: 0.153125  [51264/60000]
loss: 0.176277  [57664/60000]
Test Error: 
 Accuracy: 96.9% 

Epoch:14

loss: 0.077309  [   64/60000]
loss: 0.113174  [ 6464/60000]
loss: 0.087097  [12864/60000]
loss: 0.154341  [19264/60000]
loss: 0.054654  [25664/60000]
loss: 0.110167  [32064/60000]
loss: 0.099950  [38464/60000]
loss: 0.174374  [44864/60000]
loss: 0.151898  [51264/60000]
loss: 0.174753  [57664/60000]
Test Error: 
 Accuracy: 96.9% 

Epoch:15

loss: 0.075497  [   64/60000]
loss: 0.110737  [ 6464/60000]
loss: 0.085666  [12864/60000]
loss: 0.151679  [19264/60000]
loss: 0.052687  [25664/60000]
loss: 0.107632  [32064/60000]
loss: 0.100248  [38464/60000]
loss: 0.171724  [44864/60000]
loss: 0.150230  [51264/60000]
loss: 0.173451  [57664/60000]
Test Error: 
 Accuracy: 97.0% 

Epoch:16

loss: 0.073662  [   64/60000]
loss: 0.107864  [ 6464/60000]
loss: 0.084569  [12864/60000]
loss: 0.148988  [19264/60000]
loss: 0.050488  [25664/60000]
loss: 0.105394  [32064/60000]
loss: 0.100237  [38464/60000]
loss: 0.168985  [44864/60000]
loss: 0.148702  [51264/60000]
loss: 0.172232  [57664/60000]
Test Error: 
 Accuracy: 97.0% 

Epoch:17

loss: 0.072101  [   64/60000]
loss: 0.106040  [ 6464/60000]
loss: 0.083778  [12864/60000]
loss: 0.146514  [19264/60000]
loss: 0.048645  [25664/60000]
loss: 0.103209  [32064/60000]
loss: 0.100347  [38464/60000]
loss: 0.166806  [44864/60000]
loss: 0.147567  [51264/60000]
loss: 0.171285  [57664/60000]
Test Error: 
 Accuracy: 97.0% 

Epoch:18

loss: 0.070447  [   64/60000]
loss: 0.104067  [ 6464/60000]
loss: 0.082554  [12864/60000]
loss: 0.143509  [19264/60000]
loss: 0.046810  [25664/60000]
loss: 0.101432  [32064/60000]
loss: 0.100408  [38464/60000]
loss: 0.164900  [44864/60000]
loss: 0.145843  [51264/60000]
loss: 0.170779  [57664/60000]
Test Error: 
 Accuracy: 97.1% 

Epoch:19

loss: 0.068664  [   64/60000]
loss: 0.101738  [ 6464/60000]
loss: 0.081756  [12864/60000]
loss: 0.141499  [19264/60000]
loss: 0.045196  [25664/60000]
loss: 0.099498  [32064/60000]
loss: 0.100201  [38464/60000]
loss: 0.162691  [44864/60000]
loss: 0.144729  [51264/60000]
loss: 0.169254  [57664/60000]
Test Error: 
 Accuracy: 97.2% 

Epoch:20

loss: 0.067405  [   64/60000]
loss: 0.100164  [ 6464/60000]
loss: 0.081304  [12864/60000]
loss: 0.139226  [19264/60000]
loss: 0.043736  [25664/60000]
loss: 0.097499  [32064/60000]
loss: 0.100292  [38464/60000]
loss: 0.160949  [44864/60000]
loss: 0.143649  [51264/60000]
loss: 0.168554  [57664/60000]
Test Error: 
 Accuracy: 97.2% 

Epoch:21

loss: 0.065803  [   64/60000]
loss: 0.097873  [ 6464/60000]
loss: 0.080800  [12864/60000]
loss: 0.137347  [19264/60000]
loss: 0.042340  [25664/60000]
loss: 0.095061  [32064/60000]
loss: 0.100351  [38464/60000]
loss: 0.158768  [44864/60000]
loss: 0.141980  [51264/60000]
loss: 0.166981  [57664/60000]
Test Error: 
 Accuracy: 97.2% 

In [33]:
torch.save(model.state_dict(), "./LeNet.pth")
model = LeNet().to(device)
model.load_state_dict(torch.load("./LeNet.pth"))
Out[33]:
<All keys matched successfully>
In [23]:
model.eval()
x, y = test_data[180][0].view(1,1,28,28), test_data[180][1]
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = pred[0].argmax(0), y
    print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "1", Actual: "1"
In [ ]: