自监督图像论文复现 | BYOL(pytorch)| 2020
发布日期:2021-05-09 19:23:28 浏览次数:18 分类:博客文章

本文共 10227 字,大约阅读时间需要 34 分钟。

������

������������������������,������������������Bootstrap Your Onw Latent������������������������������������

https://juejin.cn/post/6922347006144970760

���������������������������pytorch������������������������������������������������������������������������������

github���https://github.com/lucidrains/byol-pytorch

������������������������������������������������������������������������������GPU���������������

������������������

class BYOL(nn.Module):    def __init__(        self,        net,        image_size,        hidden_layer = -2,        projection_size = 256,        projection_hidden_size = 4096,        augment_fn = None,        augment_fn2 = None,        moving_average_decay = 0.99,        use_momentum = True    ):        super().__init__()        self.net = net        # default SimCLR augmentation        DEFAULT_AUG = torch.nn.Sequential(            RandomApply(                T.ColorJitter(0.8, 0.8, 0.8, 0.2),                p = 0.3            ),            T.RandomGrayscale(p=0.2),            T.RandomHorizontalFlip(),            RandomApply(                T.GaussianBlur((3, 3), (1.0, 2.0)),                p = 0.2            ),            T.RandomResizedCrop((image_size, image_size)),            T.Normalize(                mean=torch.tensor([0.485, 0.456, 0.406]),                std=torch.tensor([0.229, 0.224, 0.225])),        )        self.augment1 = default(augment_fn, DEFAULT_AUG)        self.augment2 = default(augment_fn2, self.augment1)        self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)        self.use_momentum = use_momentum        self.target_encoder = None        self.target_ema_updater = EMA(moving_average_decay)        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)        # get device of network and make wrapper same device        device = get_module_device(net)        self.to(device)        # send a mock image tensor to instantiate singleton parameters        self.forward(torch.randn(2, 3, image_size, image_size, device=device))    @singleton('target_encoder')    def _get_target_encoder(self):        target_encoder = copy.deepcopy(self.online_encoder)        set_requires_grad(target_encoder, False)        return target_encoder    def reset_moving_average(self):        del self.target_encoder        self.target_encoder = None    def update_moving_average(self):        assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'        assert self.target_encoder is not None, 'target encoder has not been created yet'        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)    def forward(self, x, return_embedding = False):        if return_embedding:            return self.online_encoder(x)        image_one, image_two = self.augment1(x), self.augment2(x)        online_proj_one, _ = self.online_encoder(image_one)        online_proj_two, _ = self.online_encoder(image_two)        online_pred_one = self.online_predictor(online_proj_one)        online_pred_two = self.online_predictor(online_proj_two)        with torch.no_grad():            target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder            target_proj_one, _ = target_encoder(image_one)            target_proj_two, _ = target_encoder(image_two)            target_proj_one.detach_()            target_proj_two.detach_()        loss_one = loss_fn(online_pred_one, target_proj_two.detach())        loss_two = loss_fn(online_pred_two, target_proj_one.detach())        loss = loss_one + loss_two        return loss.mean()
  • ������forward()������������������������������������������������������������������������������������loss
  • ������������������������������return_embedding=True���������������������������online network������encoder������������������������������������������������predictor���������������������������������encoder���������������������encoder+projector���
  • ������������self.augment1���self.augment2������������������������������������������������������������������view���
  • ���������������������online-encoder������������������������������������������������������������online network���������������������target network���������������������������������online-encoder���������������������������������������������������symmetric loss���������������������������������������������������������������online network���target network���
  • ���target network������������������������������������������������������target network���������online network������������������
  • ������self.use_momentum=False,������������������������������������target network���������������������������online network���������target network���������������������������github���������������600���stars������������������������������self.use_momentum=True,���������������online network������������target network���������������������������������������������
  • ������������������loss_fn���������������������return loss.mean()

������������������������������������������BYOL���������������������������������������������������������4������

  • online_encoder���������������
  • predictor���������������
  • ���������������������������������
  • loss_fn���������������������������

augment

���������������������������������������������

