This commit is contained in:
James Betker 2020-12-08 00:33:07 -07:00
parent bca59ed98a
commit 5369cba8ed

View File

@ -64,7 +64,7 @@ def create_latent_database(model):
latent_dict = {} latent_dict = {}
all_paths = [] all_paths = []
for batch in tqdm(dataloader): for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1') hq = batch['hq'].to('cuda')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
for b in range(latent.shape[0]): for b in range(latent.shape[0]):
im_path = batch['HQ_path'][b] im_path = batch['HQ_path'][b]
@ -93,7 +93,7 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si
mins = [] mins = []
min_offsets = [] min_offsets = []
for cpbl_chunk in tqdm(cpbl_chunked): for cpbl_chunk in tqdm(cpbl_chunked):
cpbl_chunk = cpbl_chunk.to('cuda:1') cpbl_chunk = cpbl_chunk.to('cuda')
dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0)) dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0))
_min = torch.min(dist, dim=-1) _min = torch.min(dist, dim=-1)
mins.append(_min[0]) mins.append(_min[0])
@ -116,7 +116,7 @@ def find_similar_latents(model):
os.makedirs(output_path, exist_ok=True) os.makedirs(output_path, exist_ok=True)
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth")) img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0) img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
_, _, h, w = img_t.shape _, _, h, w = img_t.shape
img_t = img_t[:, :, :128*(h//128), :128*(w//128)] img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
@ -182,7 +182,7 @@ def explore_latent_results(model):
dataloader = get_image_folder_dataloader(batch_size, num_workers) dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0 id = 0
for batch in tqdm(dataloader): for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1') hq = batch['hq'].to('cuda')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
# This operation works by computing the distance of every structural index from the center and using that # This operation works by computing the distance of every structural index from the center and using that
# as a "heatmap". # as a "heatmap".
@ -198,7 +198,7 @@ def explore_latent_results(model):
if __name__ == '__main__': if __name__ == '__main__':
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth' pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1') model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
model.load_state_dict(torch.load(pretrained_path), strict=True) model.load_state_dict(torch.load(pretrained_path), strict=True)
model.eval() model.eval()