forked from mrq/DL-Art-School
Add testing capabilities for segformer & contrastive feature
This commit is contained in:
parent
eab8546f73
commit
119f17c808
|
@ -208,21 +208,21 @@ class NetWrapper(nn.Module):
|
|||
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
|
||||
return projector.to(hidden)
|
||||
|
||||
def get_representation(self, x, pt):
|
||||
def get_representation(self, **kwargs):
|
||||
if self.layer == -1:
|
||||
return self.net(x, pt)
|
||||
return self.net(**kwargs)
|
||||
|
||||
if not self.hook_registered:
|
||||
self._register_hook()
|
||||
|
||||
unused = self.net(x, pt)
|
||||
unused = self.net(**kwargs)
|
||||
hidden = self.hidden
|
||||
self.hidden = None
|
||||
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
|
||||
return hidden
|
||||
|
||||
def forward(self, x, pt):
|
||||
representation = self.get_representation(x, pt)
|
||||
def forward(self, **kwargs):
|
||||
representation = self.get_representation(**kwargs)
|
||||
projector = self._get_projector(representation)
|
||||
projection = checkpoint(projector, representation)
|
||||
return projection
|
||||
|
@ -239,6 +239,7 @@ class BYOL(nn.Module):
|
|||
moving_average_decay=0.99,
|
||||
use_momentum=True,
|
||||
structural_mlp=False,
|
||||
contrastive=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -247,6 +248,7 @@ class BYOL(nn.Module):
|
|||
|
||||
self.aug = PointwiseAugmentor(image_size)
|
||||
self.use_momentum = use_momentum
|
||||
self.contrastive = contrastive
|
||||
self.target_encoder = None
|
||||
self.target_ema_updater = EMA(moving_average_decay)
|
||||
|
||||
|
@ -278,13 +280,17 @@ class BYOL(nn.Module):
|
|||
|
||||
def get_debug_values(self, step, __):
|
||||
# In the BYOL paper, this is made to increase over time. Not yet implemented, but still logging the value.
|
||||
return {'target_ema_beta': self.target_ema_updater.beta}
|
||||
dbg = {'target_ema_beta': self.target_ema_updater.beta}
|
||||
if self.contrastive and hasattr(self, 'logs_closs'):
|
||||
dbg['contrastive_distance'] = self.logs_closs
|
||||
dbg['byol_distance'] = self.logs_loss
|
||||
return dbg
|
||||
|
||||
def visual_dbg(self, step, path):
|
||||
torchvision.utils.save_image(self.im1.cpu().float(), os.path.join(path, "%i_image1.png" % (step,)))
|
||||
torchvision.utils.save_image(self.im2.cpu().float(), os.path.join(path, "%i_image2.png" % (step,)))
|
||||
|
||||
def forward(self, image):
|
||||
def get_predictions_and_projections(self, image):
|
||||
_, _, h, w = image.shape
|
||||
point = torch.randint(h//8, 7*h//8, (2,)).long().to(image.device)
|
||||
|
||||
|
@ -297,16 +303,20 @@ class BYOL(nn.Module):
|
|||
self.im2 = image_two.detach().clone()
|
||||
self.im2[:,:,pt_two[0]-3:pt_two[0]+3,pt_two[1]-3:pt_two[1]+3] = 1
|
||||
|
||||
online_proj_one = self.online_encoder(image_one, pt_one)
|
||||
online_proj_two = self.online_encoder(image_two, pt_two)
|
||||
online_proj_one = self.online_encoder(img=image_one, pos=pt_one)
|
||||
online_proj_two = self.online_encoder(img=image_two, pos=pt_two)
|
||||
|
||||
online_pred_one = self.online_predictor(online_proj_one)
|
||||
online_pred_two = self.online_predictor(online_proj_two)
|
||||
|
||||
with torch.no_grad():
|
||||
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
|
||||
target_proj_one = target_encoder(image_one, pt_one).detach()
|
||||
target_proj_two = target_encoder(image_two, pt_two).detach()
|
||||
target_proj_one = target_encoder(img=image_one, pos=pt_one).detach()
|
||||
target_proj_two = target_encoder(img=image_two, pos=pt_two).detach()
|
||||
return online_pred_one, online_pred_two, target_proj_one, target_proj_two
|
||||
|
||||
def forward_normal(self, image):
|
||||
online_pred_one, online_pred_two, target_proj_one, target_proj_two = self.get_predictions_and_projections(image)
|
||||
|
||||
loss_one = loss_fn(online_pred_one, target_proj_two.detach())
|
||||
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
|
||||
|
@ -314,6 +324,35 @@ class BYOL(nn.Module):
|
|||
loss = loss_one + loss_two
|
||||
return loss.mean()
|
||||
|
||||
def forward_contrastive(self, image):
|
||||
online_pred_one_1, online_pred_two_1, target_proj_one_1, target_proj_two_1 = self.get_predictions_and_projections(image)
|
||||
loss_one = loss_fn(online_pred_one_1, target_proj_two_1.detach())
|
||||
loss_two = loss_fn(online_pred_two_1, target_proj_one_1.detach())
|
||||
loss = loss_one + loss_two
|
||||
|
||||
online_pred_one_2, online_pred_two_2, target_proj_one_2, target_proj_two_2 = self.get_predictions_and_projections(image)
|
||||
loss_one = loss_fn(online_pred_one_2, target_proj_two_2.detach())
|
||||
loss_two = loss_fn(online_pred_two_2, target_proj_one_2.detach())
|
||||
loss = (loss + loss_one + loss_two).mean()
|
||||
|
||||
contrastive_loss = torch.cat([loss_fn(online_pred_one_1, target_proj_two_2),
|
||||
loss_fn(online_pred_two_1, target_proj_one_2),
|
||||
loss_fn(online_pred_one_2, target_proj_two_1),
|
||||
loss_fn(online_pred_two_2, target_proj_one_1)], dim=0)
|
||||
k = contrastive_loss.shape[0] // 2 # Take half of the total contrastive loss predictions.
|
||||
contrastive_loss = torch.topk(contrastive_loss, k, dim=0).values.mean()
|
||||
|
||||
self.logs_loss = loss.detach()
|
||||
self.logs_closs = contrastive_loss.detach()
|
||||
|
||||
return loss - contrastive_los00s
|
||||
|
||||
def forward(self, image):
|
||||
if self.contrastive:
|
||||
return self.forward_contrastive(image)
|
||||
else:
|
||||
return self.forward_normal(image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pa = PointwiseAugmentor(256)
|
||||
|
@ -331,4 +370,4 @@ if __name__ == '__main__':
|
|||
@register_model
|
||||
def register_pixel_local_byol(opt_net, opt):
|
||||
subnet = create_model(opt, opt_net['subnet'])
|
||||
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'])
|
||||
return BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'], contrastive=opt_net['contrastive'])
|
|
@ -94,14 +94,21 @@ class Segformer(nn.Module):
|
|||
self.transformer_layers = nn.Sequential(*[nn.TransformerEncoderLayer(final_latent_channels, nhead=4) for _ in range(layers)])
|
||||
self.tail = Tail()
|
||||
|
||||
def forward(self, x, pos):
|
||||
layers = self.backbone(x)
|
||||
set = []
|
||||
def forward(self, img=None, layers=None, pos=None, return_layers=False):
|
||||
assert img is not None or layers is not None
|
||||
if img is not None:
|
||||
bs = img.shape[0]
|
||||
layers = self.backbone(img)
|
||||
else:
|
||||
bs = layers[0].shape[0]
|
||||
if return_layers:
|
||||
return layers
|
||||
|
||||
# A single position can be optionally given, in which case we need to expand it to represent the entire input.
|
||||
if pos.shape == (2,):
|
||||
pos = pos.unsqueeze(0).repeat(x.shape[0],1)
|
||||
pos = pos.unsqueeze(0).repeat(bs, 1)
|
||||
|
||||
set = []
|
||||
pos = pos // 4
|
||||
for layer_out, dilator in zip(layers, self.dilators):
|
||||
for subdilator in dilator:
|
||||
|
@ -124,4 +131,4 @@ if __name__ == '__main__':
|
|||
model = Segformer().to('cuda')
|
||||
for j in tqdm(range(1000)):
|
||||
test_tensor = torch.randn(64,3,224,224).cuda()
|
||||
print(model(test_tensor, torch.randint(0,224,(64,2)).cuda()).shape)
|
||||
print(model(img=test_tensor, pos=torch.randint(0,224,(64,2)).cuda()).shape)
|
240
codes/scripts/byol/byol_segformer_playground.py
Normal file
240
codes/scripts/byol/byol_segformer_playground.py
Normal file
|
@ -0,0 +1,240 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.transforms import ToTensor, Resize
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
from data.image_folder_dataset import ImageFolderDataset
|
||||
from models.resnet_with_checkpointing import resnet50
|
||||
from models.segformer.segformer import Segformer
|
||||
from models.spinenet_arch import SpineNet
|
||||
|
||||
|
||||
# Computes the structural euclidean distance between [x,y]. "Structural" here means the [h,w] dimensions are preserved
|
||||
# and the distance is computed across the channel dimension.
|
||||
from utils import util
|
||||
from utils.kmeans import kmeans, kmeans_predict
|
||||
from utils.options import dict_to_nonedict
|
||||
|
||||
|
||||
def structural_euc_dist(x, y):
|
||||
diff = torch.square(x - y)
|
||||
sum = torch.sum(diff, dim=-1)
|
||||
return torch.sqrt(sum)
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
x = norm(x)
|
||||
y = norm(y)
|
||||
return -nn.CosineSimilarity()(x, y) # probably better to just use this class to perform the calc. Just left this here to remind myself.
|
||||
|
||||
|
||||
def key_value_difference(x, y):
|
||||
x = F.normalize(x, dim=-1, p=2)
|
||||
y = F.normalize(y, dim=-1, p=2)
|
||||
return 2 - 2 * (x * y).sum(dim=-1)
|
||||
|
||||
|
||||
def norm(x):
|
||||
sh = x.shape
|
||||
sh_r = tuple([sh[i] if i != len(sh)-1 else 1 for i in range(len(sh))])
|
||||
return (x - torch.mean(x, dim=-1).reshape(sh_r)) / torch.std(x, dim=-1).reshape(sh_r)
|
||||
|
||||
|
||||
def im_norm(x):
|
||||
return (((x - torch.mean(x, dim=(2,3)).reshape(-1,1,1,1)) / torch.std(x, dim=(2,3)).reshape(-1,1,1,1)) * .5) + .5
|
||||
|
||||
|
||||
def get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True):
|
||||
dataset_opt = dict_to_nonedict({
|
||||
'name': 'amalgam',
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\pn_coven\\cropped2'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_1024_square_with_new'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_tiled_filtered_flattened'],
|
||||
#'paths': ['F:\\4k6k\\datasets\\ns_images\\imagesets\\1024_test'],
|
||||
'paths': ['E:\\4k6k\\datasets\\ns_images\\imagesets\\imageset_256_full'],
|
||||
'weights': [1],
|
||||
'target_size': target_size,
|
||||
'force_multiple': 32,
|
||||
'normalize': 'imagenet',
|
||||
'scale': 1
|
||||
})
|
||||
dataset = ImageFolderDataset(dataset_opt)
|
||||
return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
|
||||
|
||||
|
||||
def _find_layer(net, layer_name):
|
||||
if type(layer_name) == str:
|
||||
modules = dict([*net.named_modules()])
|
||||
return modules.get(layer_name, None)
|
||||
elif type(layer_name) == int:
|
||||
children = [*net.children()]
|
||||
return children[layer_name]
|
||||
return None
|
||||
|
||||
|
||||
layer_hooked_value = None
|
||||
def _hook(_, __, output):
|
||||
global layer_hooked_value
|
||||
layer_hooked_value = output
|
||||
|
||||
|
||||
def register_hook(net, layer_name):
|
||||
layer = _find_layer(net, layer_name)
|
||||
assert layer is not None, f'hidden layer ({self.layer}) not found'
|
||||
layer.register_forward_hook(_hook)
|
||||
|
||||
|
||||
def get_latent_for_img(model, img):
|
||||
img_t = ToTensor()(Image.open(img)).to('cuda').unsqueeze(0)
|
||||
_, _, h, w = img_t.shape
|
||||
# Center crop img_t and resize to 224.
|
||||
d = min(h, w)
|
||||
dh, dw = (h-d)//2, (w-d)//2
|
||||
if dw != 0:
|
||||
img_t = img_t[:, :, :, dw:-dw]
|
||||
elif dh != 0:
|
||||
img_t = img_t[:, :, dh:-dh, :]
|
||||
img_t = img_t[:,:3,:,:]
|
||||
img_t = torch.nn.functional.interpolate(img_t, size=(224, 224), mode="area")
|
||||
model(img_t)
|
||||
latent = layer_hooked_value
|
||||
return latent
|
||||
|
||||
|
||||
def produce_latent_dict(model):
|
||||
batch_size = 32
|
||||
num_workers = 4
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
paths = []
|
||||
latents = []
|
||||
points = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
# Pull several points from every image.
|
||||
for k in range(10):
|
||||
_, _, h, _ = hq.shape
|
||||
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
|
||||
model(hq, point)
|
||||
l = layer_hooked_value.cpu().split(1, dim=0)
|
||||
latents.extend(l)
|
||||
points.extend([point for p in range(batch_size)])
|
||||
paths.extend(batch['HQ_path'])
|
||||
id += batch_size
|
||||
if id > 10000:
|
||||
print("Saving checkpoint..")
|
||||
torch.save((latents, points, paths), '../results_segformer.pth')
|
||||
id = 0
|
||||
|
||||
|
||||
def find_similar_latents(model, compare_fn=structural_euc_dist):
|
||||
global layer_hooked_value
|
||||
|
||||
img = 'D:\\dlas\\results\\bobz.png'
|
||||
#img = 'F:\\4k6k\\datasets\\ns_images\\adrianna\\analyze\\analyze_xx\\nicky_xx.jpg'
|
||||
output_path = '../../../results/byol_resnet_similars'
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
imglatent = get_latent_for_img(model, img).squeeze().unsqueeze(0)
|
||||
_, c = imglatent.shape
|
||||
|
||||
batch_size = 512
|
||||
num_workers = 8
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers)
|
||||
id = 0
|
||||
output_batch = 1
|
||||
results = []
|
||||
result_paths = []
|
||||
for batch in tqdm(dataloader):
|
||||
hq = batch['hq'].to('cuda')
|
||||
model(hq)
|
||||
latent = layer_hooked_value.clone().squeeze()
|
||||
compared = compare_fn(imglatent.repeat(latent.shape[0], 1), latent)
|
||||
results.append(compared.cpu())
|
||||
result_paths.extend(batch['HQ_path'])
|
||||
id += batch_size
|
||||
if id > 10000:
|
||||
k = 200
|
||||
results = torch.cat(results, dim=0)
|
||||
vals, inds = torch.topk(results, k, largest=False)
|
||||
for i in inds:
|
||||
mag = int(results[i].item() * 1000)
|
||||
shutil.copy(result_paths[i], os.path.join(output_path, f'{mag:05}_{output_batch}_{i}.jpg'))
|
||||
results = []
|
||||
result_paths = []
|
||||
id = 0
|
||||
|
||||
|
||||
def build_kmeans():
|
||||
latents, _, _ = torch.load('../results_segformer.pth')
|
||||
latents = torch.cat(latents, dim=0).squeeze().to('cuda')
|
||||
cluster_ids_x, cluster_centers = kmeans(latents, num_clusters=16, distance="euclidean", device=torch.device('cuda:0'))
|
||||
torch.save((cluster_ids_x, cluster_centers), '../k_means_segformer.pth')
|
||||
|
||||
|
||||
class UnNormalize(object):
|
||||
def __init__(self, mean, std):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
|
||||
def __call__(self, tensor):
|
||||
"""
|
||||
Args:
|
||||
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
||||
Returns:
|
||||
Tensor: Normalized image.
|
||||
"""
|
||||
for t, m, s in zip(tensor, self.mean, self.std):
|
||||
t.mul_(s).add_(m)
|
||||
# The normalize code -> t.sub_(m).div_(s)
|
||||
return tensor
|
||||
|
||||
|
||||
def use_kmeans():
|
||||
output = "../results/k_means_segformer/"
|
||||
_, centers = torch.load('../k_means_segformer.pth')
|
||||
centers = centers.to('cuda')
|
||||
batch_size = 32
|
||||
num_workers = 1
|
||||
dataloader = get_image_folder_dataloader(batch_size, num_workers, target_size=224, shuffle=True)
|
||||
denorm = UnNormalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
||||
for i, batch in enumerate(tqdm(dataloader)):
|
||||
hq = batch['hq'].to('cuda')
|
||||
_,_,h,w = hq.shape
|
||||
point = torch.randint(h//4, 3*h//4, (2,)).long().to(hq.device)
|
||||
model(hq, point)
|
||||
l = layer_hooked_value.clone().squeeze()
|
||||
pred = kmeans_predict(l, centers)
|
||||
hq = denorm(hq * .5)
|
||||
hq[:,:,point[0]-5:point[0]+5,point[1]-5:point[1]+5] *= 2
|
||||
for b in range(pred.shape[0]):
|
||||
outpath = os.path.join(output, str(pred[b].item()))
|
||||
os.makedirs(outpath, exist_ok=True)
|
||||
torchvision.utils.save_image(hq[b], os.path.join(outpath, f'{i*batch_size+b}.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pretrained_path = '../../../experiments/segformer_byol_only.pth'
|
||||
model = Segformer().to('cuda')
|
||||
sd = torch.load(pretrained_path)
|
||||
resnet_sd = {}
|
||||
for k, v in sd.items():
|
||||
if 'target_encoder.net.' in k:
|
||||
resnet_sd[k.replace('target_encoder.net.', '')] = v
|
||||
model.load_state_dict(resnet_sd, strict=True)
|
||||
model.eval()
|
||||
register_hook(model, 'tail')
|
||||
|
||||
with torch.no_grad():
|
||||
#find_similar_latents(model, structural_euc_dist)
|
||||
#produce_latent_dict(model)
|
||||
#build_kmeans()
|
||||
use_kmeans()
|
|
@ -344,11 +344,75 @@ def plot_pixel_level_results_as_image_graph():
|
|||
pyplot.savefig('tsne_pix.pdf')
|
||||
|
||||
|
||||
def run_tsne_segformer():
|
||||
print("Run Y = tsne.tsne(X, no_dims, perplexity) to perform t-SNE on your dataset.")
|
||||
|
||||
limit = 10000
|
||||
X, points, files = torch.load('../results_segformer.pth')
|
||||
zipped = list(zip(X, points, files))
|
||||
shuffle(zipped)
|
||||
X, points, files = zip(*zipped)
|
||||
X = torch.cat(X, dim=0).squeeze()[:limit]
|
||||
labels = np.zeros(X.shape[0]) # We don't have any labels..
|
||||
|
||||
# confirm that x file get same number point than label file
|
||||
# otherwise may cause error in scatter
|
||||
assert(len(X[:, 0])==len(X[:,1]))
|
||||
assert(len(X)==len(labels))
|
||||
|
||||
with torch.no_grad():
|
||||
Y = tsne(X, 2, 1024, 20.0)
|
||||
|
||||
if opt.cuda:
|
||||
Y = Y.cpu().numpy()
|
||||
|
||||
# You may write result in two files
|
||||
# print("Save Y values in file")
|
||||
# Y1 = open("y1.txt", 'w')
|
||||
# Y2 = open('y2.txt', 'w')
|
||||
# for i in range(Y.shape[0]):
|
||||
# Y1.write(str(Y[i,0])+"\n")
|
||||
# Y2.write(str(Y[i,1])+"\n")
|
||||
|
||||
pyplot.scatter(Y[:, 0], Y[:, 1], 20, labels)
|
||||
pyplot.show()
|
||||
torch.save((Y, points, files[:limit]), "../tsne_output.pth")
|
||||
|
||||
|
||||
# Uses the results from the calculation above to create a **massive** pdf plot that shows 1/8 size images on the tsne
|
||||
# spectrum.
|
||||
def plot_segformer_results_as_image_graph():
|
||||
Y, points, files = torch.load('../tsne_output.pth')
|
||||
fig, ax = pyplot.subplots()
|
||||
fig.set_size_inches(200,200,forward=True)
|
||||
ax.update_datalim(np.column_stack([Y[:, 0], Y[:, 1]]))
|
||||
ax.autoscale()
|
||||
|
||||
margins = 32
|
||||
for b in tqdm(range(Y.shape[0])):
|
||||
imgfile = files[b]
|
||||
baseim = pyplot.imread(imgfile)
|
||||
ct, cl = points[b]
|
||||
|
||||
im = baseim[(ct-margins):(ct+margins),
|
||||
(cl-margins):(cl+margins),:]
|
||||
im = OffsetImage(im, zoom=1)
|
||||
ab = AnnotationBbox(im, (Y[b, 0], Y[b, 1]), xycoords='data', frameon=False)
|
||||
ax.add_artist(ab)
|
||||
ax.scatter(Y[:, 0], Y[:, 1])
|
||||
|
||||
pyplot.savefig('tsne_segformer.pdf')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# For use with instance-level results (e.g. from byol_resnet_playground.py)
|
||||
#run_tsne_instance_level()
|
||||
plot_instance_level_results_as_image_graph()
|
||||
#plot_instance_level_results_as_image_graph()
|
||||
|
||||
# For use with pixel-level results (e.g. from byol_uresnet_playground)
|
||||
#run_tsne_pixel_level()
|
||||
#plot_pixel_level_results_as_image_graph()
|
||||
#plot_pixel_level_results_as_image_graph()
|
||||
|
||||
# For use with segformer results
|
||||
#run_tsne_segformer()
|
||||
plot_segformer_results_as_image_graph()
|
|
@ -295,7 +295,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_xx.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_segformer_contrastive_xx.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -33,8 +33,8 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
|||
distances = []
|
||||
l2 = MSELoss()
|
||||
for i, data in tqdm(enumerate(dl)):
|
||||
latent1 = self.model(data['img1'].to(dev), torch.stack(data['coords1'], dim=1).to(dev))
|
||||
latent2 = self.model(data['img2'].to(dev), torch.stack(data['coords2'], dim=1).to(dev))
|
||||
latent1 = self.model(img=data['img1'].to(dev), pos=torch.stack(data['coords1'], dim=1).to(dev))
|
||||
latent2 = self.model(img=data['img2'].to(dev), pos=torch.stack(data['coords2'], dim=1).to(dev))
|
||||
distances.append(l2(latent1, latent2))
|
||||
if i * self.batch_sz >= self.eval_qty:
|
||||
break
|
||||
|
@ -52,7 +52,7 @@ class SinglePointPairContrastiveEval(evaluator.Evaluator):
|
|||
diff = dissimilars.item() - similars.item()
|
||||
print(f"Eval done. val_similar_lq: {similars.item()}; val_dissimilar_l2: {dissimilars.item()}; val_diff: {diff}")
|
||||
self.model.train()
|
||||
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff.item()}
|
||||
return {"val_similar_l2": similars.item(), "val_dissimilar_l2": dissimilars.item(), "val_diff": diff}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue
Block a user