Forgot to re-add in setting the weight's dtype on model load
This commit is contained in:
parent
9c5a33bfd2
commit
524d289c9c
|
@ -38,17 +38,19 @@ class TTS():
|
||||||
models = get_models(cfg.models.get())
|
models = get_models(cfg.models.get())
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
if name.startswith("ar"):
|
if name.startswith("ar"):
|
||||||
self.ar = model.to(self.device, dtype=torch.float32)
|
self.ar = model
|
||||||
state = torch.load(self.ar_ckpt)
|
state = torch.load(self.ar_ckpt)
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state['module']
|
state = state['module']
|
||||||
self.ar.load_state_dict(state)
|
self.ar.load_state_dict(state)
|
||||||
|
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype)
|
||||||
elif name.startswith("nar"):
|
elif name.startswith("nar"):
|
||||||
self.nar = model.to(self.device, dtype=torch.float32)
|
self.nar = model
|
||||||
state = torch.load(self.nar_ckpt)
|
state = torch.load(self.nar_ckpt)
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state['module']
|
state = state['module']
|
||||||
self.nar.load_state_dict(state)
|
self.nar.load_state_dict(state)
|
||||||
|
self.nar = self.nar.to(self.device, dtype=cfg.inference.dtype)
|
||||||
else:
|
else:
|
||||||
self.load_models()
|
self.load_models()
|
||||||
|
|
||||||
|
@ -62,9 +64,9 @@ class TTS():
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
if name[:2] == "ar":
|
if name[:2] == "ar":
|
||||||
self.ar = engine.module.to(self.device)
|
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype)
|
||||||
elif name[:3] == "nar":
|
elif name[:3] == "nar":
|
||||||
self.nar = engine.module.to(self.device)
|
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype)
|
||||||
|
|
||||||
def encode_text( self, text, language="en" ):
|
def encode_text( self, text, language="en" ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
|
|
Loading…
Reference in New Issue
Block a user