Forgot to re-add in setting the weight's dtype on model load
This commit is contained in:
parent
9c5a33bfd2
commit
524d289c9c
|
@ -38,17 +38,19 @@ class TTS():
|
|||
models = get_models(cfg.models.get())
|
||||
for name, model in models.items():
|
||||
if name.startswith("ar"):
|
||||
self.ar = model.to(self.device, dtype=torch.float32)
|
||||
self.ar = model
|
||||
state = torch.load(self.ar_ckpt)
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype)
|
||||
elif name.startswith("nar"):
|
||||
self.nar = model.to(self.device, dtype=torch.float32)
|
||||
self.nar = model
|
||||
state = torch.load(self.nar_ckpt)
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.nar.load_state_dict(state)
|
||||
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype)
|
||||
else:
|
||||
self.load_models()
|
||||
|
||||
|
@ -62,9 +64,9 @@ class TTS():
|
|||
engines = load_engines()
|
||||
for name, engine in engines.items():
|
||||
if name[:2] == "ar":
|
||||
self.ar = engine.module.to(self.device)
|
||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype)
|
||||
elif name[:3] == "nar":
|
||||
self.nar = engine.module.to(self.device)
|
||||
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype)
|
||||
|
||||
def encode_text( self, text, language="en" ):
|
||||
# already a tensor, return it
|
||||
|
|
Loading…
Reference in New Issue
Block a user