diff --git a/README.md b/README.md index 7f991b8..bb12cd9 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,6 @@ An unofficial (toy) implementation of VALL-E, based on the [encodec](https://git ## TODO - [x] AR model for the first quantizer. -- [ ] Built-in decode. +- [x] Audio decoding from tokens. - [ ] NAR model for the rest quantizers. - [ ] Trainers for both models. diff --git a/data/test/test.ar.recon.wav b/data/test/test.ar.recon.wav new file mode 100644 index 0000000..87cf1d8 Binary files /dev/null and b/data/test/test.ar.recon.wav differ diff --git a/vall_e/ar/__init__.py b/vall_e/ar/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vall_e/ar/model.py b/vall_e/ar/model.py index f017ba9..36dcbd1 100644 --- a/vall_e/ar/model.py +++ b/vall_e/ar/model.py @@ -279,13 +279,16 @@ class VALLEAR(nn.Module): return logits @staticmethod - def _prune(l: list[int], stop_token: int | None): + def _prune(l: Tensor, stop_token: int | None): if stop_token is None: return l - n = next((i for i, x in enumerate(l) if x == stop_token), None) - if n is not None: - l = l[:n] - return l + + indices = (l == stop_token).nonzero() + + if len(indices) == 0: + return l + + return l[: indices[0].item()] def generate( self, @@ -313,11 +316,13 @@ class VALLEAR(nn.Module): output_list[i] = torch.cat([output_list[i], oi[None]]) if all(stopped): break - pruned = [self._prune(o.tolist(), stop_token) for o in output_list] + pruned = [self._prune(o, stop_token) for o in output_list] return pruned def example_usage(): + import soundfile + device = "cuda" test_qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device) @@ -348,7 +353,7 @@ def example_usage(): stop_token=eoq, ) - print(test_qnt) + print(out) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) @@ -372,6 +377,12 @@ def example_usage(): print(out) + from ..emb.qnt import decode + + codes = rearrange(out[1], "t -> 1 1 t") + wavs, sr = decode(codes) + soundfile.write("data/test/test.ar.recon.wav", wavs.cpu()[0, 0], sr) + if __name__ == "__main__": example_usage() diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 128c8bf..1cba87e 100644 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -6,6 +6,7 @@ import torch import torchaudio from encodec import EncodecModel from encodec.utils import convert_audio +from torch import Tensor from tqdm import tqdm @@ -18,6 +19,17 @@ def _load_model(device="cuda"): return model +@torch.inference_mode() +def decode(codes: Tensor, device="cuda"): + """ + Args: + codes: (b k t) + """ + assert codes.dim() == 3 + model = _load_model(device) + return model.decode([(codes, None)]), model.sample_rate + + def replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix)