Clean up byol a bit
- Remove option to aug in dataset (there's really no reason for this now that kornia works on GPU on windows) - Other stufff
This commit is contained in:
parent
6649ef2dae
commit
f129eaa39e
|
@ -206,10 +206,6 @@ class BYOL(nn.Module):
|
|||
dbg['byol_distance'] = self.logs_loss
|
||||
return dbg
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||
|
||||
def get_predictions_and_projections(self, image):
|
||||
_, _, h, w = image.shape
|
||||
point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
|
||||
|
@ -217,12 +213,6 @@ class BYOL(nn.Module):
|
|||
image_one, pt_one = self.aug(image, point)
|
||||
image_two, pt_two = self.aug(image, point)
|
||||
|
||||
# Keep copies on hand for visual_dbg.
|
||||
self.im1 = image_one.detach().clone()
|
||||
self.im1[:,:,pt_one[0]-3:pt_one[0]+3,pt_one[1]-3:pt_one[1]+3] = 1
|
||||
self.im2 = image_two.detach().clone()
|
||||
self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1
|
||||
|
||||
online_proj_one = self.online_encoder(img=image_one, pos=pt_one)
|
||||
online_proj_two = self.online_encoder(img=image_two, pos=pt_two)
|
||||
|
||||
|
|
|
@ -188,24 +188,19 @@ class BYOL(nn.Module):
|
|||
moving_average_decay=0.99,
|
||||
use_momentum=True,
|
||||
structural_mlp=False,
|
||||
do_augmentation=False # In DLAS this was intended to be done at the dataset level. For massive batch sizes
|
||||
# this can overwhelm the CPU though, and it becomes desirable to do the augmentations
|
||||
# on the GPU again.
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer,
|
||||
use_structural_mlp=structural_mlp)
|
||||
|
||||
self.do_aug = do_augmentation
|
||||
if self.do_aug:
|
||||
augmentations = [ \
|
||||
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
augs.RandomHorizontalFlip(),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||
augs.RandomResizedCrop((self.cropped_img_size, self.cropped_img_size))]
|
||||
self.aug = nn.Sequential(*augmentations)
|
||||
augmentations = [ \
|
||||
RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
|
||||
augs.RandomGrayscale(p=0.2),
|
||||
augs.RandomHorizontalFlip(),
|
||||
RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
|
||||
augs.RandomResizedCrop((image_size, image_size))]
|
||||
self.aug = nn.Sequential(*augmentations)
|
||||
self.use_momentum = use_momentum
|
||||
self.target_encoder = None
|
||||
self.target_ema_updater = EMA(moving_average_decay)
|
||||
|
@ -242,17 +237,16 @@ class BYOL(nn.Module):
|
|||
return {'target_ema_beta': self.target_ema_updater.beta}
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
if self.do_aug:
|
||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||
|
||||
def forward(self, image_one, image_two):
|
||||
if self.do_aug:
|
||||
image_one = self.aug(image_one)
|
||||
image_two = self.aug(image_two)
|
||||
# Keep copies on hand for visual_dbg.
|
||||
self.im1 = image_one.detach().copy()
|
||||
self.im2 = image_two.detach().copy()
|
||||
image_one = self.aug(image_one)
|
||||
image_two = self.aug(image_two)
|
||||
|
||||
# Keep copies on hand for visual_dbg.
|
||||
self.im1 = image_one.detach().clone()
|
||||
self.im2 = image_two.detach().clone()
|
||||
|
||||
online_proj_one = self.online_encoder(image_one)
|
||||
online_proj_two = self.online_encoder(image_two)
|
||||
|
@ -276,5 +270,4 @@ class BYOL(nn.Module):
|
|||
def register_byol(opt_net, opt):
|
||||
subnet = create_model(opt, opt_net['subnet'])
|
||||
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False),
|
||||
do_augmentation=opt_get(opt_net, ['gpu_augmentation'], False))
|
||||
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
|
|
@ -122,11 +122,3 @@ def backbone152(pretrained=False, progress=True, **kwargs):
|
|||
return _backbone('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_resnet50(opt_net, opt):
|
||||
model = resnet50(pretrained=opt_net['pretrained'])
|
||||
if opt_net['custom_head_logits']:
|
||||
model.fc = nn.Linear(512 * 4, opt_net['custom_head_logits'])
|
||||
return model
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import ToTensor, Resize
|
||||
from torchvision.transforms import ToTensor, Resize, Normalize
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
|
@ -93,20 +93,25 @@ def register_hook(net, layer_name):
|
|||
layer.register_forward_hook(_hook)
|
||||
|
||||
|
||||
def get_latent_for_img(model, img):
|
||||
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
||||
def get_latent_for_img(model, img, pos):
|
||||
img_t = ToTensor()(Image.open(img)).to('cuda')[:3]
|
||||
img_t = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(img_t).unsqueeze(0)
|
||||
_, _, h, w = img_t.shape
|
||||
# Center crop img_t and resize to 224.
|
||||
d = min(h, w)
|
||||
dh, dw = (h-d)//2, (w-d)//2
|
||||
if dw != 0:
|
||||
img_t = img_t[:, :, :, dw:-dw]
|
||||
pos[1] = pos[1]-dw
|
||||
elif dh != 0:
|
||||
img_t = img_t[:, :, dh:-dh, :]
|
||||
pos[0] = pos[0]-dh
|
||||
scale = 224 / img_t.shape[-1]
|
||||
pos = (pos * scale).long()
|
||||
assert(pos.min() >= 0 and pos.max() < 224)
|
||||
img_t = img_t[:,:3,:,:]
|
||||
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
|
||||
model(img_t)
|
||||
latent = layer_hooked_value
|
||||
latent = model(img=img_t,pos=pos)
|
||||
return latent
|
||||
|
||||
|
||||
|
@ -124,7 +129,7 @@ def produce_latent_dict(model):
|
|||
for k in range(10):
|
||||
_, _, h, _ = hq.shape
|
||||
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
|
||||
model(hq, point)
|
||||
model(img=hq, pos=point)
|
||||
l = layer_hooked_value.cpu().split(1, dim=0)
|
||||
latents.extend(l)
|
||||
points.extend([point for p in range(batch_size)])
|
||||
|
@ -139,43 +144,54 @@ def produce_latent_dict(model):
|
|||
def find_similar_latents(model, compare_fn=structural_euc_dist):
|
||||
global layer_hooked_value
|
||||
|
||||
img = 'D:\\dlas\\results\\bobz.png'
|
||||
img = 'F:\\dlas\\results\\bobz.png'
|
||||
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
||||
point=torch.tensor([154,330], dtype=torch.long, device='cuda')
|
||||
|
||||
output_path = '../../../results/byol_resnet_similars'
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
imglatent = get_latent_for_img(model, img).squeeze().unsqueeze(0)
|
||||
imglatent = get_latent_for_img(model, img, point).squeeze().unsqueeze(0)
|
||||
_, c = imglatent.shape
|
||||
|
||||
batch_size = 512
|
||||
num_workers = 8
|
||||
num_workers = 1
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
output_batch = 1
|
||||
results = []
|
||||
result_paths = []
|
||||
results_points = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
model(hq)
|
||||
latent = layer_hooked_value.clone().squeeze()
|
||||
_,_,h,w = hq.shape
|
||||
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
|
||||
latent = model(img=hq, pos=point)
|
||||
compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent)
|
||||
results.append(compared.cpu())
|
||||
result_paths.extend(batch['HQ_path'])
|
||||
results_points.append(point.unsqueeze(0).repeat(batch_size,1))
|
||||
id += batch_size
|
||||
if id > 10000:
|
||||
k = 200
|
||||
k = 10
|
||||
results = torch.cat(results, dim=0)
|
||||
results_points = torch.cat(results_points, dim=0)
|
||||
vals, inds = torch.topk(results, k, largest=False)
|
||||
for i in inds:
|
||||
mag = int(results[i].item() * 1000)
|
||||
shutil.copy(result_paths[i], os.path.join(output_path, f'{mag:05}_{output_batch}_{i}.jpg'))
|
||||
point = results_points[i]
|
||||
mag = int(results[i].item() * 100000000)
|
||||
hqr = ToTensor()(Image.open(result_paths[i])).to('cuda')
|
||||
hqr *= .5
|
||||
hqr[:,point[0]-3:point[0]+3,point[1]-3:point[1]+3] *= 2
|
||||
torchvision.utils.save_image(hqr, os.path.join(output_path, f'{mag:08}_{output_batch}_{i}.jpg'))
|
||||
results = []
|
||||
result_paths = []
|
||||
results_points = []
|
||||
id = 0
|
||||
|
||||
|
||||
def build_kmeans():
|
||||
latents, _, _ = torch.load('../results_segformer.pth')
|
||||
latents = torch.cat(latents, dim=0).squeeze().to('cuda')
|
||||
latents = torch.cat(latents, dim=0).squeeze().to('cuda')[50000:] * 10000
|
||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0'))
|
||||
torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth')
|
||||
|
||||
|
@ -222,7 +238,7 @@ def use_kmeans():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../../../experiments/segformer_byol_only.pth'
|
||||
pretrained_path = '../../../experiments/segformer_contrastive.pth'
|
||||
model = Segformer().to('cuda')
|
||||
sd = torch.load(pretrained_path)
|
||||
resnet_sd = {}
|
||||
|
@ -234,7 +250,7 @@ if __name__ == '__main__':
|
|||
register_hook(model, 'tail')
|
||||
|
||||
with torch.no_grad():
|
||||
#find_similar_latents(model, structural_euc_dist)
|
||||
find_similar_latents(model, structural_euc_dist)
|
||||
#produce_latent_dict(model)
|
||||
#build_kmeans()
|
||||
use_kmeans()
|
||||
#use_kmeans()
|
||||
|
|
|
@ -414,5 +414,5 @@ if __name__ == "__main__":
|
|||
#plot_pixel_level_results_as_image_graph()
|
||||
|
||||
# For use with segformer results
|
||||
#run_tsne_segformer()
|
||||
plot_segformer_results_as_image_graph()
|
||||
run_tsne_segformer()
|
||||
#plot_segformer_results_as_image_graph()
|
|
@ -295,7 +295,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_contrastive_xx.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_resnet_imagenet_yt.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user