Mods to byol_resnet_playground for large batches
This commit is contained in:
parent
e992e18767
commit
aae65e6ed8
|
@ -26,7 +26,7 @@ from utils.options import dict_to_nonedict
|
||||||
def structural_euc_dist(x, y):
|
def structural_euc_dist(x, y):
|
||||||
diff = torch.square(x - y)
|
diff = torch.square(x - y)
|
||||||
sum = torch.sum(diff, dim=-1)
|
sum = torch.sum(diff, dim=-1)
|
||||||
return torch.mean(torch.sqrt(sum))
|
return torch.sqrt(sum)
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(x, y):
|
def cosine_similarity(x, y):
|
||||||
|
@ -87,43 +87,15 @@ def register_hook(net, layer_name):
|
||||||
layer.register_forward_hook(_hook)
|
layer.register_forward_hook(_hook)
|
||||||
|
|
||||||
|
|
||||||
def create_latent_database(model, model_index=0):
|
|
||||||
batch_size = 32
|
|
||||||
num_workers = 1
|
|
||||||
output_path = '../../results/byol_resnet_latents/'
|
|
||||||
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
|
||||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
|
||||||
id = 0
|
|
||||||
dict_count = 1
|
|
||||||
latent_dict = {}
|
|
||||||
all_paths = []
|
|
||||||
for batch in tqdm(dataloader):
|
|
||||||
hq = batch['hq'].to('cuda')
|
|
||||||
latent = model(hq)[model_index] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
|
||||||
for b in range(latent.shape[0]):
|
|
||||||
im_path = batch['HQ_path'][b]
|
|
||||||
all_paths.append(im_path)
|
|
||||||
latent_dict[id] = latent[b].detach().cpu()
|
|
||||||
if (id+1) % 1000 == 0:
|
|
||||||
print("Saving checkpoint..")
|
|
||||||
torch.save(latent_dict, os.path.join(output_path, "latent_dict_%i.pth" % (dict_count,)))
|
|
||||||
latent_dict = {}
|
|
||||||
torch.save(all_paths, os.path.join(output_path, "all_paths.pth"))
|
|
||||||
dict_count += 1
|
|
||||||
id += 1
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_latent_for_img(model, img):
|
def get_latent_for_img(model, img):
|
||||||
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
||||||
_, _, h, w = img_t.shape
|
_, _, h, w = img_t.shape
|
||||||
# Center crop img_t and resize to 224.
|
# Center crop img_t and resize to 224.
|
||||||
d = min(h, w)
|
d = min(h, w)
|
||||||
dh, dw = (h-d)//2, (w-d)//2
|
dh, dw = (h-d)//2, (w-d)//2
|
||||||
if dh == 0:
|
if dw != 0:
|
||||||
img_t = img_t[:, :, :, dw:-dw]
|
img_t = img_t[:, :, :, dw:-dw]
|
||||||
else:
|
elif dh != 0:
|
||||||
img_t = img_t[:, :, dh:-dh, :]
|
img_t = img_t[:, :, dh:-dh, :]
|
||||||
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
|
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
|
||||||
model(img_t)
|
model(img_t)
|
||||||
|
@ -134,36 +106,42 @@ def get_latent_for_img(model, img):
|
||||||
def find_similar_latents(model, compare_fn=structural_euc_dist):
|
def find_similar_latents(model, compare_fn=structural_euc_dist):
|
||||||
global layer_hooked_value
|
global layer_hooked_value
|
||||||
|
|
||||||
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\yui_xx.jpg'
|
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\poon.jpg'
|
||||||
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
||||||
output_path = '../../results/byol_resnet_similars'
|
output_path = '../../results/byol_resnet_similars'
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
imglatent = get_latent_for_img(model, img)
|
imglatent = get_latent_for_img(model, img).squeeze().unsqueeze(0)
|
||||||
_, c, h, w = imglatent.shape
|
_, c = imglatent.shape
|
||||||
|
|
||||||
batch_size = 32
|
batch_size = 128
|
||||||
num_workers = 1
|
num_workers = 8
|
||||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||||
id = 0
|
id = 0
|
||||||
|
output_batch = 1
|
||||||
results = []
|
results = []
|
||||||
|
result_paths = []
|
||||||
for batch in tqdm(dataloader):
|
for batch in tqdm(dataloader):
|
||||||
hq = batch['hq'].to('cuda')
|
hq = batch['hq'].to('cuda')
|
||||||
model(hq)
|
model(hq)
|
||||||
latent = layer_hooked_value
|
latent = layer_hooked_value.clone().squeeze()
|
||||||
for b in range(latent.shape[0]):
|
compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent)
|
||||||
im_path = batch['HQ_path'][b]
|
results.append(compared.cpu())
|
||||||
results.append((im_path, compare_fn(imglatent, latent[b].unsqueeze(0)).item()))
|
result_paths.extend(batch['HQ_path'])
|
||||||
id += 1
|
id += batch_size
|
||||||
if id > 2000:
|
if id > 10000:
|
||||||
break
|
k = 500
|
||||||
results.sort(key=lambda x: x[1])
|
results = torch.cat(results, dim=0)
|
||||||
for i in range(50):
|
vals, inds = torch.topk(results, k, largest=False)
|
||||||
mag = results[i][1]
|
for i in inds:
|
||||||
shutil.copy(results[i][0], os.path.join(output_path, f'{i}_{mag}.jpg'))
|
mag = int(results[i].item() * 1000)
|
||||||
|
shutil.copy(result_paths[i], os.path.join(output_path, f'{mag:05}_{output_batch}_{i}.jpg'))
|
||||||
|
results = []
|
||||||
|
result_paths = []
|
||||||
|
id = 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pretrained_path = '../../experiments/resnet_byol_diffframe_69k.pth'
|
pretrained_path = '../../experiments/resnet_byol_diffframe_85k.pth'
|
||||||
model = resnet50(pretrained=False).to('cuda')
|
model = resnet50(pretrained=False).to('cuda')
|
||||||
sd = torch.load(pretrained_path)
|
sd = torch.load(pretrained_path)
|
||||||
resnet_sd = {}
|
resnet_sd = {}
|
||||||
|
|
|
@ -19,9 +19,9 @@ def main():
|
||||||
# compression time. If read raw images during training, use 0 for faster IO speed.
|
# compression time. If read raw images during training, use 0 for faster IO speed.
|
||||||
|
|
||||||
opt['dest'] = 'file'
|
opt['dest'] = 'file'
|
||||||
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imgset4']
|
opt['input_folder'] = ['F:\\4k6k\\datasets\\ns_images\\vixen\\vix_cropped']
|
||||||
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\256_unsupervised_new'
|
opt['save_folder'] = 'F:\\4k6k\\datasets\\ns_images\\video_512_cropped'
|
||||||
opt['imgsize'] = 256
|
opt['imgsize'] = 512
|
||||||
#opt['bottom_crop'] = 120
|
#opt['bottom_crop'] = 120
|
||||||
|
|
||||||
save_folder = opt['save_folder']
|
save_folder = opt['save_folder']
|
||||||
|
@ -45,7 +45,7 @@ class TiledDataset(data.Dataset):
|
||||||
def get(self, index):
|
def get(self, index):
|
||||||
path = self.images[index]
|
path = self.images[index]
|
||||||
basename = osp.basename(path)
|
basename = osp.basename(path)
|
||||||
img = data_util.read_img(None, path)
|
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
# Greyscale not supported.
|
# Greyscale not supported.
|
||||||
if img is None:
|
if img is None:
|
||||||
|
@ -62,7 +62,7 @@ class TiledDataset(data.Dataset):
|
||||||
|
|
||||||
h, w, c = img.shape
|
h, w, c = img.shape
|
||||||
# Uncomment to filter any image that doesnt meet a threshold size.
|
# Uncomment to filter any image that doesnt meet a threshold size.
|
||||||
if min(h,w) < 256:
|
if min(h,w) < 512:
|
||||||
print("Skipping due to threshold")
|
print("Skipping due to threshold")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -71,7 +71,6 @@ class TiledDataset(data.Dataset):
|
||||||
# Crop the image so that only the center is left, since this is often the most salient part of the image.
|
# Crop the image so that only the center is left, since this is often the most salient part of the image.
|
||||||
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
img = img[(h - dim) // 2:dim + (h - dim) // 2, (w - dim) // 2:dim + (w - dim) // 2, :]
|
||||||
img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA)
|
img = cv2.resize(img, (self.opt['imgsize'], self.opt['imgsize']), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
cv2.imwrite(osp.join(self.opt['save_folder'], basename + ".jpg"), img, [cv2.IMWRITE_JPEG_QUALITY, self.opt['compression_level']])
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user