More improvements
This commit is contained in:
parent
54202aa099
commit
8b376e63d9
|
@ -388,6 +388,23 @@ class DiffusionTts(nn.Module):
|
|||
}
|
||||
return groups
|
||||
|
||||
def fix_alignment(self, x, aligned_conditioning):
|
||||
"""
|
||||
The UNet requires that the input <x> is a certain multiple of 2, defined by the UNet depth. Enforce this by
|
||||
padding both <x> and <aligned_conditioning> before forward propagation and removing the padding before returning.
|
||||
"""
|
||||
cm = ceil_multiple(x.shape[-1], self.alignment_size)
|
||||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
# Also fix aligned_latent, which is aligned to x.
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = torch.cat([aligned_conditioning,
|
||||
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
|
||||
else:
|
||||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||
return x, aligned_conditioning
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
@ -415,16 +432,7 @@ class DiffusionTts(nn.Module):
|
|||
|
||||
# Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net.
|
||||
orig_x_shape = x.shape[-1]
|
||||
cm = ceil_multiple(x.shape[-1], self.alignment_size)
|
||||
if cm != 0:
|
||||
pc = (cm-x.shape[-1])/x.shape[-1]
|
||||
x = F.pad(x, (0,cm-x.shape[-1]))
|
||||
# Also fix aligned_latent, which is aligned to x.
|
||||
if is_latent(aligned_conditioning):
|
||||
aligned_conditioning = torch.cat([aligned_conditioning,
|
||||
self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1)
|
||||
else:
|
||||
aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1])))
|
||||
x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning)
|
||||
|
||||
with autocast(x.device.type, enabled=self.enable_fp16):
|
||||
|
||||
|
|
|
@ -314,6 +314,11 @@ class ExtensibleTrainer(BaseModel):
|
|||
if hasattr(net.module, "before_step"):
|
||||
net.module.before_step(it)
|
||||
|
||||
# Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point)
|
||||
# This is needed to accurately log the grad norms.
|
||||
for opt in step.optimizers:
|
||||
step.scaler.unscale_(opt)
|
||||
|
||||
if return_grad_norms and train_step:
|
||||
for name in nets_to_train:
|
||||
model = self.networks[name]
|
||||
|
|
|
@ -125,7 +125,7 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel)
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0)
|
||||
univnet_mel = wav_to_univnet_mel(audio, mel_norms_file=None) # to be used for a conditioning input
|
||||
univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=False) # to be used for a conditioning input, but also guides output shape.
|
||||
|
||||
output_size = univnet_mel.shape[-1]
|
||||
aligned_codes_compression_factor = output_size // mel_codes.shape[-1]
|
||||
|
|
|
@ -304,7 +304,7 @@ class ConfigurableStep(Module):
|
|||
return
|
||||
self.grads_generated = False
|
||||
for opt in self.optimizers:
|
||||
self.scaler.unscale_(opt)
|
||||
# self.scaler.unscale_(opt) It would be important to do this here, but ExtensibleTrainer currently does it.
|
||||
|
||||
# Optimizers can be opted out in the early stages of training.
|
||||
after = opt._config['after'] if 'after' in opt._config.keys() else 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user