forked from mrq/DL-Art-School
Structural latents checkpoint
This commit is contained in:
parent
26ceca68c0
commit
ec0ee25f4b
|
@ -4,7 +4,7 @@ from time import time
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from kornia import augmentation as augs
|
from kornia import augmentation as augs, kornia
|
||||||
from kornia import filters
|
from kornia import filters
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -98,10 +98,11 @@ class RandomSharedRegionCrop(nn.Module):
|
||||||
# 2. Pick a random width, height and top corner location for the first patch.
|
# 2. Pick a random width, height and top corner location for the first patch.
|
||||||
# 3. Pick a random width, height and top corner location for the second patch.
|
# 3. Pick a random width, height and top corner location for the second patch.
|
||||||
# Note: All dims from (2) and (3) must contain at least half of the image, guaranteeing overlap.
|
# Note: All dims from (2) and (3) must contain at least half of the image, guaranteeing overlap.
|
||||||
# 6. Build patches from input images. Resize them appropriately. Apply translational jitter.
|
# 4. Build patches from input images. Resize them appropriately. Apply translational jitter.\
|
||||||
# 7. Compute the metrics needed to extract overlapping regions from the resized patches: top, left,
|
# 5. Randomly flip image 2 if needed.
|
||||||
|
# 5. Compute the metrics needed to extract overlapping regions from the resized patches: top, left,
|
||||||
# original_height, original_width.
|
# original_height, original_width.
|
||||||
# 8. Compute the "shared_view" from the above data.
|
# 6. Compute the "shared_view" from the above data.
|
||||||
|
|
||||||
# Step 1
|
# Step 1
|
||||||
c, d, _ = i1.shape
|
c, d, _ = i1.shape
|
||||||
|
@ -122,7 +123,7 @@ class RandomSharedRegionCrop(nn.Module):
|
||||||
im2_t = random.randint(0, d-im2_h)
|
im2_t = random.randint(0, d-im2_h)
|
||||||
im2_r, im2_b = im2_l+im2_w, im2_t+im2_h
|
im2_r, im2_b = im2_l+im2_w, im2_t+im2_h
|
||||||
|
|
||||||
# Step 6
|
# Step 4
|
||||||
m = self.multiple
|
m = self.multiple
|
||||||
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
jl, jt = random.randint(-self.jitter_range, self.jitter_range), random.randint(-self.jitter_range, self.jitter_range)
|
||||||
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
|
jt = jt if base_t != 0 else abs(jt) # If the top of a patch is zero, a negative jitter will cause it to go negative.
|
||||||
|
@ -139,14 +140,22 @@ class RandomSharedRegionCrop(nn.Module):
|
||||||
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
|
p2 = i2[:, im2_t*m+jt:(im2_t+im2_h)*m+jt, im2_l*m+jl:(im2_l+im2_w)*m+jl]
|
||||||
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
|
p2_resized = no_batch_interpolate(p2, size=(d*m, d*m), mode="bilinear")
|
||||||
|
|
||||||
# Step 7
|
# Step 5
|
||||||
|
should_flip = random.random() < .5
|
||||||
|
if should_flip:
|
||||||
|
should_flip = 1
|
||||||
|
p2_resized = kornia.geometry.transform.hflip(p2_resized)
|
||||||
|
else:
|
||||||
|
should_flip = 0
|
||||||
|
|
||||||
|
# Step 6
|
||||||
i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l)
|
i1_shared_t, i1_shared_l = snap(base_t, im2_t), snap(base_l, im2_l)
|
||||||
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
|
i2_shared_t, i2_shared_l = snap(im2_t, base_t), snap(im2_l, base_l)
|
||||||
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
|
ix_h = min(base_b, im2_b) - max(base_t, im2_t)
|
||||||
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
|
ix_w = min(base_r, im2_r) - max(base_l, im2_l)
|
||||||
recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, ix_h, ix_w], dtype=torch.long)
|
recompute_package = torch.tensor([base_h, base_w, i1_shared_t, i1_shared_l, im2_h, im2_w, i2_shared_t, i2_shared_l, should_flip, ix_h, ix_w], dtype=torch.long)
|
||||||
|
|
||||||
# Step 8
|
# Step 7
|
||||||
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
|
mask1 = torch.full((1, base_h*m, base_w*m), fill_value=.5)
|
||||||
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
|
mask1[:, i1_shared_t*m:(i1_shared_t+ix_h)*m, i1_shared_l*m:(i1_shared_l+ix_w)*m] = 1
|
||||||
masked1 = pad_to(p1 * mask1, d*m)
|
masked1 = pad_to(p1 * mask1, d*m)
|
||||||
|
@ -171,10 +180,14 @@ def reconstructed_shared_regions(fea1, fea2, recompute_package: torch.Tensor):
|
||||||
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
|
# It'd be real nice if we could do this at the batch level, but I don't see a really good way to do that outside
|
||||||
# of conforming the recompute_package across the entire batch.
|
# of conforming the recompute_package across the entire batch.
|
||||||
for b in range(package.shape[0]):
|
for b in range(package.shape[0]):
|
||||||
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, s_h, s_w = tuple(package[b].tolist())
|
f1_h, f1_w, f1s_t, f1s_l, f2_h, f2_w, f2s_t, f2s_l, should_flip, s_h, s_w = tuple(package[b].tolist())
|
||||||
|
# Unflip 2 if needed.
|
||||||
|
f2 = fea2[b]
|
||||||
|
if should_flip == 1:
|
||||||
|
f2 = kornia.geometry.transform.hflip(f2)
|
||||||
# Resize the input features to match
|
# Resize the input features to match
|
||||||
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear")
|
f1s = F.interpolate(fea1[b].unsqueeze(0), (f1_h, f1_w), mode="bilinear")
|
||||||
f2s = F.interpolate(fea2[b].unsqueeze(0), (f2_h, f2_w), mode="bilinear")
|
f2s = F.interpolate(f2.unsqueeze(0), (f2_h, f2_w), mode="bilinear")
|
||||||
# Outputs must be padded so they can "get along" with each other.
|
# Outputs must be padded so they can "get along" with each other.
|
||||||
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
|
res1.append(pad_to(f1s[:, :, f1s_t:f1s_t+s_h, f1s_l:f1s_l+s_w], pad_dim))
|
||||||
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
|
res2.append(pad_to(f2s[:, :, f2s_t:f2s_t+s_h, f2s_l:f2s_l+s_w], pad_dim))
|
||||||
|
|
|
@ -176,3 +176,8 @@ class StructuralBYOL(nn.Module):
|
||||||
|
|
||||||
loss = loss_one + loss_two
|
loss = loss_one + loss_two
|
||||||
return loss.mean()
|
return loss.mean()
|
||||||
|
|
||||||
|
def get_projection(self, image):
|
||||||
|
enc = self.online_encoder(image)
|
||||||
|
proj = self.online_predictor(enc)
|
||||||
|
return enc, proj
|
|
@ -3,8 +3,8 @@ import torch
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.archs.spinenet_arch import SpineNet
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pretrained_path = '../../experiments/train_byol_512unsupervised/models/117000_generator.pth'
|
pretrained_path = '../../experiments/train_sbyol_512unsupervised/models/35000_generator.pth'
|
||||||
output_path = '../../experiments/spinenet49_imgset_byol.pth'
|
output_path = '../../experiments/spinenet49_imgset_sbyol.pth'
|
||||||
|
|
||||||
wrap_key = 'online_encoder.net.'
|
wrap_key = 'online_encoder.net.'
|
||||||
sd = torch.load(pretrained_path)
|
sd = torch.load(pretrained_path)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import shutil
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -10,12 +11,16 @@ from torchvision.transforms import ToTensor, Resize
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import utils
|
||||||
from data.image_folder_dataset import ImageFolderDataset
|
from data.image_folder_dataset import ImageFolderDataset
|
||||||
from models.archs.spinenet_arch import SpineNet
|
from models.archs.spinenet_arch import SpineNet
|
||||||
|
|
||||||
|
|
||||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||||
# and the distance is computed across the channel dimension.
|
# and the distance is computed across the channel dimension.
|
||||||
|
from utils import util
|
||||||
|
|
||||||
|
|
||||||
def structural_euc_dist(x, y):
|
def structural_euc_dist(x, y):
|
||||||
diff = torch.square(x - y)
|
diff = torch.square(x - y)
|
||||||
sum = torch.sum(diff, dim=-1)
|
sum = torch.sum(diff, dim=-1)
|
||||||
|
@ -28,6 +33,12 @@ def cosine_similarity(x, y):
|
||||||
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
||||||
|
|
||||||
|
|
||||||
|
def key_value_difference(x, y):
|
||||||
|
x = F.normalize(x, dim=-1, p=2)
|
||||||
|
y = F.normalize(y, dim=-1, p=2)
|
||||||
|
return 2 - 2 * (x * y).sum(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
def norm(x):
|
def norm(x):
|
||||||
sh = x.shape
|
sh = x.shape
|
||||||
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
|
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
|
||||||
|
@ -41,8 +52,8 @@ def im_norm(x):
|
||||||
def get_image_folder_dataloader(batch_size, num_workers):
|
def get_image_folder_dataloader(batch_size, num_workers):
|
||||||
dataset_opt = {
|
dataset_opt = {
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 512,
|
'target_size': 512,
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
|
@ -52,7 +63,7 @@ def get_image_folder_dataloader(batch_size, num_workers):
|
||||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
def create_latent_database(model):
|
def create_latent_database(model, model_index=0):
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
num_workers = 1
|
num_workers = 1
|
||||||
output_path = '../../results/byol_spinenet_latents/'
|
output_path = '../../results/byol_spinenet_latents/'
|
||||||
|
@ -65,7 +76,7 @@ def create_latent_database(model):
|
||||||
all_paths = []
|
all_paths = []
|
||||||
for batch in tqdm(dataloader):
|
for batch in tqdm(dataloader):
|
||||||
hq = batch['hq'].to('cuda')
|
hq = batch['hq'].to('cuda')
|
||||||
latent = model(hq)[1] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
latent = model(hq)[model_index] # BYOL trainer only trains the '4' output, which is indexed at [1]. Confusing.
|
||||||
for b in range(latent.shape[0]):
|
for b in range(latent.shape[0]):
|
||||||
im_path = batch['HQ_path'][b]
|
im_path = batch['HQ_path'][b]
|
||||||
all_paths.append(im_path)
|
all_paths.append(im_path)
|
||||||
|
@ -79,14 +90,8 @@ def create_latent_database(model):
|
||||||
id += 1
|
id += 1
|
||||||
|
|
||||||
|
|
||||||
def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size):
|
def _get_mins_from_comparables(latent, comparables, batch_size, compare_fn):
|
||||||
_, c, h, w = latent.shape
|
_, c, h, w = latent.shape
|
||||||
lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name))
|
|
||||||
comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1)
|
|
||||||
cbl_shape = comparables.shape[:3]
|
|
||||||
assert cbl_shape[1] == 32
|
|
||||||
comparables = comparables.reshape(-1, c)
|
|
||||||
|
|
||||||
clat = latent.reshape(1,-1,h*w).permute(2,0,1)
|
clat = latent.reshape(1,-1,h*w).permute(2,0,1)
|
||||||
cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size)
|
cpbl_chunked = torch.chunk(comparables, len(comparables) // batch_size)
|
||||||
assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case.
|
assert len(comparables) % batch_size == 0 # The reconstruction logic doesn't work if this is not the case.
|
||||||
|
@ -94,11 +99,12 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si
|
||||||
min_offsets = []
|
min_offsets = []
|
||||||
for cpbl_chunk in tqdm(cpbl_chunked):
|
for cpbl_chunk in tqdm(cpbl_chunked):
|
||||||
cpbl_chunk = cpbl_chunk.to('cuda')
|
cpbl_chunk = cpbl_chunk.to('cuda')
|
||||||
dist = structural_euc_dist(clat, cpbl_chunk.unsqueeze(0))
|
dist = compare_fn(clat, cpbl_chunk.unsqueeze(0))
|
||||||
_min = torch.min(dist, dim=-1)
|
_min = torch.min(dist, dim=-1)
|
||||||
mins.append(_min[0])
|
mins.append(_min[0])
|
||||||
min_offsets.append(_min[1])
|
min_offsets.append(_min[1])
|
||||||
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
|
mins = torch.min(torch.stack(mins, dim=-1), dim=-1)
|
||||||
|
|
||||||
# There's some way to do this in torch, I just can't figure it out..
|
# There's some way to do this in torch, I just can't figure it out..
|
||||||
for i in range(len(mins[1])):
|
for i in range(len(mins[1])):
|
||||||
mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i]
|
mins[1][i] = mins[1][i] * batch_size + min_offsets[mins[1][i]][i]
|
||||||
|
@ -106,26 +112,36 @@ def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_si
|
||||||
return mins[0].cpu(), mins[1].cpu(), len(comparables)
|
return mins[0].cpu(), mins[1].cpu(), len(comparables)
|
||||||
|
|
||||||
|
|
||||||
def find_similar_latents(model):
|
def _get_mins_from_latent_dictionary(latent, hq_img_repo, ld_file_name, batch_size, compare_fn):
|
||||||
|
_, c, h, w = latent.shape
|
||||||
|
lat_dict = torch.load(os.path.join(hq_img_repo, ld_file_name))
|
||||||
|
comparables = torch.stack(list(lat_dict.values()), dim=0).permute(0,2,3,1)
|
||||||
|
cbl_shape = comparables.shape[:3]
|
||||||
|
comparables = comparables.reshape(-1, c)
|
||||||
|
return _get_mins_from_comparables(latent, comparables, batch_size, compare_fn)
|
||||||
|
|
||||||
|
|
||||||
|
def find_similar_latents(model, model_index=0, lat_patch_size=16, compare_fn=structural_euc_dist):
|
||||||
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg'
|
img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\adrianna_xx.jpg'
|
||||||
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
||||||
hq_img_repo = '../../results/byol_spinenet_latents'
|
hq_img_repo = '../../results/byol_spinenet_latents'
|
||||||
output_path = '../../results/byol_spinenet_similars'
|
output_path = '../../results/byol_spinenet_similars'
|
||||||
batch_size = 1024
|
batch_size = 2048
|
||||||
num_maps = 8
|
num_maps = 4
|
||||||
|
lat_patch_mult = 512 // lat_patch_size
|
||||||
|
|
||||||
os.makedirs(output_path, exist_ok=True)
|
os.makedirs(output_path, exist_ok=True)
|
||||||
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
|
img_bank_paths = torch.load(os.path.join(hq_img_repo, "all_paths.pth"))
|
||||||
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
||||||
_, _, h, w = img_t.shape
|
_, _, h, w = img_t.shape
|
||||||
img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
|
img_t = img_t[:, :, :128*(h//128), :128*(w//128)]
|
||||||
|
latent = model(img_t)[model_index]
|
||||||
latent = model(img_t)[1]
|
|
||||||
_, c, h, w = latent.shape
|
_, c, h, w = latent.shape
|
||||||
|
|
||||||
mins, min_offsets = [], []
|
mins, min_offsets = [], []
|
||||||
total_latents = -1
|
total_latents = -1
|
||||||
for d_id in range(1,num_maps+1):
|
for d_id in range(1,num_maps+1):
|
||||||
mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size)
|
mn, of, tl = _get_mins_from_latent_dictionary(latent, hq_img_repo, "latent_dict_%i.pth" % (d_id), batch_size, compare_fn)
|
||||||
if total_latents != -1:
|
if total_latents != -1:
|
||||||
assert total_latents == tl
|
assert total_latents == tl
|
||||||
else:
|
else:
|
||||||
|
@ -140,32 +156,37 @@ def find_similar_latents(model):
|
||||||
|
|
||||||
print("Constructing image map..")
|
print("Constructing image map..")
|
||||||
doc_out = '''
|
doc_out = '''
|
||||||
<html><body><img id="imgmap" src="source.png" usemap="#map">
|
<html><body><img id="imgmap" src="output.png" usemap="#map">
|
||||||
<map name="map">%s</map><br>
|
<map name="map">%s</map><br>
|
||||||
<button onclick="if(imgmap.src.includes('output.png')){imgmap.src='source.png';}else{imgmap.src='output.png';}">Swap Images</button>
|
<button onclick="if(imgmap.src.includes('output.png')){imgmap.src='source.png';}else{imgmap.src='output.png';}">Swap Images</button>
|
||||||
</body></html>
|
</body></html>
|
||||||
'''
|
'''
|
||||||
img_map_areas = []
|
img_map_areas = []
|
||||||
img_out = torch.zeros((1,3,h*16,w*16))
|
img_out = torch.zeros((1, 3, h * lat_patch_size, w * lat_patch_size))
|
||||||
for i, ind in enumerate(tqdm(min_ids)):
|
for i, ind in enumerate(tqdm(min_ids)):
|
||||||
u = np.unravel_index(ind.item(), (num_maps*total_latents//(32*32),32,32))
|
u = np.unravel_index(ind.item(), (num_maps * total_latents // (lat_patch_mult ** 2), lat_patch_mult, lat_patch_mult))
|
||||||
h_, w_ = np.unravel_index(i, (h, w))
|
h_, w_ = np.unravel_index(i, (h, w))
|
||||||
|
|
||||||
img = ToTensor()(Resize((512, 512))(Image.open(img_bank_paths[u[0]])))
|
img = ToTensor()(Resize((512, 512))(Image.open(img_bank_paths[u[0]])))
|
||||||
t = 16 * u[1]
|
t = lat_patch_size * u[1]
|
||||||
l = 16 * u[2]
|
l = lat_patch_size * u[2]
|
||||||
patch = img[:, t:t+16, l:l+16]
|
patch = img[:, t:t + lat_patch_size, l:l + lat_patch_size]
|
||||||
img_out[:,:,h_*16:h_*16+16,w_*16:w_*16+16] = patch
|
img_out[:,:, h_ * lat_patch_size:h_ * lat_patch_size + lat_patch_size,
|
||||||
|
w_ * lat_patch_size:w_ * lat_patch_size + lat_patch_size] = patch
|
||||||
|
|
||||||
# Also save the image with a masked map
|
# Also save the image with a masked map
|
||||||
mask = torch.full_like(img, fill_value=.3)
|
mask = torch.full_like(img, fill_value=.3)
|
||||||
mask[:, t:t+16, l:l+16] = 1
|
mask[:, t:t + lat_patch_size, l:l + lat_patch_size] = 1
|
||||||
masked_img = img * mask
|
masked_img = img * mask
|
||||||
masked_src_img_output_file = os.path.join(output_path, "%i_%i__%i.png" % (t, l, u[0]))
|
masked_src_img_output_file = os.path.join(output_path, "%i_%i__%i.png" % (t, l, u[0]))
|
||||||
torchvision.utils.save_image(masked_img, masked_src_img_output_file)
|
torchvision.utils.save_image(masked_img, masked_src_img_output_file)
|
||||||
|
|
||||||
# Update the image map areas.
|
# Update the image map areas.
|
||||||
img_map_areas.append('<area shape="rect" coords="%i,%i,%i,%i" href="%s">' % (w_*16,h_*16,w_*16+16,h_*16+16,masked_src_img_output_file))
|
img_map_areas.append('<area shape="rect" coords="%i,%i,%i,%i" href="%s">' % (w_ * lat_patch_size,
|
||||||
|
h_ * lat_patch_size,
|
||||||
|
w_ * lat_patch_size + lat_patch_size,
|
||||||
|
h_ * lat_patch_size + lat_patch_size,
|
||||||
|
masked_src_img_output_file))
|
||||||
torchvision.utils.save_image(img_out, os.path.join(output_path, "output.png"))
|
torchvision.utils.save_image(img_out, os.path.join(output_path, "output.png"))
|
||||||
torchvision.utils.save_image(img_t, os.path.join(output_path, "source.png"))
|
torchvision.utils.save_image(img_t, os.path.join(output_path, "source.png"))
|
||||||
doc_out = doc_out % ('\n'.join(img_map_areas))
|
doc_out = doc_out % ('\n'.join(img_map_areas))
|
||||||
|
@ -195,12 +216,30 @@ def explore_latent_results(model):
|
||||||
id += 1
|
id += 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
class BYOLModelWrapper(nn.Module):
|
||||||
pretrained_path = '../../experiments/spinenet49_imgset_byol.pth'
|
def __init__(self, wrap):
|
||||||
|
super().__init__()
|
||||||
|
self.wrap = wrap
|
||||||
|
|
||||||
|
def forward(self, img):
|
||||||
|
return self.wrap.get_projection(img)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
util.loaded_options = {'checkpointing_enabled': True}
|
||||||
|
pretrained_path = '../../experiments/spinenet49_imgset_sbyol.pth'
|
||||||
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
|
model = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
|
||||||
model.load_state_dict(torch.load(pretrained_path), strict=True)
|
model.load_state_dict(torch.load(pretrained_path), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
#pretrained_path = '../../experiments/train_sbyol_512unsupervised/models/35000_generator.pth'
|
||||||
|
#from models.byol.byol_structural import StructuralBYOL
|
||||||
|
#subnet = SpineNet('49', in_channels=3, use_input_norm=True).to('cuda')
|
||||||
|
#model = StructuralBYOL(subnet, image_size=256, hidden_layer='endpoint_convs.3.conv')
|
||||||
|
#model.load_state_dict(torch.load(pretrained_path), strict=True)
|
||||||
|
#model = BYOLModelWrapper(model)
|
||||||
|
#model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
find_similar_latents(model)
|
#create_latent_database(model, 0) # 0 = model output dimension to use for latent storage
|
||||||
|
find_similar_latents(model, 0, 8, structural_euc_dist) # 1 = model output dimension to use for latent predictor.
|
||||||
|
|
|
@ -292,7 +292,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_512unsupervised.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_sbyol_512unsupervised.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user