BYOL with structure!

This commit is contained in:
James Betker 2020-12-10 15:07:35 -07:00
parent 9c5e272a22
commit 26ceca68c0
4 changed files with 237 additions and 21 deletions

View File

@ -49,6 +49,8 @@ def create_dataset(dataset_opt):
from data.torch_dataset import TorchDataset as D
elif mode == 'byol_dataset':
from data.byol_attachment import ByolDatasetWrapper as D
elif mode == 'byol_structured_dataset':
from data.byol_attachment import StructuredCropDatasetWrapper as D
elif mode == 'random_dataset':
from data.random_dataset import RandomDataset as D
else:

View File

@ -1,4 +1,5 @@
import random
from time import time
import torch
import torchvision
@ -10,6 +11,8 @@ import torch.nn.functional as F
# Wrapper for a DLAS Dataset class that applies random augmentations from the BYOL paper to BOTH the 'lq' and 'hq'
# inputs. These are then outputted as 'aug1' and 'aug2'.
from tqdm import tqdm
from data import create_dataset
from models.archs.arch_util import PixelUnshuffle
from utils.util import opt_get
@ -66,6 +69,17 @@ def snap(ref, other):
return other - ref
# Pads a tensor with zeros so that it fits in a dxd square.
def pad_to(im, d):
if len(im.shape) == 3:
pd = torch.zeros((im.shape[0],d,d))
pd[:, :im.shape[1], :im.shape[2]] = im
else:
pd = torch.zeros((im.shape[0],im.shape[1],d,d), device=im.device)
pd[:, :, :im.shape[2], :im.shape[3]] = im
return pd
# Variation of RandomResizedCrop, which picks a region of the image that the two augments must share. The augments
# then propagate off random corners of the shared region, using the same scale.
#
@ -111,9 +125,17 @@ class RandomSharedRegionCrop(nn.Module):
# Step 6
m = self.multiple
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
jt = jt if (base_t+base_h)*m != i1.shape[1] else 0 # Likewise, jitter shouldn't allow the patch to go over-bounds.
jl = jl if base_l != 0 else abs(jl)
jl = jl if (base_l+base_w)*m != i1.shape[1] else 0
p1 = i1[:, base_t*m+jt:(base_t+base_h)*m+jt, base_l*m+jl:(base_l+base_w)*m+jl]
p1_resized = no_batch_interpolate(p1, size=(d*m, d*m), mode="bilinear")
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
jt = jt if im2_t != 0 else abs(jt)
jt = jt if (im2_t+im2_h)*m != i2.shape[1] else 0
jl = jl if im2_l != 0 else abs(jl)
jl = jl if (im2_l+im2_w)*m != i2.shape[1] else 0
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
@ -122,15 +144,15 @@ class RandomSharedRegionCrop(nn.Module):
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
recompute_package = (base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w)
recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w], dtype=torch.long)
# Step 8
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
masked1 = p1 * mask1
masked1 = pad_to(p1 * mask1, d*m)
mask2 = torch.full((1, im2_h*m, im2_w*m), fill_value=.5)
mask2[:, i2_shared_t*m:(i2_shared_t+ix_h)*m, i2_shared_l*m:(i2_shared_l+ix_w)*m] = 1
masked2 = p2 * mask2
masked2 = pad_to(p2 * mask2, d*m)
mask = torch.full((1, d*m, d*m), fill_value=.33)
mask[:, base_t*m:(base_t+base_w)*m, base_l*m:(base_l+base_h)*m] += .33
mask[:, im2_t*m:(im2_t+im2_w)*m, im2_l*m:(im2_l+im2_h)*m] += .33
@ -141,14 +163,22 @@ class RandomSharedRegionCrop(nn.Module):
# Uses the recompute package returned from the above dataset to extract matched-size "similar regions" from two feature
# maps.
def reconstructed_shared_regions(fea1, fea2, recompute_package):
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = recompute_package
# Resize the input features to match
f1s = F.interpolate(fea1, (f1_h, f1_w), mode="bilinear")
f2s = F.interpolate(fea2, (f2_h, f2_w), mode="bilinear")
f1sh = f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w]
f2sh = f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w]
return f1sh, f2sh
def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
package = recompute_package.cpu()
res1 = []
res2 = []
pad_dim = torch.max(package[:, -2:]).item()
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
# of conforming the recompute_package across the entire batch.
for b in range(package.shape[0]):
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = tuple(package[b].tolist())
# Resize the input features to match
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear")
f2s = F.interpolate(fea2[b].unsqueeze(0), (f2_h, f2_w), mode="bilinear")
# Outputs must be padded so they can "get along" with each other.
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
return torch.cat(res1, dim=0), torch.cat(res2, dim=0)
# Follows the general template of BYOL dataset, with the following changes:
@ -169,8 +199,8 @@ class StructuredCropDatasetWrapper(Dataset):
def __getitem__(self, item):
item = self.wrapped_dataset[item]
a1 = item['hq'] #self.aug(item['hq']).squeeze(dim=0)
a2 = item['hq'] #self.aug(item['lq']).squeeze(dim=0)
a1 = self.aug(item['hq']).squeeze(dim=0)
a2 = self.aug(item['lq']).squeeze(dim=0)
a1, a2, sr_dim, m1, m2, db = self.rrc(a1, a2)
item.update({'aug1': a1, 'aug2': a2, 'similar_region_dimensions': sr_dim,
'masked1': m1, 'masked2': m2, 'aug_shared_view': db})
@ -187,7 +217,7 @@ if __name__ == '__main__':
{
'mode': 'imagefolder',
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\images\\flickr\\flickr-scrape\\filtered\carrot'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\512_unsupervised'],
'weights': [1],
'target_size': 256,
'force_multiple': 32,
@ -204,15 +234,15 @@ if __name__ == '__main__':
ds = StructuredCropDatasetWrapper(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in range(0, len(ds)):
o = ds[random.randint(0, len(ds))]
for k, v in o.items():
for i in tqdm(range(0, len(ds))):
o = ds[random.randint(0, len(ds)-1)]
#for k, v in o.items():
# 'lq', 'hq', 'aug1', 'aug2',
if k in [ 'aug_shared_view', 'masked1', 'masked2']:
torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
#if k in [ 'aug_shared_view', 'masked1', 'masked2']:
#torchvision.utils.save_image(v.unsqueeze(0), "debug/%i_%s.png" % (i, k))
rcpkg = o['similar_region_dimensions']
pixun = PixelUnshuffle(8)
pixsh = nn.PixelShuffle(8)
rc1, rc2 = reconstructed_shared_regions(pixun(o['aug1'].unsqueeze(0)), pixun(o['aug2'].unsqueeze(0)), rcpkg)
torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))
#torchvision.utils.save_image(pixsh(rc1), "debug/%i_rc1.png" % (i,))
#torchvision.utils.save_image(pixsh(rc2), "debug/%i_rc2.png" % (i,))

