More improvements

This commit is contained in:
James Betker 2022-03-16 10:16:34 -06:00
parent 54202aa099
commit 8b376e63d9
4 changed files with 25 additions and 12 deletions

View File

@ -388,6 +388,23 @@ class DiffusionTts(nn.Module):
}
return groups
def fix_alignment(self, x, aligned_conditioning):
"""
The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
padding both <x> and <aligned_conditioning> before forward propagation and removing the padding before returning.
"""
cm = ceil_multiple(x.shape[-1], self.alignment_size)
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
# Also fix aligned_latent, which is aligned to x.
if is_latent(aligned_conditioning):
aligned_conditioning = torch.cat([aligned_conditioning,
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
else:
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
return x, aligned_conditioning
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
"""
Apply the model to an input batch.
@ -415,16 +432,7 @@ class DiffusionTts(nn.Module):
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
orig_x_shape = x.shape[-1]
cm = ceil_multiple(x.shape[-1], self.alignment_size)
if cm != 0:
pc = (cm-x.shape[-1])/x.shape[-1]
x = F.pad(x, (0,cm-x.shape[-1]))
# Also fix aligned_latent, which is aligned to x.
if is_latent(aligned_conditioning):
aligned_conditioning = torch.cat([aligned_conditioning,
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
else:
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
with autocast(x.device.type, enabled=self.enable_fp16):

View File

@ -314,6 +314,11 @@ class ExtensibleTrainer(BaseModel):
if hasattr(net.module, "before_step"):
net.module.before_step(it)
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
# This is needed to accurately log the grad norms.
for opt in step.optimizers:
step.scaler.unscale_(opt)
if return_grad_norms and train_step:
for name in nets_to_train:
model = self.networks[name]

View File

@ -125,7 +125,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
mel = wav_to_mel(audio)
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
univnet_mel = wav_to_univnet_mel(audio, mel_norms_file=None) # to be used for a conditioning input
univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=False) # to be used for a conditioning input, but also guides output shape.
output_size = univnet_mel.shape[-1]
aligned_codes_compression_factor = output_size // mel_codes.shape[-1]

View File

@ -304,7 +304,7 @@ class ConfigurableStep(Module):
return
self.grads_generated = False
for opt in self.optimizers:
self.scaler.unscale_(opt)
# self.scaler.unscale_(opt) It would be important to do this here, but ExtensibleTrainer currently does it.
# Optimizers can be opted out in the early stages of training.
after = opt._config['after'] if 'after' in opt._config.keys() else 0