Add testing capabilities for segformer & contrastive feature

This commit is contained in:
James Betker 2021-04-27 09:59:50 -06:00
parent eab8546f73
commit 119f17c808
6 changed files with 373 additions and 23 deletions

View File

@ -208,21 +208,21 @@ class NetWrapper(nn.Module):
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)
def get_representation(self, x, pt):
def get_representation(self, **kwargs):
if self.layer == -1:
return self.net(x, pt)
return self.net(**kwargs)
if not self.hook_registered:
self._register_hook()
unused = self.net(x, pt)
unused = self.net(**kwargs)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden
def forward(self, x, pt):
representation = self.get_representation(x, pt)
def forward(self, **kwargs):
representation = self.get_representation(**kwargs)
projector = self._get_projector(representation)
projection = checkpoint(projector, representation)
return projection
@ -239,6 +239,7 @@ class BYOL(nn.Module):
moving_average_decay=0.99,
use_momentum=True,
structural_mlp=False,
contrastive=False,
):
super().__init__()
@ -247,6 +248,7 @@ class BYOL(nn.Module):
self.aug = PointwiseAugmentor(image_size)
self.use_momentum = use_momentum
self.contrastive = contrastive
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
@ -278,13 +280,17 @@ class BYOL(nn.Module):
def get_debug_values(self, step, __):
# In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value.
return {'target_ema_beta': self.target_ema_updater.beta}
dbg = {'target_ema_beta': self.target_ema_updater.beta}
if self.contrastive and hasattr(self, 'logs_closs'):
dbg['contrastive_distance'] = self.logs_closs
dbg['byol_distance'] = self.logs_loss
return dbg
def visual_dbg(self, step, path):
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def forward(self, image):
def get_predictions_and_projections(self, image):
_, _, h, w = image.shape
point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
@ -297,16 +303,20 @@ class BYOL(nn.Module):
self.im2 = image_two.detach().clone()
self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1
online_proj_one = self.online_encoder(image_one, pt_one)
online_proj_two = self.online_encoder(image_two, pt_two)
online_proj_one = self.online_encoder(img=image_one, pos=pt_one)
online_proj_two = self.online_encoder(img=image_two, pos=pt_two)
online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two)
with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one = target_encoder(image_one, pt_one).detach()
target_proj_two = target_encoder(image_two, pt_two).detach()
target_proj_one = target_encoder(img=image_one, pos=pt_one).detach()
target_proj_two = target_encoder(img=image_two, pos=pt_two).detach()
return online_pred_one, online_pred_two, target_proj_one, target_proj_two
def forward_normal(self, image):
online_pred_one, online_pred_two, target_proj_one, target_proj_two = self.get_predictions_and_projections(image)
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
@ -314,6 +324,35 @@ class BYOL(nn.Module):
loss = loss_one + loss_two
return loss.mean()
def forward_contrastive(self, image):
online_pred_one_1, online_pred_two_1, target_proj_one_1, target_proj_two_1 = self.get_predictions_and_projections(image)
loss_one = loss_fn(online_pred_one_1, target_proj_two_1.detach())
loss_two = loss_fn(online_pred_two_1, target_proj_one_1.detach())
loss = loss_one + loss_two
online_pred_one_2, online_pred_two_2, target_proj_one_2, target_proj_two_2 = self.get_predictions_and_projections(image)
loss_one = loss_fn(online_pred_one_2, target_proj_two_2.detach())
loss_two = loss_fn(online_pred_two_2, target_proj_one_2.detach())
loss = (loss + loss_one + loss_two).mean()
contrastive_loss = torch.cat([loss_fn(online_pred_one_1, target_proj_two_2),
loss_fn(online_pred_two_1, target_proj_one_2),
loss_fn(online_pred_one_2, target_proj_two_1),
loss_fn(online_pred_two_2, target_proj_one_1)], dim=0)
k = contrastive_loss.shape[0] // 2 # Take half of the total contrastive loss predictions.
contrastive_loss = torch.topk(contrastive_loss, k, dim=0).values.mean()
self.logs_loss = loss.detach()
self.logs_closs = contrastive_loss.detach()
return loss - contrastive_los00s
def forward(self, image):
if self.contrastive:
return self.forward_contrastive(image)
else:
return self.forward_normal(image)
if __name__ == '__main__':
pa = PointwiseAugmentor(256)
@ -331,4 +370,4 @@ if __name__ == '__main__':
@register_model
def register_pixel_local_byol(opt_net, opt):
subnet = create_model(opt, opt_net['subnet'])
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'])
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], contrastive=opt_net['contrastive'])

View File

