1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| class FashionMNISTModelV2(nn.Module): def __init__(self, in_ch, out_ft, hid): super().__init__() self.conv_block_1 = nn.Sequential( nn.Conv2d(in_channels=in_ch, out_channels=hid, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels=hid, out_channels=hid, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.conv_block_2 = nn.Sequential( nn.Conv2d(in_channels=hid, out_channels=hid, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(in_channels=hid, out_channels=hid, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(in_features=hid*7*7, out_features=out_ft), ) def forward(self, x): x = self.conv_block_1(x) x = self.conv_block_2(x) x = self.classifier(x) return x model_2 = FashionMNISTModelV2(in_ch=1, out_ft=10, hid=10).to(device) model_2
|