diff --git a/codes/scripts/byol_spinenet_playground.py b/codes/scripts/byol_spinenet_playground.py index bb74a6bb..272e0c0d 100644 --- a/codes/scripts/byol_spinenet_playground.py +++ b/codes/scripts/byol_spinenet_playground.py @@ -64,7 +64,7 @@ def create_latent_database(model): latent_dict = {} all_paths = [] for batch in tqdm(dataloader): - hq = batch['hq'].to('cuda:1') + hq = batch['hq'].to('cuda') latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. for b in range(latent.shape[0]): im_path = batch['HQ_path'][b] @@ -93,7 +93,7 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si mins = [] min_offsets = [] for cpbl_chunk in tqdm(cpbl_chunked): - cpbl_chunk = cpbl_chunk.to('cuda:1') + cpbl_chunk = cpbl_chunk.to('cuda') dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0)) _min = torch.min(dist, dim=-1) mins.append(_min[0]) @@ -116,7 +116,7 @@ def find_similar_latents(model): os.makedirs(output_path, exist_ok=True) img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth")) - img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0) + img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0) _, _, h, w = img_t.shape img_t = img_t[:, :, :128*(h//128), :128*(w//128)] @@ -182,7 +182,7 @@ def explore_latent_results(model): dataloader = get_image_folder_dataloader(batch_size, num_workers) id = 0 for batch in tqdm(dataloader): - hq = batch['hq'].to('cuda:1') + hq = batch['hq'].to('cuda') latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing. # This operation works by computing the distance of every structural index from the center and using that # as a "heatmap". @@ -198,7 +198,7 @@ def explore_latent_results(model): if __name__ == '__main__': pretrained_path = '../../experiments/spinenet49_imgset_byol.pth' - model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1') + model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda') model.load_state_dict(torch.load(pretrained_path), strict=True) model.eval()