《动手学深度学习》(PyTorch版)代码注释 - 50 【Semantic_segmentation】
发布日期:2021-05-19 18:03:18 浏览次数:19 分类:精选文章

本文共 2204 字,大约阅读时间需要 7 分钟。

语义分割与数据集配置说明

开源代码说明

本博客代码来源于开源项目,专业学习过程中对代码进行了丰富的注释,便于理解各模块功能和实现原理。

环境配置

  • 代码版本:Python 3.8
  • 运行平台:Windows 10
  • 开发工具:PyCharm

功能说明

此模块主要负责语义分割任务和数据集的构建 мор 公司相对复杂,因此注释较多且功能实现较为繁琐。

代码示例

from matplotlib import.pyplot as plt
import time
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
from PIL import Image
from tqdm import tqdm
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l

随机裁剪功能代码

def voc_rand_crop(feature, label, height, width):
i, j, h, w = torchvision.transforms.RandomCrop.get_params(
feature, output_size=(height, width))
feature = torchvision.transforms.functional.crop(feature, i, j, h, w)
label = torchvision.transforms.functional.crop(label, i, j, h, w)
return feature, label

数据集定义类

class VOCSegDataset(torch.utils.data.Dataset):
def __init__(self, is_train, crop_size, voc_dir, colormap2label, max_num=None):
self.rgb_mean = np.array([0.485, 0.456, 0.406])
self.rgb_std = np.array([0.229, 0.224, 0.225])
self.tsf = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=self.rgb_mean,
std=self.rgb_std)
])
features, labels = read_voc_images(root=voc_dir, is_train=is_train, max_num=max_num)
self.features = self.filter(features)
self.labels = self.filter(labels)
self.colormap2label = colormap2label

数据集筛选方法

def filter(self, imgs):
return [img for img in imgs if
img.size()[1] >= self.crop_size[0] and
img.size()[0] >= self.crop_size[1]]

这个代码注释详细解释了各个部分的功能,帮助开发者理解代码的实现细节。如果需要获取完整版本,请访问原始代码仓库链接。

模型训练准备

crop_size = (320, 480)
max_num = 100
voc_train = VOCSegDataset(True, crop_size, voc_dir, colormap2label, max_num)
voc_test = VOCSegDataset(False, crop_size, voc_dir, colormap2label, max_num)
batch_size = 64
num_workers = 4 # 多线程加速
train_iter = DataLoader(voc_train, batch_size, shuffle=True,
drop_last=True, num_workers=num_workers)
test_iter = DataLoader(voc_test, batch_size, drop_last=True,
num_workers=num_workers)

代码清单中涵盖了数据集的创建、随机裁剪、特征提取和标签转换等核心功能,适合在语义分割任务中使用。

上一篇:《动手学深度学习》(PyTorch版)代码注释 - 51 【Style_transfer】
下一篇:《动手学深度学习》(PyTorch版)代码注释 - 48 【Multi-scale_target_detection】

发表评论

最新留言

哈哈,博客排版真的漂亮呢~
[***.90.31.176]2025年04月24日 11时40分03秒