Pytorch语义分割(2)——–模型搭建

作者 : admin 本文共2708个字,预计阅读时间需要7分钟 发布时间: 2024-06-6 共2人阅读

经典的模型还是Unet,也可以使用torch自带的unet来训练,但为了更好地了解,还是选择自己搭建。

unet.py:

import torch
import torch.nn as nn
import torch.nn.functional as F


class Up(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Up, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.block(x)
        out = F.interpolate(x, scale_factor=2)
        return out


class Down(nn.Module):
    def __init__(self, in_channel, out_channel, stride=2):
        super(Down, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, stride, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        return self.block(x)


class UpConcat(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UpConcat, self).__init__()
        self.up = nn.Upsample(scale_factor=2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channel+out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.ReLU6(inplace=True),
        )

    def forward(self, in_map1, in_map2):
        in_map2 = self.up(in_map2)
        out = torch.cat([in_map1, in_map2], dim=1)
        return self.conv2(out)


class MainNet(nn.Module):
    def __init__(self, num_classes):
        super(MainNet, self).__init__()
        self.down1 = Down(3, 64, stride=1)
        self.down2 = Down(64, 128)
        self.down3 = Down(128, 256)
        self.down4 = Down(256, 512)
        self.down5 = Down(512, 1024)

        # self.conv = nn.Conv2d(1024, 512, 3, 1, 1)

        self.up5concat = UpConcat(1024, 512)
        self.up4concat = UpConcat(512, 256)
        self.up3concat = UpConcat(256, 128)
        self.up2concat = UpConcat(128, 64)

        self.head = nn.Sequential(
            nn.Conv2d(64, num_classes, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        feat1 = self.down1(x)       # 3, 512, 512 ---->64, 512, 512
        feat2 = self.down2(feat1)   # 64, 512, 512 ---->128, 256, 256
        feat3 = self.down3(feat2)   # 128, 256, 256 ---->256,128,128
        feat4 = self.down4(feat3)   # 256,128,128 ---> 512,64,64
        feat5 = self.down5(feat4)   # 512,64,64 ----> 1024,32,32
        print("feat5:", feat5.shape)
        # feat5 = self.conv(feat5)

        feat4_up = self.up5concat(feat4, feat5)
        print("feat4_up:", feat4_up.shape)
        feat3_up = self.up4concat(feat3, feat4_up)
        feat2_up = self.up3concat(feat2, feat3_up)
        feat1_up = self.up2concat(feat1, feat2_up)
        print("feat1_up:", feat1_up.shape)

        print(feat1_up.shape, feat2_up.shape, feat3_up.shape, feat4_up.shape)
        return self.head(feat1_up)


if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tensor = torch.zeros((1, 3, 512, 512)).to(device)
    model = MainNet(num_classes=3).to(device)

    # print(model)
    # model.apply(inplace_relu)

    out = model(tensor)
    # print(out.shape)
    #
    from torchsummary import torchsummary
    torchsummary.summary(model, (3, 512, 512))
    # # from torchstat import stat
    # # stat(model, (3, 512, 512))
    # from thop import profile
    #
    # flops, params = profile(model, inputs=(tensor,))
    #
    # print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
    # print("params=", str(params / 1e6) + '{}'.format("M"))
    #
    # #FLOPs= 63.406604288G
    # # params= 14.127683M

本站无任何商业行为
个人在线分享 » Pytorch语义分割(2)——–模型搭建
E-->