【ML笔记】PyTorch转移训练 (Transfer Learning)

转移训练就是在已经训练好的模型上通过少量新数据再对模型进行调参,从而训练出一个新模型,使之可以运用于特定任务,有点类似于初级版的fine-tune。由于也是调参,所以转移训练和全新训练的过程大差不差,大致为以下:

  • 预处理数据 (transform)
  • 建立数据集 (dataloader)
  • 准备,调整模型 (model)
  • 开练
  • 得到结果
  • 预测

顺便介绍一个我觉得很好用的PyTorch学习资源

引用

当然是先应用各种乱七八糟的库了,没有的就pip装吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torchvision
import os
import matplotlib.pyplot as plt

from torchinfo import summary
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm
from typing import Dict, List, Tuple
from pathlib import Path
from PIL import Image

if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"

print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
print(f"device: {device}")

输出应该类似于这样的,我用的MBP,所以是mps

1
2
3
torch version: 2.6.0
torchvision version: 0.21.0
device: mps

预处理数据 & 建立数据集

创建一个函数来处理吧,就叫createDataLoader。我们需要再里面创建好dataloader数据集,顺便导出类名

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
NUM_WORKERS = os.cpu_count()
def createDataLoader(
trainDir: str,
testDir: str,
tramsform: transforms.Compose,
batchSize: int,
numWorkers: int=NUM_WORKERS):
trainData = datasets.ImageFolder(trainDir, transform=tramsform)
testData = datasets.ImageFolder(testDir, transform=tramsform)
classNames = trainData.classes

trainDataLoader = DataLoader(
trainData,
batch_size=batchSize,
shuffle=True,
num_workers=numWorkers,
pin_memory=True) # 这里将数据锁定在内存中,减少数据传输的瓶颈
testDataLoader = DataLoader(
testData,
batch_size=batchSize,
shuffle=True,
num_workers=numWorkers,
pin_memory=True) # 这里将数据锁定在内存中,减少数据传输的瓶颈

return trainDataLoader, testDataLoader, classNames

然后我们需要创建transform,定位好训练数据。训练数据分为两个文件夹traintest,然后再对应文件夹下面创建类别,大概像这样:

然后剩下的就看代码吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
IMAGE_PATH = Path("datasets")

trainDir = IMAGE_PATH / "train"
testDir = IMAGE_PATH / "test"

manualTransform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

trainDataLoader, testDataLoader, className = createDataLoader(
trainDir=trainDir,
testDir=testDir,
tramsform=manualTransform,
batchSize=32)

trainDataLoader, testDataLoader, className

输出应该是这样的(根据上面的截图)

1
2
3
(<torch.utils.data.dataloader.DataLoader at 0x129dea140>,
<torch.utils.data.dataloader.DataLoader at 0x129deb790>,
['diana', 'mashiro', 'miku', 'paimon'])

然后数据集就在trainDataLoadertestDataLoader中了

准备、调试模型

模型需要准备好初始权重以及初始权重的转换,以及对模型进行结构上的必要微调,比如调整输出的linear layer,以适配输出,见代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
weights = torchvision.models.ResNet50_Weights.DEFAULT
autoTransform = weights.transforms()
model = torchvision.models.resnet50(weights=weights).to(device)

torch.manual_seed(2233)

outputShape = len(className)

model.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=True),
torch.nn.Linear(in_features=1280,
out_features=outputShape,
bias=True)).to(device)

summary(model=model,
input_size=(32, 3, 224, 224),
verbose=0,
col_names=["input_size", "output_size", "num_params", "trainable"],
col_width=20,
row_settings=["var_names"])

