pytorch 训练数据以及测试 全部代码(9)---deeplab v3+ 对Cityscapes数据的处理
发布日期:2021-06-29 11:45:01 浏览次数:4 分类:技术文章

本文共 11390 字,大约阅读时间需要 37 分钟。

 下面是全部的代码:

import osimport torchimport numpy as npimport scipy.misc as mfrom PIL import Imagefrom torch.utils import datafrom dataloaders.utils import recursive_glob, decode_segmapfrom mypath import Pathclass CityscapesSegmentation(data.Dataset):    def __init__(self, root=Path.db_root_dir('cityscapes'), split="train", transform=None):        self.root = root        self.split = split        self.transform = transform        self.files = {}        self.n_classes = 19        self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)        self.annotations_base = os.path.join(self.root, 'gtFine', self.split)        self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')        self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]  # 16        self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]  # 19        self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \                            'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \                            'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \                            'motorcycle', 'bicycle']  # 20        self.ignore_index = 255        self.class_map = dict(zip(self.valid_classes, range(self.n_classes)))        if not self.files[split]:            raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))        print("Found %d %s images" % (len(self.files[split]), split))    def __len__(self):        return len(self.files[self.split])    def __getitem__(self, index):        img_path = self.files[self.split][index].rstrip()        lbl_path = os.path.join(self.annotations_base,                                img_path.split(os.sep)[-2],  # os.sep=='/'  get city name                                os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')        _img = Image.open(img_path).convert('RGB')        _tmp = np.array(Image.open(lbl_path), dtype=np.uint8)        _tmp = self.encode_segmap(_tmp)        _target = Image.fromarray(_tmp)        sample = {'image': _img, 'label': _target}        if self.transform:  # to do Data transformation or Data enhancement and  convert torch            sample = self.transform(sample)        return sample    def encode_segmap(self, mask):  # to change original image pixel value to 0-18 and 255 according class id        # Put all void classes to zero        for _voidc in self.void_classes:            mask[mask == _voidc] = self.ignore_index  # no need class and unto set 255 (white)        for _validc in self.valid_classes:            mask[mask == _validc] = self.class_map[_validc]  # 19 classes encode from 0 to 18        return maskif __name__ == '__main__':    from dataloaders import custom_transforms as tr    from dataloaders.utils import decode_segmap    from torch.utils.data import DataLoader    from torchvision import transforms    import matplotlib.pyplot as plt  # to show image    composed_transforms_tr = transforms.Compose([        tr.RandomHorizontalFlip(),        tr.RandomScale((0.5, 0.75)),        tr.RandomCrop((512, 1024)),        tr.RandomRotate(5),        tr.ToTensor()])    cityscapes_train = CityscapesSegmentation(split='train',                                transform=composed_transforms_tr)    dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)    for ii, sample in enumerate(dataloader):        for jj in range(sample["image"].size()[0]):            img = sample['image'].numpy()  # from torch convert to numpy n x c x h x w            gt = sample['label'].numpy()  # from torch convert to numpy n x c x h x w            tmp = np.array(gt[jj]).astype(np.uint8)  # tmp.shape=c x h x w            tmp = np.squeeze(tmp, axis=0)  # if c=1,tmp.shape=c x h x w; or tmp.shape=c x h x w            segmap = decode_segmap(tmp, dataset='cityscapes')            img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8)  # img_tmp=h x w x c            plt.figure()            plt.title('display')            plt.subplot(211)            plt.imshow(img_tmp)            plt.subplot(212)            plt.imshow(segmap)        if ii == 1:            break    plt.show(block=True)

下面怎么读取图片的 可以参考:

self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')

转换的为:

composed_transforms_tr = transforms.Compose([        tr.RandomHorizontalFlip(),        tr.RandomScale((0.5, 0.75)),        tr.RandomCrop((512, 1024)),        tr.RandomRotate(5),        tr.ToTensor()])

上面关于图像变换或者说增强的实现代码如下:

上面的前四个变换都保持了原图和标签的type为PIL.PngImagePlugin.PngImageFile,这些图的像素数值大小和类型(uint8)不发生改变,结构也没有变化(原图为h x w x 3,标签图为h x w)

