Add audio decoding

This commit is contained in:
enhuiz 2023-01-11 23:51:56 +08:00
parent 2296e2ea3c
commit 43483bb394
5 changed files with 31 additions and 8 deletions

View File

@ -5,6 +5,6 @@ An unofficial (toy) implementation of VALL-E, based on the [encodec](https://git
## TODO
- [x] AR model for the first quantizer.
- [ ] Built-in decode.
- [x] Audio decoding from tokens.
- [ ] NAR model for the rest quantizers.
- [ ] Trainers for both models.

BIN
data/test/test.ar.recon.wav Normal file

Binary file not shown.

0
vall_e/ar/__init__.py Normal file
View File

View File

@ -279,13 +279,16 @@ class VALLEAR(nn.Module):
return logits
@staticmethod
def _prune(l: list[int], stop_token: int | None):
def _prune(l: Tensor, stop_token: int | None):
if stop_token is None:
return l
n = next((i for i, x in enumerate(l) if x == stop_token), None)
if n is not None:
l = l[:n]
return l
indices = (l == stop_token).nonzero()
if len(indices) == 0:
return l
return l[: indices[0].item()]
def generate(
self,
@ -313,11 +316,13 @@ class VALLEAR(nn.Module):
output_list[i] = torch.cat([output_list[i], oi[None]])
if all(stopped):
break
pruned = [self._prune(o.tolist(), stop_token) for o in output_list]
pruned = [self._prune(o, stop_token) for o in output_list]
return pruned
def example_usage():
import soundfile
device = "cuda"
test_qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device)
@ -348,7 +353,7 @@ def example_usage():
stop_token=eoq,
)
print(test_qnt)
print(out)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
@ -372,6 +377,12 @@ def example_usage():
print(out)
from ..emb.qnt import decode
codes = rearrange(out[1], "t -> 1 1 t")
wavs, sr = decode(codes)
soundfile.write("data/test/test.ar.recon.wav", wavs.cpu()[0, 0], sr)
if __name__ == "__main__":
example_usage()

View File

@ -6,6 +6,7 @@ import torch
import torchaudio
from encodec import EncodecModel
from encodec.utils import convert_audio
from torch import Tensor
from tqdm import tqdm
@ -18,6 +19,17 @@ def _load_model(device="cuda"):
return model
@torch.inference_mode()
def decode(codes: Tensor, device="cuda"):
"""
Args:
codes: (b k t)
"""
assert codes.dim() == 3
model = _load_model(device)
return model.decode([(codes, None)]), model.sample_rate
def replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)