BYOL with structure!
This commit is contained in:
parent
9c5e272a22
commit
26ceca68c0
|
@ -49,6 +49,8 @@ def create_dataset(dataset_opt):
|
|||
from data.torch_dataset import TorchDataset as D
|
||||
elif mode == 'byol_dataset':
|
||||
from data.byol_attachment import ByolDatasetWrapper as D
|
||||
elif mode == 'byol_structured_dataset':
|
||||
from data.byol_attachment import StructuredCropDatasetWrapper as D
|
||||
elif mode == 'random_dataset':
|
||||
from data.random_dataset import RandomDataset as D
|
||||
else:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import random
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
@ -10,6 +11,8 @@ import torch.nn.functional as F
|
|||
|
||||
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
|
||||
# inputs. These are then outputted as 'aug1' and 'aug2'.
|
||||
from tqdm import tqdm
|
||||
|
||||
from data import create_dataset
|
||||
from models.archs.arch_util import PixelUnshuffle
|
||||
from utils.util import opt_get
|
||||
|
@ -66,6 +69,17 @@ def snap(ref, other):
|
|||
return other - ref
|
||||
|
||||
|
||||
# Pads a tensor with zeros so that it fits in a dxd square.
|
||||
def pad_to(im, d):
|
||||
if len(im.shape) == 3:
|
||||
pd = torch.zeros((im.shape[0],d,d))
|
||||
pd[:, :im.shape[1], :im.shape[2]] = im
|
||||
else:
|
||||
pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device)
|
||||
pd[:, :, :im.shape[2], :im.shape[3]] = im
|
||||
return pd
|
||||
|
||||
|
||||
# Variation of RandomResizedCrop, which picks a region of the image that the two augments must share. The augments
|
||||
# then propagate off random corners of the shared region, using the same scale.
|
||||
#
|
||||
|
@ -111,9 +125,17 @@ class RandomSharedRegionCrop(nn.Module):
|
|||
# Step 6
|
||||
m = self.multiple
|
||||
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
|
||||
jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds.
|
||||
jl = jl if base_l != 0 else abs(jl)
|
||||
jl = jl if (base_l+base_w)*m != i1.shape[1] else 0
|
||||
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
|
||||
p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
|
||||
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||
jt = jt if im2_t != 0 else abs(jt)
|
||||
jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0
|
||||
jl = jl if im2_l != 0 else abs(jl)
|
||||
jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0
|
||||
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
|
||||
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
|
||||
|
||||
|
@ -122,15 +144,15 @@ class RandomSharedRegionCrop(nn.Module):
|
|||
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
|
||||
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
|
||||
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
|
||||
recompute_package = (base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w)
|
||||
recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w], dtype=torch.long)
|
||||
|
||||
# Step 8
|
||||
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
|
||||
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
|
||||
masked1 = p1 * mask1
|
||||
masked1 = pad_to(p1 * mask1, d*m)
|
||||
mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
|
||||
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
|
||||
masked2 = p2 * mask2
|
||||
masked2 = pad_to(p2 * mask2, d*m)
|
||||
mask = torch.full((1, d*m, d*m), fill_value=.33)
|
||||
mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
|
||||
mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33
|
||||
|
@ -141,14 +163,22 @@ class RandomSharedRegionCrop(nn.Module):
|
|||
|
||||
# Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature
|
||||
# maps.
|
||||
def reconstructed_shared_regions(fea1, fea2, recompute_package):
|
||||
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = recompute_package
|
||||
# Resize the input features to match
|
||||
f1s = F.interpolate(fea1, (f1_h, f1_w), mode="bilinear")
|
||||
f2s = F.interpolate(fea2, (f2_h, f2_w), mode="bilinear")
|
||||
f1sh = f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w]
|
||||
f2sh = f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w]
|
||||
return f1sh, f2sh
|
||||
def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
|
||||
package = recompute_package.cpu()
|
||||
res1 = []
|
||||
res2 = []
|
||||
pad_dim = torch.max(package[:, -2:]).item()
|
||||
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
|
||||
# of conforming the recompute_package across the entire batch.
|
||||
for b in range(package.shape[0]):
|
||||
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = tuple(package[b].tolist())
|
||||
# Resize the input features to match
|
||||
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear")
|
||||
f2s = F.interpolate(fea2[b].unsqueeze(0), (f2_h, f2_w), mode="bilinear")
|
||||
# Outputs must be padded so they can "get along" with each other.
|
||||
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
|
||||
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
|
||||
return torch.cat(res1, dim=0), torch.cat(res2, dim=0)
|
||||
|
||||
|
||||
# Follows the general template of BYOL dataset, with the following changes:
|
||||
|
@ -169,8 +199,8 @@ class StructuredCropDatasetWrapper(Dataset):
|
|||
|
||||
def __getitem__(self, item):
|
||||
item = self.wrapped_dataset[item]
|
||||
a1 = item['hq'] #self.aug(item['hq']).squeeze(dim=0)
|
||||
a2 = item['hq'] #self.aug(item['lq']).squeeze(dim=0)
|
||||
a1 = self.aug(item['hq']).squeeze(dim=0)
|
||||
a2 = self.aug(item['lq']).squeeze(dim=0)
|
||||
a1, a2, sr_dim, m1, m2, db = self.rrc(a1, a2)
|
||||
item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim,
|
||||
'masked1': m1, 'masked2': m2, 'aug_shared_view': db})
|
||||
|
@ -187,7 +217,7 @@ if __name__ == '__main__':
|
|||
{
|
||||
'mode': 'imagefolder',
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\images\\flickr\\flickr-scrape\\filtered\carrot'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
|
||||
'weights': [1],
|
||||
'target_size': 256,
|
||||
'force_multiple': 32,
|
||||
|
@ -204,15 +234,15 @@ if __name__ == '__main__':
|
|||
ds = StructuredCropDatasetWrapper(opt)
|
||||
import os
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
for i in range(0, len(ds)):
|
||||
o = ds[random.randint(0, len(ds))]
|
||||
for k, v in o.items():
|
||||
for i in tqdm(range(0, len(ds))):
|
||||
o = ds[random.randint(0, len(ds)-1)]
|
||||
#for k, v in o.items():
|
||||
# 'lq', 'hq', 'aug1', 'aug2',
|
||||
if k in [ 'aug_shared_view', 'masked1', 'masked2']:
|
||||
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
#if k in [ 'aug_shared_view', 'masked1', 'masked2']:
|
||||
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
|
||||
rcpkg = o['similar_region_dimensions']
|
||||
pixun = PixelUnshuffle(8)
|
||||
pixsh = nn.PixelShuffle(8)
|
||||
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg)
|
||||
torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
|
||||
torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
|
||||
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
|
||||
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
|
||||
|
|
178
codes/models/byol/byol_structural.py
Normal file
178
codes/models/byol/byol_structural.py
Normal file
|
@ -0,0 +1,178 @@
|
|||
import copy
|
||||
import random
|
||||
from functools import wraps
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from data.byol_attachment import reconstructed_shared_regions
|
||||
from models.byol.byol_model_wrapper import singleton, EMA, MLP, get_module_device, set_requires_grad, \
|
||||
update_moving_average
|
||||
from utils.util import checkpoint
|
||||
|
||||
# loss function
|
||||
def structural_loss_fn(x, y):
|
||||
# Combine the structural dimensions into the batch dimension, then compute the "normal" BYOL loss.
|
||||
x = x.permute(0,2,3,1).flatten(0,2)
|
||||
y = y.permute(0,2,3,1).flatten(0,2)
|
||||
x = F.normalize(x, dim=-1, p=2)
|
||||
y = F.normalize(y, dim=-1, p=2)
|
||||
return 2 - 2 * (x * y).sum(dim=-1)
|
||||
|
||||
|
||||
class StructuralTail(nn.Module):
|
||||
def __init__(self, channels, projection_size, hidden_size=512):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(channels, hidden_size, kernel_size=1),
|
||||
nn.BatchNorm2d(hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(hidden_size, projection_size, kernel_size=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# a wrapper class for the base neural network
|
||||
# will manage the interception of the hidden layer output
|
||||
# and pipe it into the projecter and predictor nets
|
||||
class NetWrapper(nn.Module):
|
||||
def __init__(self, net, projection_size, projection_hidden_size, layer=-2):
|
||||
super().__init__()
|
||||
self.net = net
|
||||
self.layer = layer
|
||||
|
||||
self.projector = None
|
||||
self.projection_size = projection_size
|
||||
self.projection_hidden_size = projection_hidden_size
|
||||
|
||||
self.hidden = None
|
||||
self.hook_registered = False
|
||||
|
||||
def _find_layer(self):
|
||||
if type(self.layer) == str:
|
||||
modules = dict([*self.net.named_modules()])
|
||||
return modules.get(self.layer, None)
|
||||
elif type(self.layer) == int:
|
||||
children = [*self.net.children()]
|
||||
return children[self.layer]
|
||||
return None
|
||||
|
||||
def _hook(self, _, __, output):
|
||||
self.hidden = output
|
||||
|
||||
def _register_hook(self):
|
||||
layer = self._find_layer()
|
||||
assert layer is not None, f'hidden layer ({self.layer}) not found'
|
||||
handle = layer.register_forward_hook(self._hook)
|
||||
self.hook_registered = True
|
||||
|
||||
@singleton('projector')
|
||||
def _get_projector(self, hidden):
|
||||
projector = StructuralTail(hidden.shape[1], self.projection_size, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_representation(self, x):
|
||||
if self.layer == -1:
|
||||
return self.net(x)
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
unused = self.net(x)
|
||||
hidden = self.hidden
|
||||
self.hidden = None
|
||||
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
|
||||
return hidden
|
||||
|
||||
def forward(self, x):
|
||||
representation = self.get_representation(x)
|
||||
projector = self._get_projector(representation)
|
||||
projection = checkpoint(projector, representation)
|
||||
return projection
|
||||
|
||||
|
||||
class StructuralBYOL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
image_size,
|
||||
hidden_layer=-2,
|
||||
projection_size=256,
|
||||
projection_hidden_size=512,
|
||||
moving_average_decay=0.99,
|
||||
use_momentum=True,
|
||||
pretrained_state_dict=None,
|
||||
freeze_until=0
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if pretrained_state_dict:
|
||||
net.load_state_dict(torch.load(pretrained_state_dict), strict=True)
|
||||
self.freeze_until = freeze_until
|
||||
if self.freeze_until > 0:
|
||||
for p in net.parameters():
|
||||
p.DO_NOT_TRAIN = True
|
||||
self.frozen = True
|
||||
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
|
||||
|
||||
self.use_momentum = use_momentum
|
||||
self.target_encoder = None
|
||||
self.target_ema_updater = EMA(moving_average_decay)
|
||||
|
||||
self.online_predictor = StructuralTail(projection_size, projection_size, projection_hidden_size)
|
||||
|
||||
# get device of network and make wrapper same device
|
||||
device = get_module_device(net)
|
||||
self.to(device)
|
||||
|
||||
# send a mock image tensor to instantiate singleton parameters
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device),
|
||||
torch.randn(2, 3, image_size, image_size, device=device), None)
|
||||
|
||||
@singleton('target_encoder')
|
||||
def _get_target_encoder(self):
|
||||
target_encoder = copy.deepcopy(self.online_encoder)
|
||||
set_requires_grad(target_encoder, False)
|
||||
return target_encoder
|
||||
|
||||
def reset_moving_average(self):
|
||||
del self.target_encoder
|
||||
self.target_encoder = None
|
||||
|
||||
def update_for_step(self, step, __):
|
||||
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
|
||||
assert self.target_encoder is not None, 'target encoder has not been created yet'
|
||||
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
|
||||
if self.frozen and self.freeze_until < step:
|
||||
print("Unfreezing model weights. Let the latent training commence..")
|
||||
for p in self.online_encoder.net.parameters():
|
||||
del p.DO_NOT_TRAIN
|
||||
self.frozen = False
|
||||
|
||||
def forward(self, image_one, image_two, similar_region_params):
|
||||
online_proj_one = self.online_encoder(image_one)
|
||||
online_proj_two = self.online_encoder(image_two)
|
||||
|
||||
online_pred_one = self.online_predictor(online_proj_one)
|
||||
online_pred_two = self.online_predictor(online_proj_two)
|
||||
|
||||
with torch.no_grad():
|
||||
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
|
||||
target_proj_one = target_encoder(image_one).detach()
|
||||
target_proj_two = target_encoder(image_two).detach()
|
||||
|
||||
# In the structural BYOL, only the regions of the source image that are shared between the two augments are
|
||||
# compared. These regions can be extracted from the latents using `reconstruct_shared_regions`.
|
||||
if similar_region_params is not None:
|
||||
online_pred_one, target_proj_two = reconstructed_shared_regions(online_pred_one, target_proj_two, similar_region_params)
|
||||
loss_one = structural_loss_fn(online_pred_one, target_proj_two.detach())
|
||||
if similar_region_params is not None:
|
||||
online_pred_two, target_proj_one = reconstructed_shared_regions(online_pred_two, target_proj_one, similar_region_params)
|
||||
loss_two = structural_loss_fn(online_pred_two, target_proj_one.detach())
|
||||
|
||||
loss = loss_one + loss_two
|
||||
return loss.mean()
|
|
@ -153,6 +153,12 @@ def define_G(opt, opt_net, scale=None):
|
|||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
||||
elif which_model == 'structural_byol':
|
||||
from models.byol.byol_structural import StructuralBYOL
|
||||
subnet = define_G(opt, opt_net['subnet'])
|
||||
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
|
||||
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
|
||||
elif which_model == 'spinenet':
|
||||
from models.archs.spinenet_arch import SpineNet
|
||||
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user