我选的resnet50,输出大概是这样的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
========================================================================================================================
Layer (type (var_name)) Input Shape Output Shape Param # Trainable
========================================================================================================================
ResNet (ResNet) [32, 3, 224, 224] [32, 1000] 5,124 True
├─Conv2d (conv1) [32, 3, 224, 224] [32, 64, 112, 112] 9,408 True
├─BatchNorm2d (bn1) [32, 64, 112, 112] [32, 64, 112, 112] 128 True
├─ReLU (relu) [32, 64, 112, 112] [32, 64, 112, 112] -- --
├─MaxPool2d (maxpool) [32, 64, 112, 112] [32, 64, 56, 56] -- --
├─Sequential (layer1) [32, 64, 56, 56] [32, 256, 56, 56] -- True
│ └─Bottleneck (0) [32, 64, 56, 56] [32, 256, 56, 56] -- True
│ │ └─Conv2d (conv1) [32, 64, 56, 56] [32, 64, 56, 56] 4,096 True
│ │ └─BatchNorm2d (bn1) [32, 64, 56, 56] [32, 64, 56, 56] 128 True
│ │ └─ReLU (relu) [32, 64, 56, 56] [32, 64, 56, 56] -- --
│ │ └─Conv2d (conv2) [32, 64, 56, 56] [32, 64, 56, 56] 36,864 True
│ │ └─BatchNorm2d (bn2) [32, 64, 56, 56] [32, 64, 56, 56] 128 True
│ │ └─ReLU (relu) [32, 64, 56, 56] [32, 64, 56, 56] -- --
│ │ └─Conv2d (conv3) [32, 64, 56, 56] [32, 256, 56, 56] 16,384 True
│ │ └─BatchNorm2d (bn3) [32, 256, 56, 56] [32, 256, 56, 56] 512 True
│ │ └─Sequential (downsample) [32, 64, 56, 56] [32, 256, 56, 56] 16,896 True
│ │ └─ReLU (relu) [32, 256, 56, 56] [32, 256, 56, 56] -- --
│ └─Bottleneck (1) [32, 256, 56, 56] [32, 256, 56, 56] -- True
│ │ └─Conv2d (conv1) [32, 256, 56, 56] [32, 64, 56, 56] 16,384 True
│ │ └─BatchNorm2d (bn1) [32, 64, 56, 56] [32, 64, 56, 56] 128 True
│ │ └─ReLU (relu) [32, 64, 56, 56] [32, 64, 56, 56] -- --
│ │ └─Conv2d (conv2) [32, 64, 56, 56] [32, 64, 56, 56] 36,864 True
│ │ └─BatchNorm2d (bn2) [32, 64, 56, 56] [32, 64, 56, 56] 128 True
│ │ └─ReLU (relu) [32, 64, 56, 56] [32, 64, 56, 56] -- --
│ │ └─Conv2d (conv3) [32, 64, 56, 56] [32, 256, 56, 56] 16,384 True
│ │ └─BatchNorm2d (bn3) [32, 256, 56, 56] [32, 256, 56, 56] 512 True
│ │ └─ReLU (relu) [32, 256, 56, 56] [32, 256, 56, 56] -- --
│ └─Bottleneck (2) [32, 256, 56, 56] [32, 256, 56, 56] -- True
│ │ └─Conv2d (conv1) [32, 256, 56, 56] [32, 64, 56, 56] 16,384 True
│ │ └─BatchNorm2d (bn1) [32, 64, 56, 56] [32, 64, 56, 56] 128 True
│ │ └─ReLU (relu) [32, 64, 56, 56] [32, 64, 56, 56] -- --
│ │ └─Conv2d (conv2) [32, 64, 56, 56] [32, 64, 56, 56] 36,864 True
│ │ └─BatchNorm2d (bn2) [32, 64, 56, 56] [32, 64, 56, 56] 128 True
│ │ └─ReLU (relu) [32, 64, 56, 56] [32, 64, 56, 56] -- --
│ │ └─Conv2d (conv3) [32, 64, 56, 56] [32, 256, 56, 56] 16,384 True
│ │ └─BatchNorm2d (bn3) [32, 256, 56, 56] [32, 256, 56, 56] 512 True
│ │ └─ReLU (relu) [32, 256, 56, 56] [32, 256, 56, 56] -- --
├─Sequential (layer2) [32, 256, 56, 56] [32, 512, 28, 28] -- True
│ └─Bottleneck (0) [32, 256, 56, 56] [32, 512, 28, 28] -- True
│ │ └─Conv2d (conv1) [32, 256, 56, 56] [32, 128, 56, 56] 32,768 True
│ │ └─BatchNorm2d (bn1) [32, 128, 56, 56] [32, 128, 56, 56] 256 True
│ │ └─ReLU (relu) [32, 128, 56, 56] [32, 128, 56, 56] -- --
│ │ └─Conv2d (conv2) [32, 128, 56, 56] [32, 128, 28, 28] 147,456 True
│ │ └─BatchNorm2d (bn2) [32, 128, 28, 28] [32, 128, 28, 28] 256 True
│ │ └─ReLU (relu) [32, 128, 28, 28] [32, 128, 28, 28] -- --
│ │ └─Conv2d (conv3) [32, 128, 28, 28] [32, 512, 28, 28] 65,536 True
│ │ └─BatchNorm2d (bn3) [32, 512, 28, 28] [32, 512, 28, 28] 1,024 True
│ │ └─Sequential (downsample) [32, 256, 56, 56] [32, 512, 28, 28] 132,096 True
│ │ └─ReLU (relu) [32, 512, 28, 28] [32, 512, 28, 28] -- --
│ └─Bottleneck (1) [32, 512, 28, 28] [32, 512, 28, 28] -- True
│ │ └─Conv2d (conv1) [32, 512, 28, 28] [32, 128, 28, 28] 65,536 True
│ │ └─BatchNorm2d (bn1) [32, 128, 28, 28] [32, 128, 28, 28] 256 True
│ │ └─ReLU (relu) [32, 128, 28, 28] [32, 128, 28, 28] -- --
│ │ └─Conv2d (conv2) [32, 128, 28, 28] [32, 128, 28, 28] 147,456 True
│ │ └─BatchNorm2d (bn2) [32, 128, 28, 28] [32, 128, 28, 28] 256 True
│ │ └─ReLU (relu) [32, 128, 28, 28] [32, 128, 28, 28] -- --
│ │ └─Conv2d (conv3) [32, 128, 28, 28] [32, 512, 28, 28] 65,536 True
│ │ └─BatchNorm2d (bn3) [32, 512, 28, 28] [32, 512, 28, 28] 1,024 True
│ │ └─ReLU (relu) [32, 512, 28, 28] [32, 512, 28, 28] -- --
│ └─Bottleneck (2) [32, 512, 28, 28] [32, 512, 28, 28] -- True
│ │ └─Conv2d (conv1) [32, 512, 28, 28] [32, 128, 28, 28] 65,536 True
│ │ └─BatchNorm2d (bn1) [32, 128, 28, 28] [32, 128, 28, 28] 256 True
│ │ └─ReLU (relu) [32, 128, 28, 28] [32, 128, 28, 28] -- --
│ │ └─Conv2d (conv2) [32, 128, 28, 28] [32, 128, 28, 28] 147,456 True
│ │ └─BatchNorm2d (bn2) [32, 128, 28, 28] [32, 128, 28, 28] 256 True
│ │ └─ReLU (relu) [32, 128, 28, 28] [32, 128, 28, 28] -- --
│ │ └─Conv2d (conv3) [32, 128, 28, 28] [32, 512, 28, 28] 65,536 True
│ │ └─BatchNorm2d (bn3) [32, 512, 28, 28] [32, 512, 28, 28] 1,024 True
│ │ └─ReLU (relu) [32, 512, 28, 28] [32, 512, 28, 28] -- --
│ └─Bottleneck (3) [32, 512, 28, 28] [32, 512, 28, 28] -- True
│ │ └─Conv2d (conv1) [32, 512, 28, 28] [32, 128, 28, 28] 65,536 True
│ │ └─BatchNorm2d (bn1) [32, 128, 28, 28] [32, 128, 28, 28] 256 True
│ │ └─ReLU (relu) [32, 128, 28, 28] [32, 128, 28, 28] -- --
│ │ └─Conv2d (conv2) [32, 128, 28, 28] [32, 128, 28, 28] 147,456 True
│ │ └─BatchNorm2d (bn2) [32, 128, 28, 28] [32, 128, 28, 28] 256 True
│ │ └─ReLU (relu) [32, 128, 28, 28] [32, 128, 28, 28] -- --
│ │ └─Conv2d (conv3) [32, 128, 28, 28] [32, 512, 28, 28] 65,536 True
│ │ └─BatchNorm2d (bn3) [32, 512, 28, 28] [32, 512, 28, 28] 1,024 True
│ │ └─ReLU (relu) [32, 512, 28, 28] [32, 512, 28, 28] -- --
├─Sequential (layer3) [32, 512, 28, 28] [32, 1024, 14, 14] -- True
│ └─Bottleneck (0) [32, 512, 28, 28] [32, 1024, 14, 14] -- True
│ │ └─Conv2d (conv1) [32, 512, 28, 28] [32, 256, 28, 28] 131,072 True
│ │ └─BatchNorm2d (bn1) [32, 256, 28, 28] [32, 256, 28, 28] 512 True
│ │ └─ReLU (relu) [32, 256, 28, 28] [32, 256, 28, 28] -- --
│ │ └─Conv2d (conv2) [32, 256, 28, 28] [32, 256, 14, 14] 589,824 True
│ │ └─BatchNorm2d (bn2) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv3) [32, 256, 14, 14] [32, 1024, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn3) [32, 1024, 14, 14] [32, 1024, 14, 14] 2,048 True
│ │ └─Sequential (downsample) [32, 512, 28, 28] [32, 1024, 14, 14] 526,336 True
│ │ └─ReLU (relu) [32, 1024, 14, 14] [32, 1024, 14, 14] -- --
│ └─Bottleneck (1) [32, 1024, 14, 14] [32, 1024, 14, 14] -- True
│ │ └─Conv2d (conv1) [32, 1024, 14, 14] [32, 256, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn1) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv2) [32, 256, 14, 14] [32, 256, 14, 14] 589,824 True
│ │ └─BatchNorm2d (bn2) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv3) [32, 256, 14, 14] [32, 1024, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn3) [32, 1024, 14, 14] [32, 1024, 14, 14] 2,048 True
│ │ └─ReLU (relu) [32, 1024, 14, 14] [32, 1024, 14, 14] -- --
│ └─Bottleneck (2) [32, 1024, 14, 14] [32, 1024, 14, 14] -- True
│ │ └─Conv2d (conv1) [32, 1024, 14, 14] [32, 256, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn1) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv2) [32, 256, 14, 14] [32, 256, 14, 14] 589,824 True
│ │ └─BatchNorm2d (bn2) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv3) [32, 256, 14, 14] [32, 1024, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn3) [32, 1024, 14, 14] [32, 1024, 14, 14] 2,048 True
│ │ └─ReLU (relu) [32, 1024, 14, 14] [32, 1024, 14, 14] -- --
│ └─Bottleneck (3) [32, 1024, 14, 14] [32, 1024, 14, 14] -- True
│ │ └─Conv2d (conv1) [32, 1024, 14, 14] [32, 256, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn1) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv2) [32, 256, 14, 14] [32, 256, 14, 14] 589,824 True
│ │ └─BatchNorm2d (bn2) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv3) [32, 256, 14, 14] [32, 1024, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn3) [32, 1024, 14, 14] [32, 1024, 14, 14] 2,048 True
│ │ └─ReLU (relu) [32, 1024, 14, 14] [32, 1024, 14, 14] -- --
│ └─Bottleneck (4) [32, 1024, 14, 14] [32, 1024, 14, 14] -- True
│ │ └─Conv2d (conv1) [32, 1024, 14, 14] [32, 256, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn1) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv2) [32, 256, 14, 14] [32, 256, 14, 14] 589,824 True
│ │ └─BatchNorm2d (bn2) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv3) [32, 256, 14, 14] [32, 1024, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn3) [32, 1024, 14, 14] [32, 1024, 14, 14] 2,048 True
│ │ └─ReLU (relu) [32, 1024, 14, 14] [32, 1024, 14, 14] -- --
│ └─Bottleneck (5) [32, 1024, 14, 14] [32, 1024, 14, 14] -- True
│ │ └─Conv2d (conv1) [32, 1024, 14, 14] [32, 256, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn1) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv2) [32, 256, 14, 14] [32, 256, 14, 14] 589,824 True
│ │ └─BatchNorm2d (bn2) [32, 256, 14, 14] [32, 256, 14, 14] 512 True
│ │ └─ReLU (relu) [32, 256, 14, 14] [32, 256, 14, 14] -- --
│ │ └─Conv2d (conv3) [32, 256, 14, 14] [32, 1024, 14, 14] 262,144 True
│ │ └─BatchNorm2d (bn3) [32, 1024, 14, 14] [32, 1024, 14, 14] 2,048 True
│ │ └─ReLU (relu) [32, 1024, 14, 14] [32, 1024, 14, 14] -- --
├─Sequential (layer4) [32, 1024, 14, 14] [32, 2048, 7, 7] -- True
│ └─Bottleneck (0) [32, 1024, 14, 14] [32, 2048, 7, 7] -- True
│ │ └─Conv2d (conv1) [32, 1024, 14, 14] [32, 512, 14, 14] 524,288 True
│ │ └─BatchNorm2d (bn1) [32, 512, 14, 14] [32, 512, 14, 14] 1,024 True
│ │ └─ReLU (relu) [32, 512, 14, 14] [32, 512, 14, 14] -- --
│ │ └─Conv2d (conv2) [32, 512, 14, 14] [32, 512, 7, 7] 2,359,296 True
│ │ └─BatchNorm2d (bn2) [32, 512, 7, 7] [32, 512, 7, 7] 1,024 True
│ │ └─ReLU (relu) [32, 512, 7, 7] [32, 512, 7, 7] -- --
│ │ └─Conv2d (conv3) [32, 512, 7, 7] [32, 2048, 7, 7] 1,048,576 True
│ │ └─BatchNorm2d (bn3) [32, 2048, 7, 7] [32, 2048, 7, 7] 4,096 True
│ │ └─Sequential (downsample) [32, 1024, 14, 14] [32, 2048, 7, 7] 2,101,248 True
│ │ └─ReLU (relu) [32, 2048, 7, 7] [32, 2048, 7, 7] -- --
│ └─Bottleneck (1) [32, 2048, 7, 7] [32, 2048, 7, 7] -- True
│ │ └─Conv2d (conv1) [32, 2048, 7, 7] [32, 512, 7, 7] 1,048,576 True
│ │ └─BatchNorm2d (bn1) [32, 512, 7, 7] [32, 512, 7, 7] 1,024 True
│ │ └─ReLU (relu) [32, 512, 7, 7] [32, 512, 7, 7] -- --
│ │ └─Conv2d (conv2) [32, 512, 7, 7] [32, 512, 7, 7] 2,359,296 True
│ │ └─BatchNorm2d (bn2) [32, 512, 7, 7] [32, 512, 7, 7] 1,024 True
│ │ └─ReLU (relu) [32, 512, 7, 7] [32, 512, 7, 7] -- --
│ │ └─Conv2d (conv3) [32, 512, 7, 7] [32, 2048, 7, 7] 1,048,576 True
│ │ └─BatchNorm2d (bn3) [32, 2048, 7, 7] [32, 2048, 7, 7] 4,096 True
│ │ └─ReLU (relu) [32, 2048, 7, 7] [32, 2048, 7, 7] -- --
│ └─Bottleneck (2) [32, 2048, 7, 7] [32, 2048, 7, 7] -- True
│ │ └─Conv2d (conv1) [32, 2048, 7, 7] [32, 512, 7, 7] 1,048,576 True
│ │ └─BatchNorm2d (bn1) [32, 512, 7, 7] [32, 512, 7, 7] 1,024 True
│ │ └─ReLU (relu) [32, 512, 7, 7] [32, 512, 7, 7] -- --
│ │ └─Conv2d (conv2) [32, 512, 7, 7] [32, 512, 7, 7] 2,359,296 True
│ │ └─BatchNorm2d (bn2) [32, 512, 7, 7] [32, 512, 7, 7] 1,024 True
│ │ └─ReLU (relu) [32, 512, 7, 7] [32, 512, 7, 7] -- --
│ │ └─Conv2d (conv3) [32, 512, 7, 7] [32, 2048, 7, 7] 1,048,576 True
│ │ └─BatchNorm2d (bn3) [32, 2048, 7, 7] [32, 2048, 7, 7] 4,096 True
│ │ └─ReLU (relu) [32, 2048, 7, 7] [32, 2048, 7, 7] -- --
├─AdaptiveAvgPool2d (avgpool) [32, 2048, 7, 7] [32, 2048, 1, 1] -- --
├─Linear (fc) [32, 2048] [32, 1000] 2,049,000 True
========================================================================================================================
Total params: 25,562,156
Trainable params: 25,562,156
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 130.86
========================================================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 5690.62
Params size (MB): 102.23
Estimated Total Size (MB): 5812.11
========================================================================================================================

