forked from mrq/DL-Art-School
test uresnet playground mods
This commit is contained in:
parent
1b8a26db93
commit
dac7d768fa
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from random import shuffle
|
||||||
|
|
||||||
import matplotlib.cm as cm
|
import matplotlib.cm as cm
|
||||||
import torch
|
import torch
|
||||||
|
@ -16,6 +17,8 @@ import numpy as np
|
||||||
import utils
|
import utils
|
||||||
from data.image_folder_dataset import ImageFolderDataset
|
from data.image_folder_dataset import ImageFolderDataset
|
||||||
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
|
from models.pixel_level_contrastive_learning.resnet_unet import UResNet50
|
||||||
|
from models.pixel_level_contrastive_learning.resnet_unet_2 import UResNet50_2
|
||||||
|
from models.pixel_level_contrastive_learning.resnet_unet_3 import UResNet50_3
|
||||||
from models.resnet_with_checkpointing import resnet50
|
from models.resnet_with_checkpointing import resnet50
|
||||||
from models.spinenet_arch import SpineNet
|
from models.spinenet_arch import SpineNet
|
||||||
|
|
||||||
|
@ -59,9 +62,9 @@ def im_norm(x):
|
||||||
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
|
def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
|
||||||
dataset_opt = dict_to_nonedict({
|
dataset_opt = dict_to_nonedict({
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
|
#'paths': ['F:\\4k6k\\datasets\\images\\imagenet_2017\\train'],
|
||||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': target_size,
|
'target_size': target_size,
|
||||||
|
@ -72,8 +75,8 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=256):
|
||||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
def produce_latent_dict(model):
|
def produce_latent_dict(model, basename):
|
||||||
batch_size = 32
|
batch_size = 64
|
||||||
num_workers = 4
|
num_workers = 4
|
||||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||||
id = 0
|
id = 0
|
||||||
|
@ -89,30 +92,33 @@ def produce_latent_dict(model):
|
||||||
# extract a random set of 10 latents from each image
|
# extract a random set of 10 latents from each image
|
||||||
if prob is None:
|
if prob is None:
|
||||||
prob = torch.full((dim,), 1/(dim))
|
prob = torch.full((dim,), 1/(dim))
|
||||||
l = l[prob.multinomial(num_samples=10, replacement=False)].split(1, dim=0)
|
l = l[prob.multinomial(num_samples=100, replacement=False)].split(1, dim=0)
|
||||||
latents.extend(l)
|
latents.extend(l)
|
||||||
paths.extend(batch['HQ_path'])
|
paths.extend(batch['HQ_path'])
|
||||||
id += batch_size
|
id += batch_size
|
||||||
if id > 1000:
|
if id > 5000:
|
||||||
print("Saving checkpoint..")
|
print("Saving checkpoint..")
|
||||||
torch.save((latents, paths), '../imagenet_latent_dict.pth')
|
torch.save((latents, paths), f'../{basename}_latent_dict.pth')
|
||||||
id = 0
|
id = 0
|
||||||
|
|
||||||
|
|
||||||
def build_kmeans():
|
def build_kmeans(basename):
|
||||||
latents, _ = torch.load('../imagenet_latent_dict.pth')
|
latents, _ = torch.load(f'../{basename}_latent_dict.pth')
|
||||||
|
shuffle(latents)
|
||||||
latents = torch.cat(latents, dim=0).to('cuda')
|
latents = torch.cat(latents, dim=0).to('cuda')
|
||||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=4, distance="euclidean", device=torch.device('cuda:0'))
|
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=8, distance="euclidean", device=torch.device('cuda:0'), tol=0, iter_limit=5000, gravity_limit_per_iter=1000)
|
||||||
torch.save((cluster_ids_x, cluster_centers), '../k_means_imagenet.pth')
|
torch.save((cluster_ids_x, cluster_centers), f'../{basename}_k_means_centroids.pth')
|
||||||
|
|
||||||
|
|
||||||
def use_kmeans():
|
def use_kmeans(basename):
|
||||||
_, centers = torch.load('../k_means_imagenet.pth')
|
output_path = f'../results/{basename}_kmeans_viz'
|
||||||
|
_, centers = torch.load(f'../{basename}_k_means_centroids.pth')
|
||||||
centers = centers.to('cuda')
|
centers = centers.to('cuda')
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
num_workers = 0
|
num_workers = 0
|
||||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
|
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=256)
|
||||||
colormap = cm.get_cmap('viridis', 8)
|
colormap = cm.get_cmap('viridis', 8)
|
||||||
|
os.makedirs(output_path, exist_ok=True)
|
||||||
for i, batch in enumerate(tqdm(dataloader)):
|
for i, batch in enumerate(tqdm(dataloader)):
|
||||||
hq = batch['hq'].to('cuda')
|
hq = batch['hq'].to('cuda')
|
||||||
l = model(hq)
|
l = model(hq)
|
||||||
|
@ -122,13 +128,16 @@ def use_kmeans():
|
||||||
pred = kmeans_predict(l, centers)
|
pred = kmeans_predict(l, centers)
|
||||||
pred = pred.reshape(b,h,w)
|
pred = pred.reshape(b,h,w)
|
||||||
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
|
img = torch.tensor(colormap(pred[:, :, :].detach().cpu().numpy()))
|
||||||
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=8, mode="nearest"), f"{i}_categories.png")
|
scale = hq.shape[-2] / h
|
||||||
torchvision.utils.save_image(hq, f"{i}_hq.png")
|
torchvision.utils.save_image(torch.nn.functional.interpolate(img.permute(0,3,1,2), scale_factor=scale, mode="nearest"),
|
||||||
|
f"{output_path}/{i}_categories.png")
|
||||||
|
torchvision.utils.save_image(hq, f"{output_path}/{i}_hq.png")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pretrained_path = '../experiments/train_imagenet_pixpro_resnet/models/66500_generator.pth'
|
pretrained_path = '../experiments/uresnet_pixpro4_imgset.pth'
|
||||||
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=256).to('cuda')
|
basename = 'uresnet_pixpro4'
|
||||||
|
model = UResNet50_3(Bottleneck, [3,4,6,3], out_dim=64).to('cuda')
|
||||||
sd = torch.load(pretrained_path)
|
sd = torch.load(pretrained_path)
|
||||||
resnet_sd = {}
|
resnet_sd = {}
|
||||||
for k, v in sd.items():
|
for k, v in sd.items():
|
||||||
|
@ -140,6 +149,6 @@ if __name__ == '__main__':
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
#find_similar_latents(model, 0, 8, structural_euc_dist)
|
#find_similar_latents(model, 0, 8, structural_euc_dist)
|
||||||
#create_latent_database(model, batch_size=32)
|
#create_latent_database(model, batch_size=32)
|
||||||
#produce_latent_dict(model)
|
#produce_latent_dict(model, basename)
|
||||||
build_kmeans()
|
#uild_kmeans(basename)
|
||||||
#use_kmeans()
|
use_kmeans(basename)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# From: https://github.com/subhadarship/kmeans_pytorch
|
# From: https://github.com/subhadarship/kmeans_pytorch
|
||||||
# License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE
|
# License: https://github.com/subhadarship/kmeans_pytorch/blob/master/LICENSE
|
||||||
|
import random
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -30,6 +31,7 @@ def kmeans(
|
||||||
tol=1e-4,
|
tol=1e-4,
|
||||||
tqdm_flag=True,
|
tqdm_flag=True,
|
||||||
iter_limit=0,
|
iter_limit=0,
|
||||||
|
gravity_limit_per_iter=None,
|
||||||
device=torch.device('cpu')
|
device=torch.device('cpu')
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -83,9 +85,10 @@ def kmeans(
|
||||||
|
|
||||||
for index in range(num_clusters):
|
for index in range(num_clusters):
|
||||||
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
|
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
|
||||||
|
|
||||||
selected = torch.index_select(X, 0, selected)
|
selected = torch.index_select(X, 0, selected)
|
||||||
|
if gravity_limit_per_iter and len(selected) > gravity_limit_per_iter:
|
||||||
|
ch = random.randint(0, len(selected)-gravity_limit_per_iter)
|
||||||
|
selected=selected[ch:ch+gravity_limit_per_iter]
|
||||||
initial_state[index] = selected.mean(dim=0)
|
initial_state[index] = selected.mean(dim=0)
|
||||||
|
|
||||||
center_shift = torch.sum(
|
center_shift = torch.sum(
|
||||||
|
@ -97,14 +100,16 @@ def kmeans(
|
||||||
iteration = iteration + 1
|
iteration = iteration + 1
|
||||||
|
|
||||||
# update tqdm meter
|
# update tqdm meter
|
||||||
|
bins = torch.bincount(choice_cluster)
|
||||||
if tqdm_flag:
|
if tqdm_flag:
|
||||||
tqdm_meter.set_postfix(
|
tqdm_meter.set_postfix(
|
||||||
iteration=f'{iteration}',
|
iteration=f'{iteration}',
|
||||||
center_shift=f'{center_shift ** 2:0.6f}',
|
center_shift=f'{center_shift ** 2}',
|
||||||
tol=f'{tol:0.6f}'
|
tol=f'{tol}',
|
||||||
|
bins=f'{bins}',
|
||||||
)
|
)
|
||||||
tqdm_meter.update()
|
tqdm_meter.update()
|
||||||
if center_shift ** 2 < tol:
|
if tol > 0 and center_shift ** 2 < tol:
|
||||||
break
|
break
|
||||||
if iter_limit != 0 and iteration >= iter_limit:
|
if iter_limit != 0 and iteration >= iter_limit:
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue
Block a user