forked from mrq/DL-Art-School
test uresnet playground mods
This commit is contained in:
parent
1b8a26db93
commit
dac7d768fa
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import shutil
|
||||
from random import shuffle
|
||||
|
||||
import matplotlib.cm as cm
|
||||
import torch
|
||||
|
@ -16,6 +17,8 @@ import numpy as np
|
|||
import utils
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
|
||||
from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2
|
||||
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
||||
from models.resnet_with_checkpointing import resnet50
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
|
@ -59,9 +62,9 @@ def im_norm(x):
|
|||
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'weights': [1],
|
||||
'target_size': target_size,
|
||||
|
@ -72,8 +75,8 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
|
|||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||
|
||||
|
||||
def produce_latent_dict(model):
|
||||
batch_size = 32
|
||||
def produce_latent_dict(model, basename):
|
||||
batch_size = 64
|
||||
num_workers = 4
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
|
@ -89,30 +92,33 @@ def produce_latent_dict(model):
|
|||
# extract a random set of 10 latents from each image
|
||||
if prob is None:
|
||||
prob = torch.full((dim,), 1/(dim))
|
||||
l = l[prob.multinomial(num_samples=10, replacement=False)].split(1, dim=0)
|
||||
l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0)
|
||||
latents.extend(l)
|
||||
paths.extend(batch['HQ_path'])
|
||||
id += batch_size
|
||||
if id > 1000:
|
||||
if id > 5000:
|
||||
print("Saving checkpoint..")
|
||||
torch.save((latents, paths), '../imagenet_latent_dict.pth')
|
||||
torch.save((latents, paths), f'../{basename}_latent_dict.pth')
|
||||
id = 0
|
||||
|
||||
|
||||
def build_kmeans():
|
||||
latents, _ = torch.load('../imagenet_latent_dict.pth')
|
||||
def build_kmeans(basename):
|
||||
latents, _ = torch.load(f'../{basename}_latent_dict.pth')
|
||||
shuffle(latents)
|
||||
latents = torch.cat(latents, dim=0).to('cuda')
|
||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=4, distance="euclidean", device=torch.device('cuda:0'))
|
||||
torch.save((cluster_ids_x, cluster_centers), '../k_means_imagenet.pth')
|
||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000)
|
||||
torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth')
|
||||
|
||||
|
||||
def use_kmeans():
|
||||
_, centers = torch.load('../k_means_imagenet.pth')
|
||||
def use_kmeans(basename):
|
||||
output_path = f'../results/{basename}_kmeans_viz'
|
||||
_, centers = torch.load(f'../{basename}_k_means_centroids.pth')
|
||||
centers = centers.to('cuda')
|
||||
batch_size = 8
|
||||
num_workers = 0
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
|
||||
colormap = cm.get_cmap('viridis', 8)
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
for i, batch in enumerate(tqdm(dataloader)):
|
||||
hq = batch['hq'].to('cuda')
|
||||
l = model(hq)
|
||||
|
@ -122,13 +128,16 @@ def use_kmeans():
|
|||
pred = kmeans_predict(l, centers)
|
||||
pred = pred.reshape(b,h,w)
|
||||
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
|
||||
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png")
|
||||
torchvision.utils.save_image(hq, f"{i}_hq.png")
|
||||
scale = hq.shape[-2] / h
|
||||
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"),
|
||||
f"{output_path}/{i}_categories.png")
|
||||
torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth'
|
||||
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=256).to('cuda')
|
||||
pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth'
|
||||
basename = 'uresnet_pixpro4'
|
||||
model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda')
|
||||
sd = torch.load(pretrained_path)
|
||||
resnet_sd = {}
|
||||
for k, v in sd.items():
|
||||
|
@ -140,6 +149,6 @@ if __name__ == '__main__':
|
|||
with torch.no_grad():
|
||||
#find_similar_latents(model, 0, 8, structural_euc_dist)
|
||||
#create_latent_database(model, batch_size=32)
|
||||
#produce_latent_dict(model)
|
||||
build_kmeans()
|
||||
#use_kmeans()
|
||||
#produce_latent_dict(model, basename)
|
||||
#uild_kmeans(basename)
|
||||
use_kmeans(basename)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# From: https://github.com/subhadarship/kmeans_pytorch
|
||||
# License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -30,6 +31,7 @@ def kmeans(
|
|||
tol=1e-4,
|
||||
tqdm_flag=True,
|
||||
iter_limit=0,
|
||||
gravity_limit_per_iter=None,
|
||||
device=torch.device('cpu')
|
||||
):
|
||||
"""
|
||||
|
@ -83,9 +85,10 @@ def kmeans(
|
|||
|
||||
for index in range(num_clusters):
|
||||
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
|
||||
|
||||
selected = torch.index_select(X, 0, selected)
|
||||
|
||||
if gravity_limit_per_iter and len(selected) > gravity_limit_per_iter:
|
||||
ch = random.randint(0, len(selected)-gravity_limit_per_iter)
|
||||
selected=selected[ch:ch+gravity_limit_per_iter]
|
||||
initial_state[index] = selected.mean(dim=0)
|
||||
|
||||
center_shift = torch.sum(
|
||||
|
@ -97,14 +100,16 @@ def kmeans(
|
|||
iteration = iteration + 1
|
||||
|
||||
# update tqdm meter
|
||||
bins = torch.bincount(choice_cluster)
|
||||
if tqdm_flag:
|
||||
tqdm_meter.set_postfix(
|
||||
iteration=f'{iteration}',
|
||||
center_shift=f'{center_shift ** 2:0.6f}',
|
||||
tol=f'{tol:0.6f}'
|
||||
center_shift=f'{center_shift ** 2}',
|
||||
tol=f'{tol}',
|
||||
bins=f'{bins}',
|
||||
)
|
||||
tqdm_meter.update()
|
||||
if center_shift ** 2 < tol:
|
||||
if tol > 0 and center_shift ** 2 < tol:
|
||||
break
|
||||
if iter_limit != 0 and iteration >= iter_limit:
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue
Block a user