forked from mrq/DL-Art-School
Clean up byol a bit
- Remove option to aug in dataset (there's really no reason for this now that kornia works on GPU on windows) - Other stufff
This commit is contained in:
parent
6649ef2dae
commit
f129eaa39e
|
@ -206,10 +206,6 @@ class BYOL(nn.Module):
|
||||||
dbg['byol_distance'] = self.logs_loss
|
dbg['byol_distance'] = self.logs_loss
|
||||||
return dbg
|
return dbg
|
||||||
|
|
||||||
def visual_dbg(self, step, path):
|
|
||||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
|
||||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
|
||||||
|
|
||||||
def get_predictions_and_projections(self, image):
|
def get_predictions_and_projections(self, image):
|
||||||
_, _, h, w = image.shape
|
_, _, h, w = image.shape
|
||||||
point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
|
point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
|
||||||
|
@ -217,12 +213,6 @@ class BYOL(nn.Module):
|
||||||
image_one, pt_one = self.aug(image, point)
|
image_one, pt_one = self.aug(image, point)
|
||||||
image_two, pt_two = self.aug(image, point)
|
image_two, pt_two = self.aug(image, point)
|
||||||
|
|
||||||
# Keep copies on hand for visual_dbg.
|
|
||||||
self.im1 = image_one.detach().clone()
|
|
||||||
self.im1[:,:,pt_one[0]-3:pt_one[0]+3,pt_one[1]-3:pt_one[1]+3] = 1
|
|
||||||
self.im2 = image_two.detach().clone()
|
|
||||||
self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1
|
|
||||||
|
|
||||||
online_proj_one = self.online_encoder(img=image_one, pos=pt_one)
|
online_proj_one = self.online_encoder(img=image_one, pos=pt_one)
|
||||||
online_proj_two = self.online_encoder(img=image_two, pos=pt_two)
|
online_proj_two = self.online_encoder(img=image_two, pos=pt_two)
|
||||||
|
|
||||||
|
|
|
@ -188,24 +188,19 @@ class BYOL(nn.Module):
|
||||||
moving_average_decay=0.99,
|
moving_average_decay=0.99,
|
||||||
use_momentum=True,
|
use_momentum=True,
|
||||||
structural_mlp=False,
|
structural_mlp=False,
|
||||||
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
|
|
||||||
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
|
|
||||||
# on the GPU again.
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
|
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
|
||||||
use_structural_mlp=structural_mlp)
|
use_structural_mlp=structural_mlp)
|
||||||
|
|
||||||
self.do_aug = do_augmentation
|
augmentations = [ \
|
||||||
if self.do_aug:
|
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||||
augmentations = [ \
|
augs.RandomGrayscale(p=0.2),
|
||||||
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
augs.RandomHorizontalFlip(),
|
||||||
augs.RandomGrayscale(p=0.2),
|
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||||
augs.RandomHorizontalFlip(),
|
augs.RandomResizedCrop((image_size, image_size))]
|
||||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
self.aug = nn.Sequential(*augmentations)
|
||||||
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
|
||||||
self.aug = nn.Sequential(*augmentations)
|
|
||||||
self.use_momentum = use_momentum
|
self.use_momentum = use_momentum
|
||||||
self.target_encoder = None
|
self.target_encoder = None
|
||||||
self.target_ema_updater = EMA(moving_average_decay)
|
self.target_ema_updater = EMA(moving_average_decay)
|
||||||
|
@ -242,17 +237,16 @@ class BYOL(nn.Module):
|
||||||
return {'target_ema_beta': self.target_ema_updater.beta}
|
return {'target_ema_beta': self.target_ema_updater.beta}
|
||||||
|
|
||||||
def visual_dbg(self, step, path):
|
def visual_dbg(self, step, path):
|
||||||
if self.do_aug:
|
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
|
||||||
|
|
||||||
def forward(self, image_one, image_two):
|
def forward(self, image_one, image_two):
|
||||||
if self.do_aug:
|
image_one = self.aug(image_one)
|
||||||
image_one = self.aug(image_one)
|
image_two = self.aug(image_two)
|
||||||
image_two = self.aug(image_two)
|
|
||||||
# Keep copies on hand for visual_dbg.
|
# Keep copies on hand for visual_dbg.
|
||||||
self.im1 = image_one.detach().copy()
|
self.im1 = image_one.detach().clone()
|
||||||
self.im2 = image_two.detach().copy()
|
self.im2 = image_two.detach().clone()
|
||||||
|
|
||||||
online_proj_one = self.online_encoder(image_one)
|
online_proj_one = self.online_encoder(image_one)
|
||||||
online_proj_two = self.online_encoder(image_two)
|
online_proj_two = self.online_encoder(image_two)
|
||||||
|
@ -276,5 +270,4 @@ class BYOL(nn.Module):
|
||||||
def register_byol(opt_net, opt):
|
def register_byol(opt_net, opt):
|
||||||
subnet = create_model(opt, opt_net['subnet'])
|
subnet = create_model(opt, opt_net['subnet'])
|
||||||
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
|
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
||||||
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
|
|
|
@ -122,11 +122,3 @@ def backbone152(pretrained=False, progress=True, **kwargs):
|
||||||
return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def register_resnet50(opt_net, opt):
|
|
||||||
model = resnet50(pretrained=opt_net['pretrained'])
|
|
||||||
if opt_net['custom_head_logits']:
|
|
||||||
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.transforms import ToTensor, Resize
|
from torchvision.transforms import ToTensor, Resize, Normalize
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -93,20 +93,25 @@ def register_hook(net, layer_name):
|
||||||
layer.register_forward_hook(_hook)
|
layer.register_forward_hook(_hook)
|
||||||
|
|
||||||
|
|
||||||
def get_latent_for_img(model, img):
|
def get_latent_for_img(model, img, pos):
|
||||||
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
img_t = ToTensor()(Image.open(img)).to('cuda')[:3]
|
||||||
|
img_t = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(img_t).unsqueeze(0)
|
||||||
_, _, h, w = img_t.shape
|
_, _, h, w = img_t.shape
|
||||||
# Center crop img_t and resize to 224.
|
# Center crop img_t and resize to 224.
|
||||||
d = min(h, w)
|
d = min(h, w)
|
||||||
dh, dw = (h-d)//2, (w-d)//2
|
dh, dw = (h-d)//2, (w-d)//2
|
||||||
if dw != 0:
|
if dw != 0:
|
||||||
img_t = img_t[:, :, :, dw:-dw]
|
img_t = img_t[:, :, :, dw:-dw]
|
||||||
|
pos[1] = pos[1]-dw
|
||||||
elif dh != 0:
|
elif dh != 0:
|
||||||
img_t = img_t[:, :, dh:-dh, :]
|
img_t = img_t[:, :, dh:-dh, :]
|
||||||
|
pos[0] = pos[0]-dh
|
||||||
|
scale = 224 / img_t.shape[-1]
|
||||||
|
pos = (pos * scale).long()
|
||||||
|
assert(pos.min() >= 0 and pos.max() < 224)
|
||||||
img_t = img_t[:,:3,:,:]
|
img_t = img_t[:,:3,:,:]
|
||||||
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
|
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
|
||||||
model(img_t)
|
latent = model(img=img_t,pos=pos)
|
||||||
latent = layer_hooked_value
|
|
||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
|
||||||
|
@ -124,7 +129,7 @@ def produce_latent_dict(model):
|
||||||
for k in range(10):
|
for k in range(10):
|
||||||
_, _, h, _ = hq.shape
|
_, _, h, _ = hq.shape
|
||||||
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
|
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
|
||||||
model(hq, point)
|
model(img=hq, pos=point)
|
||||||
l = layer_hooked_value.cpu().split(1, dim=0)
|
l = layer_hooked_value.cpu().split(1, dim=0)
|
||||||
latents.extend(l)
|
latents.extend(l)
|
||||||
points.extend([point for p in range(batch_size)])
|
points.extend([point for p in range(batch_size)])
|
||||||
|
@ -139,43 +144,54 @@ def produce_latent_dict(model):
|
||||||
def find_similar_latents(model, compare_fn=structural_euc_dist):
|
def find_similar_latents(model, compare_fn=structural_euc_dist):
|
||||||
global layer_hooked_value
|
global layer_hooked_value
|
||||||
|
|
||||||
img = 'D:\\dlas\\results\\bobz.png'
|
img = 'F:\\dlas\\results\\bobz.png'
|
||||||
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
||||||
|
point=torch.tensor([154,330], dtype=torch.long, device='cuda')
|
||||||
|
|
||||||
output_path = '../../../results/byol_resnet_similars'
|
output_path = '../../../results/byol_resnet_similars'
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
imglatent = get_latent_for_img(model, img).squeeze().unsqueeze(0)
|
imglatent = get_latent_for_img(model, img, point).squeeze().unsqueeze(0)
|
||||||
_, c = imglatent.shape
|
_, c = imglatent.shape
|
||||||
|
|
||||||
batch_size = 512
|
batch_size = 512
|
||||||
num_workers = 8
|
num_workers = 1
|
||||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||||
id = 0
|
id = 0
|
||||||
output_batch = 1
|
output_batch = 1
|
||||||
results = []
|
results = []
|
||||||
result_paths = []
|
result_paths = []
|
||||||
|
results_points = []
|
||||||
for batch in tqdm(dataloader):
|
for batch in tqdm(dataloader):
|
||||||
hq = batch['hq'].to('cuda')
|
hq = batch['hq'].to('cuda')
|
||||||
model(hq)
|
_,_,h,w = hq.shape
|
||||||
latent = layer_hooked_value.clone().squeeze()
|
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
|
||||||
|
latent = model(img=hq, pos=point)
|
||||||
compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent)
|
compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent)
|
||||||
results.append(compared.cpu())
|
results.append(compared.cpu())
|
||||||
result_paths.extend(batch['HQ_path'])
|
result_paths.extend(batch['HQ_path'])
|
||||||
|
results_points.append(point.unsqueeze(0).repeat(batch_size,1))
|
||||||
id += batch_size
|
id += batch_size
|
||||||
if id > 10000:
|
if id > 10000:
|
||||||
k = 200
|
k = 10
|
||||||
results = torch.cat(results, dim=0)
|
results = torch.cat(results, dim=0)
|
||||||
|
results_points = torch.cat(results_points, dim=0)
|
||||||
vals, inds = torch.topk(results, k, largest=False)
|
vals, inds = torch.topk(results, k, largest=False)
|
||||||
for i in inds:
|
for i in inds:
|
||||||
mag = int(results[i].item() * 1000)
|
point = results_points[i]
|
||||||
shutil.copy(result_paths[i], os.path.join(output_path, f'{mag:05}_{output_batch}_{i}.jpg'))
|
mag = int(results[i].item() * 100000000)
|
||||||
|
hqr = ToTensor()(Image.open(result_paths[i])).to('cuda')
|
||||||
|
hqr *= .5
|
||||||
|
hqr[:,point[0]-3:point[0]+3,point[1]-3:point[1]+3] *= 2
|
||||||
|
torchvision.utils.save_image(hqr, os.path.join(output_path, f'{mag:08}_{output_batch}_{i}.jpg'))
|
||||||
results = []
|
results = []
|
||||||
result_paths = []
|
result_paths = []
|
||||||
|
results_points = []
|
||||||
id = 0
|
id = 0
|
||||||
|
|
||||||
|
|
||||||
def build_kmeans():
|
def build_kmeans():
|
||||||
latents, _, _ = torch.load('../results_segformer.pth')
|
latents, _, _ = torch.load('../results_segformer.pth')
|
||||||
latents = torch.cat(latents, dim=0).squeeze().to('cuda')
|
latents = torch.cat(latents, dim=0).squeeze().to('cuda')[50000:] * 10000
|
||||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0'))
|
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0'))
|
||||||
torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth')
|
torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth')
|
||||||
|
|
||||||
|
@ -222,7 +238,7 @@ def use_kmeans():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pretrained_path = '../../../experiments/segformer_byol_only.pth'
|
pretrained_path = '../../../experiments/segformer_contrastive.pth'
|
||||||
model = Segformer().to('cuda')
|
model = Segformer().to('cuda')
|
||||||
sd = torch.load(pretrained_path)
|
sd = torch.load(pretrained_path)
|
||||||
resnet_sd = {}
|
resnet_sd = {}
|
||||||
|
@ -234,7 +250,7 @@ if __name__ == '__main__':
|
||||||
register_hook(model, 'tail')
|
register_hook(model, 'tail')
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
#find_similar_latents(model, structural_euc_dist)
|
find_similar_latents(model, structural_euc_dist)
|
||||||
#produce_latent_dict(model)
|
#produce_latent_dict(model)
|
||||||
#build_kmeans()
|
#build_kmeans()
|
||||||
use_kmeans()
|
#use_kmeans()
|
||||||
|
|
|
@ -414,5 +414,5 @@ if __name__ == "__main__":
|
||||||
#plot_pixel_level_results_as_image_graph()
|
#plot_pixel_level_results_as_image_graph()
|
||||||
|
|
||||||
# For use with segformer results
|
# For use with segformer results
|
||||||
#run_tsne_segformer()
|
run_tsne_segformer()
|
||||||
plot_segformer_results_as_image_graph()
|
#plot_segformer_results_as_image_graph()
|
|
@ -295,7 +295,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_contrastive_xx.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_imagenet_yt.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user