This commit is contained in:
James Betker 2020-12-08 00:33:07 -07:00
parent bca59ed98a
commit 5369cba8ed

View File

@ -64,7 +64,7 @@ def create_latent_database(model):
latent_dict = {}
all_paths = []
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1')
hq = batch['hq'].to('cuda')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
for b in range(latent.shape[0]):
im_path = batch['HQ_path'][b]
@ -93,7 +93,7 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si
mins = []
min_offsets = []
for cpbl_chunk in tqdm(cpbl_chunked):
cpbl_chunk = cpbl_chunk.to('cuda:1')
cpbl_chunk = cpbl_chunk.to('cuda')
dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0))
_min = torch.min(dist, dim=-1)
mins.append(_min[0])
@ -116,7 +116,7 @@ def find_similar_latents(model):
os.makedirs(output_path, exist_ok=True)
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0)
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
_, _, h, w = img_t.shape
img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
@ -182,7 +182,7 @@ def explore_latent_results(model):
dataloader = get_image_folder_dataloader(batch_size, num_workers)
id = 0
for batch in tqdm(dataloader):
hq = batch['hq'].to('cuda:1')
hq = batch['hq'].to('cuda')
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
# This operation works by computing the distance of every structural index from the center and using that
# as a "heatmap".
@ -198,7 +198,7 @@ def explore_latent_results(model):
if __name__ == '__main__':
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1')
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
model.load_state_dict(torch.load(pretrained_path), strict=True)
model.eval()