update flat0 to break out timestep-independent inference steps
This commit is contained in:
parent
a6181a489b
commit
3e97abc8a9
|
@ -187,7 +187,46 @@ class DiffusionTtsFlat(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False, return_code_pred=False):
|
||||
def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred):
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||
|
||||
# Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
|
||||
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(
|
||||
conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
cond_emb = conds.mean(dim=-1)
|
||||
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
else:
|
||||
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest')
|
||||
|
||||
if not return_code_pred:
|
||||
return expanded_code_emb
|
||||
else:
|
||||
mel_pred = self.mel_head(expanded_code_emb)
|
||||
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
|
||||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
||||
return expanded_code_emb, mel_pred
|
||||
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
|
@ -195,49 +234,32 @@ class DiffusionTtsFlat(nn.Module):
|
|||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced.
|
||||
:param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded.
|
||||
:param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent()
|
||||
:param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
# Shuffle aligned_latent to BxCxS format
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = aligned_conditioning.permute(0, 2, 1)
|
||||
assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None)
|
||||
assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive.
|
||||
|
||||
# Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent.
|
||||
unused_params = []
|
||||
if conditioning_free:
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
|
||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
unused_params.extend(list(self.latent_converter.parameters()))
|
||||
else:
|
||||
if precomputed_aligned_embeddings is not None:
|
||||
code_emb = precomputed_aligned_embeddings
|
||||
else:
|
||||
code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True)
|
||||
|
||||
unused_params.append(self.unconditioned_embedding)
|
||||
speech_conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input
|
||||
conds = []
|
||||
for j in range(speech_conditioning_input.shape[1]):
|
||||
conds.append(self.contextual_embedder(speech_conditioning_input[:, j]))
|
||||
conds = torch.cat(conds, dim=-1)
|
||||
cond_emb = conds.mean(dim=-1)
|
||||
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
||||
if is_latent(aligned_conditioning):
|
||||
code_emb = self.latent_converter(aligned_conditioning)
|
||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||
else:
|
||||
code_emb = self.code_embedding(aligned_conditioning).permute(0,2,1)
|
||||
code_emb = self.code_converter(code_emb)
|
||||
unused_params.extend(list(self.latent_converter.parameters()))
|
||||
code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1)
|
||||
unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device)
|
||||
# Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance.
|
||||
if self.training and self.unconditioned_percentage > 0:
|
||||
unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1),
|
||||
device=code_emb.device) < self.unconditioned_percentage
|
||||
code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1),
|
||||
code_emb)
|
||||
expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest')
|
||||
mel_pred = self.mel_head(expanded_code_emb)
|
||||
# Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss.
|
||||
mel_pred = mel_pred * unconditioned_batches.logical_not()
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
code_emb = self.conditioning_timestep_integrator(expanded_code_emb, time_emb)
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
x = self.inp_block(x)
|
||||
x = torch.cat([x, code_emb], dim=1)
|
||||
x = self.integrating_conv(x)
|
||||
|
@ -272,7 +294,7 @@ def register_diffusion_tts_flat0(opt_net, opt):
|
|||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 100, 400)
|
||||
aligned_latent = torch.randn(2,388,512)
|
||||
aligned_sequence = torch.randint(0,8192,(2,388))
|
||||
aligned_sequence = torch.randint(0,8192,(2,100))
|
||||
cond = torch.randn(2, 100, 400)
|
||||
ts = torch.LongTensor([600, 600])
|
||||
model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5)
|
||||
|
|
|
@ -591,6 +591,15 @@ def load_audio(audiopath, sampling_rate, raw_data=None):
|
|||
return audio
|
||||
|
||||
|
||||
def pad_or_truncate(t, length):
|
||||
if t.shape[-1] == length:
|
||||
return t
|
||||
elif t.shape[-1] < length:
|
||||
return F.pad(t, (0, length-t.shape[-1]))
|
||||
else:
|
||||
return t[..., :length]
|
||||
|
||||
|
||||
def load_wav_to_torch(full_path):
|
||||
import scipy.io.wavfile
|
||||
sampling_rate, data = scipy.io.wavfile.read(full_path)
|
||||
|
|
Loading…
Reference in New Issue
Block a user