Pytorch实现基于U-net的医学图像分割
发布日期:2021-05-14 14:58:15 浏览次数:18 分类:精选文章

本文共 9248 字,大约阅读时间需要 30 分钟。

Pytorch������������U-net���������������������

������������

  • ������������
  • Train.py
  • ������������
  • Loss������������������
  • ���������������������
  • ���������������������

���������������

Train.py

������������������������������������������

import numpy as np
import os
import time
from torch.nn import nn
import torch
import numpy as np
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from PIL import Image
import torch.optim as optim
from tools.my_dataset import MyDataset
from tools.unet import UNet
from tools.set_seed import set_seed
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dir = os.path.join(BASE_DIR, "..", "data", "blood", "train")
valid_dir = os.path.join(BASE_DIR, "..", "data", "blood", "valid")
train_set = MyDataset(train_dir)
valid_set = MyDataset(valid_dir)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True, drop_last=True)
valid_loader = DataLoader(valid_set, batch_size=1, shuffle=True, drop_last=False)
net = UNet(in_channels=3, out_channels=1, init_features=32)
net.to(device)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
compute_dice(y_pred, y_true):
y_pred, y_true = np.array(y_pred), np.array(y_true)
y_pred, y_true = np.round(y_pred).astype(int), np.round(y_true).astype(int)
if np.sum(y_pred) == np.sum(y_true):
return 0.0
return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true))
train_curve = list()
valid_curve = list()
train_dice_curve = list()
valid_dice_curve = list()
for epoch in range(1, max_epoch+1):
current_lr = scheduler.get_lr()[0]
print(f"Epoch({epoch}/{max_epoch})", end`)
net.train()
epoch_train_loss = 0.0
epoch_train_dice = 0.0
for inputs, labels in train_loader:
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch == 0 and iter == 0:
initial_weights = list(net.parameters())
train_dice = compute_dice(outputs.cpu().data, labels.cpu())
train_curve.append(loss.item())
train_dice_curve.append(train_dice)
epoch_train_loss += loss.item()
print(f"Training: Epoch({epoch}) Iter({iter+1}/{len(train_loader)}) Loss: {loss.item():.4f} DICE: {train_dice:.4f} LR: {current_lr}")
scheduler.step()
if (epoch + 1) % 3 == 0:
valid_loss = 0.0
valid_dice = 0.0
net.eval()
with torch.no_grad():
for inputs, labels in valid_loader:
if torch.cuda.is_available():
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
loss = loss_fn(outputs, labels)
valid_loss += loss.item()
valid_dice += compute_dice(outputs.cpu().data, labels.cpu())
valid_loss_mean = valid_loss / len(valid_loader)
valid_dice_mean = valid_dice / len(valid_loader)
valid_curve.append(valid_loss_mean)
valid_dice_curve.append(valid_dice_mean)
print(f"Valid: Epoch({epoch}) Mean Loss: {valid_loss_mean:.4f} DICE: {valid_dice_mean:.4f}")
torch.cuda.empty_cache()

Inference.py

������������������������������

model_path = "checkpoint_399_epoch.pkl"
model = UNet(in_channels=3, out_channels=1, init_features=32)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
img_dir = os.path.join(BASE_DIR, "..", "..", "data", "PortraitDataset", "valid")

my_dataset.py

���������������������������

import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
random.seed(1)
class MyDataset(Dataset):
def __init__(self, data_dir, transform=None, in_size=224):
super(MyDataset, self).__init__()
self.data_dir = data_dir
self.transform = transform
self.in_size = in_size
self._get_img_path()
def __getitem__(self, index):
path_label = self.label_path_list[index]
path_img = path_label[:-9] + ".tif"
img_pil = Image.open(path_img).convert("RGB")
img_pil = img_pil.resize((self.in_size, self.in_size), Image.BILINEAR)
img_chw = np.array(img_pil)
img_chw = img_chw.transpose((2, 0, 1))
label_pil = Image.open(path_label).convert("L")
label_pil = label_pil.resize((self.in_size, self.in_size), Image.NEAREST)
label_chw = np.array(label_pil)
label_chw = label_chw[np.newaxis, :, :]
label_chw[label_chw != 0] = 1
if self.transform is not None:
img_chw_tensor = self.transform(img_chw)
label_chw_tensor = self.transform(label_chw)
else:
img_chw_tensor = torch.from_numpy(img_chw).float()
label_chw_tensor = torch.from_numpy(label_chw).float()
return img_chw_tensor, label_chw_tensor
def __len__(self):
return len(self.label_path_list)
def _get_img_path(self):
file_list = os.listdir(self.data_dir)
file_list = list(filter(lambda x: x.endswith("_mask.gif"), file_list))
path_list = [os.path.join(self.data_dir, name) for name in file_list]
random.shuffle(path_list)
if len(path_list) == 0:
raise Exception("data_dir exists but is empty. Please check your path to images.")
self.label_path_list = path_list

set_seed.py

���������������������

import random
import torch
import numpy as np
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

unet.py

U-net������������������

from collections import OrderedDict
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1, init_features=32):
super(UNet, self).__init__()
features = init_features
self.encoder1 = UNet._block(in_channels, features, name="enc1")
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = UNet._block(features, features * 2, name="enc2")
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
self.decoder1 = UNet._block(features * 2, features, name="dec1")
self.conv = nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=1)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, enc4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, enc3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, enc2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return torch.sigmoid(self.conv(dec1))
@staticmethod
def _block(in_channels, features, name):
return nn.Sequential(
OrderedDict(
[
(f"{name}conv1", nn.Conv2d(in_channels=in_channels, out_channels=features, kernel_size=3, padding=1, bias=False)),
(f"{name}norm1", nn.BatchNorm2d(num_features=features)),
(f"{name}relu1", nn.ReLU(inplace=True)),
(f"{name}conv2", nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False)),
(f"{name}norm2", nn.BatchNorm2d(num_features=features)),
(f"{name}relu2", nn.ReLU(inplace=True)),
]
)
)
上一篇:Pytorch实现Faster-RCNN
下一篇:RNN

发表评论

最新留言

很好
[***.229.124.182]2025年04月25日 17时31分35秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章