forked from mrq/DL-Art-School
Some additional context regularization in tfd
This commit is contained in:
parent
7170ccdfa9
commit
51291ab070
|
@ -215,6 +215,15 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
left_pt = -1
|
left_pt = -1
|
||||||
cond_right = conditioning_input[:,:,cond_start+N:]
|
cond_right = conditioning_input[:,:,cond_start+N:]
|
||||||
right_pt = 0
|
right_pt = 0
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# Arbitrarily restrict the context given. We should support short contexts and without this they are never encountered.
|
||||||
|
arb_context_cap = random.randint(50, 100)
|
||||||
|
if cond_left.shape[-1] > arb_context_cap and random() > .5:
|
||||||
|
cond_left = cond_left[:,:,-arb_context_cap:]
|
||||||
|
if cond_right.shape[-1] > arb_context_cap and random() > .5:
|
||||||
|
cond_right = cond_right[:,:,:arb_context_cap]
|
||||||
|
|
||||||
elif cond_left is None:
|
elif cond_left is None:
|
||||||
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
assert conditioning_input.shape[-1] - cond_start - N >= 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}'
|
||||||
cond_pre = conditioning_input[:,:,:cond_start]
|
cond_pre = conditioning_input[:,:,:cond_start]
|
||||||
|
|
|
@ -230,10 +230,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
# 1. Generate the cheater latent using the input as a reference.
|
# 1. Generate the cheater latent using the input as a reference.
|
||||||
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop
|
||||||
output_shape = (1, 256, cheater.shape[-1]-80)
|
# center-pad the conditioning input (the center isn't actually used). this is hack for giving tfdpc5 a bigger working context.
|
||||||
gen_cheater = sampler(self.model, output_shape, progress=True,
|
cheater_padded = torch.cat([cheater[:,:,cheater.shape[-1]//2:], torch.zeros(1,256,160, device=cheater.device), cheater[:,:,:cheater.shape[-1]//2]], dim=-1)
|
||||||
|
gen_cheater = sampler(self.model, cheater.shape, progress=True,
|
||||||
causal=self.causal, causal_slope=self.causal_slope,
|
causal=self.causal, causal_slope=self.causal_slope,
|
||||||
model_kwargs={'conditioning_input': cheater, 'cond_start': 40})
|
model_kwargs={'conditioning_input': cheater_padded, 'cond_start': 80})
|
||||||
|
|
||||||
# 2. Decode the cheater into a MEL
|
# 2. Decode the cheater into a MEL
|
||||||
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,
|
gen_mel = self.cheater_decoder_diffuser.ddim_sample_loop(self.local_modules['cheater_decoder'].diff.to(audio.device), (1,256,gen_cheater.shape[-1]*16), progress=True,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user