开练

训练模型的基本步骤应该都烂熟于心了罢

  • forward pass(前向传播)
  • calculate the loss(计算损失)
  • zero the gradient(清零梯度)
  • backpropagation(反向传播)
  • gradient descent(降低梯度)

然后就是老样子,先写函数,然后调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from tqdm.auto import tqdm
from typing import Dict, List, Tuple

def trainModel(model: torch.nn.Module,
dataLoader: torch.utils.data.DataLoader,
lossFunc: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device) -> Tuple[float, float]:
model.train()
trainLoss, trainAcc = 0, 0
for batch, (X, Y) in enumerate(dataLoader):
X, Y = X.to(device), Y.to(device)
yPred = model(X)
loss = lossFunc(yPred, Y)
trainLoss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
yPredClass = torch.argmax(torch.softmax(yPred, dim=1), dim=1)
trainAcc += (yPredClass == Y).sum().item()/len(yPred)
trainLoss = trainLoss / len(dataLoader)
trainAcc = trainAcc / len(dataLoader)
return trainLoss, trainAcc

def testModel(model: torch.nn.Module,
dataLoader: torch.utils.data.DataLoader,
lossFunc: torch.nn.Module,
device: torch.device) -> Tuple[float, float]:
model.eval()
testLoss, testAcc = 0, 0
with torch.inference_mode():
for batch, (X, Y) in enumerate(dataLoader):
X, Y = X.to(device), Y.to(device)
testPredLogits = model(X)
loss = lossFunc(testPredLogits, Y)
testLoss += loss.item()
testPred = testPredLogits.argmax(dim=1)
testAcc += (testPred == Y).sum().item()/len(testPred)
testLoss = testLoss / len(dataLoader)
testAcc = testAcc / len(dataLoader)
return testLoss, testAcc

