Clean up byol a bit

- Remove option to aug in dataset (there's really no reason for this now that kornia works on GPU on windows)
- Other stufff
This commit is contained in:
James Betker 2021-05-24 21:35:46 -06:00
parent 6649ef2dae
commit f129eaa39e
6 changed files with 53 additions and 62 deletions

View File

@ -206,10 +206,6 @@ class BYOL(nn.Module):
dbg['byol_distance'] = self.logs_loss
return dbg
def visual_dbg(self, step, path):
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def get_predictions_and_projections(self, image):
_, _, h, w = image.shape
point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
@ -217,12 +213,6 @@ class BYOL(nn.Module):
image_one, pt_one = self.aug(image, point)
image_two, pt_two = self.aug(image, point)
# Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().clone()
self.im1[:,:,pt_one[0]-3:pt_one[0]+3,pt_one[1]-3:pt_one[1]+3] = 1
self.im2 = image_two.detach().clone()
self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1
online_proj_one = self.online_encoder(img=image_one, pos=pt_one)
online_proj_two = self.online_encoder(img=image_two, pos=pt_two)

View File

@ -188,23 +188,18 @@ class BYOL(nn.Module):
moving_average_decay=0.99,
use_momentum=True,
structural_mlp=False,
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
# on the GPU again.
):
super().__init__()
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
use_structural_mlp=structural_mlp)
self.do_aug = do_augmentation
if self.do_aug:
augmentations = [ \
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augs.RandomGrayscale(p=0.2),
augs.RandomHorizontalFlip(),
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
augs.RandomResizedCrop((image_size, image_size))]
self.aug = nn.Sequential(*augmentations)
self.use_momentum = use_momentum
self.target_encoder = None
@ -242,17 +237,16 @@ class BYOL(nn.Module):
return {'target_ema_beta': self.target_ema_updater.beta}
def visual_dbg(self, step, path):
if self.do_aug:
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def forward(self, image_one, image_two):
if self.do_aug:
image_one = self.aug(image_one)
image_two = self.aug(image_two)
# Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().copy()
self.im2 = image_two.detach().copy()
self.im1 = image_one.detach().clone()
self.im2 = image_two.detach().clone()
online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two)
@ -276,5 +270,4 @@ class BYOL(nn.Module):
def register_byol(opt_net, opt):
subnet = create_model(opt, opt_net['subnet'])
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))

View File

@ -122,11 +122,3 @@ def backbone152(pretrained=False, progress=True, **kwargs):
return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
@register_model
def register_resnet50(opt_net, opt):
model = resnet50(pretrained=opt_net['pretrained'])
if opt_net['custom_head_logits']:
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
return model

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Resize
from torchvision.transforms import ToTensor, Resize, Normalize
from tqdm import tqdm
import numpy as np
@ -93,20 +93,25 @@ def register_hook(net, layer_name):
layer.register_forward_hook(_hook)
def get_latent_for_img(model, img):
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
def get_latent_for_img(model, img, pos):
img_t = ToTensor()(Image.open(img)).to('cuda')[:3]
img_t = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(img_t).unsqueeze(0)
_, _, h, w = img_t.shape
# Center crop img_t and resize to 224.
d = min(h, w)
dh, dw = (h-d)//2, (w-d)//2
if dw != 0:
img_t = img_t[:, :, :, dw:-dw]
pos[1] = pos[1]-dw
elif dh != 0:
img_t = img_t[:, :, dh:-dh, :]
pos[0] = pos[0]-dh
scale = 224 / img_t.shape[-1]
pos = (pos * scale).long()
assert(pos.min() >= 0 and pos.max() < 224)
img_t = img_t[:,:3,:,:]
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
model(img_t)
latent = layer_hooked_value
latent = model(img=img_t,pos=pos)
return latent
@ -124,7 +129,7 @@ def produce_latent_dict(model):
for k in range(10):
_, _, h, _ = hq.shape
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
model(hq, point)
model(img=hq, pos=point)
l = layer_hooked_value.cpu().split(1, dim=0)
latents.extend(l)
points.extend([point for p in range(batch_size)])
@ -139,43 +144,54 @@ def produce_latent_dict(model):
def find_similar_latents(model, compare_fn=structural_euc_dist):
global layer_hooked_value
img = 'D:\\dlas\\results\\bobz.png'
img = 'F:\\dlas\\results\\bobz.png'
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
point=torch.tensor([154,330], dtype=torch.long, device='cuda')
output_path = '../../../results/byol_resnet_similars'
os.makedirs(output_path, exist_ok=True)
imglatent = get_latent_for_img(model, img).squeeze().unsqueeze(0)
imglatent = get_latent_for_img(model, img, point).squeeze().unsqueeze(0)
_, c = imglatent.shape
batch_size = 512
num_workers = 8
num_workers = 1
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
output_batch = 1
results = []
result_paths = []
results_points = []
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda')
model(hq)
latent = layer_hooked_value.clone().squeeze()
_,_,h,w = hq.shape
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
latent = model(img=hq, pos=point)
compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent)
results.append(compared.cpu())
result_paths.extend(batch['HQ_path'])
results_points.append(point.unsqueeze(0).repeat(batch_size,1))
id += batch_size
if id > 10000:
k = 200
k = 10
results = torch.cat(results, dim=0)
results_points = torch.cat(results_points, dim=0)
vals, inds = torch.topk(results, k, largest=False)
for i in inds:
mag = int(results[i].item() * 1000)
shutil.copy(result_paths[i], os.path.join(output_path, f'{mag:05}_{output_batch}_{i}.jpg'))
point = results_points[i]
mag = int(results[i].item() * 100000000)
hqr = ToTensor()(Image.open(result_paths[i])).to('cuda')
hqr *= .5
hqr[:,point[0]-3:point[0]+3,point[1]-3:point[1]+3] *= 2
torchvision.utils.save_image(hqr, os.path.join(output_path, f'{mag:08}_{output_batch}_{i}.jpg'))
results = []
result_paths = []
results_points = []
id = 0
def build_kmeans():
latents, _, _ = torch.load('../results_segformer.pth')
latents = torch.cat(latents, dim=0).squeeze().to('cuda')
latents = torch.cat(latents, dim=0).squeeze().to('cuda')[50000:] * 10000
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0'))
torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth')
@ -222,7 +238,7 @@ def use_kmeans():
if __name__ == '__main__':
pretrained_path = '../../../experiments/segformer_byol_only.pth'
pretrained_path = '../../../experiments/segformer_contrastive.pth'
model = Segformer().to('cuda')
sd = torch.load(pretrained_path)
resnet_sd = {}
@ -234,7 +250,7 @@ if __name__ == '__main__':
register_hook(model, 'tail')
with torch.no_grad():
#find_similar_latents(model, structural_euc_dist)
find_similar_latents(model, structural_euc_dist)
#produce_latent_dict(model)
#build_kmeans()
use_kmeans()
#use_kmeans()

View File

@ -414,5 +414,5 @@ if __name__ == "__main__":
#plot_pixel_level_results_as_image_graph()
# For use with segformer results
#run_tsne_segformer()
plot_segformer_results_as_image_graph()
run_tsne_segformer()
#plot_segformer_results_as_image_graph()

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_contrastive_xx.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_imagenet_yt.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()