tweaks
This commit is contained in:
parent
277c759ab1
commit
1e3e1d9315
11
README.md
11
README.md
|
@ -126,6 +126,17 @@ Some additional flags you can pass are:
|
|||
* `--nar-temp`: sampling temperature to use for the NAR pass.
|
||||
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
||||
|
||||
|
||||
## To-Do
|
||||
|
||||
* properly pass in `modules` names to `weight_quantization` and `activation_quantization`.
|
||||
|
||||
* fix `quit` hanging when using distributed training.
|
||||
|
||||
* train and release a model.
|
||||
|
||||
* extend to multiple languages (VALL-E X) and extend to SpeechX features.
|
||||
|
||||
## Notice
|
||||
|
||||
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
|
||||
|
|
|
@ -336,7 +336,10 @@ class DeepSpeed:
|
|||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": True
|
||||
}
|
||||
},
|
||||
"zero_quantized_weights": self.use_compression_training,
|
||||
"zero_hpz_partition_size": world_size(),
|
||||
"zero_quantized_gradients": self.use_compression_training,
|
||||
} if self.zero_optimization_level > 0 else None,
|
||||
"comms_logger": {
|
||||
"enabled": False
|
||||
|
@ -439,21 +442,24 @@ class Config(_Config):
|
|||
tmp = Config.from_yaml( config_path )
|
||||
self.__dict__.update(tmp.__dict__)
|
||||
|
||||
def format( self ):
|
||||
self.dataset = Dataset(**self.dataset)
|
||||
self.models = Models(**self.models)
|
||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||
self.evaluation = Evaluation(**self.evaluation)
|
||||
self.trainer = Trainer(**self.trainer)
|
||||
self.inference = Inference(**self.inference)
|
||||
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
|
||||
|
||||
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||
|
||||
|
||||
cfg = Config.from_cli()
|
||||
|
||||
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
|
||||
try:
|
||||
cfg.dataset = Dataset(**cfg.dataset)
|
||||
cfg.models = Models(**cfg.models)
|
||||
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
|
||||
cfg.evaluation = Evaluation(**cfg.evaluation)
|
||||
cfg.trainer = Trainer(**cfg.trainer)
|
||||
cfg.inference = Inference(**cfg.inference)
|
||||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
||||
cfg.format()
|
||||
|
||||
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
|
||||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
try:
|
||||
|
|
|
@ -8,10 +8,21 @@ from .emb import g2p, qnt
|
|||
from .utils import to_device
|
||||
|
||||
from .config import cfg
|
||||
from .export import load_models
|
||||
from .models import get_models
|
||||
from .train import load_engines
|
||||
from .data import get_phone_symmap
|
||||
|
||||
import random
|
||||
|
||||
def trim( qnt, trim_length ):
|
||||
length = qnt.shape[0]
|
||||
start = int(length * random.random())
|
||||
end = start + trim_length
|
||||
if end >= length:
|
||||
start = length - trim_length
|
||||
end = length
|
||||
return qnt[start:end]
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
||||
self.loading = True
|
||||
|
@ -22,6 +33,8 @@ class TTS():
|
|||
|
||||
if config:
|
||||
cfg.load_yaml( config )
|
||||
|
||||
cfg.format()
|
||||
|
||||
if ar_ckpt and nar_ckpt:
|
||||
self.ar_ckpt = ar_ckpt
|
||||
|
@ -45,18 +58,12 @@ class TTS():
|
|||
self.loading = False
|
||||
|
||||
def load_models( self ):
|
||||
print("Loading models...")
|
||||
models = load_models()
|
||||
print("Loaded models")
|
||||
for name in models:
|
||||
model = models[name]
|
||||
engines = load_engines()
|
||||
for name, engine in engines.items():
|
||||
if name[:2] == "ar":
|
||||
self.ar = model.to(self.device, dtype=torch.float32)
|
||||
self.symmap = self.ar.phone_symmap
|
||||
self.ar = engine.module.to(self.device)
|
||||
elif name[:3] == "nar":
|
||||
self.nar = model.to(self.device, dtype=torch.float32)
|
||||
else:
|
||||
print("Unknown:", name)
|
||||
self.nar = engine.module.to(self.device)
|
||||
|
||||
def encode_text( self, text, lang_marker="en" ):
|
||||
text = g2p.encode(text)
|
||||
|
@ -66,7 +73,10 @@ class TTS():
|
|||
|
||||
def encode_audio( self, path ):
|
||||
enc = qnt.encode_from_file( path )
|
||||
return enc[0].t().to(torch.int16)
|
||||
res = enc[0].t().to(torch.int16)
|
||||
if trim:
|
||||
res = trim( res, int( 75 * cfg.dataset.duration_range[1] ) )
|
||||
return res
|
||||
|
||||
|
||||
def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
|
||||
|
|
Loading…
Reference in New Issue
Block a user