forked from mrq/DL-Art-School
mods
This commit is contained in:
parent
9c79fec734
commit
d4218d8443
|
@ -182,7 +182,10 @@ class DiffusionTtsFlat(nn.Module):
|
|||
def get_grad_norm_parameter_groups(self):
|
||||
groups = {
|
||||
'minicoder': list(self.contextual_embedder.parameters()),
|
||||
'layers': list(self.layers),
|
||||
'layers': list(self.layers.parameters()),
|
||||
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
|
||||
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
|
||||
'time_embed': list(self.time_embed.parameters()),
|
||||
}
|
||||
return groups
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ class DiscreteTokenInjector(Injector):
|
|||
super().__init__(opt, env)
|
||||
cfg = opt_get(opt, ['dvae_config'], "../experiments/train_diffusion_vocoder_22k_level.yml")
|
||||
dvae_name = opt_get(opt, ['dvae_name'], 'dvae')
|
||||
self.dvae = load_model_from_config(cfg, dvae_name).cuda().eval()
|
||||
self.dvae = load_model_from_config(cfg, dvae_name, device=env['device']).eval()
|
||||
|
||||
def forward(self, state):
|
||||
inp = state[self.input]
|
||||
|
|
Loading…
Reference in New Issue
Block a user