forked from mrq/DL-Art-School
byol playground updates
This commit is contained in:
parent
7c6c7a8014
commit
14a868e8e6
|
@ -52,10 +52,11 @@ def im_norm(x):
|
|||
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
||||
|
||||
|
||||
def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
|
||||
def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'weights': [1],
|
||||
|
@ -64,7 +65,7 @@ def get_image_folder_dataloader(batch_size, num_workers, target_size=224):
|
|||
'scale': 1
|
||||
})
|
||||
dataset = ImageFolderDataset(dataset_opt)
|
||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
|
||||
|
||||
|
||||
def _find_layer(net, layer_name):
|
||||
|
@ -119,7 +120,7 @@ def produce_latent_dict(model):
|
|||
latents.extend(l)
|
||||
paths.extend(batch['HQ_path'])
|
||||
id += batch_size
|
||||
if id > 1000:
|
||||
if id > 10000:
|
||||
print("Saving checkpoint..")
|
||||
torch.save((latents, paths), '../results_instance_resnet.pth')
|
||||
id = 0
|
||||
|
@ -172,18 +173,19 @@ def build_kmeans():
|
|||
def use_kmeans():
|
||||
output = "../results/k_means_instance_resnet/"
|
||||
_, centers = torch.load('../k_means_instance_resnet.pth')
|
||||
batch_size = 8
|
||||
num_workers = 0
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224)
|
||||
batch_size = 32
|
||||
num_workers = 1
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=False)
|
||||
for i, batch in enumerate(tqdm(dataloader)):
|
||||
hq = batch['hq'].to('cuda')
|
||||
model(hq)
|
||||
l = layer_hooked_value.clone().squeeze()
|
||||
pred = kmeans_predict(l, centers, device=l.device)
|
||||
for b in range(pred.shape[0]):
|
||||
cat = str(pred[b].item())
|
||||
os.makedirs(os.path.join(output, cat), exist_ok=True)
|
||||
torchvision.utils.save_image(hq[b], os.path.join(output, cat, f'{i}.png'))
|
||||
if pred[b] == 3:
|
||||
outpath = os.path.dirname(batch['HQ_path'][b]).replace('\\pn_coven\\cropped', '\\pn_coven\\modeling')
|
||||
os.makedirs(outpath, exist_ok=True)
|
||||
shutil.move(batch['HQ_path'][b], outpath)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -106,7 +106,7 @@ def build_kmeans():
|
|||
|
||||
|
||||
def use_kmeans():
|
||||
_, centers = torch.load('../k_means.pth')
|
||||
_, centers = torch.load('../experiments/k_means_uresnet_512.pth')
|
||||
batch_size = 8
|
||||
num_workers = 0
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=512)
|
||||
|
@ -125,7 +125,7 @@ def use_kmeans():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../experiments/uresnet_pixpro_attempt2.pth'
|
||||
pretrained_path = '../experiments/uresnet_pixpro_512.pth'
|
||||
model = UResNet50(Bottleneck, [3,4,6,3], out_dim=512).to('cuda')
|
||||
sd = torch.load(pretrained_path)
|
||||
resnet_sd = {}
|
||||
|
|
Loading…
Reference in New Issue
Block a user