I think I fixed a bug?

This commit is contained in:
mrq 2023-08-24 23:33:36 -05:00
parent f3fbed5ffd
commit 6455a2f9d7
2 changed files with 11 additions and 6 deletions

View File

@ -10,7 +10,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
* [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements):
- DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed.
- If your config YAML has the training backend set to [`deepspeed`], you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
- If your config YAML has the training backend set to `deepspeed`, you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.

View File

@ -13,7 +13,7 @@ from .utils import to_device
from .config import cfg
from .models import get_models
from .train import load_engines
from .data import get_phone_symmap
from .data import get_phone_symmap, _load_quants
class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
@ -82,7 +82,12 @@ class TTS():
return text
content = g2p.encode(text, language=language)
#phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
# ick
try:
phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
return torch.tensor([*map(self.symmap.get, phones)])
except Exception as e:
pass
phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
@ -96,10 +101,10 @@ class TTS():
paths = [ Path(p) for p in paths.split(";") ]
# merge inputs
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
if should_trim:
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
res = trim( res, int( 75 * cfg.dataset.prompt_duration ) )
return res