Add audio decoding
This commit is contained in:
parent
2296e2ea3c
commit
43483bb394
@ -5,6 +5,6 @@ An unofficial (toy) implementation of VALL-E, based on the [encodec](https://git
|
||||
## TODO
|
||||
|
||||
- [x] AR model for the first quantizer.
|
||||
- [ ] Built-in decode.
|
||||
- [x] Audio decoding from tokens.
|
||||
- [ ] NAR model for the rest quantizers.
|
||||
- [ ] Trainers for both models.
|
||||
|
||||
BIN
data/test/test.ar.recon.wav
Normal file
BIN
data/test/test.ar.recon.wav
Normal file
Binary file not shown.
0
vall_e/ar/__init__.py
Normal file
0
vall_e/ar/__init__.py
Normal file
@ -279,13 +279,16 @@ class VALLEAR(nn.Module):
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _prune(l: list[int], stop_token: int | None):
|
||||
def _prune(l: Tensor, stop_token: int | None):
|
||||
if stop_token is None:
|
||||
return l
|
||||
n = next((i for i, x in enumerate(l) if x == stop_token), None)
|
||||
if n is not None:
|
||||
l = l[:n]
|
||||
return l
|
||||
|
||||
indices = (l == stop_token).nonzero()
|
||||
|
||||
if len(indices) == 0:
|
||||
return l
|
||||
|
||||
return l[: indices[0].item()]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
@ -313,11 +316,13 @@ class VALLEAR(nn.Module):
|
||||
output_list[i] = torch.cat([output_list[i], oi[None]])
|
||||
if all(stopped):
|
||||
break
|
||||
pruned = [self._prune(o.tolist(), stop_token) for o in output_list]
|
||||
pruned = [self._prune(o, stop_token) for o in output_list]
|
||||
return pruned
|
||||
|
||||
|
||||
def example_usage():
|
||||
import soundfile
|
||||
|
||||
device = "cuda"
|
||||
|
||||
test_qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device)
|
||||
@ -348,7 +353,7 @@ def example_usage():
|
||||
stop_token=eoq,
|
||||
)
|
||||
|
||||
print(test_qnt)
|
||||
print(out)
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
|
||||
@ -372,6 +377,12 @@ def example_usage():
|
||||
|
||||
print(out)
|
||||
|
||||
from ..emb.qnt import decode
|
||||
|
||||
codes = rearrange(out[1], "t -> 1 1 t")
|
||||
wavs, sr = decode(codes)
|
||||
soundfile.write("data/test/test.ar.recon.wav", wavs.cpu()[0, 0], sr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
import torchaudio
|
||||
from encodec import EncodecModel
|
||||
from encodec.utils import convert_audio
|
||||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@ -18,6 +19,17 @@ def _load_model(device="cuda"):
|
||||
return model
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda"):
|
||||
"""
|
||||
Args:
|
||||
codes: (b k t)
|
||||
"""
|
||||
assert codes.dim() == 3
|
||||
model = _load_model(device)
|
||||
return model.decode([(codes, None)]), model.sample_rate
|
||||
|
||||
|
||||
def replace_file_extension(path, suffix):
|
||||
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user