diff --git a/vall_e/data.py b/vall_e/data.py index 5257abf..8f7a229 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -1034,13 +1034,29 @@ if __name__ == "__main__": "validation": [ next(iter(val_dl)), next(iter(val_dl)) ], } + Path("./data/sample-test/").mkdir(parents=True, exist_ok=True) + for k, v in samples.items(): for i in range(len(v)): - del v[i]['proms'] - del v[i]['resps'] - print(f'{k}:', v) + for j in tqdm(range(len(v[i]['proms'])), desc="Decoding..."): + """ + try: + decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" ) + except Exception as e: + print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e)) + try: + decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" ) + except Exception as e: + print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e)) + """ + v[i]['proms'][j] = v[i]['proms'][j].shape + v[i]['resps'][j] = v[i]['resps'][j].shape + + for k, v in samples.items(): + for i in range(len(v)): + print(f'{k}[{i}]:', v[i]) - train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt") + #train_dl.dataset.save_state_dict(cfg.relpath / "train_dataset.pt") elif args.action == "tasks": index = 0 diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 8229a51..ea4531b 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -212,12 +212,12 @@ def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels, metadata=N dummy = False if metadata is None: metadata = dict( - chunk_length=120, + chunk_length= codes.shape[-1], original_length=0, input_db=-12, channels=1, sample_rate=model.sample_rate, - padding=False, + padding=True, dac_version='1.0.0', ) dummy = True diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 55a84e0..0d587d6 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -658,8 +658,11 @@ class Base(nn.Module): else: raise RuntimeError(f'Unknown arch specified: {self.arch_type}') + # Disabling for now, it might be broken + """ if self.config.attention in ["xformers", "auto", "mem_efficient", "math", "flash"]: self.model = ml.replace_attention( self.model, klass=Llama_Attention, target=LlamaAttention, mode=self.config.attention ) + """ self.classifier = nn.Linear(d_model, n_resp_tokens)