Forgot to re-add in setting the weight's dtype on model load

This commit is contained in:
mrq 2023-08-22 22:57:23 -05:00
parent 9c5a33bfd2
commit 524d289c9c

View File

@ -38,17 +38,19 @@ class TTS():
models = get_models(cfg.models.get())
for name, model in models.items():
if name.startswith("ar"):
self.ar = model.to(self.device, dtype=torch.float32)
self.ar = model
state = torch.load(self.ar_ckpt)
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype)
elif name.startswith("nar"):
self.nar = model.to(self.device, dtype=torch.float32)
self.nar = model
state = torch.load(self.nar_ckpt)
if "module" in state:
state = state['module']
self.nar.load_state_dict(state)
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype)
else:
self.load_models()
@ -62,9 +64,9 @@ class TTS():
engines = load_engines()
for name, engine in engines.items():
if name[:2] == "ar":
self.ar = engine.module.to(self.device)
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype)
elif name[:3] == "nar":
self.nar = engine.module.to(self.device)
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype)
def encode_text( self, text, language="en" ):
# already a tensor, return it