byol playground updates

This commit is contained in:
James Betker 2021-01-09 20:54:21 -07:00
parent 7c6c7a8014
commit 14a868e8e6
2 changed files with 14 additions and 12 deletions

View File

@ -52,10 +52,11 @@ def im_norm(x):
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True):
dataset_opt = dict_to_nonedict({
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
'weights': [1],
@ -64,7 +65,7 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
'scale': 1
})
dataset = ImageFolderDataset(dataset_opt)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
def _find_layer(net, layer_name):
@ -119,7 +120,7 @@ def produce_latent_dict(model):
latents.extend(l)
paths.extend(batch['HQ_path'])
id += batch_size
if id > 1000:
if id > 10000:
print("Saving checkpoint..")
torch.save((latents, paths), '../results_instance_resnet.pth')
id = 0
@ -172,18 +173,19 @@ def build_kmeans():
def use_kmeans():
output = "../results/k_means_instance_resnet/"
_, centers = torch.load('../k_means_instance_resnet.pth')
batch_size = 8
num_workers = 0
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224)
batch_size = 32
num_workers = 1
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=False)
for i, batch in enumerate(tqdm(dataloader)):
hq = batch['hq'].to('cuda')
model(hq)
l = layer_hooked_value.clone().squeeze()
pred = kmeans_predict(l, centers, device=l.device)
for b in range(pred.shape[0]):
cat = str(pred[b].item())
os.makedirs(os.path.join(output, cat), exist_ok=True)
torchvision.utils.save_image(hq[b], os.path.join(output, cat, f'{i}.png'))
if pred[b] == 3:
outpath = os.path.dirname(batch['HQ_path'][b]).replace('\\pn_coven\\cropped', '\\pn_coven\\modeling')
os.makedirs(outpath, exist_ok=True)
shutil.move(batch['HQ_path'][b], outpath)
if __name__ == '__main__':

View File

@ -106,7 +106,7 @@ def build_kmeans():
def use_kmeans():
_, centers = torch.load('../k_means.pth')
_, centers = torch.load('../experiments/k_means_uresnet_512.pth')
batch_size = 8
num_workers = 0
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512)
@ -125,7 +125,7 @@ def use_kmeans():
if __name__ == '__main__':
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth'
pretrained_path = '../experiments/uresnet_pixpro_512.pth'
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
sd = torch.load(pretrained_path)
resnet_sd = {}