View File

@ -0,0 +1,178 @@
import copy
import random
from functools import wraps
from time import time
import torch
import torch.nn.functional as F
from torch import nn
from data.byol_attachment import reconstructed_shared_regions
from models.byol.byol_model_wrapper import singleton, EMA, MLP, get_module_device, set_requires_grad, \
update_moving_average
from utils.util import checkpoint
# loss function
def structural_loss_fn(x, y):
# Combine the structural dimensions into the batch dimension, then compute the "normal" BYOL loss.
x = x.permute(0,2,3,1).flatten(0,2)
y = y.permute(0,2,3,1).flatten(0,2)
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
class StructuralTail(nn.Module):
def __init__(self, channels, projection_size, hidden_size=512):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(channels, hidden_size, kernel_size=1),
nn.BatchNorm2d(hidden_size),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_size, projection_size, kernel_size=1),
)
def forward(self, x):
return self.net(x)
# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets
class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer=-2):
super().__init__()
self.net = net
self.layer = layer
self.projector = None
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size
self.hidden = None
self.hook_registered = False
def _find_layer(self):
if type(self.layer) == str:
modules = dict([*self.net.named_modules()])
return modules.get(self.layer, None)
elif type(self.layer) == int:
children = [*self.net.children()]
return children[self.layer]
return None
def _hook(self, _, __, output):
self.hidden = output
def _register_hook(self):
layer = self._find_layer()
assert layer is not None, f'hidden layer ({self.layer}) not found'
handle = layer.register_forward_hook(self._hook)
self.hook_registered = True
@singleton('projector')
def _get_projector(self, hidden):
projector = StructuralTail(hidden.shape[1], self.projection_size, self.projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x):
if self.layer == -1:
return self.net(x)
if not self.hook_registered:
self._register_hook()
unused = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x):
representation = self.get_representation(x)
projector = self._get_projector(representation)
projection = checkpoint(projector, representation)
return projection
class StructuralBYOL(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer=-2,
projection_size=256,
projection_hidden_size=512,
moving_average_decay=0.99,
use_momentum=True,
pretrained_state_dict=None,
freeze_until=0
):
super().__init__()
if pretrained_state_dict:
net.load_state_dict(torch.load(pretrained_state_dict), strict=True)
self.freeze_until = freeze_until
if self.freeze_until > 0:
for p in net.parameters():
p.DO_NOT_TRAIN = True
self.frozen = True
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
self.online_predictor = StructuralTail(projection_size, projection_size, projection_hidden_size)
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device),
torch.randn(2, 3, image_size, image_size, device=device), None)
@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None
def update_for_step(self, step, __):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
if self.frozen and self.freeze_until < step:
print("Unfreezing model weights. Let the latent training commence..")
for p in self.online_encoder.net.parameters():
del p.DO_NOT_TRAIN
self.frozen = False
def forward(self, image_one, image_two, similar_region_params):
online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two)
online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two)
with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one = target_encoder(image_one).detach()
target_proj_two = target_encoder(image_two).detach()
# In the structural BYOL, only the regions of the source image that are shared between the two augments are
# compared. These regions can be extracted from the latents using `reconstruct_shared_regions`.
if similar_region_params is not None:
online_pred_one, target_proj_two = reconstructed_shared_regions(online_pred_one, target_proj_two, similar_region_params)
loss_one = structural_loss_fn(online_pred_one, target_proj_two.detach())
if similar_region_params is not None:
online_pred_two, target_proj_one = reconstructed_shared_regions(online_pred_two, target_proj_one, similar_region_params)
loss_two = structural_loss_fn(online_pred_two, target_proj_one.detach())
loss = loss_one + loss_two
return loss.mean()

View File

@ -153,6 +153,12 @@ def define_G(opt, opt_net, scale=None):
subnet = define_G(opt, opt_net['subnet'])
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
elif which_model == 'structural_byol':
from models.byol.byol_structural import StructuralBYOL
subnet = define_G(opt, opt_net['subnet'])
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
elif which_model == 'spinenet':
from models.archs.spinenet_arch import SpineNet
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])