Get segformer to a trainable state
This commit is contained in:
parent
23e01314d4
commit
9bbe6fc81e
|
@ -37,6 +37,8 @@ class ImageFolderDataset:
|
|||
if 'normalize' in opt.keys():
|
||||
if opt['normalize'] == 'stylegan2_norm':
|
||||
self.normalize = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
elif opt['normalize'] == 'imagenet':
|
||||
self.normalize = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225), inplace=True)
|
||||
else:
|
||||
raise Exception('Unsupported normalize')
|
||||
else:
|
||||
|
|
|
@ -51,6 +51,54 @@ def set_requires_grad(model, val):
|
|||
p.requires_grad = val
|
||||
|
||||
|
||||
# Specialized augmentor class that applies a set of image transformations on points as well, allowing one to track
|
||||
# where a point in the src image is located in the dest image. Restricts transformation such that this is possible.
|
||||
class PointwiseAugmentor(nn.Module):
|
||||
def __init__(self, img_size=224):
|
||||
super().__init__()
|
||||
self.jitter = RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8)
|
||||
self.gray = augs.RandomGrayscale(p=0.2)
|
||||
self.blur = RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
|
||||
self.rrc = augs.RandomResizedCrop((img_size, img_size), same_on_batch=True)
|
||||
|
||||
# Given a point in the source image, returns the same point in the source image, given the kornia RRC params.
|
||||
def rrc_on_point(self, src_point, params):
|
||||
dh, dw = params['dst'][:,2,1]-params['dst'][:,0,1], params['dst'][:,2,0] - params['dst'][:,0,0]
|
||||
sh, sw = params['src'][:,2,1]-params['src'][:,0,1], params['src'][:,2,0] - params['src'][:,0,0]
|
||||
scale_h, scale_w = sh.float() / dh.float(), sw.float() / dw.float()
|
||||
t, l = src_point[0] - params['src'][0,0,1], src_point[1] - params['src'][0,0,0]
|
||||
t = (t.float() / scale_h[0]).long()
|
||||
l = (l.float() / scale_w[0]).long()
|
||||
return torch.stack([t,l])
|
||||
|
||||
def flip_on_point(self, pt, input):
|
||||
t, l = pt[0], pt[1]
|
||||
center = input.shape[-1] // 2
|
||||
return t, 2 * center - l
|
||||
|
||||
def forward(self, x, point):
|
||||
d = self.jitter(x)
|
||||
d = self.gray(d)
|
||||
will_flip = random.random() > .5
|
||||
if will_flip:
|
||||
d = apply_hflip(d)
|
||||
point = self.flip_on_point(point, x)
|
||||
d = self.blur(d)
|
||||
|
||||
invalid = True
|
||||
while invalid:
|
||||
params = self.rrc.generate_parameters(d.shape)
|
||||
potential = self.rrc_on_point(point, params)
|
||||
# '10' is an arbitrary number: we want to provide some margin. Making predictions at the very edge of an image is not very useful.
|
||||
if potential[0] <= 10 or potential[1] <= 10 or potential[0] > x.shape[-2]-10 or potential[1] > x.shape[-1]-10:
|
||||
continue
|
||||
d = self.rrc(d, params=params)
|
||||
point = potential
|
||||
invalid = False
|
||||
|
||||
return d, point
|
||||
|
||||
|
||||
# loss fn
|
||||
def loss_fn(x, y):
|
||||
x = F.normalize(x, dim=-1, p=2)
|
||||
|
@ -160,21 +208,21 @@ class NetWrapper(nn.Module):
|
|||
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_representation(self, x):
|
||||
def get_representation(self, x, pt):
|
||||
if self.layer == -1:
|
||||
return self.net(x)
|
||||
return self.net(x, pt)
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
unused = self.net(x)
|
||||
unused = self.net(x, pt)
|
||||
hidden = self.hidden
|
||||
self.hidden = None
|
||||
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
|
||||
return hidden
|
||||
|
||||
def forward(self, x):
|
||||
representation = self.get_representation(x)
|
||||
def forward(self, x, pt):
|
||||
representation = self.get_representation(x, pt)
|
||||
projector = self._get_projector(representation)
|
||||
projection = checkpoint(projector, representation)
|
||||
return projection
|
||||
|
@ -191,24 +239,13 @@ class BYOL(nn.Module):
|
|||
moving_average_decay=0.99,
|
||||
use_momentum=True,
|
||||
structural_mlp=False,
|
||||
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
|
||||
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
|
||||
# on the GPU again.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
|
||||
use_structural_mlp=structural_mlp)
|
||||
|
||||
self.do_aug = do_augmentation
|
||||
if self.do_aug:
|
||||
augmentations = [ \
|
||||
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
augs.RandomHorizontalFlip(),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
||||
self.aug = nn.Sequential(*augmentations)
|
||||
self.aug = PointwiseAugmentor(image_size)
|
||||
self.use_momentum = use_momentum
|
||||
self.target_encoder = None
|
||||
self.target_ema_updater = EMA(moving_average_decay)
|
||||
|
@ -220,8 +257,7 @@ class BYOL(nn.Module):
|
|||
self.to(device)
|
||||
|
||||
# send a mock image tensor to instantiate singleton parameters
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device),
|
||||
torch.randn(2, 3, image_size, image_size, device=device))
|
||||
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
|
||||
|
||||
@singleton('target_encoder')
|
||||
def _get_target_encoder(self):
|
||||
|
@ -245,28 +281,32 @@ class BYOL(nn.Module):
|
|||
return {'target_ema_beta': self.target_ema_updater.beta}
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
if self.do_aug:
|
||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||
|
||||
def forward(self, image_one, image_two):
|
||||
if self.do_aug:
|
||||
image_one = self.aug(image_one)
|
||||
image_two = self.aug(image_two)
|
||||
# Keep copies on hand for visual_dbg.
|
||||
self.im1 = image_one.detach().copy()
|
||||
self.im2 = image_two.detach().copy()
|
||||
def forward(self, image):
|
||||
_, _, h, w = image.shape
|
||||
point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
|
||||
|
||||
online_proj_one = self.online_encoder(image_one)
|
||||
online_proj_two = self.online_encoder(image_two)
|
||||
image_one, pt_one = self.aug(image, point)
|
||||
image_two, pt_two = self.aug(image, point)
|
||||
|
||||
# Keep copies on hand for visual_dbg.
|
||||
self.im1 = image_one.detach().clone()
|
||||
self.im1[:,:,pt_one[0]-3:pt_one[0]+3,pt_one[1]-3:pt_one[1]+3] = 1
|
||||
self.im2 = image_two.detach().clone()
|
||||
self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1
|
||||
|
||||
online_proj_one = self.online_encoder(image_one, pt_one)
|
||||
online_proj_two = self.online_encoder(image_two, pt_two)
|
||||
|
||||
online_pred_one = self.online_predictor(online_proj_one)
|
||||
online_pred_two = self.online_predictor(online_proj_two)
|
||||
|
||||
with torch.no_grad():
|
||||
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
|
||||
target_proj_one = target_encoder(image_one).detach()
|
||||
target_proj_two = target_encoder(image_two).detach()
|
||||
target_proj_one = target_encoder(image_one, pt_one).detach()
|
||||
target_proj_two = target_encoder(image_two, pt_two).detach()
|
||||
|
||||
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
|
||||
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
|
||||
|
@ -275,53 +315,20 @@ class BYOL(nn.Module):
|
|||
return loss.mean()
|
||||
|
||||
|
||||
class PointwiseAugmentor(nn.Module):
|
||||
def __init__(self, img_size=224):
|
||||
super().__init__()
|
||||
self.jitter = RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8)
|
||||
self.gray = augs.RandomGrayscale(p=0.2)
|
||||
self.blur = RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1)
|
||||
self.rrc = augs.RandomResizedCrop((img_size, img_size))
|
||||
|
||||
# Given a point in the *destination* image, returns the same point in the source image, given the kornia RRC params.
|
||||
def reverse_rrc(self, dest_point, params):
|
||||
dh, dw = params['dst'][:,2,1]-params['dst'][:,0,1], params['dst'][:,2,0] - params['dst'][:,0,0]
|
||||
sh, sw = params['src'][:,2,1]-params['src'][:,0,1], params['src'][:,2,0] - params['src'][:,0,0]
|
||||
scale_h, scale_w = sh.float() / dh.float(), sw.float() / dw.float()
|
||||
t, l = dest_point
|
||||
t = (t.float() * scale_h).int()
|
||||
l = (l.float() * scale_w).int()
|
||||
return t + params['src'][:,0,1], l + params['src'][:,0,0]
|
||||
|
||||
def reverse_horizontal_flip(self, pt, input):
|
||||
t, l = pt
|
||||
center = input.shape[-1] // 2
|
||||
return t, 2 * center - l
|
||||
|
||||
def forward(self, x, points):
|
||||
d = self.jitter(x)
|
||||
d = self.gray(d)
|
||||
will_flip = random.random() > .5
|
||||
if will_flip:
|
||||
d = apply_hflip(d)
|
||||
d = self.blur(d)
|
||||
params = self.rrc.generate_parameters(d.shape)
|
||||
d = self.rrc(d, params=params)
|
||||
|
||||
rev = self.reverse_rrc(points, params)
|
||||
if will_flip:
|
||||
rev = self.reverse_horizontal_flip(rev, x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
p = PointwiseAugmentor(256)
|
||||
t = ToTensor()(Image.open('E:\\4k6k\\datasets\\ns_images\\imagesets\\000001_152761.jpg')).unsqueeze(0).repeat(8,1,1,1)
|
||||
points = (torch.randint(0,224,(t.shape[0],)),torch.randint(0,224,(t.shape[0],)))
|
||||
p(t, points)
|
||||
pa = PointwiseAugmentor(256)
|
||||
for j in range(100):
|
||||
t = ToTensor()(Image.open('E:\\4k6k\\datasets\\ns_images\\imagesets\\000001_152761.jpg')).unsqueeze(0).repeat(8,1,1,1)
|
||||
p = torch.randint(50,180,(2,))
|
||||
augmented, dp = pa(t, p)
|
||||
t, p = pa(t, p)
|
||||
t[:,:,p[0]-3:p[0]+3,p[1]-3:p[1]+3] = 0
|
||||
torchvision.utils.save_image(t, f"{j}_src.png")
|
||||
augmented[:,:,dp[0]-3:dp[0]+3,dp[1]-3:dp[1]+3] = 0
|
||||
torchvision.utils.save_image(augmented, f"{j}_dst.png")
|
||||
|
||||
|
||||
@register_model
|
||||
def register_byol(opt_net, opt):
|
||||
def register_pixel_local_byol(opt_net, opt):
|
||||
subnet = create_model(opt, opt_net['subnet'])
|
||||
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
|
||||
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
|
||||
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'])
|
0
codes/models/segformer/__init__.py
Normal file
0
codes/models/segformer/__init__.py
Normal file
|
@ -2,11 +2,30 @@ import math
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.segformer.backbone import backbone50
|
||||
|
||||
|
||||
# torch.gather() which operates as it always fucking should have: pulling indexes from the input.
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
def gather_2d(input, index):
|
||||
b, c, h, w = input.shape
|
||||
nodim = input.view(b, c, h * w)
|
||||
ind_nd = index[:, 0]*w + index[:, 1]
|
||||
ind_nd = ind_nd.unsqueeze(1)
|
||||
ind_nd = ind_nd.repeat((1, c))
|
||||
ind_nd = ind_nd.unsqueeze(2)
|
||||
result = torch.gather(nodim, dim=2, index=ind_nd)
|
||||
result = result.squeeze()
|
||||
if b == 1:
|
||||
result = result.unsqueeze(0)
|
||||
return result
|
||||
|
||||
|
||||
class DilatorModule(nn.Module):
|
||||
def __init__(self, input_channels, output_channels, max_dilation):
|
||||
super().__init__()
|
||||
|
@ -15,7 +34,7 @@ class DilatorModule(nn.Module):
|
|||
if max_dilation > 1:
|
||||
self.bn = nn.BatchNorm2d(input_channels)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, dilation=max_dilation, bias=True)
|
||||
self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=max_dilation, dilation=max_dilation, bias=True)
|
||||
self.dense = nn.Linear(input_channels, output_channels, bias=True)
|
||||
|
||||
def forward(self, inp, loc):
|
||||
|
@ -24,9 +43,8 @@ class DilatorModule(nn.Module):
|
|||
x = self.bn(self.relu(x))
|
||||
x = self.conv2(x)
|
||||
|
||||
# This can be made (possibly substantially) more efficient by only computing these convolutions across a subset of the image. Possibly.
|
||||
i, j = loc
|
||||
x = x[:,:,i,j]
|
||||
# This can be made more efficient by only computing these convolutions across a subset of the image. Possibly.
|
||||
x = gather_2d(x, loc).contiguous()
|
||||
return self.dense(x)
|
||||
|
||||
|
||||
|
@ -48,13 +66,22 @@ class PositionalEncoding(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Segformer(nn.Module):
|
||||
# Simple mean() layer encoded into a class so that BYOL can grab it.
|
||||
class Tail(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.mean(dim=0)
|
||||
|
||||
|
||||
class Segformer(nn.Module):
|
||||
def __init__(self, latent_channels=1024, layers=8):
|
||||
super().__init__()
|
||||
self.backbone = backbone50()
|
||||
backbone_channels = [256, 512, 1024, 2048]
|
||||
dilations = [[1,2,3,4],[1,2,3],[1,2],[1]]
|
||||
final_latent_channels = 2048
|
||||
final_latent_channels = latent_channels
|
||||
dilators = []
|
||||
for ic, dis in zip(backbone_channels, dilations):
|
||||
layer_dilators = []
|
||||
|
@ -64,26 +91,37 @@ class Segformer(nn.Module):
|
|||
self.dilators = nn.ModuleList(dilators)
|
||||
|
||||
self.token_position_encoder = PositionalEncoding(final_latent_channels, max_len=10)
|
||||
self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(16)])
|
||||
self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)])
|
||||
self.tail = Tail()
|
||||
|
||||
def forward(self, x, pos):
|
||||
layers = self.backbone(x)
|
||||
set = []
|
||||
i, j = pos[0] // 4, pos[1] // 4
|
||||
|
||||
# A single position can be optionally given, in which case we need to expand it to represent the entire input.
|
||||
if pos.shape == (2,):
|
||||
pos = pos.unsqueeze(0).repeat(x.shape[0],1)
|
||||
|
||||
pos = pos // 4
|
||||
for layer_out, dilator in zip(layers, self.dilators):
|
||||
for subdilator in dilator:
|
||||
set.append(subdilator(layer_out, (i, j)))
|
||||
i, j = i // 2, j // 2
|
||||
set.append(subdilator(layer_out, pos))
|
||||
pos = pos // 2
|
||||
|
||||
# The torch transformer expects the set dimension to be 0.
|
||||
set = torch.stack(set, dim=0)
|
||||
set = self.token_position_encoder(set)
|
||||
set = self.transformer_layers(set)
|
||||
return set
|
||||
return self.tail(set)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_segformer(opt_net, opt):
|
||||
return Segformer()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Segformer().to('cuda')
|
||||
for j in tqdm(range(1000)):
|
||||
test_tensor = torch.randn(64,3,224,224).cuda()
|
||||
model(test_tensor, (43, 73))
|
||||
print(model(test_tensor, torch.randint(0,224,(64,2)).cuda()).shape)
|
|
@ -295,7 +295,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_cats_stylegan2_rosinality.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_xx.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -11,6 +11,7 @@ import trainer.eval.evaluator as evaluator
|
|||
from pytorch_fid import fid_score
|
||||
|
||||
from data.image_pair_with_corresponding_points_dataset import ImagePairWithCorrespondingPointsDataset
|
||||
from models.segformer.segformer import Segformer
|
||||
from utils.util import opt_get
|
||||
|
||||
# Uses two datasets: a "similar" and "dissimilar" dataset, each of which contains pairs of images and similar/dissimilar
|
||||
|
@ -23,15 +24,17 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
|||
self.batch_sz = opt_eval['batch_size']
|
||||
self.eval_qty = opt_eval['quantity']
|
||||
assert self.eval_qty % self.batch_sz == 0
|
||||
self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz)
|
||||
self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(**opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz)
|
||||
self.similar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(opt_eval['similar_set_args']), shuffle=False, batch_size=self.batch_sz)
|
||||
self.dissimilar_set = DataLoader(ImagePairWithCorrespondingPointsDataset(opt_eval['dissimilar_set_args']), shuffle=False, batch_size=self.batch_sz)
|
||||
# Hack to make this work with the BYOL generator. TODO: fix
|
||||
self.model = self.model.online_encoder.net
|
||||
|
||||
def get_l2_score(self, dl):
|
||||
def get_l2_score(self, dl, dev):
|
||||
distances = []
|
||||
l2 = MSELoss()
|
||||
for i, data in tqdm(enumerate(dl)):
|
||||
latent1 = self.model(data['img1'], data['coords1'])
|
||||
latent2 = self.model(data['img2'], data['coords2'])
|
||||
latent1 = self.model(data['img1'].to(dev), torch.stack(data['coords1'], dim=1).to(dev))
|
||||
latent2 = self.model(data['img2'].to(dev), torch.stack(data['coords2'], dim=1).to(dev))
|
||||
distances.append(l2(latent1, latent2))
|
||||
if i * self.batch_sz >= self.eval_qty:
|
||||
break
|
||||
|
@ -40,9 +43,30 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
|||
|
||||
def perform_eval(self):
|
||||
self.model.eval()
|
||||
print("Computing contrastive eval on similar set")
|
||||
similars = self.get_l2_score(self.similar_set)
|
||||
print("Computing contrastive eval on dissimilar set")
|
||||
dissimilars = self.get_l2_score(self.dissimilar_set)
|
||||
print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}")
|
||||
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item()}
|
||||
with torch.no_grad():
|
||||
dev = next(self.model.parameters()).device
|
||||
print("Computing contrastive eval on similar set")
|
||||
similars = self.get_l2_score(self.similar_set, dev)
|
||||
print("Computing contrastive eval on dissimilar set")
|
||||
dissimilars = self.get_l2_score(self.dissimilar_set, dev)
|
||||
diff = dissimilars.item() - similars.item()
|
||||
print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}; val_diff: {diff}")
|
||||
self.model.train()
|
||||
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff.item()}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = Segformer(1024, 4).cuda()
|
||||
eval = SinglePointPairContrastiveEval(model, {
|
||||
'batch_size': 8,
|
||||
'quantity': 32,
|
||||
'similar_set_args': {
|
||||
'path': 'E:\\4k6k\\datasets\\ns_images\\segformer_validation\\similar',
|
||||
'size': 256
|
||||
},
|
||||
'dissimilar_set_args': {
|
||||
'path': 'E:\\4k6k\\datasets\\ns_images\\segformer_validation\\dissimilar',
|
||||
'size': 256
|
||||
},
|
||||
}, {})
|
||||
eval.perform_eval()
|
||||
|
|
Loading…
Reference in New Issue
Block a user