byol playground updates

This commit is contained in:
James Betker 2021-01-09 20:54:21 -07:00
parent 7c6c7a8014
commit 14a868e8e6
2 changed files with 14 additions and 12 deletions

View File

@ -52,10 +52,11 @@ def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5 return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers, target_size=224): def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True):
dataset_opt = dict_to_nonedict({ dataset_opt = dict_to_nonedict({
'name': 'amalgam', 'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'], 'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'], #'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1], 'weights': [1],
@ -64,7 +65,7 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
'scale': 1 'scale': 1
}) })
dataset = ImageFolderDataset(dataset_opt) dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True) return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
def _find_layer(net, layer_name): def _find_layer(net, layer_name):
@ -119,7 +120,7 @@ def produce_latent_dict(model):
latents.extend(l) latents.extend(l)
paths.extend(batch['HQ_path']) paths.extend(batch['HQ_path'])
id += batch_size id += batch_size
if id > 1000: if id > 10000:
print("Saving checkpoint..") print("Saving checkpoint..")
torch.save((latents, paths), '../results_instance_resnet.pth') torch.save((latents, paths), '../results_instance_resnet.pth')
id = 0 id = 0
@ -172,18 +173,19 @@ def build_kmeans():
def use_kmeans(): def use_kmeans():
output = "../results/k_means_instance_resnet/" output = "../results/k_means_instance_resnet/"
_, centers = torch.load('../k_means_instance_resnet.pth') _, centers = torch.load('../k_means_instance_resnet.pth')
batch_size = 8 batch_size = 32
num_workers = 0 num_workers = 1
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224) dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=False)
for i, batch in enumerate(tqdm(dataloader)): for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda') hq = batch['hq'].to('cuda')
model(hq) model(hq)
l = layer_hooked_value.clone().squeeze() l = layer_hooked_value.clone().squeeze()
pred = kmeans_predict(l, centers, device=l.device) pred = kmeans_predict(l, centers, device=l.device)
for b in range(pred.shape[0]): for b in range(pred.shape[0]):
cat = str(pred[b].item()) if pred[b] == 3:
os.makedirs(os.path.join(output, cat), exist_ok=True) outpath = os.path.dirname(batch['HQ_path'][b]).replace('\\pn_coven\\cropped', '\\pn_coven\\modeling')
torchvision.utils.save_image(hq[b], os.path.join(output, cat, f'{i}.png')) os.makedirs(outpath, exist_ok=True)
shutil.move(batch['HQ_path'][b], outpath)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -106,7 +106,7 @@ def build_kmeans():
def use_kmeans(): def use_kmeans():
_, centers = torch.load('../k_means.pth') _, centers = torch.load('../experiments/k_means_uresnet_512.pth')
batch_size = 8 batch_size = 8
num_workers = 0 num_workers = 0
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512) dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512)
@ -125,7 +125,7 @@ def use_kmeans():
if __name__ == '__main__': if __name__ == '__main__':
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth' pretrained_path = '../experiments/uresnet_pixpro_512.pth'
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda') model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
sd = torch.load(pretrained_path) sd = torch.load(pretrained_path)
resnet_sd = {} resnet_sd = {}