forked from mrq/DL-Art-School
Stage
This commit is contained in:
parent
bca59ed98a
commit
5369cba8ed
|
@ -64,7 +64,7 @@ def create_latent_database(model):
|
|||
latent_dict = {}
|
||||
all_paths = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda:1')
|
||||
hq = batch['hq'].to('cuda')
|
||||
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
||||
for b in range(latent.shape[0]):
|
||||
im_path = batch['HQ_path'][b]
|
||||
|
@ -93,7 +93,7 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si
|
|||
mins = []
|
||||
min_offsets = []
|
||||
for cpbl_chunk in tqdm(cpbl_chunked):
|
||||
cpbl_chunk = cpbl_chunk.to('cuda:1')
|
||||
cpbl_chunk = cpbl_chunk.to('cuda')
|
||||
dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0))
|
||||
_min = torch.min(dist, dim=-1)
|
||||
mins.append(_min[0])
|
||||
|
@ -116,7 +116,7 @@ def find_similar_latents(model):
|
|||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
|
||||
img_t = ToTensor()(Image.open(img)).to('cuda:1').unsqueeze(0)
|
||||
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
||||
_, _, h, w = img_t.shape
|
||||
img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
|
||||
|
||||
|
@ -182,7 +182,7 @@ def explore_latent_results(model):
|
|||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda:1')
|
||||
hq = batch['hq'].to('cuda')
|
||||
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
||||
# This operation works by computing the distance of every structural index from the center and using that
|
||||
# as a "heatmap".
|
||||
|
@ -198,7 +198,7 @@ def explore_latent_results(model):
|
|||
if __name__ == '__main__':
|
||||
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
|
||||
|
||||
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda:1')
|
||||
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
|
||||
model.load_state_dict(torch.load(pretrained_path), strict=True)
|
||||
model.eval()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user