Some additional context regularization in tfd

This commit is contained in:
James Betker 2022-07-14 21:49:47 -06:00
parent 7170ccdfa9
commit 51291ab070
2 changed files with 13 additions and 3 deletions

View File

@ -215,6 +215,15 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
left_pt = -1
cond_right = conditioning_input[:,:,cond_start+N:]
right_pt = 0
if self.training:
# Arbitrarily restrict the context given. We should support short contexts and without this they are never encountered.
arb_context_cap = random.randint(50, 100)
if cond_left.shape[-1] > arb_context_cap and random() > .5:
cond_left = cond_left[:,:,-arb_context_cap:]
if cond_right.shape[-1] > arb_context_cap and random() > .5:
cond_right = cond_right[:,:,:arb_context_cap]
elif cond_left is None:
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
cond_pre = conditioning_input[:,:,:cond_start]

View File

@ -230,10 +230,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
# 1. Generate the cheater latent using the input as a reference.
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
output_shape = (1, 256, cheater.shape[-1]-80)
gen_cheater = sampler(self.model, output_shape, progress=True,
# center-pad the conditioning input (the center isn't actually used). this is hack for giving tfdpc5 a bigger working context.
cheater_padded = torch.cat([cheater[:,:,cheater.shape[-1]//2:], torch.zeros(1,256,160, device=cheater.device), cheater[:,:,:cheater.shape[-1]//2]], dim=-1)
gen_cheater = sampler(self.model, cheater.shape, progress=True,
causal=self.causal, causal_slope=self.causal_slope,
model_kwargs={'conditioning_input': cheater, 'cond_start': 40})
model_kwargs={'conditioning_input': cheater_padded, 'cond_start': 80})
# 2. Decode the cheater into a MEL
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,