Get segformer to a trainable state

This commit is contained in:
James Betker 2021-04-25 11:45:20 -06:00
parent 23e01314d4
commit 9bbe6fc81e
6 changed files with 173 additions and 102 deletions

View File

@ -37,6 +37,8 @@ class ImageFolderDataset:
if 'normalize' in opt.keys(): if 'normalize' in opt.keys():
if opt['normalize'] == 'stylegan2_norm': if opt['normalize'] == 'stylegan2_norm':
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
elif opt['normalize'] == 'imagenet':
self.normalize = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True)
else: else:
raise Exception('Unsupported normalize') raise Exception('Unsupported normalize')
else: else:

View File

@ -51,6 +51,54 @@ def set_requires_grad(model, val):
p.requires_grad = val p.requires_grad = val
# Specialized augmentor class that applies a set of image transformations on points as well, allowing one to track
# where a point in the src image is located in the dest image. Restricts transformation such that this is possible.
class PointwiseAugmentor(nn.Module):
def __init__(self, img_size=224):
super().__init__()
self.jitter = RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8)
self.gray = augs.RandomGrayscale(p=0.2)
self.blur = RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
self.rrc = augs.RandomResizedCrop((img_size, img_size), same_on_batch=True)
# Given a point in the source image, returns the same point in the source image, given the kornia RRC params.
def rrc_on_point(self, src_point, params):
dh, dw = params['dst'][:,2,1]-params['dst'][:,0,1], params['dst'][:,2,0] - params['dst'][:,0,0]
sh, sw = params['src'][:,2,1]-params['src'][:,0,1], params['src'][:,2,0] - params['src'][:,0,0]
scale_h, scale_w = sh.float() / dh.float(), sw.float() / dw.float()
t, l = src_point[0] - params['src'][0,0,1], src_point[1] - params['src'][0,0,0]
t = (t.float() / scale_h[0]).long()
l = (l.float() / scale_w[0]).long()
return torch.stack([t,l])
def flip_on_point(self, pt, input):
t, l = pt[0], pt[1]
center = input.shape[-1] // 2
return t, 2 * center - l
def forward(self, x, point):
d = self.jitter(x)
d = self.gray(d)
will_flip = random.random() > .5
if will_flip:
d = apply_hflip(d)
point = self.flip_on_point(point, x)
d = self.blur(d)
invalid = True
while invalid:
params = self.rrc.generate_parameters(d.shape)
potential = self.rrc_on_point(point, params)
# '10' is an arbitrary number: we want to provide some margin. Making predictions at the very edge of an image is not very useful.
if potential[0] <= 10 or potential[1] <= 10 or potential[0] > x.shape[-2]-10 or potential[1] > x.shape[-1]-10:
continue
d = self.rrc(d, params=params)
point = potential
invalid = False
return d, point
# loss fn # loss fn
def loss_fn(x, y): def loss_fn(x, y):
x = F.normalize(x, dim=-1, p=2) x = F.normalize(x, dim=-1, p=2)
@ -160,21 +208,21 @@ class NetWrapper(nn.Module):
projector = MLP(dim, self.projection_size, self.projection_hidden_size) projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden) return projector.to(hidden)
def get_representation(self, x): def get_representation(self, x, pt):
if self.layer == -1: if self.layer == -1:
return self.net(x) return self.net(x, pt)
if not self.hook_registered: if not self.hook_registered:
self._register_hook() self._register_hook()
unused = self.net(x) unused = self.net(x, pt)
hidden = self.hidden hidden = self.hidden
self.hidden = None self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output' assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden return hidden
def forward(self, x): def forward(self, x, pt):
representation = self.get_representation(x) representation = self.get_representation(x, pt)
projector = self._get_projector(representation) projector = self._get_projector(representation)
projection = checkpoint(projector, representation) projection = checkpoint(projector, representation)
return projection return projection
@ -191,24 +239,13 @@ class BYOL(nn.Module):
moving_average_decay=0.99, moving_average_decay=0.99,
use_momentum=True, use_momentum=True,
structural_mlp=False, structural_mlp=False,
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
# on the GPU again.
): ):
super().__init__() super().__init__()
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
use_structural_mlp=structural_mlp) use_structural_mlp=structural_mlp)
self.do_aug = do_augmentation self.aug = PointwiseAugmentor(image_size)
if self.do_aug:
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
self.aug = nn.Sequential(*augmentations)
self.use_momentum = use_momentum self.use_momentum = use_momentum
self.target_encoder = None self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay) self.target_ema_updater = EMA(moving_average_decay)
@ -220,8 +257,7 @@ class BYOL(nn.Module):
self.to(device) self.to(device)
# send a mock image tensor to instantiate singleton parameters # send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device), self.forward(torch.randn(2, 3, image_size, image_size, device=device))
torch.randn(2, 3, image_size, image_size, device=device))
@singleton('target_encoder') @singleton('target_encoder')
def _get_target_encoder(self): def _get_target_encoder(self):
@ -245,28 +281,32 @@ class BYOL(nn.Module):
return {'target_ema_beta': self.target_ema_updater.beta} return {'target_ema_beta': self.target_ema_updater.beta}
def visual_dbg(self, step, path): def visual_dbg(self, step, path):
if self.do_aug:
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,))) torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def forward(self, image_one, image_two): def forward(self, image):
if self.do_aug: _, _, h, w = image.shape
image_one = self.aug(image_one) point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
image_two = self.aug(image_two)
# Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().copy()
self.im2 = image_two.detach().copy()
online_proj_one = self.online_encoder(image_one) image_one, pt_one = self.aug(image, point)
online_proj_two = self.online_encoder(image_two) image_two, pt_two = self.aug(image, point)
# Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().clone()
self.im1[:,:,pt_one[0]-3:pt_one[0]+3,pt_one[1]-3:pt_one[1]+3] = 1
self.im2 = image_two.detach().clone()
self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1
online_proj_one = self.online_encoder(image_one, pt_one)
online_proj_two = self.online_encoder(image_two, pt_two)
online_pred_one = self.online_predictor(online_proj_one) online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two) online_pred_two = self.online_predictor(online_proj_two)
with torch.no_grad(): with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one = target_encoder(image_one).detach() target_proj_one = target_encoder(image_one, pt_one).detach()
target_proj_two = target_encoder(image_two).detach() target_proj_two = target_encoder(image_two, pt_two).detach()
loss_one = loss_fn(online_pred_one, target_proj_two.detach()) loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach()) loss_two = loss_fn(online_pred_two, target_proj_one.detach())
@ -275,53 +315,20 @@ class BYOL(nn.Module):
return loss.mean() return loss.mean()
class PointwiseAugmentor(nn.Module):
def __init__(self, img_size=224):
super().__init__()
self.jitter = RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8)
self.gray = augs.RandomGrayscale(p=0.2)
self.blur = RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
self.rrc = augs.RandomResizedCrop((img_size, img_size))
# Given a point in the *destination* image, returns the same point in the source image, given the kornia RRC params.
def reverse_rrc(self, dest_point, params):
dh, dw = params['dst'][:,2,1]-params['dst'][:,0,1], params['dst'][:,2,0] - params['dst'][:,0,0]
sh, sw = params['src'][:,2,1]-params['src'][:,0,1], params['src'][:,2,0] - params['src'][:,0,0]
scale_h, scale_w = sh.float() / dh.float(), sw.float() / dw.float()
t, l = dest_point
t = (t.float() * scale_h).int()
l = (l.float() * scale_w).int()
return t + params['src'][:,0,1], l + params['src'][:,0,0]
def reverse_horizontal_flip(self, pt, input):
t, l = pt
center = input.shape[-1] // 2
return t, 2 * center - l
def forward(self, x, points):
d = self.jitter(x)
d = self.gray(d)
will_flip = random.random() > .5
if will_flip:
d = apply_hflip(d)
d = self.blur(d)
params = self.rrc.generate_parameters(d.shape)
d = self.rrc(d, params=params)
rev = self.reverse_rrc(points, params)
if will_flip:
rev = self.reverse_horizontal_flip(rev, x)
if __name__ == '__main__': if __name__ == '__main__':
p = PointwiseAugmentor(256) pa = PointwiseAugmentor(256)
for j in range(100):
t = ToTensor()(Image.open('E:\\4k6k\\datasets\\ns_images\\imagesets\\000001_152761.jpg')).unsqueeze(0).repeat(8,1,1,1) t = ToTensor()(Image.open('E:\\4k6k\\datasets\\ns_images\\imagesets\\000001_152761.jpg')).unsqueeze(0).repeat(8,1,1,1)
points = (torch.randint(0,224,(t.shape[0],)),torch.randint(0,224,(t.shape[0],))) p = torch.randint(50,180,(2,))
p(t, points) augmented, dp = pa(t, p)
t, p = pa(t, p)
t[:,:,p[0]-3:p[0]+3,p[1]-3:p[1]+3] = 0
torchvision.utils.save_image(t, f"{j}_src.png")
augmented[:,:,dp[0]-3:dp[0]+3,dp[1]-3:dp[1]+3] = 0
torchvision.utils.save_image(augmented, f"{j}_dst.png")
@register_model @register_model
def register_byol(opt_net, opt): def register_pixel_local_byol(opt_net, opt):
subnet = create_model(opt, opt_net['subnet']) subnet = create_model(opt, opt_net['subnet'])
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'])
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))