def run(model: torch.nn.Module,
trainDataLoader: torch.utils.data.DataLoader,
testDataLoader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
lossFunc: torch.nn.Module,
epochs: int,
device: torch.device) -> Dict[str, list]:
results = {
"trainLoss": [],
"trainAcc": [],
"testLoss": [],
"testAcc": []
}
model.to(device)
for epoch in tqdm(range(epochs)):
trainLoss, trainAcc = trainModel(
model=model,
dataLoader=trainDataLoader,
lossFunc=lossFunc,
optimizer=optimizer,
device=device)
testLoss, testAcc = testModel(
model= model,
dataLoader=testDataLoader,
lossFunc=lossFunc,
device=device)
print(
f"Epoch: {epoch+1} | "
f"trainLoss: {trainLoss:.4f} | "
f"trainAcc: {trainAcc:.4f} | "
f"testLoss: {testLoss:.4f} | "
f"testAcc: {testAcc:.4f}")
results["trainLoss"].append(trainLoss)
results["trainAcc"].append(trainAcc)
results["testLoss"].append(testLoss)
results["testAcc"].append(testAcc)
return results

调用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
torch.manual_seed(2233)

from timeit import default_timer as timer
startTime = timer()

result = run(model=model,
trainDataLoader=trainDataLoader,
testDataLoader=testDataLoader,
optimizer=torch.optim.Adam(model.parameters(), lr=0.001),
lossFunc=torch.nn.CrossEntropyLoss(),
epochs=5,
device=device)
endTime = timer()
print(f"[INFO] Total Training time: {endTime - startTime:.3f} seconds")