@ -94,14 +94,21 @@ class Segformer(nn.Module):
self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)])
self.tail = Tail()
def forward(self, x, pos):
layers = self.backbone(x)
set = []
def forward(self, img=None, layers=None, pos=None, return_layers=False):
assert img is not None or layers is not None
if img is not None:
bs = img.shape[0]
layers = self.backbone(img)
else:
bs = layers[0].shape[0]
if return_layers:
return layers
# A single position can be optionally given, in which case we need to expand it to represent the entire input.
if pos.shape == (2,):
pos = pos.unsqueeze(0).repeat(x.shape[0],1)
pos = pos.unsqueeze(0).repeat(bs, 1)
set = []
pos = pos // 4
for layer_out, dilator in zip(layers, self.dilators):
for subdilator in dilator:
@ -124,4 +131,4 @@ if __name__ == '__main__':
model = Segformer().to('cuda')
for j in tqdm(range(1000)):
test_tensor = torch.randn(64,3,224,224).cuda()
print(model(test_tensor, torch.randint(0,224,(64,2)).cuda()).shape)
print(model(img=test_tensor, pos=torch.randint(0,224,(64,2)).cuda()).shape)

View File

@ -0,0 +1,240 @@
import os
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Resize
from tqdm import tqdm
import numpy as np
import utils
from data.image_folder_dataset import ImageFolderDataset
from models.resnet_with_checkpointing import resnet50
from models.segformer.segformer import Segformer
from models.spinenet_arch import SpineNet
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
# and the distance is computed across the channel dimension.
from utils import util
from utils.kmeans import kmeans, kmeans_predict
from utils.options import dict_to_nonedict
def structural_euc_dist(x, y):
diff = torch.square(x - y)
sum = torch.sum(diff, dim=-1)
return torch.sqrt(sum)
def cosine_similarity(x, y):
x = norm(x)
y = norm(y)
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
def key_value_difference(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
def norm(x):
sh = x.shape
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True):
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_tiled_filtered_flattened'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
'weights': [1],
'target_size': target_size,
'force_multiple': 32,
'normalize': 'imagenet',
'scale': 1
})
dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
def _find_layer(net, layer_name):
if type(layer_name) == str:
modules = dict([*net.named_modules()])
return modules.get(layer_name, None)
elif type(layer_name) == int:
children = [*net.children()]
return children[layer_name]
return None
layer_hooked_value = None
def _hook(_, __, output):
global layer_hooked_value
layer_hooked_value = output
def register_hook(net, layer_name):
layer = _find_layer(net, layer_name)
assert layer is not None, f'hidden layer ({self.layer}) not found'
layer.register_forward_hook(_hook)
def get_latent_for_img(model, img):
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
_, _, h, w = img_t.shape
# Center crop img_t and resize to 224.
d = min(h, w)
dh, dw = (h-d)//2, (w-d)//2
if dw != 0:
img_t = img_t[:, :, :, dw:-dw]
elif dh != 0:
img_t = img_t[:, :, dh:-dh, :]
img_t = img_t[:,:3,:,:]
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
model(img_t)
latent = layer_hooked_value
return latent
def produce_latent_dict(model):
batch_size = 32
num_workers = 4
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
paths = []
latents = []
points = []
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
# Pull several points from every image.
for k in range(10):
_, _, h, _ = hq.shape
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
model(hq, point)
l = layer_hooked_value.cpu().split(1, dim=0)
latents.extend(l)
points.extend([point for p in range(batch_size)])
paths.extend(batch['HQ_path'])
id += batch_size
if id > 10000:
print("Saving checkpoint..")
torch.save((latents, points, paths), '../results_segformer.pth')
id = 0
def find_similar_latents(model, compare_fn=structural_euc_dist):
global layer_hooked_value
img = 'D:\\dlas\\results\\bobz.png'
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
output_path = '../../../results/byol_resnet_similars'
os.makedirs(output_path, exist_ok=True)
imglatent = get_latent_for_img(model, img).squeeze().unsqueeze(0)
_, c = imglatent.shape
batch_size = 512
num_workers = 8
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
output_batch = 1
results = []
result_paths = []
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
model(hq)
latent = layer_hooked_value.clone().squeeze()
compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent)
results.append(compared.cpu())
result_paths.extend(batch['HQ_path'])
id += batch_size
if id > 10000:
k = 200
results = torch.cat(results, dim=0)
vals, inds = torch.topk(results, k, largest=False)
for i in inds:
mag = int(results[i].item() * 1000)
shutil.copy(result_paths[i], os.path.join(output_path, f'{mag:05}_{output_batch}_{i}.jpg'))
results = []
result_paths = []
id = 0
def build_kmeans():
latents, _, _ = torch.load('../results_segformer.pth')
latents = torch.cat(latents, dim=0).squeeze().to('cuda')
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0'))
torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth')
class UnNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
Returns:
Tensor: Normalized image.
"""
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
# The normalize code -> t.sub_(m).div_(s)
return tensor
def use_kmeans():
output = "../results/k_means_segformer/"
_, centers = torch.load('../k_means_segformer.pth')
centers = centers.to('cuda')
batch_size = 32
num_workers = 1
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True)
denorm = UnNormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda')
_,_,h,w = hq.shape
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
model(hq, point)
l = layer_hooked_value.clone().squeeze()
pred = kmeans_predict(l, centers)
hq = denorm(hq * .5)
hq[:,:,point[0]-5:point[0]+5,point[1]-5:point[1]+5] *= 2
for b in range(pred.shape[0]):
outpath = os.path.join(output, str(pred[b].item()))
os.makedirs(outpath, exist_ok=True)
torchvision.utils.save_image(hq[b], os.path.join(outpath, f'{i*batch_size+b}.png'))
if __name__ == '__main__':
pretrained_path = '../../../experiments/segformer_byol_only.pth'
model = Segformer().to('cuda')
sd = torch.load(pretrained_path)
resnet_sd = {}
for k, v in sd.items():
if 'target_encoder.net.' in k:
resnet_sd[k.replace('target_encoder.net.', '')] = v
model.load_state_dict(resnet_sd, strict=True)
model.eval()
register_hook(model, 'tail')
with torch.no_grad():
#find_similar_latents(model, structural_euc_dist)
#produce_latent_dict(model)
#build_kmeans()
use_kmeans()

View File

@ -344,11 +344,75 @@ def plot_pixel_level_results_as_image_graph():
pyplot.savefig('tsne_pix.pdf')
def run_tsne_segformer():
print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.")
limit = 10000
X, points, files = torch.load('../results_segformer.pth')
zipped = list(zip(X, points, files))
shuffle(zipped)
X, points, files = zip(*zipped)
X = torch.cat(X, dim=0).squeeze()[:limit]
labels = np.zeros(X.shape[0]) # We don't have any labels..
# confirm that x file get same number point than label file
# otherwise may cause error in scatter
assert(len(X[:, 0])==len(X[:,1]))
assert(len(X)==len(labels))
with torch.no_grad():
Y = tsne(X, 2, 1024, 20.0)
if opt.cuda:
Y = Y.cpu().numpy()
# You may write result in two files
# print("Save Y values in file")
# Y1 = open("y1.txt", 'w')
# Y2 = open('y2.txt', 'w')
# for i in range(Y.shape[0]):
# Y1.write(str(Y[i,0])+"\n")
# Y2.write(str(Y[i,1])+"\n")
pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels)
pyplot.show()
torch.save((Y, points, files[:limit]), "../tsne_output.pth")
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
# spectrum.
def plot_segformer_results_as_image_graph():
Y, points, files = torch.load('../tsne_output.pth')
fig, ax = pyplot.subplots()
fig.set_size_inches(200,200,forward=True)
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
ax.autoscale()
margins = 32
for b in tqdm(range(Y.shape[0])):
imgfile = files[b]
baseim = pyplot.imread(imgfile)
ct, cl = points[b]
im = baseim[(ct-margins):(ct+margins),
(cl-margins):(cl+margins),:]
im = OffsetImage(im, zoom=1)
ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False)
ax.add_artist(ab)
ax.scatter(Y[:, 0], Y[:, 1])
pyplot.savefig('tsne_segformer.pdf')
if __name__ == "__main__":
# For use with instance-level results (e.g. from byol_resnet_playground.py)
#run_tsne_instance_level()
plot_instance_level_results_as_image_graph()
#plot_instance_level_results_as_image_graph()
# For use with pixel-level results (e.g. from byol_uresnet_playground)
#run_tsne_pixel_level()
#plot_pixel_level_results_as_image_graph()
#plot_pixel_level_results_as_image_graph()
# For use with segformer results
#run_tsne_segformer()
plot_segformer_results_as_image_graph()

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_xx.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_contrastive_xx.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -33,8 +33,8 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator):
distances = []
l2 = MSELoss()
for i, data in tqdm(enumerate(dl)):
latent1 = self.model(data['img1'].to(dev), torch.stack(data['coords1'], dim=1).to(dev))
latent2 = self.model(data['img2'].to(dev), torch.stack(data['coords2'], dim=1).to(dev))
latent1 = self.model(img=data['img1'].to(dev), pos=torch.stack(data['coords1'], dim=1).to(dev))
latent2 = self.model(img=data['img2'].to(dev), pos=torch.stack(data['coords2'], dim=1).to(dev))
distances.append(l2(latent1, latent2))
if i * self.batch_sz >= self.eval_qty:
break
@ -52,7 +52,7 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator):
diff = dissimilars.item() - similars.item()
print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}; val_diff: {diff}")
self.model.train()
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff.item()}
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff}
if __name__ == '__main__':