class RandomHorizontalFlip(object):    def __call__(self, sample):        img = sample['image']        mask = sample['label']        if random.random() < 0.5:            img = img.transpose(Image.FLIP_LEFT_RIGHT)            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)        return {'image': img,                'label': mask}class RandomScale(object):    def __init__(self, limit):        self.limit = limit    def __call__(self, sample):        img = sample['image']        mask = sample['label']        assert img.size == mask.size        scale = random.uniform(self.limit[0], self.limit[1])        w = int(scale * img.size[0])        h = int(scale * img.size[1])        img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)        return {'image': img, 'label': mask}class RandomCrop(object):    def __init__(self, size, padding=0):        if isinstance(size, numbers.Number):            self.size = (int(size), int(size))        else:            self.size = size # h, w        self.padding = padding    def __call__(self, sample):        img, mask = sample['image'], sample['label']        if self.padding > 0:            img = ImageOps.expand(img, border=self.padding, fill=0)            mask = ImageOps.expand(mask, border=self.padding, fill=0)        assert img.size == mask.size        w, h = img.size        th, tw = self.size # target size        if w == tw and h == th:            return {'image': img,                    'label': mask}        if w < tw or h < th:            img = img.resize((tw, th), Image.BILINEAR)            mask = mask.resize((tw, th), Image.NEAREST)            return {'image': img,                    'label': mask}        x1 = random.randint(0, w - tw)        y1 = random.randint(0, h - th)        img = img.crop((x1, y1, x1 + tw, y1 + th))        mask = mask.crop((x1, y1, x1 + tw, y1 + th))        return {'image': img,                'label': mask}class RandomRotate(object):    def __init__(self, degree):        self.degree = degree    def __call__(self, sample):        img = sample['image']        mask = sample['label']        rotate_degree = random.random() * 2 * self.degree - self.degree        img = img.rotate(rotate_degree, Image.BILINEAR)        mask = mask.rotate(rotate_degree, Image.NEAREST)        return {'image': img,                'label': mask}class ToTensor(object):    """Convert ndarrays in sample to Tensors."""    def __call__(self, sample):        # swap color axis because        # numpy image: H x W x C        # torch image: C X H X W        img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))        mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))        mask[mask == 255] = 0   #        img = torch.from_numpy(img).float()        mask = torch.from_numpy(mask).float()        return {'image': img,                'label': mask}

直到第五个也就是最后一个(ToTensor函数)变化,对原图首先从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w x c )到(c x h x w),最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将原图转变为一个tensor可以输入后面的深度学习网络中了。

与此相对的标签图也是从PIL.PngImagePlugin.PngImageFile变到numpy类型同时数据类型从uint8变为float32类型,然后维度变化从(h x w  )增加一维得到(h x w x 1)接着调整维度到(1 x h x w),然后mask里面的数值进行处理:255.值大小的全部被重置为0,所以mask里面的值现在只有0-18这些数字了;最后从numpy类型变为torch的tensor类型,同时强制将数据类型为torch.FloatTensor。这样,就将标签图转变为一个tensor可以输入后面的深度学习网络中了。

对上面的两个tensor的重新变成图像的代码如下:

for ii, sample in enumerate(dataloader):        for jj in range(sample["image"].size()[0]):            img = sample['image'].numpy()  # from torch convert to numpy n x 3 x h x w            gt = sample['label'].numpy()  # from torch convert to numpy n x 1 x h x w            tmp = np.array(gt[jj]).astype(np.uint8)  # tmp.shape=1 x h x w            tmp = np.squeeze(tmp, axis=0)  # if c=1,tmp.shape=h x w; or tmp.shape=c x h x w dimension-reduction            segmap = decode_segmap(tmp, dataset='cityscapes')             img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8)  # img_tmp=h x w x 3            plt.figure()            plt.title('display')            plt.subplot(211)            plt.imshow(img_tmp)            plt.subplot(212)            plt.imshow(segmap)        if ii == 1:            break    plt.show(block=True)

里面的标签图(h x w)解码代码如下:

只要是同一类的就给相应的RGB数值,然后整合三张图到一张图里面