然后就是等训练了

1
2
3
4
5
6
Epoch: 1 | trainLoss: 2.4154 | trainAcc: 0.5917 | testLoss: 1.4108 | testAcc: 0.7266
Epoch: 2 | trainLoss: 0.1883 | trainAcc: 0.9469 | testLoss: 1.5478 | testAcc: 0.7812
Epoch: 3 | trainLoss: 0.1803 | trainAcc: 0.9625 | testLoss: 0.2871 | testAcc: 0.9297
Epoch: 4 | trainLoss: 0.0867 | trainAcc: 0.9667 | testLoss: 0.3806 | testAcc: 0.9062
Epoch: 5 | trainLoss: 0.1593 | trainAcc: 0.9531 | testLoss: 0.7680 | testAcc: 0.8281
[INFO] Total Training time: 409.650 seconds

可以看到这里模型没有处理好,然后导致过饱和了。后面就要继续调整参数和训练数据,以训练出性能更加好的模型

得到结果 & 预测

可以先画个图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def plotGraph(results):
trainLoss = results["trainLoss"]
trainAcc = results["trainAcc"]
testLoss = results["testLoss"]
testAcc = results["testAcc"]
epochs = range(len(trainLoss))
plt.figure(figsize=(15, 7))

# Plot Loss
plt.subplot(1, 2, 1)
plt.plot(epochs, trainLoss, label="transLoss")
plt.plot(epochs, testLoss, label="testLoss")
plt.title("Loss")
plt.xlabel("Epochs")
plt.legend()

# Plot Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, trainAcc, label="transAcc")
plt.plot(epochs, testAcc, label="testAcc")
plt.title("Accuracy")
plt.xlabel("Epochs")
plt.legend()

plotGraph(results=result)

然后就是预测了

1
2
3
4
5
6
7
8
9
10
11
12
model.to(device)
model.eval()
transform = torchvision.models.ResNet50_Weights.DEFAULT.transforms()
img = Image.open("custom.jpg")
with torch.inference_mode():
pred = model(transform(img).unsqueeze(dim=0).to(device))
prob = torch.softmax(pred, dim=1)
label = torch.argmax(prob, dim=1)
plt.figure()
plt.imshow(Image.open("custom.jpg"))
plt.title(f"Pred: {className[label]} | Prob: {prob.max():.3f}")
plt.axis(False)


这个图本身就很牛魔所以就是用来玩一玩就算了

終わりました!再由Deepseek给大伙讲个笑话吧(好像一点都不好笑😅)


【ML笔记】PyTorch转移训练 (Transfer Learning)
https://学习.fun/ml-note/ml-transfer-learning/
Author
Stephen Zeng
Posted on
February 15, 2025
Licensed under