diff --git a/README.md b/README.md
index 5bd7be7..8ed2ed6 100755
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
* [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements):
- DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed.
- - If your config YAML has the training backend set to [`deepspeed`], you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
+ - If your config YAML has the training backend set to `deepspeed`, you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
diff --git a/vall_e/inference.py b/vall_e/inference.py
index dc754cc..5dbd2d2 100755
--- a/vall_e/inference.py
+++ b/vall_e/inference.py
@@ -13,7 +13,7 @@ from .utils import to_device
from .config import cfg
from .models import get_models
from .train import load_engines
-from .data import get_phone_symmap
+from .data import get_phone_symmap, _load_quants
class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
@@ -82,7 +82,12 @@ class TTS():
return text
content = g2p.encode(text, language=language)
- #phones = [""] + [ " " if not p else p for p in content ] + [""]
+ # ick
+ try:
+ phones = [""] + [ " " if not p else p for p in content ] + [""]
+ return torch.tensor([*map(self.symmap.get, phones)])
+ except Exception as e:
+ pass
phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
@@ -96,10 +101,10 @@ class TTS():
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
- res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
-
+ res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
+
if should_trim:
- res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
+ res = trim( res, int( 75 * cfg.dataset.prompt_duration ) )
return res