forked from mrq/DL-Art-School
a few fixes to multiresolution sr
This commit is contained in:
parent
2fb85526bc
commit
eecb534e66
|
@ -126,7 +126,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
)
|
)
|
||||||
self.resolution_embed = nn.Embedding(resolution_steps, model_channels)
|
self.resolution_embed = nn.Embedding(resolution_steps, model_channels)
|
||||||
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64)
|
self.conditioning_encoder = ConditioningEncoder(in_channels, model_channels, resolution_steps, num_attn_heads=model_channels//64)
|
||||||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,1,model_channels))
|
self.unconditioned_embedding = nn.Parameter(torch.randn(1,model_channels,6))
|
||||||
|
self.unconditioned_prior = nn.Parameter(torch.zeros(1,in_channels,1))
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
|
||||||
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)])
|
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, num_heads, dropout) for _ in range(num_layers)])
|
||||||
|
@ -163,10 +164,17 @@ class TransformerDiffusion(nn.Module):
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
def input_to_random_resolution_and_window(self, x, x_prior):
|
def input_to_random_resolution_and_window(self, x, x_prior):
|
||||||
|
"""
|
||||||
|
This function MUST be applied to the target *before* noising. It returns the reduced, re-scoped target as well
|
||||||
|
as caches an internal prior for the rescoped target which will be useud in training.
|
||||||
|
Args:
|
||||||
|
x: Diffusion target
|
||||||
|
x_prior: Prior input, which is generally just {x}
|
||||||
|
"""
|
||||||
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}'
|
assert x.shape == x_prior.shape, f'{x.shape} {x_prior.shape}'
|
||||||
resolution = randrange(0, self.resolution_steps)
|
resolution = randrange(1, self.resolution_steps)
|
||||||
resolution_scale = 2 ** resolution
|
resolution_scale = 2 ** resolution
|
||||||
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
s = F.interpolate(x, scale_factor=1/resolution_scale, mode='nearest', align_corners=True)
|
||||||
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
s_prior = F.interpolate(x_prior, scale_factor=1/resolution_scale, mode='linear', align_corners=True)
|
||||||
s_diff = s.shape[-1] - self.max_window
|
s_diff = s.shape[-1] - self.max_window
|
||||||
if s_diff > 1:
|
if s_diff > 1:
|
||||||
|
@ -179,7 +187,6 @@ class TransformerDiffusion(nn.Module):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
|
def forward(self, x, timesteps, x_prior=None, resolution=None, conditioning_input=None, conditioning_free=False):
|
||||||
unused_params = []
|
|
||||||
conditioning_input = x_prior if conditioning_input is None else conditioning_input
|
conditioning_input = x_prior if conditioning_input is None else conditioning_input
|
||||||
|
|
||||||
h = x
|
h = x
|
||||||
|
@ -202,9 +209,14 @@ class TransformerDiffusion(nn.Module):
|
||||||
gap = conditioning_input.shape[-1] - clen
|
gap = conditioning_input.shape[-1] - clen
|
||||||
cstart = randrange(0, gap)
|
cstart = randrange(0, gap)
|
||||||
conditioning_input = conditioning_input[:,:,cstart:cstart+clen]
|
conditioning_input = conditioning_input[:,:,cstart:cstart+clen]
|
||||||
|
|
||||||
code_emb = self.conditioning_encoder(conditioning_input, resolution)
|
code_emb = self.conditioning_encoder(conditioning_input, resolution)
|
||||||
unused_params.append(self.unconditioned_embedding)
|
|
||||||
|
# Mask out the conditioning input and x_prior inputs for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
|
unconditioned_batches = torch.rand((h.shape[0], 1, 1),
|
||||||
|
device=h.device) < self.unconditioned_percentage
|
||||||
|
h_sub = torch.where(unconditioned_batches, self.unconditioned_prior.repeat(h_sub.shape[0], 1, h_sub.shape[-1]), h_sub)
|
||||||
|
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(code_emb.shape[0], 1, 1), code_emb)
|
||||||
|
|
||||||
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
with torch.autocast(x.device.type, enabled=self.enable_fp16):
|
||||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||||
|
@ -219,7 +231,8 @@ class TransformerDiffusion(nn.Module):
|
||||||
h = h.float()
|
h = h.float()
|
||||||
out = self.out(h)
|
out = self.out(h)
|
||||||
|
|
||||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
# Defensively involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||||
|
unused_params = [self.unconditioned_prior, self.unconditioned_embedding]
|
||||||
extraneous_addition = 0
|
extraneous_addition = 0
|
||||||
for p in unused_params:
|
for p in unused_params:
|
||||||
extraneous_addition = extraneous_addition + p.mean()
|
extraneous_addition = extraneous_addition + p.mean()
|
||||||
|
@ -238,11 +251,19 @@ def test_tfd():
|
||||||
cond = torch.randn(2,256,10336)
|
cond = torch.randn(2,256,10336)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||||
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1)
|
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
|
||||||
|
unconditioned_percentage=.6)
|
||||||
for k in range(100):
|
for k in range(100):
|
||||||
x = model.input_to_random_resolution_and_window(clip, x_prior=clip)
|
x = model.input_to_random_resolution_and_window(clip, x_prior=clip)
|
||||||
model(x, ts, clip)
|
model(x, ts, clip)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_conditioning(sd_path):
|
||||||
|
sd = torch.load(sd_path)
|
||||||
|
del sd['unconditioned_embedding']
|
||||||
|
torch.save(sd, sd_path.replace('.pth', '') + '_fixed.pth')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
remove_conditioning('X:\\dlas\\experiments\\train_music_diffusion_multilevel_sr\\models\\12000_generator.pth')
|
||||||
test_tfd()
|
test_tfd()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user