few things

This commit is contained in:
James Betker 2022-07-26 11:52:03 -06:00
parent f8108cfdb2
commit a1bbde8a43
4 changed files with 91 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import random
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
import os
import torchvision import torchvision
from torch.utils.data import DataLoader from torch.utils.data import DataLoader

View File

@ -0,0 +1,87 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from codes.models.arch_util import ResBlock
from codes.models.lucidrains.x_transformers import Encoder
from codes.trainer.networks import register_model
class VitLatent(nn.Module):
def __init__(self, top_dim, hidden_dim, depth, dropout=.1):
super().__init__()
self.upper = nn.Sequential(nn.Conv2d(3, top_dim, kernel_size=7, padding=3, stride=2),
ResBlock(top_dim, use_conv=True, dropout=dropout),
ResBlock(top_dim, out_channels=top_dim*2, down=True, use_conv=True, dropout=dropout),
ResBlock(top_dim*2, use_conv=True, dropout=dropout),
ResBlock(top_dim*2, out_channels=top_dim*4, down=True, use_conv=True, dropout=dropout),
ResBlock(top_dim*4, use_conv=True, dropout=dropout),
ResBlock(top_dim*4, out_channels=hidden_dim, down=True, use_conv=True, dropout=dropout),
nn.GroupNorm(8, hidden_dim))
self.encoder = Encoder(
dim=hidden_dim,
depth=depth,
heads=hidden_dim//64,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_pos_emb=True,
ff_mult=2,
do_checkpointing=True
)
self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim*2, hidden_dim))
def provide_ema(self, ema):
self.ema = ema
def project(self, x):
h = self.upper(x)
h = torch.flatten(h, 2).permute(0,2,1)
h = self.encoder(h)[:,0]
h_norm = F.normalize(h)
return h_norm
def forward(self, x1, x2):
h1 = self.project(x1)
#p1 = self.mlp(h1)
h2 = self.project(x2)
#p2 = self.mlp(h2)
with torch.no_grad():
he1 = self.ema.project(x1)
he2 = self.ema.project(x2)
def csim(h1, h2):
b = x1.shape[0]
sim = F.cosine_similarity(h1.unsqueeze(0), h2.unsqueeze(1).detach(), 2)
eye = torch.eye(b, device=x1.device)
neye = eye != 1
return -(sim*eye).sum()/b, (sim*neye).sum()/(b**2-b)
pos, neg = csim(h1, he2)
pos2, neg2 = csim(h2, he1)
return (pos+pos2)/2, (neg+neg2)/2
def get_grad_norm_parameter_groups(self):
return {
'upper': list(self.upper.parameters()),
'encoder': list(self.encoder.parameters()),
'mlp': list(self.mlp.parameters()),
}
@register_model
def register_vit_latent(opt_net, opt):
return VitLatent(**opt_net['kwargs'])
if __name__ == '__main__':
net = VitLatent(128, 1024, 8)
net.provide_ema(net)
x1 = torch.randn(2,3,244,244)
x2 = torch.randn(2,3,244,244)
net(x1,x2)

View File

@ -340,7 +340,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_multilevel_sr.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
opt = option.parse(args.opt, is_train=True) opt = option.parse(args.opt, is_train=True)

View File

@ -174,6 +174,8 @@ class ExtensibleTrainer(BaseModel):
self.emas[k] = copy.deepcopy(v) self.emas[k] = copy.deepcopy(v)
if self.ema_on_cpu: if self.ema_on_cpu:
self.emas[k] = self.emas[k].cpu() self.emas[k] = self.emas[k].cpu()
if hasattr(v, 'provide_ema'):
v.provide_ema(self.emas[k])
found += 1 found += 1
assert found == len(self.netsG) + len(self.netsD) assert found == len(self.netsG) + len(self.netsD)