update flat0 to break out timestep-independent inference steps

This commit is contained in:
James Betker 2022-04-01 14:38:53 -06:00
parent a6181a489b
commit 3e97abc8a9
2 changed files with 62 additions and 31 deletions

View File

@ -187,7 +187,46 @@ class DiffusionTtsFlat(nn.Module):
}
return groups
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False, return_code_pred=False):
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
# Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
# Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning):
code_emb = self.latent_converter(aligned_conditioning)
else:
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
code_emb = self.code_converter(code_emb)
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
if not return_code_pred:
return expanded_code_emb
else:
mel_pred = self.mel_head(expanded_code_emb)
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
mel_pred = mel_pred * unconditioned_batches.logical_not()
return expanded_code_emb, mel_pred
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
"""
Apply the model to an input batch.
@ -195,49 +234,32 @@ class DiffusionTtsFlat(nn.Module):
:param timesteps: a 1-D batch of timesteps.
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
:return: an [N x C x ...] Tensor of outputs.
"""
# Shuffle aligned_latent to BxCxS format
if is_latent(aligned_conditioning):
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
# Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
unused_params = []
if conditioning_free:
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
unused_params.extend(list(self.latent_converter.parameters()))
else:
if precomputed_aligned_embeddings is not None:
code_emb = precomputed_aligned_embeddings
else:
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
unused_params.append(self.unconditioned_embedding)
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
conds = []
for j in range(speech_conditioning_input.shape[1]):
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
conds = torch.cat(conds, dim=-1)
cond_emb = conds.mean(dim=-1)
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
if is_latent(aligned_conditioning):
code_emb = self.latent_converter(aligned_conditioning)
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
else:
code_emb = self.code_embedding(aligned_conditioning).permute(0,2,1)
code_emb = self.code_converter(code_emb)
unused_params.extend(list(self.latent_converter.parameters()))
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
if self.training and self.unconditioned_percentage > 0:
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
device=code_emb.device) < self.unconditioned_percentage
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
code_emb)
expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
mel_pred = self.mel_head(expanded_code_emb)
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
mel_pred = mel_pred * unconditioned_batches.logical_not()
# Everything after this comment is timestep dependent.
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
code_emb = self.conditioning_timestep_integrator(expanded_code_emb, time_emb)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
x = self.inp_block(x)
x = torch.cat([x, code_emb], dim=1)
x = self.integrating_conv(x)
@ -272,7 +294,7 @@ def register_diffusion_tts_flat0(opt_net, opt):
if __name__ == '__main__':
clip = torch.randn(2, 100, 400)
aligned_latent = torch.randn(2,388,512)
aligned_sequence = torch.randint(0,8192,(2,388))
aligned_sequence = torch.randint(0,8192,(2,100))
cond = torch.randn(2, 100, 400)
ts = torch.LongTensor([600, 600])
model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5)

View File

@ -591,6 +591,15 @@ def load_audio(audiopath, sampling_rate, raw_data=None):
return audio
def pad_or_truncate(t, length):
if t.shape[-1] == length:
return t
elif t.shape[-1] < length:
return F.pad(t, (0, length-t.shape[-1]))
else:
return t[..., :length]
def load_wav_to_torch(full_path):
import scipy.io.wavfile
sampling_rate, data = scipy.io.wavfile.read(full_path)