View File

View File

@ -2,11 +2,30 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision
from tqdm import tqdm from tqdm import tqdm
from models.segformer.backbone import backbone50 from models.segformer.backbone import backbone50
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
from trainer.networks import register_model
def gather_2d(input, index):
b, c, h, w = input.shape
nodim = input.view(b, c, h * w)
ind_nd = index[:, 0]*w + index[:, 1]
ind_nd = ind_nd.unsqueeze(1)
ind_nd = ind_nd.repeat((1, c))
ind_nd = ind_nd.unsqueeze(2)
result = torch.gather(nodim, dim=2, index=ind_nd)
result = result.squeeze()
if b == 1:
result = result.unsqueeze(0)
return result
class DilatorModule(nn.Module): class DilatorModule(nn.Module):
def __init__(self, input_channels, output_channels, max_dilation): def __init__(self, input_channels, output_channels, max_dilation):
super().__init__() super().__init__()
@ -15,7 +34,7 @@ class DilatorModule(nn.Module):
if max_dilation > 1: if max_dilation > 1:
self.bn = nn.BatchNorm2d(input_channels) self.bn = nn.BatchNorm2d(input_channels)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=max_dilation, bias=True) self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=max_dilation, dilation=max_dilation, bias=True)
self.dense = nn.Linear(input_channels, output_channels, bias=True) self.dense = nn.Linear(input_channels, output_channels, bias=True)
def forward(self, inp, loc): def forward(self, inp, loc):
@ -24,9 +43,8 @@ class DilatorModule(nn.Module):
x = self.bn(self.relu(x)) x = self.bn(self.relu(x))
x = self.conv2(x) x = self.conv2(x)
# This can be made (possibly substantially) more efficient by only computing these convolutions across a subset of the image. Possibly. # This can be made more efficient by only computing these convolutions across a subset of the image. Possibly.
i, j = loc x = gather_2d(x, loc).contiguous()
x = x[:,:,i,j]
return self.dense(x) return self.dense(x)
@ -48,13 +66,22 @@ class PositionalEncoding(nn.Module):
return x return x
class Segformer(nn.Module): # Simple mean() layer encoded into a class so that BYOL can grab it.
class Tail(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, x):
return x.mean(dim=0)
class Segformer(nn.Module):
def __init__(self, latent_channels=1024, layers=8):
super().__init__()
self.backbone = backbone50() self.backbone = backbone50()
backbone_channels = [256, 512, 1024, 2048] backbone_channels = [256, 512, 1024, 2048]
dilations = [[1,2,3,4],[1,2,3],[1,2],[1]] dilations = [[1,2,3,4],[1,2,3],[1,2],[1]]
final_latent_channels = 2048 final_latent_channels = latent_channels
dilators = [] dilators = []
for ic, dis in zip(backbone_channels, dilations): for ic, dis in zip(backbone_channels, dilations):
layer_dilators = [] layer_dilators = []
@ -64,26 +91,37 @@ class Segformer(nn.Module):
self.dilators = nn.ModuleList(dilators) self.dilators = nn.ModuleList(dilators)
self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10) self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10)
self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(16)]) self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)])
self.tail = Tail()
def forward(self, x, pos): def forward(self, x, pos):
layers = self.backbone(x) layers = self.backbone(x)
set = [] set = []
i, j = pos[0] // 4, pos[1] // 4
# A single position can be optionally given, in which case we need to expand it to represent the entire input.
if pos.shape == (2,):
pos = pos.unsqueeze(0).repeat(x.shape[0],1)
pos = pos // 4
for layer_out, dilator in zip(layers, self.dilators): for layer_out, dilator in zip(layers, self.dilators):
for subdilator in dilator: for subdilator in dilator:
set.append(subdilator(layer_out, (i, j))) set.append(subdilator(layer_out, pos))
i, j = i // 2, j // 2 pos = pos // 2
# The torch transformer expects the set dimension to be 0. # The torch transformer expects the set dimension to be 0.
set = torch.stack(set, dim=0) set = torch.stack(set, dim=0)
set = self.token_position_encoder(set) set = self.token_position_encoder(set)
set = self.transformer_layers(set) set = self.transformer_layers(set)
return set return self.tail(set)
@register_model
def register_segformer(opt_net, opt):
return Segformer()
if __name__ == '__main__': if __name__ == '__main__':
model = Segformer().to('cuda') model = Segformer().to('cuda')
for j in tqdm(range(1000)): for j in tqdm(range(1000)):
test_tensor = torch.randn(64,3,224,224).cuda() test_tensor = torch.randn(64,3,224,224).cuda()
model(test_tensor, (43, 73)) print(model(test_tensor, torch.randint(0,224,(64,2)).cuda()).shape)

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_xx.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -11,6 +11,7 @@ import trainer.eval.evaluator as evaluator
from pytorch_fid import fid_score from pytorch_fid import fid_score
from data.image_pair_with_corresponding_points_dataset import ImagePairWithCorrespondingPointsDataset from data.image_pair_with_corresponding_points_dataset import ImagePairWithCorrespondingPointsDataset
from models.segformer.segformer import Segformer
from utils.util import opt_get from utils.util import opt_get
# Uses two datasets: a "similar" and "dissimilar" dataset, each of which contains pairs of images and similar/dissimilar # Uses two datasets: a "similar" and "dissimilar" dataset, each of which contains pairs of images and similar/dissimilar
@ -23,15 +24,17 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator):
self.batch_sz = opt_eval['batch_size'] self.batch_sz = opt_eval['batch_size']
self.eval_qty = opt_eval['quantity'] self.eval_qty = opt_eval['quantity']
assert self.eval_qty % self.batch_sz == 0 assert self.eval_qty % self.batch_sz == 0
self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz) self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz)
self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz) self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz)
# Hack to make this work with the BYOL generator. TODO: fix
self.model = self.model.online_encoder.net
def get_l2_score(self, dl): def get_l2_score(self, dl, dev):
distances = [] distances = []
l2 = MSELoss() l2 = MSELoss()
for i, data in tqdm(enumerate(dl)): for i, data in tqdm(enumerate(dl)):
latent1 = self.model(data['img1'], data['coords1']) latent1 = self.model(data['img1'].to(dev), torch.stack(data['coords1'], dim=1).to(dev))
latent2 = self.model(data['img2'], data['coords2']) latent2 = self.model(data['img2'].to(dev), torch.stack(data['coords2'], dim=1).to(dev))
distances.append(l2(latent1, latent2)) distances.append(l2(latent1, latent2))
if i * self.batch_sz >= self.eval_qty: if i * self.batch_sz >= self.eval_qty:
break break
@ -40,9 +43,30 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator):
def perform_eval(self): def perform_eval(self):
self.model.eval() self.model.eval()
with torch.no_grad():
dev = next(self.model.parameters()).device
print("Computing contrastive eval on similar set") print("Computing contrastive eval on similar set")
similars = self.get_l2_score(self.similar_set) similars = self.get_l2_score(self.similar_set, dev)
print("Computing contrastive eval on dissimilar set") print("Computing contrastive eval on dissimilar set")
dissimilars = self.get_l2_score(self.dissimilar_set) dissimilars = self.get_l2_score(self.dissimilar_set, dev)
print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}") diff = dissimilars.item() - similars.item()
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item()} print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}; val_diff: {diff}")
self.model.train()
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff.item()}
if __name__ == '__main__':
model = Segformer(1024, 4).cuda()
eval = SinglePointPairContrastiveEval(model, {
'batch_size': 8,
'quantity': 32,
'similar_set_args': {
'path': 'E:\\4k6k\\datasets\\ns_images\\segformer_validation\\similar',
'size': 256
},
'dissimilar_set_args': {
'path': 'E:\\4k6k\\datasets\\ns_images\\segformer_validation\\dissimilar',
'size': 256
},
}, {})
eval.perform_eval()