forked from mrq/DL-Art-School
tfdpcv5 updates
This commit is contained in:
parent
ce82eb6022
commit
f46d6645da
|
@ -114,9 +114,6 @@ class ConditioningEncoder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusionWithPointConditioning(nn.Module):
|
class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
"""
|
|
||||||
A diffusion model composed entirely of stacks of transformer layers. Why would you do it any other way?
|
|
||||||
"""
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels=256,
|
in_channels=256,
|
||||||
|
@ -129,9 +126,11 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
input_cond_dim=1024,
|
input_cond_dim=1024,
|
||||||
num_heads=8,
|
num_heads=8,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
time_proj=False,
|
time_proj=True,
|
||||||
|
new_cond=False,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
|
checkpoint_conditioning=True, # This will need to be false for DDP training. :(
|
||||||
|
regularization=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
):
|
):
|
||||||
|
@ -144,6 +143,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
|
self.regularization = regularization
|
||||||
|
self.new_cond = new_cond
|
||||||
|
|
||||||
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1)
|
||||||
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj)
|
self.conditioning_encoder = ConditioningEncoder(256, model_channels, time_embed_dim, do_checkpointing=checkpoint_conditioning, time_proj=time_proj)
|
||||||
|
@ -166,13 +167,11 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
cond_projection=(k % 3 == 0),
|
cond_projection=(k % 3 == 0),
|
||||||
use_conv=(k % 3 != 0),
|
use_conv=(k % 3 != 0),
|
||||||
) for k in range(num_layers)])
|
) for k in range(num_layers)])
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(model_channels),
|
normalization(model_channels),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.debug_codes = {}
|
self.debug_codes = {}
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
@ -199,10 +198,24 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, custom_conditioning_fetcher):
|
def process_conditioning(self, conditioning_input, time_emb, N, cond_start, cond_left, cond_right):
|
||||||
if custom_conditioning_fetcher is not None:
|
if self.training and self.regularization:
|
||||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
# frequency regularization
|
||||||
else:
|
fstart = random.randint(0, conditioning_input.shape[1] - 1)
|
||||||
|
fclip = random.randint(1, min(conditioning_input.shape[1]-fstart, 16))
|
||||||
|
conditioning_input[:,fstart:fstart+fclip] = 0
|
||||||
|
# time regularization
|
||||||
|
for k in range(1, random.randint(2, 4)):
|
||||||
|
tstart = random.randint(0, conditioning_input.shape[-1] - 1)
|
||||||
|
tclip = random.randint(1, min(conditioning_input.shape[-1]-tstart, 10))
|
||||||
|
conditioning_input[:,:,tstart:tstart+tclip] = 0
|
||||||
|
|
||||||
|
if cond_left is None and self.new_cond:
|
||||||
|
cond_left = conditioning_input[:,:,:max(cond_start, 20)]
|
||||||
|
left_pt = cond_start
|
||||||
|
cond_right = conditioning_input[:,:,min(N+cond_start, conditioning_input.shape[-1]-20):]
|
||||||
|
right_pt = cond_right.shape[-1] - (conditioning_input.shape[-1] - (N+cond_start))
|
||||||
|
elif cond_left is None:
|
||||||
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
||||||
cond_pre = conditioning_input[:,:,:cond_start]
|
cond_pre = conditioning_input[:,:,:cond_start]
|
||||||
cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
|
cond_aligned = conditioning_input[:,:,cond_start:N+cond_start]
|
||||||
|
@ -223,19 +236,25 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
cond_right = cond_right[:,:,to_remove_right:]
|
cond_right = cond_right[:,:,to_remove_right:]
|
||||||
|
|
||||||
# Concatenate the _pre and _post back on.
|
# Concatenate the _pre and _post back on.
|
||||||
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
|
left_pt = cond_start
|
||||||
cond_right_full = torch.cat([cond_right, cond_post], dim=-1)
|
right_pt = cond_right.shape[-1]
|
||||||
|
cond_left = torch.cat([cond_pre, cond_left], dim=-1)
|
||||||
|
cond_right = torch.cat([cond_right, cond_post], dim=-1)
|
||||||
|
else:
|
||||||
|
left_pt = -1
|
||||||
|
right_pt = 0
|
||||||
|
|
||||||
# Propagate through the encoder.
|
# Propagate through the encoder.
|
||||||
cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb)
|
cond_left_enc = self.conditioning_encoder(cond_left, time_emb)
|
||||||
cs = cond_left_enc[:,:,cond_start]
|
cs = cond_left_enc[:,:,left_pt]
|
||||||
cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb)
|
cond_right_enc = self.conditioning_encoder(cond_right, time_emb)
|
||||||
ce = cond_right_enc[:,:,cond_right.shape[-1]-1]
|
ce = cond_right_enc[:,:,right_pt]
|
||||||
|
|
||||||
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1)
|
||||||
cond = F.interpolate(cond_enc, size=(N,), mode='linear', align_corners=True).permute(0,2,1)
|
cond = F.interpolate(cond_enc, size=(N,), mode='linear', align_corners=True).permute(0,2,1)
|
||||||
return cond
|
return cond
|
||||||
|
|
||||||
def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None):
|
def forward(self, x, timesteps, conditioning_input=None, cond_left=None, cond_right=None, conditioning_free=False, cond_start=0):
|
||||||
unused_params = []
|
unused_params = []
|
||||||
|
|
||||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
time_emb = self.time_embed(timestep_embedding(timesteps, self.time_embed_dim))
|
||||||
|
@ -244,7 +263,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
cond = self.unconditioned_embedding
|
cond = self.unconditioned_embedding
|
||||||
cond = cond.repeat(1,x.shape[-1],1)
|
cond = cond.repeat(1,x.shape[-1],1)
|
||||||
else:
|
else:
|
||||||
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, custom_conditioning_fetcher)
|
cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, cond_left, cond_right)
|
||||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||||
if self.training and self.unconditioned_percentage > 0:
|
if self.training and self.unconditioned_percentage > 0:
|
||||||
unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
|
unconditioned_batches = torch.rand((cond.shape[0], 1, 1),
|
||||||
|
@ -276,6 +295,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
# Scale back the gradients of the blkout and prenorm layers by a constant factor. These get two orders of magnitudes
|
||||||
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
|
||||||
# directly fiddling with the gradients.
|
# directly fiddling with the gradients.
|
||||||
|
if not self.new_cond: # Not really related, I just don't want to add a new config.
|
||||||
for p in scaled_grad_parameters:
|
for p in scaled_grad_parameters:
|
||||||
if hasattr(p, 'grad') and p.grad is not None:
|
if hasattr(p, 'grad') and p.grad is not None:
|
||||||
p.grad *= .2
|
p.grad *= .2
|
||||||
|
@ -294,10 +314,11 @@ def test_cheater_model():
|
||||||
# For music:
|
# For music:
|
||||||
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024,
|
||||||
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
contraction_dim=512, num_heads=8, num_layers=15, dropout=0,
|
||||||
unconditioned_percentage=.4, checkpoint_conditioning=False)
|
unconditioned_percentage=.4, checkpoint_conditioning=False,
|
||||||
|
regularization=True, new_cond=True)
|
||||||
print_network(model)
|
print_network(model)
|
||||||
for k in range(100):
|
for cs in range(276,cl.shape[-1]-clip.shape[-1]):
|
||||||
o = model(clip, ts, cl)
|
o = model(clip, ts, cl, cond_start=cs)
|
||||||
pg = model.get_grad_norm_parameter_groups()
|
pg = model.get_grad_norm_parameter_groups()
|
||||||
def prmsz(lp):
|
def prmsz(lp):
|
||||||
sz = 0
|
sz = 0
|
||||||
|
|
|
@ -617,8 +617,6 @@ class GaussianDiffusion:
|
||||||
mask,
|
mask,
|
||||||
noise=None,
|
noise=None,
|
||||||
clip_denoised=True,
|
clip_denoised=True,
|
||||||
causal=False,
|
|
||||||
causal_slope=1,
|
|
||||||
denoised_fn=None,
|
denoised_fn=None,
|
||||||
cond_fn=None,
|
cond_fn=None,
|
||||||
model_kwargs=None,
|
model_kwargs=None,
|
||||||
|
@ -640,8 +638,6 @@ class GaussianDiffusion:
|
||||||
img,
|
img,
|
||||||
t,
|
t,
|
||||||
clip_denoised=clip_denoised,
|
clip_denoised=clip_denoised,
|
||||||
causal=causal,
|
|
||||||
causal_slope=causal_slope,
|
|
||||||
denoised_fn=denoised_fn,
|
denoised_fn=denoised_fn,
|
||||||
cond_fn=cond_fn,
|
cond_fn=cond_fn,
|
||||||
model_kwargs=model_kwargs,
|
model_kwargs=model_kwargs,
|
||||||
|
|
|
@ -436,18 +436,18 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_cheater_gen.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\53000_generator_ema.pth'
|
load_path='X:\\dlas\\experiments\\train_music_cheater_gen_v5_causal_retrain\\models\\80500_generator_ema.pth'
|
||||||
).cuda()
|
).cuda()
|
||||||
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :)
|
||||||
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
#'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety.
|
||||||
'diffusion_steps': 220, # basis: 192
|
'diffusion_steps': 256, # basis: 192
|
||||||
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': False,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
|
'diffusion_schedule': 'linear', 'diffusion_type': 'cheater_gen',
|
||||||
# Slope 1: 1.03x, 2: 1.06, 4: 1.135, 8: 1.27, 16: 1.54
|
# Slope 1: 1.03x, 2: 1.06, 4: 1.135, 8: 1.27, 16: 1.54
|
||||||
'causal': True, 'causal_slope': 3, # DONT FORGET TO INCREMENT THE STEP!
|
'causal': True, 'causal_slope': 4, # DONT FORGET TO INCREMENT THE STEP!
|
||||||
#'partial_low': 128, 'partial_high': 192
|
#'partial_low': 128, 'partial_high': 192
|
||||||
}
|
}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 3, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 104, 'device': 'cuda', 'opt': {}}
|
||||||
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
eval = MusicDiffusionFid(diffusion, opt_eval, env)
|
||||||
fds = []
|
fds = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user