
ReID基础 | 表征学习代码实践
发布日期:2021-05-07 00:09:56
浏览次数:27
分类:精选文章
本文共 10232 字,大约阅读时间需要 34 分钟。
文章目录
1. market1501数据集简介
(1)描述:
该数据在清华大学的开放式环境中由六个摄像头采集得到。 该数据集还包括来自DPM的2793个虚假警报,它们是干扰因素,可以模仿真实情况。
该数据集包括了1501个行人,751个行人用于训练,有750个人用于测试。共有3368个query图像,测试集中有19732张图像,训练集中有12936张图像。
(2)每张图像的命名及其包含的信息:
2. Market1501数据集预处理
# 解决python2和python3不同版本,print语法不一样的情况# 加入绝对引入这个新特性from __future__ import print_function, absolute_importimport os # 路径import os.path as ospimport numpy as np # 数据处理import glob # 查找符合特定规则的文件路径名import re # 正则表达式from IPython import embed # 调试用dataset_dir = 'Market1501' # 数据集路径class Market1501(object): """ Market1501 Reference: Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. URL: http://www.liangzheng.org/Project/project_reid.html Dataset statistics: # identities: 1501 (+1 for background) # images: 12936 (train) + 3368 (query) + 15913 (gallery) """ # 初始化 def __init__(self, root='data', **kwargs): # 不确定的参数 # 1.设置绝对路径 self.dataset_dir = osp.join(root, dataset_dir) self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') self.query_dir = osp.join(self.dataset_dir, 'query') self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') # 2.判断文件是否存在 self._check_before_run() # 3.提取数据集信息 train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True) # 训练集pid重排序 query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False) gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False) num_total_pids = num_train_pids + num_query_pids # 所有的ID数目,等于训练和测试的ID数目 num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs # 所有的图片数量 # 4.将上述信息打印到屏幕 print("=> Market1501 loaded") print("Dataset statistics:") print(" ------------------------------") print(" subset | # ids | # images") print(" ------------------------------") print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs)) print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) print(" ------------------------------") print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) print(" ------------------------------") # 5.将变量放到market类中 self.train = train self.query = query self.gallery = gallery self.num_train_pids = num_train_pids self.num_query_pids = num_query_pids self.num_gallery_pids = num_gallery_pids # 判断文件夹的路径有无问题函数 def _check_before_run(self): if not os.path.exists(self.dataset_dir): raise RuntimeError("'{}'路径不存在".format(self.dataset_dir)) if not os.path.exists(self.train_dir): raise RuntimeError("'{}'路径不存在".format(self.train_dir)) if not os.path.exists(self.query_dir): raise RuntimeError("'{}'路径不存在".format(self.query_dir)) if not os.path.exists(self.gallery_dir): raise RuntimeError("'{}'路径不存在".format(self.gallery_dir)) # 读取文件标注信息(ID和CAMID)、图片数量函数 def _process_dir(self, dir_path, relabel=False): # 传入路径 img_paths = glob.glob(osp.join(dir_path, '*.jpg')) # 把dir_path中所有.jpg的文件拿出来 # embed() # 打一个中断,交互式调试 pattern = re.compile(r'([-\d]+)_c(\d)') # 定义一个搜索模式 # 因为训练集中的751个ID,在标签0-1500之间零散分布,为减少神经元数量,我们对label进行重新排序 pid_container = set() # 集合具有去重功能,把训练集中所有的ID存起来 for img_path in img_paths: pid, _ = map(int, pattern.search(img_path).groups()) # 按着这个模式对img_path进行搜索,pid是person id 的缩写 if pid == -1: # person为-1的图像是一些垃圾数据,起到产生难度的作用 continue pid_container.add(pid) # 假设ID已存在则不存 # 产生了一个映射,使用pid对应的label训练分类 pid2label = { pid: label for label, pid in enumerate(pid_container)} # label是从0开始的排序,pid是集合内的数值 ## 对ID进行relabel dataset = [] # 存储文件的一些属性 for img_path in img_paths: pid, camid = map(int, pattern.search(img_path).groups()) if pid == -1: continue # 两个工程性判断 assert 0 <= pid <= 1501 assert 1 <= camid <= 6 camid -= 1 # 把CAMID归一化到0-5 if relabel: pid = pid2label[pid] # pid->label dataset.append((img_path, pid, camid)) num_pids = len(pid_container) num_imgs = len(img_paths) return dataset, num_pids, num_imgsif __name__ == '__main__': data = Market1501(root='C:/Users/xiaobin/PycharmProjects/reid')
3. 重构dataset
from __future__ import print_function, absolute_importimport osfrom PIL import Image # 读入图片import numpy as npimport os.path as ospimport torchfrom torch.utils.data import Dataset# 1.读入图片def read_image(img_path): got_img = False # 判断是否读到图像 # 判断地址是否存在 if not osp.exists(img_path): raise IOError("{} dose not exist".format(img_path)) # 假设还没有读到图像,就一直读 while not got_img: try: img = Image.open(img_path).convert('RGB') # 读入一张图像转为RGB格式 got_img = True except IOError: print('No reading image') return img# 2. 重构datsetclass ImageDataset(Dataset): def __init__(self, dataset, transform=None): # transform数据增广 super(ImageDataset, self).__init__() self.dataset = dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, item): img_pth, pid, camid = self.dataset[item] img = read_image(img_pth) # 数据增广 if self.transform is not None: img = self.transform(img_pth) return img, pid, camidif __name__ == '__main__': import data_manager dataset = data_manager.Market1501(root='C:/Users/xiaobin/PycharmProjects/reid') train_loader = ImageDataset(dataset.train) for batch_id, (imgs, pid, camid) in enumerate(train_loader): print(imgs, pid, camid) # 发现imgs是图片,应该转变为tensor放入网络,所以需要进行transform imgs.save('aaa.jpg') break
4. 数据增广
"""通常情况下,该部分代码不需要自己写,直接使用torch的API(transforms)即可"""from __future__ import absolute_importfrom PIL import Imageimport randomclass Random2DTranslation(object): """ With a probability, first increase image size to (1 + 1/8), and then perform random crop. 给定一个概率,首先把图像扩大到9/8,然后随即裁剪一个固定区域的大小 Args: height (int): target height. width (int): target width. p (float): 随机裁剪的概率。probability of performing this transformation. Default: 0.5. """ def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): self.height = height self.width = width self.p = p self.interpolation = interpolation # 插值的方式(双线性插值) # 裁剪一个固定区域大小的图像 def __call__(self, img): # 1.不做数据增广,直接对原图像缩放 if random.random() < self.p: return img.resize((self.width, self.height), self.interpolation) # 2.做数据增广,首先扩大图像,然后随即裁剪一个固定区域 new_width, new_height = int(round(self.width*1.125)), int(round(self.height*1.125)) # round(x, n) 方法返回浮点数x的四舍五入值。当参数n不存在时,round()函数的输出为整数 resize_img = img.resize((new_width, new_height), self.interpolation) # 将原图像放大 x_maxrange = new_width - self.width # 裁剪掉的最大x值 y_maxrange = new_height - self.height # 裁剪掉的最大y值 x1 = int(round(random.uniform(0, x_maxrange))) # 起点 y1 = int(round(random.uniform(0, y_maxrange))) croped_img = resize_img.crop((x1, y1, x1+self.width, y1+self.height)) # 裁剪 return croped_imgif __name__ == '__main__': import matplotlib.pyplot as plt img = Image.open('/Market1501/bounding_box_test/0000_c1s1_000151_01.jpg') # 1. 使用自己编写的函数 transform = Random2DTranslation(256, 128, 0.5) img_t = transform(img) plt.figure(12) plt.subplot(121) plt.imshow(img) plt.subplot(122) plt.imshow(img_t) plt.show() # 2.使用现成的函数 # import torchvision.transforms as transforms # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225]) # transform = transforms.Compose([ # transforms.RandomResizedCrop((256, 128)), # transforms.RandomHorizontalFlip(), # # transforms.ToTensor(), # # normalize # ]) # img_t = transform(img) # print(img_t.size) # size图片大小,shape数组大小 # plt.imshow(img_t) # plt.show()
5. ResNet模型
from __future__ import absolute_importimport torchimport torchvisionimport torch.nn as nnfrom torch.nn import functional as Ffrom IPython import embed# 只使用ResNet网络提取特征,不改变其网络结构,不需要自己从头写class ResNet50(nn.Module): def __init__(self, num_classes, loss={ 'softmax, metric'}): super(ResNet50, self).__init__() resnet50 = torchvision.models.resnet50(pretrained=True) # 调用在ImageNet上预训练的ResNet50 # nn.Sequential(*list(resnet50.children())) # resnet50.children()是一个迭代器,使用*list包装成一个list指针,使用nn.Sequential包装为一个网络 # nn.Sequential(*list(resnet50.children())[:-2]) # avgpool和FC层需要根据实际情况修改,去掉最后两层 self.base = nn.Sequential(*list(resnet50.children()[:-2])) # 最后一层FC层 self.classifier = nn.Linear(2048, num_classes) def forward(self, x): x = self.base(x) x = F.avg_pool2d(x, x.size()[2:]) # 池化核大小为feature map的H和W,转化为(batch_size, channel, 1, 1) f = x.view(x.size(0), -1) # 展平为(batch_size, channel) # 一般不需要对特征归一化 # 乘以1.转化为浮点数,添加一个小数防止分母变成0 # f = 1.*f / (torch.norm(f, 2, dim=-1, keepdim=True).expand_as(f) + 1e-12) # 最后一层FC只在训练的时候使用 if not self.training: return f y = self.classifier(f) return y if __name__ == '__main__': model = ResNet50(num_classes=751)
发表评论
最新留言
哈哈,博客排版真的漂亮呢~
[***.90.31.176]2025年04月12日 03时54分05秒
关于作者

喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
【故障公告】10:30-10:45 左右 docker swarm 集群节点问题引发故障
2019-03-06
工作半年的思考
2019-03-06
不可思议的纯 CSS 滚动进度条效果
2019-03-06
【CSS进阶】伪元素的妙用--单标签之美
2019-03-06
开始CN的生活
2019-03-06
惊闻NBC在奥运后放弃使用Silverlight
2019-03-06
IE下尚未实现错误的原因
2019-03-06
创建自己的Docker基础镜像
2019-03-06
HTTP 协议图解
2019-03-06
Python 简明教程 --- 20,Python 类中的属性与方法
2019-03-06
Python 简明教程 --- 21,Python 继承与多态
2019-03-06
KNN 算法-理论篇-如何给电影进行分类
2019-03-06
Spring Cloud第九篇 | 分布式服务跟踪Sleuth
2019-03-06
CODING 敏捷实战系列课第三讲:可视化业务分析
2019-03-06
使用 CODING DevOps 全自动部署 Hexo 到 K8S 集群
2019-03-06
工作动态尽在掌握 - 使用 CODING 度量团队效能
2019-03-06
CODING DevOps 代码质量实战系列最后一课,周四发车
2019-03-06
CODING DevOps 深度解析系列第二课报名倒计时!
2019-03-06
CODING DevOps 线下沙龙回顾二:SDK 测试最佳实践
2019-03-06
翻译:《实用的Python编程》03_01_Script
2019-03-06