From 524d289c9c44053779abbbacf1d0665ad170b7f2 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 22 Aug 2023 22:57:23 -0500 Subject: [PATCH] Forgot to re-add in setting the weight's dtype on model load --- vall_e/inference.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vall_e/inference.py b/vall_e/inference.py index e481f50..3b5d423 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -38,17 +38,19 @@ class TTS(): models = get_models(cfg.models.get()) for name, model in models.items(): if name.startswith("ar"): - self.ar = model.to(self.device, dtype=torch.float32) + self.ar = model state = torch.load(self.ar_ckpt) if "module" in state: state = state['module'] self.ar.load_state_dict(state) + self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype) elif name.startswith("nar"): - self.nar = model.to(self.device, dtype=torch.float32) + self.nar = model state = torch.load(self.nar_ckpt) if "module" in state: state = state['module'] self.nar.load_state_dict(state) + self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype) else: self.load_models() @@ -62,9 +64,9 @@ class TTS(): engines = load_engines() for name, engine in engines.items(): if name[:2] == "ar": - self.ar = engine.module.to(self.device) + self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype) elif name[:3] == "nar": - self.nar = engine.module.to(self.device) + self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype) def encode_text( self, text, language="en" ): # already a tensor, return it