I think I fixed a bug?
This commit is contained in:
parent
f3fbed5ffd
commit
6455a2f9d7
|
@ -10,7 +10,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
|
|||
|
||||
* [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements):
|
||||
- DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed.
|
||||
- If your config YAML has the training backend set to [`deepspeed`], you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
|
||||
- If your config YAML has the training backend set to `deepspeed`, you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
|
||||
|
||||
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
|
||||
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
|
||||
|
|
|
@ -13,7 +13,7 @@ from .utils import to_device
|
|||
from .config import cfg
|
||||
from .models import get_models
|
||||
from .train import load_engines
|
||||
from .data import get_phone_symmap
|
||||
from .data import get_phone_symmap, _load_quants
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
||||
|
@ -82,7 +82,12 @@ class TTS():
|
|||
return text
|
||||
|
||||
content = g2p.encode(text, language=language)
|
||||
#phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
||||
# ick
|
||||
try:
|
||||
phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
||||
return torch.tensor([*map(self.symmap.get, phones)])
|
||||
except Exception as e:
|
||||
pass
|
||||
phones = [ " " if not p else p for p in content ]
|
||||
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
||||
|
||||
|
@ -96,10 +101,10 @@ class TTS():
|
|||
paths = [ Path(p) for p in paths.split(";") ]
|
||||
|
||||
# merge inputs
|
||||
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
|
||||
|
||||
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
|
||||
|
||||
if should_trim:
|
||||
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
|
||||
res = trim( res, int( 75 * cfg.dataset.prompt_duration ) )
|
||||
|
||||
return res
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user