Forgot to re-add in setting the weight's dtype on model load

This commit is contained in:
mrq 2023-08-22 22:57:23 -05:00
parent 9c5a33bfd2
commit 524d289c9c

View File

@ -38,17 +38,19 @@ class TTS():
models = get_models(cfg.models.get()) models = get_models(cfg.models.get())
for name, model in models.items(): for name, model in models.items():
if name.startswith("ar"): if name.startswith("ar"):
self.ar = model.to(self.device, dtype=torch.float32) self.ar = model
state = torch.load(self.ar_ckpt) state = torch.load(self.ar_ckpt)
if "module" in state: if "module" in state:
state = state['module'] state = state['module']
self.ar.load_state_dict(state) self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype)
elif name.startswith("nar"): elif name.startswith("nar"):
self.nar = model.to(self.device, dtype=torch.float32) self.nar = model
state = torch.load(self.nar_ckpt) state = torch.load(self.nar_ckpt)
if "module" in state: if "module" in state:
state = state['module'] state = state['module']
self.nar.load_state_dict(state) self.nar.load_state_dict(state)
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype)
else: else:
self.load_models() self.load_models()
@ -62,9 +64,9 @@ class TTS():
engines = load_engines() engines = load_engines()
for name, engine in engines.items(): for name, engine in engines.items():
if name[:2] == "ar": if name[:2] == "ar":
self.ar = engine.module.to(self.device) self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype)
elif name[:3] == "nar": elif name[:3] == "nar":
self.nar = engine.module.to(self.device) self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype)
def encode_text( self, text, language="en" ): def encode_text( self, text, language="en" ):
# already a tensor, return it # already a tensor, return it