# default SimCLR augmentation        DEFAULT_AUG = torch.nn.Sequential(            RandomApply(                T.ColorJitter(0.8, 0.8, 0.8, 0.2),                p = 0.3            ),            T.RandomGrayscale(p=0.2),            T.RandomHorizontalFlip(),            RandomApply(                T.GaussianBlur((3, 3), (1.0, 2.0)),                p = 0.2            ),            T.RandomResizedCrop((image_size, image_size)),            T.Normalize(                mean=torch.tensor([0.485, 0.456, 0.406]),                std=torch.tensor([0.229, 0.224, 0.225])),        )        self.augment1 = default(augment_fn, DEFAULT_AUG)        self.augment2 = default(augment_fn2, self.augment1)

���������������

  • ���������������������������pipeline������augment1���augment2������������������������������������augment1���augment2���������������DEFAULT_AUG���
  • from torchvision import transforms as T

���������������������������torchvision.transforms.ColorJitter������������������������

���������API���������������������������������������������������������������������������������������������������������

encoder+projector

class NetWrapper(nn.Module):    def __init__(self, net, projection_size, projection_hidden_size, layer = -2):        super().__init__()        self.net = net        self.layer = layer        self.projector = None        self.projection_size = projection_size        self.projection_hidden_size = projection_hidden_size        self.hidden = None        self.hook_registered = False    def _find_layer(self):        if type(self.layer) == str:            modules = dict([*self.net.named_modules()])            return modules.get(self.layer, None)        elif type(self.layer) == int:            children = [*self.net.children()]            return children[self.layer]        return None    def _hook(self, _, __, output):        self.hidden = flatten(output)    def _register_hook(self):        layer = self._find_layer()        assert layer is not None, f'hidden layer ({self.layer}) not found'        handle = layer.register_forward_hook(self._hook)        self.hook_registered = True    @singleton('projector')    def _get_projector(self, hidden):        _, dim = hidden.shape        projector = MLP(dim, self.projection_size, self.projection_hidden_size)        return projector.to(hidden)    def get_representation(self, x):        if self.layer == -1:            return self.net(x)        if not self.hook_registered:            self._register_hook()        _ = self.net(x)        hidden = self.hidden        self.hidden = None        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'        return hidden    def forward(self, x, return_embedding = False):        representation = self.get_representation(x)        if return_embedding:            return representation        projector = self._get_projector(representation)        projection = projector(representation)        return projection, representation

���������������������encoder+projector���������������encoder���projector���

encoder

������������������NetWrapper������������������������������������������������������������������������������������������������

from torchvision import models, transformsresnet = models.resnet50(pretrained=True)

������encoder������������������������������������resnet50���������������������������������resnet������������������(batch_size,1000)������������tensor���

projector

������������MLP���������������

class MLP(nn.Module):    def __init__(self, dim, projection_size, hidden_size = 4096):        super().__init__()        self.net = nn.Sequential(            nn.Linear(dim, hidden_size),            nn.BatchNorm1d(hidden_size),            nn.ReLU(inplace=True),            nn.Linear(hidden_size, projection_size)        )    def forward(self, x):        return self.net(x)

���������������+BN+���������������������������������������������������������������������������������������������������BN+relu���������������MLP���������������������(batch_size,projection_size)���������������tensor���

predictor

self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

������predictor������������������projector������������������������������������predictor���������������������������������������projection_size���

������������������������������������������������������������������������������������������BYOL������������������������������predictor���������������������������������������������������online network���target network������������������������������������������������������������������������������������������������������������������loss=0������������������

loss_fn

def loss_fn(x, y):    x = F.normalize(x, dim=-1, p=2)    y = F.normalize(y, dim=-1, p=2)    return 2 - 2 * (x * y).sum(dim=-1)

������������������������������

���������������������BYOL������������������������������������������������������

������

上一篇:MyBatis入门十一:Mybatis数据插入、修改、删除三:更新数据,删除数据;
下一篇:学习进度笔记

发表评论

最新留言

留言是一种美德,欢迎回访!
[***.207.175.100]2025年04月07日 08时20分43秒