segmap = decode_segmap(tmp, dataset='cityscapes')  # tmp.shape=h x wdef decode_segmap(label_mask, dataset, plot=False):    """Decode segmentation class labels into a color image    Args:        label_mask (np.ndarray): an (M,N) array of integer values denoting          the class label at each spatial location.        plot (bool, optional): whether to show the resulting color image          in a figure.    Returns:        (np.ndarray, optional): the resulting decoded color image.    """    if dataset == 'pascal':        n_classes = 21        label_colours = get_pascal_labels()    elif dataset == 'cityscapes':        n_classes = 19        label_colours = get_cityscapes_labels()    else:        raise NotImplementedError    r = label_mask.copy()  # h x w    g = label_mask.copy()  # h x w    b = label_mask.copy()  # h x w    for ll in range(0, n_classes):        r[label_mask == ll] = label_colours[ll, 0]        g[label_mask == ll] = label_colours[ll, 1]        b[label_mask == ll] = label_colours[ll, 2]    rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # h x w x 3初始化    rgb[:, :, 0] = r / 255.0    rgb[:, :, 1] = g / 255.0    rgb[:, :, 2] = b / 255.0    if plot:        plt.imshow(rgb)        plt.show()    else:        return rgb

下面就是label_colours的和类别对应色彩代码详情可以看cityscapes的标签颜色对照表:

def get_cityscapes_labels():    return np.array([        # [  0,   0,   0],        [128, 64, 128],        [244, 35, 232],        [70, 70, 70],        [102, 102, 156],        [190, 153, 153],        [153, 153, 153],        [250, 170, 30],        [220, 220, 0],        [107, 142, 35],        [152, 251, 152],        [0, 130, 180],        [220, 20, 60],        [255, 0, 0],        [0, 0, 142],        [0, 0, 70],        [0, 60, 100],        [0, 80, 100],        [0, 0, 230],        [119, 11, 32]])def get_pascal_labels():    """Load the mapping that associates pascal classes with label colors    Returns:        np.ndarray with dimensions (21, 3)    """    return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],                       [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],                       [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],                       [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],                       [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],                       [0, 64, 128]])

 

转载地址:https://blog.csdn.net/zz2230633069/article/details/84668984 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:PASCAL VOC 2012数据集介绍
下一篇:PIL包里面的Image模块里面的函数讲解,不能直接对numpy存储成图像,要进行转化再存取

发表评论

最新留言

初次前来,多多关照!
[***.217.46.12]2024年04月12日 19时43分36秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章

在电网工作,有多高大上? 2021-07-02
「2020年大学生电子设计竞赛分享」电源题,省一等奖! 2021-07-02
又一国产开源微内核操作系统上线!源代码已开放下载 2021-07-02
10年老兵!从大学毕业生到嵌入式系统工程师的修炼之道…… 2021-07-02
如何才能学好单片机? 2021-07-02
一根网线有这么多“花样”,你知道吗? 2021-07-02
雷军1994年写的诗一样的代码,我把它运行起来了! 2021-07-02
2020年大学生电子设计竞赛,B题,单相在线式不间断电源,详细技术方案! 2021-07-02
大佬终于把鸿蒙OS讲明白了,收藏了! 2021-07-02
C语言指针,这可能是史上最干最全的讲解啦(附代码)!!! 2021-07-02
国内大陆有哪些芯片公司处于世界前10?一起看看! 2021-07-02
单精度、双精度、多精度和混合精度计算的区别是什么? 2021-07-02
中国35位“大国工匠”榜单出炉!西工大、西电合计占半壁江山!清华仅1人!... 2021-07-02
知乎热议:嵌入式开发中C++好用吗? 2021-07-02
2020,Python 已死? 2021-07-02
漫画:程序员相亲?哈哈哈哈哈哈 2021-07-02
30种EMC标准电路分享,再不收藏就晚了! 2021-07-02
这100道Linux常见面试题,看看你会多少? 2019-04-29
十年硬件老司机,结合实际案例,带你探索单片机低功耗设计! 2019-04-29
“2020年嵌入式软件秋招经验和对嵌入式软件未来的一点思考” 2019-04-29