ReID基础 | 表征学习代码实践
发布日期:2021-05-07 00:09:56 浏览次数:27 分类:精选文章

本文共 10232 字,大约阅读时间需要 34 分钟。

文章目录

1. market1501数据集简介

(1)描述:

该数据在清华大学的开放式环境中由六个摄像头采集得到。 该数据集还包括来自DPM的2793个虚假警报,它们是干扰因素,可以模仿真实情况。

该数据集包括了1501个行人,751个行人用于训练,有750个人用于测试。共有3368个query图像,测试集中有19732张图像,训练集中有12936张图像。

(2)每张图像的命名及其包含的信息:

在这里插入图片描述

其中,第一项为0000表示噪声。

2. Market1501数据集预处理

# 解决python2和python3不同版本,print语法不一样的情况# 加入绝对引入这个新特性from __future__ import print_function, absolute_importimport os                       # 路径import os.path as ospimport numpy as np              # 数据处理import glob						# 查找符合特定规则的文件路径名import re                       # 正则表达式from IPython import embed       # 调试用dataset_dir = 'Market1501'  # 数据集路径class Market1501(object):    """    Market1501    Reference:    Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.    URL: http://www.liangzheng.org/Project/project_reid.html    Dataset statistics:    # identities: 1501 (+1 for background)    # images: 12936 (train) + 3368 (query) + 15913 (gallery)    """    # 初始化    def __init__(self, root='data', **kwargs):  # 不确定的参数    	 # 1.设置绝对路径        self.dataset_dir = osp.join(root, dataset_dir)               self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')        self.query_dir = osp.join(self.dataset_dir, 'query')        self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')        		# 2.判断文件是否存在        self._check_before_run()                # 3.提取数据集信息        train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)             # 训练集pid重排序        query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)        gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)        num_total_pids = num_train_pids + num_query_pids                        # 所有的ID数目,等于训练和测试的ID数目        num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs     # 所有的图片数量        # 4.将上述信息打印到屏幕        print("=> Market1501 loaded")        print("Dataset statistics:")        print("  ------------------------------")        print("  subset   | # ids | # images")        print("  ------------------------------")        print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))        print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))        print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))        print("  ------------------------------")        print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))        print("  ------------------------------")        # 5.将变量放到market类中        self.train = train        self.query = query        self.gallery = gallery        self.num_train_pids = num_train_pids        self.num_query_pids = num_query_pids        self.num_gallery_pids = num_gallery_pids            # 判断文件夹的路径有无问题函数    def _check_before_run(self):        if not os.path.exists(self.dataset_dir):            raise RuntimeError("'{}'路径不存在".format(self.dataset_dir))        if not os.path.exists(self.train_dir):            raise RuntimeError("'{}'路径不存在".format(self.train_dir))        if not os.path.exists(self.query_dir):            raise RuntimeError("'{}'路径不存在".format(self.query_dir))        if not os.path.exists(self.gallery_dir):            raise RuntimeError("'{}'路径不存在".format(self.gallery_dir))    # 读取文件标注信息(ID和CAMID)、图片数量函数    def _process_dir(self, dir_path, relabel=False):                # 传入路径        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))          # 把dir_path中所有.jpg的文件拿出来        # embed()                                                   # 打一个中断,交互式调试        pattern = re.compile(r'([-\d]+)_c(\d)')                     # 定义一个搜索模式        # 因为训练集中的751个ID,在标签0-1500之间零散分布,为减少神经元数量,我们对label进行重新排序        pid_container = set()                                       # 集合具有去重功能,把训练集中所有的ID存起来        for img_path in img_paths:            pid, _ = map(int, pattern.search(img_path).groups())    # 按着这个模式对img_path进行搜索,pid是person id 的缩写            if pid == -1:                                           # person为-1的图像是一些垃圾数据,起到产生难度的作用                continue            pid_container.add(pid)                                  # 假设ID已存在则不存        # 产生了一个映射,使用pid对应的label训练分类        pid2label = {   pid: label for label, pid in enumerate(pid_container)}     # label是从0开始的排序,pid是集合内的数值        ## 对ID进行relabel        dataset = []                       # 存储文件的一些属性        for img_path in img_paths:            pid, camid = map(int, pattern.search(img_path).groups())            if pid == -1:                continue            # 两个工程性判断            assert 0 <= pid <= 1501            assert 1 <= camid <= 6            camid -= 1                      # 把CAMID归一化到0-5            if relabel:                pid = pid2label[pid]        # pid->label            dataset.append((img_path, pid, camid))        num_pids = len(pid_container)        num_imgs = len(img_paths)        return dataset, num_pids, num_imgsif __name__ == '__main__':    data = Market1501(root='C:/Users/xiaobin/PycharmProjects/reid')

3. 重构dataset

