Clean up byol a bit

- Remove option to aug in dataset (there's really no reason for this now that kornia works on GPU on windows)
- Other stufff
This commit is contained in:
James Betker 2021-05-24 21:35:46 -06:00
parent 6649ef2dae
commit f129eaa39e
6 changed files with 53 additions and 62 deletions

View File

@ -206,10 +206,6 @@ class BYOL(nn.Module):
dbg['byol_distance'] = self.logs_loss dbg['byol_distance'] = self.logs_loss
return dbg return dbg
def visual_dbg(self, step, path):
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def get_predictions_and_projections(self, image): def get_predictions_and_projections(self, image):
_, _, h, w = image.shape _, _, h, w = image.shape
point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device) point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
@ -217,12 +213,6 @@ class BYOL(nn.Module):
image_one, pt_one = self.aug(image, point) image_one, pt_one = self.aug(image, point)
image_two, pt_two = self.aug(image, point) image_two, pt_two = self.aug(image, point)
# Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().clone()
self.im1[:,:,pt_one[0]-3:pt_one[0]+3,pt_one[1]-3:pt_one[1]+3] = 1
self.im2 = image_two.detach().clone()
self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1
online_proj_one = self.online_encoder(img=image_one, pos=pt_one) online_proj_one = self.online_encoder(img=image_one, pos=pt_one)
online_proj_two = self.online_encoder(img=image_two, pos=pt_two) online_proj_two = self.online_encoder(img=image_two, pos=pt_two)

View File

@ -188,24 +188,19 @@ class BYOL(nn.Module):
moving_average_decay=0.99, moving_average_decay=0.99,
use_momentum=True, use_momentum=True,
structural_mlp=False, structural_mlp=False,
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
# on the GPU again.
): ):
super().__init__() super().__init__()
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
use_structural_mlp=structural_mlp) use_structural_mlp=structural_mlp)
self.do_aug = do_augmentation augmentations = [ \
if self.do_aug: RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
augmentations = [ \ augs.RandomGrayscale(p=0.2),
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), augs.RandomHorizontalFlip(),
augs.RandomGrayscale(p=0.2), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
augs.RandomHorizontalFlip(), augs.RandomResizedCrop((image_size, image_size))]
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), self.aug = nn.Sequential(*augmentations)
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
self.aug = nn.Sequential(*augmentations)
self.use_momentum = use_momentum self.use_momentum = use_momentum
self.target_encoder = None self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay) self.target_ema_updater = EMA(moving_average_decay)
@ -242,17 +237,16 @@ class BYOL(nn.Module):
return {'target_ema_beta': self.target_ema_updater.beta} return {'target_ema_beta': self.target_ema_updater.beta}
def visual_dbg(self, step, path): def visual_dbg(self, step, path):
if self.do_aug: torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,))) torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
def forward(self, image_one, image_two): def forward(self, image_one, image_two):
if self.do_aug: image_one = self.aug(image_one)
image_one = self.aug(image_one) image_two = self.aug(image_two)
image_two = self.aug(image_two)
# Keep copies on hand for visual_dbg. # Keep copies on hand for visual_dbg.
self.im1 = image_one.detach().copy() self.im1 = image_one.detach().clone()
self.im2 = image_two.detach().copy() self.im2 = image_two.detach().clone()
online_proj_one = self.online_encoder(image_one) online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two) online_proj_two = self.online_encoder(image_two)
@ -276,5 +270,4 @@ class BYOL(nn.Module):
def register_byol(opt_net, opt): def register_byol(opt_net, opt):
subnet = create_model(opt, opt_net['subnet']) subnet = create_model(opt, opt_net['subnet'])
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False), structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))

View File

@ -122,11 +122,3 @@ def backbone152(pretrained=False, progress=True, **kwargs):
return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs) **kwargs)
@register_model
def register_resnet50(opt_net, opt):
model = resnet50(pretrained=opt_net['pretrained'])
if opt_net['custom_head_logits']:
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
return model

View File

