I think I fixed a bug?
This commit is contained in:
parent
f3fbed5ffd
commit
6455a2f9d7
|
@ -10,7 +10,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
|
||||||
|
|
||||||
* [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements):
|
* [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements):
|
||||||
- DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed.
|
- DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed.
|
||||||
- If your config YAML has the training backend set to [`deepspeed`], you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
|
- If your config YAML has the training backend set to `deepspeed`, you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
|
||||||
|
|
||||||
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
|
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
|
||||||
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
|
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
|
||||||
|
|
|
@ -13,7 +13,7 @@ from .utils import to_device
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .models import get_models
|
from .models import get_models
|
||||||
from .train import load_engines
|
from .train import load_engines
|
||||||
from .data import get_phone_symmap
|
from .data import get_phone_symmap, _load_quants
|
||||||
|
|
||||||
class TTS():
|
class TTS():
|
||||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
||||||
|
@ -82,7 +82,12 @@ class TTS():
|
||||||
return text
|
return text
|
||||||
|
|
||||||
content = g2p.encode(text, language=language)
|
content = g2p.encode(text, language=language)
|
||||||
#phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
# ick
|
||||||
|
try:
|
||||||
|
phones = ["<s>"] + [ " " if not p else p for p in content ] + ["</s>"]
|
||||||
|
return torch.tensor([*map(self.symmap.get, phones)])
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
phones = [ " " if not p else p for p in content ]
|
phones = [ " " if not p else p for p in content ]
|
||||||
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
||||||
|
|
||||||
|
@ -96,10 +101,10 @@ class TTS():
|
||||||
paths = [ Path(p) for p in paths.split(";") ]
|
paths = [ Path(p) for p in paths.split(";") ]
|
||||||
|
|
||||||
# merge inputs
|
# merge inputs
|
||||||
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
|
res = torch.cat([qnt.encode_from_file(path)[0][:, :].t().to(torch.int16) for path in paths])
|
||||||
|
|
||||||
if should_trim:
|
if should_trim:
|
||||||
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
|
res = trim( res, int( 75 * cfg.dataset.prompt_duration ) )
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user