from __future__ import print_function, absolute_importimport osfrom PIL import Image     # 读入图片import numpy as npimport os.path as ospimport torchfrom torch.utils.data import Dataset# 1.读入图片def read_image(img_path):    got_img = False                                 # 判断是否读到图像    # 判断地址是否存在    if not osp.exists(img_path):        raise IOError("{} dose not exist".format(img_path))    # 假设还没有读到图像,就一直读    while not got_img:        try:            img = Image.open(img_path).convert('RGB')   # 读入一张图像转为RGB格式            got_img = True        except IOError:            print('No reading image')    return img# 2. 重构datsetclass ImageDataset(Dataset):    def __init__(self, dataset, transform=None):     # transform数据增广        super(ImageDataset, self).__init__()        self.dataset = dataset        self.transform = transform    def __len__(self):        return len(self.dataset)    def __getitem__(self, item):        img_pth, pid, camid = self.dataset[item]        img = read_image(img_pth)        # 数据增广        if self.transform is not None:            img = self.transform(img_pth)        return img, pid, camidif __name__ == '__main__':    import data_manager    dataset = data_manager.Market1501(root='C:/Users/xiaobin/PycharmProjects/reid')    train_loader = ImageDataset(dataset.train)    for batch_id, (imgs, pid, camid) in enumerate(train_loader):        print(imgs, pid, camid)    # 发现imgs是图片,应该转变为tensor放入网络,所以需要进行transform        imgs.save('aaa.jpg')        break

4. 数据增广

"""通常情况下,该部分代码不需要自己写,直接使用torch的API(transforms)即可"""from __future__ import absolute_importfrom PIL import Imageimport randomclass Random2DTranslation(object):    """    With a probability, first increase image size to (1 + 1/8), and then perform random crop.    给定一个概率,首先把图像扩大到9/8,然后随即裁剪一个固定区域的大小    Args:        height (int): target height.        width (int): target width.        p (float): 随机裁剪的概率。probability of performing this transformation. Default: 0.5.    """    def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):        self.height = height        self.width = width        self.p = p        self.interpolation = interpolation          # 插值的方式(双线性插值)    # 裁剪一个固定区域大小的图像    def __call__(self, img):        # 1.不做数据增广,直接对原图像缩放        if random.random() < self.p:            return img.resize((self.width, self.height), self.interpolation)        # 2.做数据增广,首先扩大图像,然后随即裁剪一个固定区域        new_width, new_height = int(round(self.width*1.125)), int(round(self.height*1.125))        # round(x, n) 方法返回浮点数x的四舍五入值。当参数n不存在时,round()函数的输出为整数        resize_img = img.resize((new_width, new_height), self.interpolation)    # 将原图像放大        x_maxrange = new_width - self.width                                     # 裁剪掉的最大x值        y_maxrange = new_height - self.height                                   # 裁剪掉的最大y值        x1 = int(round(random.uniform(0, x_maxrange)))                          # 起点        y1 = int(round(random.uniform(0, y_maxrange)))        croped_img = resize_img.crop((x1, y1, x1+self.width, y1+self.height))   # 裁剪        return croped_imgif __name__ == '__main__':    import matplotlib.pyplot as plt    img = Image.open('/Market1501/bounding_box_test/0000_c1s1_000151_01.jpg')    # 1. 使用自己编写的函数    transform = Random2DTranslation(256, 128, 0.5)    img_t = transform(img)    plt.figure(12)    plt.subplot(121)    plt.imshow(img)    plt.subplot(122)    plt.imshow(img_t)    plt.show()    # 2.使用现成的函数    # import torchvision.transforms as transforms    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],    #                                  std=[0.229, 0.224, 0.225])    # transform = transforms.Compose([    #     transforms.RandomResizedCrop((256, 128)),    #     transforms.RandomHorizontalFlip(),    #     # transforms.ToTensor(),    #     # normalize    # ])    # img_t = transform(img)    # print(img_t.size)       # size图片大小,shape数组大小    # plt.imshow(img_t)    # plt.show()

5. ResNet模型

from __future__ import absolute_importimport torchimport torchvisionimport torch.nn as nnfrom torch.nn import functional as Ffrom IPython import embed# 只使用ResNet网络提取特征,不改变其网络结构,不需要自己从头写class ResNet50(nn.Module):    def __init__(self, num_classes, loss={   'softmax, metric'}):        super(ResNet50, self).__init__()        resnet50 = torchvision.models.resnet50(pretrained=True)     # 调用在ImageNet上预训练的ResNet50        # nn.Sequential(*list(resnet50.children()))        # resnet50.children()是一个迭代器,使用*list包装成一个list指针,使用nn.Sequential包装为一个网络        # nn.Sequential(*list(resnet50.children())[:-2])        # avgpool和FC层需要根据实际情况修改,去掉最后两层        self.base = nn.Sequential(*list(resnet50.children()[:-2]))        # 最后一层FC层        self.classifier = nn.Linear(2048, num_classes)    def forward(self, x):        x = self.base(x)        x = F.avg_pool2d(x, x.size()[2:])   # 池化核大小为feature map的H和W,转化为(batch_size, channel, 1, 1)        f = x.view(x.size(0), -1)           # 展平为(batch_size, channel)        # 一般不需要对特征归一化        # 乘以1.转化为浮点数,添加一个小数防止分母变成0        # f = 1.*f / (torch.norm(f, 2, dim=-1, keepdim=True).expand_as(f) + 1e-12)        # 最后一层FC只在训练的时候使用        if not self.training:            return f        y = self.classifier(f)        return y        if __name__ == '__main__':    model = ResNet50(num_classes=751)
上一篇:论文阅读4 | SpCL
下一篇:ReID基础 | 基于GAN的方法

发表评论

最新留言

哈哈,博客排版真的漂亮呢~
[***.90.31.176]2025年04月12日 03时54分05秒