forked from mrq/DL-Art-School
byol playground updates
This commit is contained in:
parent
7c6c7a8014
commit
14a868e8e6
|
@ -52,10 +52,11 @@ def im_norm(x):
|
||||||
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
||||||
|
|
||||||
|
|
||||||
def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
|
def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True):
|
||||||
dataset_opt = dict_to_nonedict({
|
dataset_opt = dict_to_nonedict({
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'],
|
||||||
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
|
@ -64,7 +65,7 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
|
||||||
'scale': 1
|
'scale': 1
|
||||||
})
|
})
|
||||||
dataset = ImageFolderDataset(dataset_opt)
|
dataset = ImageFolderDataset(dataset_opt)
|
||||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
|
||||||
|
|
||||||
|
|
||||||
def _find_layer(net, layer_name):
|
def _find_layer(net, layer_name):
|
||||||
|
@ -119,7 +120,7 @@ def produce_latent_dict(model):
|
||||||
latents.extend(l)
|
latents.extend(l)
|
||||||
paths.extend(batch['HQ_path'])
|
paths.extend(batch['HQ_path'])
|
||||||
id += batch_size
|
id += batch_size
|
||||||
if id > 1000:
|
if id > 10000:
|
||||||
print("Saving checkpoint..")
|
print("Saving checkpoint..")
|
||||||
torch.save((latents, paths), '../results_instance_resnet.pth')
|
torch.save((latents, paths), '../results_instance_resnet.pth')
|
||||||
id = 0
|
id = 0
|
||||||
|
@ -172,18 +173,19 @@ def build_kmeans():
|
||||||
def use_kmeans():
|
def use_kmeans():
|
||||||
output = "../results/k_means_instance_resnet/"
|
output = "../results/k_means_instance_resnet/"
|
||||||
_, centers = torch.load('../k_means_instance_resnet.pth')
|
_, centers = torch.load('../k_means_instance_resnet.pth')
|
||||||
batch_size = 8
|
batch_size = 32
|
||||||
num_workers = 0
|
num_workers = 1
|
||||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224)
|
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=False)
|
||||||
for i, batch in enumerate(tqdm(dataloader)):
|
for i, batch in enumerate(tqdm(dataloader)):
|
||||||
hq = batch['hq'].to('cuda')
|
hq = batch['hq'].to('cuda')
|
||||||
model(hq)
|
model(hq)
|
||||||
l = layer_hooked_value.clone().squeeze()
|
l = layer_hooked_value.clone().squeeze()
|
||||||
pred = kmeans_predict(l, centers, device=l.device)
|
pred = kmeans_predict(l, centers, device=l.device)
|
||||||
for b in range(pred.shape[0]):
|
for b in range(pred.shape[0]):
|
||||||
cat = str(pred[b].item())
|
if pred[b] == 3:
|
||||||
os.makedirs(os.path.join(output, cat), exist_ok=True)
|
outpath = os.path.dirname(batch['HQ_path'][b]).replace('\\pn_coven\\cropped', '\\pn_coven\\modeling')
|
||||||
torchvision.utils.save_image(hq[b], os.path.join(output, cat, f'{i}.png'))
|
os.makedirs(outpath, exist_ok=True)
|
||||||
|
shutil.move(batch['HQ_path'][b], outpath)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -106,7 +106,7 @@ def build_kmeans():
|
||||||
|
|
||||||
|
|
||||||
def use_kmeans():
|
def use_kmeans():
|
||||||
_, centers = torch.load('../k_means.pth')
|
_, centers = torch.load('../experiments/k_means_uresnet_512.pth')
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
num_workers = 0
|
num_workers = 0
|
||||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512)
|
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512)
|
||||||
|
@ -125,7 +125,7 @@ def use_kmeans():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth'
|
pretrained_path = '../experiments/uresnet_pixpro_512.pth'
|
||||||
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
|
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
|
||||||
sd = torch.load(pretrained_path)
|
sd = torch.load(pretrained_path)
|
||||||
resnet_sd = {}
|
resnet_sd = {}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user