test uresnet playground mods

This commit is contained in:
James Betker 2021-01-23 13:46:43 -07:00
parent 1b8a26db93
commit dac7d768fa
2 changed files with 39 additions and 25 deletions

View File

@ -1,5 +1,6 @@
import os import os
import shutil import shutil
from random import shuffle
import matplotlib.cm as cm import matplotlib.cm as cm
import torch import torch
@ -16,6 +17,8 @@ import numpy as np
import utils import utils
from data.image_folder_dataset import ImageFolderDataset from data.image_folder_dataset import ImageFolderDataset
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50 from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
from models.resnet_with_checkpointing import resnet50 from models.resnet_with_checkpointing import resnet50
from models.spinenet_arch import SpineNet from models.spinenet_arch import SpineNet
@ -59,9 +62,9 @@ def im_norm(x):
def get_image_folder_dataloader(batch_size, num_workers, target_size=256): def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
dataset_opt = dict_to_nonedict({ dataset_opt = dict_to_nonedict({
'name': 'amalgam', 'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'], #'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1], 'weights': [1],
'target_size': target_size, 'target_size': target_size,
@ -72,8 +75,8 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
def produce_latent_dict(model): def produce_latent_dict(model, basename):
batch_size = 32 batch_size = 64
num_workers = 4 num_workers = 4
dataloader = get_image_folder_dataloader(batch_size, num_workers) dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0 id = 0
@ -89,30 +92,33 @@ def produce_latent_dict(model):
# extract a random set of 10 latents from each image # extract a random set of 10 latents from each image
if prob is None: if prob is None:
prob = torch.full((dim,), 1/(dim)) prob = torch.full((dim,), 1/(dim))
l = l[prob.multinomial(num_samples=10, replacement=False)].split(1, dim=0) l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0)
latents.extend(l) latents.extend(l)
paths.extend(batch['HQ_path']) paths.extend(batch['HQ_path'])
id += batch_size id += batch_size
if id > 1000: if id > 5000:
print("Saving checkpoint..") print("Saving checkpoint..")
torch.save((latents, paths), '../imagenet_latent_dict.pth') torch.save((latents, paths), f'../{basename}_latent_dict.pth')
id = 0 id = 0
def build_kmeans(): def build_kmeans(basename):
latents, _ = torch.load('../imagenet_latent_dict.pth') latents, _ = torch.load(f'../{basename}_latent_dict.pth')
shuffle(latents)
latents = torch.cat(latents, dim=0).to('cuda') latents = torch.cat(latents, dim=0).to('cuda')
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=4, distance="euclidean", device=torch.device('cuda:0')) cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000)
torch.save((cluster_ids_x, cluster_centers), '../k_means_imagenet.pth') torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth')
def use_kmeans(): def use_kmeans(basename):
_, centers = torch.load('../k_means_imagenet.pth') output_path = f'../results/{basename}_kmeans_viz'
_, centers = torch.load(f'../{basename}_k_means_centroids.pth')
centers = centers.to('cuda') centers = centers.to('cuda')
batch_size = 8 batch_size = 8
num_workers = 0 num_workers = 0
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256) dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
colormap = cm.get_cmap('viridis', 8) colormap = cm.get_cmap('viridis', 8)
os.makedirs(output_path, exist_ok=True)
for i, batch in enumerate(tqdm(dataloader)): for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda') hq = batch['hq'].to('cuda')
l = model(hq) l = model(hq)
@ -122,13 +128,16 @@ def use_kmeans():
pred = kmeans_predict(l, centers) pred = kmeans_predict(l, centers)
pred = pred.reshape(b,h,w) pred = pred.reshape(b,h,w)
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy())) img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png") scale = hq.shape[-2] / h
torchvision.utils.save_image(hq, f"{i}_hq.png") torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"),
f"{output_path}/{i}_categories.png")
torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png")
if __name__ == '__main__': if __name__ == '__main__':
pretrained_path = '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth' pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth'
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=256).to('cuda') basename = 'uresnet_pixpro4'
model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda')
sd = torch.load(pretrained_path) sd = torch.load(pretrained_path)
resnet_sd = {} resnet_sd = {}
for k, v in sd.items(): for k, v in sd.items():
@ -140,6 +149,6 @@ if __name__ == '__main__':
with torch.no_grad(): with torch.no_grad():
#find_similar_latents(model, 0, 8, structural_euc_dist) #find_similar_latents(model, 0, 8, structural_euc_dist)
#create_latent_database(model, batch_size=32) #create_latent_database(model, batch_size=32)
#produce_latent_dict(model) #produce_latent_dict(model, basename)
build_kmeans() #uild_kmeans(basename)
#use_kmeans() use_kmeans(basename)

View File

@ -1,5 +1,6 @@
# From: https://github.com/subhadarship/kmeans_pytorch # From: https://github.com/subhadarship/kmeans_pytorch
# License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE # License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE
import random
import numpy as np import numpy as np
import torch import torch
@ -30,6 +31,7 @@ def kmeans(
tol=1e-4, tol=1e-4,
tqdm_flag=True, tqdm_flag=True,
iter_limit=0, iter_limit=0,
gravity_limit_per_iter=None,
device=torch.device('cpu') device=torch.device('cpu')
): ):
""" """
@ -83,9 +85,10 @@ def kmeans(
for index in range(num_clusters): for index in range(num_clusters):
selected = torch.nonzero(choice_cluster == index).squeeze().to(device) selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
selected = torch.index_select(X, 0, selected) selected = torch.index_select(X, 0, selected)
if gravity_limit_per_iter and len(selected) > gravity_limit_per_iter:
ch = random.randint(0, len(selected)-gravity_limit_per_iter)
selected=selected[ch:ch+gravity_limit_per_iter]
initial_state[index] = selected.mean(dim=0) initial_state[index] = selected.mean(dim=0)
center_shift = torch.sum( center_shift = torch.sum(
@ -97,14 +100,16 @@ def kmeans(
iteration = iteration + 1 iteration = iteration + 1
# update tqdm meter # update tqdm meter
bins = torch.bincount(choice_cluster)
if tqdm_flag: if tqdm_flag:
tqdm_meter.set_postfix( tqdm_meter.set_postfix(
iteration=f'{iteration}', iteration=f'{iteration}',
center_shift=f'{center_shift ** 2:0.6f}', center_shift=f'{center_shift ** 2}',
tol=f'{tol:0.6f}' tol=f'{tol}',
bins=f'{bins}',
) )
tqdm_meter.update() tqdm_meter.update()
if center_shift ** 2 < tol: if tol > 0 and center_shift ** 2 < tol:
break break
if iter_limit != 0 and iteration >= iter_limit: if iter_limit != 0 and iteration >= iter_limit:
break break