@ -7,7 +7,7 @@ import torch.nn.functional as F
import torchvision import torchvision
from PIL import Image from PIL import Image
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Resize from torchvision.transforms import ToTensor, Resize, Normalize
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
@ -93,20 +93,25 @@ def register_hook(net, layer_name):
layer.register_forward_hook(_hook) layer.register_forward_hook(_hook)
def get_latent_for_img(model, img): def get_latent_for_img(model, img, pos):
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0) img_t = ToTensor()(Image.open(img)).to('cuda')[:3]
img_t = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(img_t).unsqueeze(0)
_, _, h, w = img_t.shape _, _, h, w = img_t.shape
# Center crop img_t and resize to 224. # Center crop img_t and resize to 224.
d = min(h, w) d = min(h, w)
dh, dw = (h-d)//2, (w-d)//2 dh, dw = (h-d)//2, (w-d)//2
if dw != 0: if dw != 0:
img_t = img_t[:, :, :, dw:-dw] img_t = img_t[:, :, :, dw:-dw]
pos[1] = pos[1]-dw
elif dh != 0: elif dh != 0:
img_t = img_t[:, :, dh:-dh, :] img_t = img_t[:, :, dh:-dh, :]
pos[0] = pos[0]-dh
scale = 224 / img_t.shape[-1]
pos = (pos * scale).long()
assert(pos.min() >= 0 and pos.max() < 224)
img_t = img_t[:,:3,:,:] img_t = img_t[:,:3,:,:]
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area") img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
model(img_t) latent = model(img=img_t,pos=pos)
latent = layer_hooked_value
return latent return latent
@ -124,7 +129,7 @@ def produce_latent_dict(model):
for k in range(10): for k in range(10):
_, _, h, _ = hq.shape _, _, h, _ = hq.shape
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device) point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
model(hq, point) model(img=hq, pos=point)
l = layer_hooked_value.cpu().split(1, dim=0) l = layer_hooked_value.cpu().split(1, dim=0)
latents.extend(l) latents.extend(l)
points.extend([point for p in range(batch_size)]) points.extend([point for p in range(batch_size)])
@ -139,43 +144,54 @@ def produce_latent_dict(model):
def find_similar_latents(model, compare_fn=structural_euc_dist): def find_similar_latents(model, compare_fn=structural_euc_dist):
global layer_hooked_value global layer_hooked_value
img = 'D:\\dlas\\results\\bobz.png' img = 'F:\\dlas\\results\\bobz.png'
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg' #img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
point=torch.tensor([154,330], dtype=torch.long, device='cuda')
output_path = '../../../results/byol_resnet_similars' output_path = '../../../results/byol_resnet_similars'
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
imglatent = get_latent_for_img(model, img).squeeze().unsqueeze(0) imglatent = get_latent_for_img(model, img, point).squeeze().unsqueeze(0)
_, c = imglatent.shape _, c = imglatent.shape
batch_size = 512 batch_size = 512
num_workers = 8 num_workers = 1
dataloader = get_image_folder_dataloader(batch_size, num_workers) dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0 id = 0
output_batch = 1 output_batch = 1
results = [] results = []
result_paths = [] result_paths = []
results_points = []
for batch in tqdm(dataloader): for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda') hq = batch['hq'].to('cuda')
model(hq) _,_,h,w = hq.shape
latent = layer_hooked_value.clone().squeeze() point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
latent = model(img=hq, pos=point)
compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent) compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent)
results.append(compared.cpu()) results.append(compared.cpu())
result_paths.extend(batch['HQ_path']) result_paths.extend(batch['HQ_path'])
results_points.append(point.unsqueeze(0).repeat(batch_size,1))
id += batch_size id += batch_size
if id > 10000: if id > 10000:
k = 200 k = 10
results = torch.cat(results, dim=0) results = torch.cat(results, dim=0)
results_points = torch.cat(results_points, dim=0)
vals, inds = torch.topk(results, k, largest=False) vals, inds = torch.topk(results, k, largest=False)
for i in inds: for i in inds:
mag = int(results[i].item() * 1000) point = results_points[i]
shutil.copy(result_paths[i], os.path.join(output_path, f'{mag:05}_{output_batch}_{i}.jpg')) mag = int(results[i].item() * 100000000)
hqr = ToTensor()(Image.open(result_paths[i])).to('cuda')
hqr *= .5
hqr[:,point[0]-3:point[0]+3,point[1]-3:point[1]+3] *= 2
torchvision.utils.save_image(hqr, os.path.join(output_path, f'{mag:08}_{output_batch}_{i}.jpg'))
results = [] results = []
result_paths = [] result_paths = []
results_points = []
id = 0 id = 0
def build_kmeans(): def build_kmeans():
latents, _, _ = torch.load('../results_segformer.pth') latents, _, _ = torch.load('../results_segformer.pth')
latents = torch.cat(latents, dim=0).squeeze().to('cuda') latents = torch.cat(latents, dim=0).squeeze().to('cuda')[50000:] * 10000
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0')) cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0'))
torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth') torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth')
@ -222,7 +238,7 @@ def use_kmeans():
if __name__ == '__main__': if __name__ == '__main__':
pretrained_path = '../../../experiments/segformer_byol_only.pth' pretrained_path = '../../../experiments/segformer_contrastive.pth'
model = Segformer().to('cuda') model = Segformer().to('cuda')
sd = torch.load(pretrained_path) sd = torch.load(pretrained_path)
resnet_sd = {} resnet_sd = {}
@ -234,7 +250,7 @@ if __name__ == '__main__':
register_hook(model, 'tail') register_hook(model, 'tail')
with torch.no_grad(): with torch.no_grad():
#find_similar_latents(model, structural_euc_dist) find_similar_latents(model, structural_euc_dist)
#produce_latent_dict(model) #produce_latent_dict(model)
#build_kmeans() #build_kmeans()
use_kmeans() #use_kmeans()

View File

@ -414,5 +414,5 @@ if __name__ == "__main__":
#plot_pixel_level_results_as_image_graph() #plot_pixel_level_results_as_image_graph()
# For use with segformer results # For use with segformer results
#run_tsne_segformer() run_tsne_segformer()
plot_segformer_results_as_image_graph() #plot_segformer_results_as_image_graph()

View File

@ -295,7 +295,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_contrastive_xx.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_imagenet_yt.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()