This commit is contained in:
James Betker 2022-03-24 23:31:20 -06:00
parent 9c79fec734
commit d4218d8443
2 changed files with 5 additions and 2 deletions

View File

@ -182,7 +182,10 @@ class DiffusionTtsFlat(nn.Module):
def get_grad_norm_parameter_groups(self):
groups = {
'minicoder': list(self.contextual_embedder.parameters()),
'layers': list(self.layers),
'layers': list(self.layers.parameters()),
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
'time_embed': list(self.time_embed.parameters()),
}
return groups

View File

@ -132,7 +132,7 @@ class DiscreteTokenInjector(Injector):
super().__init__(opt, env)
cfg = opt_get(opt, ['dvae_config'], "../experiments/train_diffusion_vocoder_22k_level.yml")
dvae_name = opt_get(opt, ['dvae_name'], 'dvae')
self.dvae = load_model_from_config(cfg, dvae_name).cuda().eval()
self.dvae = load_model_from_config(cfg, dvae_name, device=env['device']).eval()
def forward(self, state):
inp = state[self.input]