From 19410a919e0520c2b19e476a70f8e7e7966beb3c Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 15 Jun 2024 12:29:03 -0500 Subject: [PATCH] ugh --- vall_e/data.py | 6 +++--- vall_e/engines/__init__.py | 3 ++- vall_e/models/ar_nar.py | 11 ++++++++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 9f414df..4dbe6a7 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -602,17 +602,17 @@ class Dataset(_Dataset): state_dict = torch.load(path) if self.sampler_type == "path": - state_dict = self.sampler.load_state(state_dict) + state_dict = self.sampler.set_state(state_dict) else: for name, sampler in state_dict["samplers"].items(): if name not in self.samplers: continue - self.samplers[name].load_state( sampler ) + self.samplers[name].set_state( sampler ) for name, sampler in state_dict["spkr_samplers"].items(): if name not in self.spkr_samplers: continue - self.spkr_samplers[name].load_state( sampler ) + self.spkr_samplers[name].set_state( sampler ) def _get_phone_symmap(self): return get_phone_symmap() diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 6142b67..b8c9d1b 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -111,10 +111,11 @@ def load_engines(training=True): # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present load_path = cfg.ckpt_dir / name / "fp32.pth" + if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): print("DeepSpeed checkpoint missing, but weights found.") loads_state_dict = True - + stats = None if loads_state_dict: state = torch.load(load_path, map_location=torch.device(cfg.device)) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index e2ae82c..81781f5 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -207,7 +207,7 @@ class AR_NAR(Base): prev_list = resps_list - for n in trange( max_levels, desc="NAR" ): + for n in trange( max_levels, desc="NAR" ): level = prev_list[0].shape[-1] if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break @@ -244,6 +244,15 @@ class AR_NAR(Base): #mirostat=mirostat, ) + # filter + """ + if self.arch_type in ["mamba2-hf"]: + for batch_index, resp in enumerate(resps_list): + for i, token in enumerate(resp): + if token >= 1024: + resps_list[batch_index][i] = 1023 + """ + prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] return prev_list