
本文共 10227 字,大约阅读时间需要 34 分钟。
������
������������������������,������������������Bootstrap Your Onw Latent������������������������������������
https://juejin.cn/post/6922347006144970760���������������������������pytorch������������������������������������������������������������������������������
github���https://github.com/lucidrains/byol-pytorch������������������������������������������������������������������������������GPU���������������
������������������
class BYOL(nn.Module): def __init__( self, net, image_size, hidden_layer = -2, projection_size = 256, projection_hidden_size = 4096, augment_fn = None, augment_fn2 = None, moving_average_decay = 0.99, use_momentum = True ): super().__init__() self.net = net # default SimCLR augmentation DEFAULT_AUG = torch.nn.Sequential( RandomApply( T.ColorJitter(0.8, 0.8, 0.8, 0.2), p = 0.3 ), T.RandomGrayscale(p=0.2), T.RandomHorizontalFlip(), RandomApply( T.GaussianBlur((3, 3), (1.0, 2.0)), p = 0.2 ), T.RandomResizedCrop((image_size, image_size)), T.Normalize( mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])), ) self.augment1 = default(augment_fn, DEFAULT_AUG) self.augment2 = default(augment_fn2, self.augment1) self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) self.use_momentum = use_momentum self.target_encoder = None self.target_ema_updater = EMA(moving_average_decay) self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) # get device of network and make wrapper same device device = get_module_device(net) self.to(device) # send a mock image tensor to instantiate singleton parameters self.forward(torch.randn(2, 3, image_size, image_size, device=device)) @singleton('target_encoder') def _get_target_encoder(self): target_encoder = copy.deepcopy(self.online_encoder) set_requires_grad(target_encoder, False) return target_encoder def reset_moving_average(self): del self.target_encoder self.target_encoder = None def update_moving_average(self): assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' assert self.target_encoder is not None, 'target encoder has not been created yet' update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) def forward(self, x, return_embedding = False): if return_embedding: return self.online_encoder(x) image_one, image_two = self.augment1(x), self.augment2(x) online_proj_one, _ = self.online_encoder(image_one) online_proj_two, _ = self.online_encoder(image_two) online_pred_one = self.online_predictor(online_proj_one) online_pred_two = self.online_predictor(online_proj_two) with torch.no_grad(): target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder target_proj_one, _ = target_encoder(image_one) target_proj_two, _ = target_encoder(image_two) target_proj_one.detach_() target_proj_two.detach_() loss_one = loss_fn(online_pred_one, target_proj_two.detach()) loss_two = loss_fn(online_pred_two, target_proj_one.detach()) loss = loss_one + loss_two return loss.mean()
- ������
forward()
������������������������������������������������������������������������������������loss - ������������������������������
return_embedding=True
���������������������������online network������encoder������������������������������������������������predictor���������������������������������encoder���������������������encoder+projector��� - ������������self.augment1���self.augment2������������������������������������������������������������������view���
- ���������������������online-encoder������������������������������������������������������������online network���������������������target network���������������������������������online-encoder���������������������������������������������������symmetric loss���������������������������������������������������������������online network���target network���
- ���target network������������������������������������������������������target network���������online network������������������
- ������
self.use_momentum=False
,������������������������������������target network���������������������������online network���������target network���������������������������github���������������600���stars������������������������������self.use_momentum=True,���������������online network������������target network��������������������������������������������� - ������������������
loss_fn
���������������������return loss.mean()
������������������������������������������BYOL���������������������������������������������������������4������
- online_encoder���������������
- predictor���������������
- ���������������������������������
- loss_fn���������������������������
augment
���������������������������������������������
# default SimCLR augmentation DEFAULT_AUG = torch.nn.Sequential( RandomApply( T.ColorJitter(0.8, 0.8, 0.8, 0.2), p = 0.3 ), T.RandomGrayscale(p=0.2), T.RandomHorizontalFlip(), RandomApply( T.GaussianBlur((3, 3), (1.0, 2.0)), p = 0.2 ), T.RandomResizedCrop((image_size, image_size)), T.Normalize( mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])), ) self.augment1 = default(augment_fn, DEFAULT_AUG) self.augment2 = default(augment_fn2, self.augment1)
���������������
- ���������������������������pipeline������augment1���augment2������������������������������������augment1���augment2���������������DEFAULT_AUG���
from torchvision import transforms as T
���������������������������torchvision.transforms.ColorJitter������
������������������
���������API���������������������������������������������������������������������������������������������������������
encoder+projector
class NetWrapper(nn.Module): def __init__(self, net, projection_size, projection_hidden_size, layer = -2): super().__init__() self.net = net self.layer = layer self.projector = None self.projection_size = projection_size self.projection_hidden_size = projection_hidden_size self.hidden = None self.hook_registered = False def _find_layer(self): if type(self.layer) == str: modules = dict([*self.net.named_modules()]) return modules.get(self.layer, None) elif type(self.layer) == int: children = [*self.net.children()] return children[self.layer] return None def _hook(self, _, __, output): self.hidden = flatten(output) def _register_hook(self): layer = self._find_layer() assert layer is not None, f'hidden layer ({self.layer}) not found' handle = layer.register_forward_hook(self._hook) self.hook_registered = True @singleton('projector') def _get_projector(self, hidden): _, dim = hidden.shape projector = MLP(dim, self.projection_size, self.projection_hidden_size) return projector.to(hidden) def get_representation(self, x): if self.layer == -1: return self.net(x) if not self.hook_registered: self._register_hook() _ = self.net(x) hidden = self.hidden self.hidden = None assert hidden is not None, f'hidden layer {self.layer} never emitted an output' return hidden def forward(self, x, return_embedding = False): representation = self.get_representation(x) if return_embedding: return representation projector = self._get_projector(representation) projection = projector(representation) return projection, representation
���������������������encoder+projector���������������encoder���projector���
encoder
������������������NetWrapper������������������������������������������������������������������������������������������������
from torchvision import models, transformsresnet = models.resnet50(pretrained=True)
������encoder������������������������������������resnet50���������������������������������resnet������������������(batch_size,1000)������������tensor���
projector
������������MLP���������������
class MLP(nn.Module): def __init__(self, dim, projection_size, hidden_size = 4096): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(inplace=True), nn.Linear(hidden_size, projection_size) ) def forward(self, x): return self.net(x)
���������������+BN+���������������������������������������������������������������������������������������������������BN+relu���������������MLP���������������������(batch_size,projection_size)���������������tensor���
predictor
self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
������predictor������������������projector������������������������������������predictor���������������������������������������projection_size
���
������������������������������������������������������������������������������������������BYOL������������������������������predictor���������������������������������������������������online network���target network������������������������������������������������������������������������������������������������������������������loss=0������������������
loss_fn
def loss_fn(x, y): x = F.normalize(x, dim=-1, p=2) y = F.normalize(y, dim=-1, p=2) return 2 - 2 * (x * y).sum(dim=-1)
������������������������������
���������������������BYOL������������������������������������������������������
������
发表评论
最新留言
关于作者
