diff --git a/README.md b/README.md index 5bd7be7..8ed2ed6 100755 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), * [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements): - DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed. - - If your config YAML has the training backend set to [`deepspeed`], you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package. + - If your config YAML has the training backend set to `deepspeed`, you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package. * [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/): - For phonemizing text, this repo requires `espeak`/`espeak-ng` installed. diff --git a/vall_e/inference.py b/vall_e/inference.py index dc754cc..5dbd2d2 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -13,7 +13,7 @@ from .utils import to_device from .config import cfg from .models import get_models from .train import load_engines -from .data import get_phone_symmap +from .data import get_phone_symmap, _load_quants class TTS(): def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ): @@ -82,7 +82,12 @@ class TTS(): return text content = g2p.encode(text, language=language) - #phones = [""] + [ " " if not p else p for p in content ] + [""] + # ick + try: + phones = [""] + [ " " if not p else p for p in content ] + [""] + return torch.tensor([*map(self.symmap.get, phones)]) + except Exception as e: + pass phones = [ " " if not p else p for p in content ] return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ]) @@ -96,10 +101,10 @@ class TTS(): paths = [ Path(p) for p in paths.split(";") ] # merge inputs - res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths]) - + res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths]) + if should_trim: - res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) ) + res = trim( res, int( 75 * cfg.dataset.prompt_duration ) ) return res