tweaks
This commit is contained in:
parent
277c759ab1
commit
1e3e1d9315
11
README.md
11
README.md
|
@ -126,6 +126,17 @@ Some additional flags you can pass are:
|
||||||
* `--nar-temp`: sampling temperature to use for the NAR pass.
|
* `--nar-temp`: sampling temperature to use for the NAR pass.
|
||||||
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
* `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
|
||||||
|
|
||||||
|
|
||||||
|
## To-Do
|
||||||
|
|
||||||
|
* properly pass in `modules` names to `weight_quantization` and `activation_quantization`.
|
||||||
|
|
||||||
|
* fix `quit` hanging when using distributed training.
|
||||||
|
|
||||||
|
* train and release a model.
|
||||||
|
|
||||||
|
* extend to multiple languages (VALL-E X) and extend to SpeechX features.
|
||||||
|
|
||||||
## Notice
|
## Notice
|
||||||
|
|
||||||
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
|
- [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
|
||||||
|
|
|
@ -336,7 +336,10 @@ class DeepSpeed:
|
||||||
"offload_param": {
|
"offload_param": {
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
"pin_memory": True
|
"pin_memory": True
|
||||||
}
|
},
|
||||||
|
"zero_quantized_weights": self.use_compression_training,
|
||||||
|
"zero_hpz_partition_size": world_size(),
|
||||||
|
"zero_quantized_gradients": self.use_compression_training,
|
||||||
} if self.zero_optimization_level > 0 else None,
|
} if self.zero_optimization_level > 0 else None,
|
||||||
"comms_logger": {
|
"comms_logger": {
|
||||||
"enabled": False
|
"enabled": False
|
||||||
|
@ -439,21 +442,24 @@ class Config(_Config):
|
||||||
tmp = Config.from_yaml( config_path )
|
tmp = Config.from_yaml( config_path )
|
||||||
self.__dict__.update(tmp.__dict__)
|
self.__dict__.update(tmp.__dict__)
|
||||||
|
|
||||||
|
def format( self ):
|
||||||
|
self.dataset = Dataset(**self.dataset)
|
||||||
|
self.models = Models(**self.models)
|
||||||
|
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||||
|
self.evaluation = Evaluation(**self.evaluation)
|
||||||
|
self.trainer = Trainer(**self.trainer)
|
||||||
|
self.inference = Inference(**self.inference)
|
||||||
|
self.bitsandbytes = BitsAndBytes(**self.bitsandbytes)
|
||||||
|
|
||||||
|
self.trainer.deepspeed = DeepSpeed(**self.trainer.deepspeed)
|
||||||
|
|
||||||
|
|
||||||
cfg = Config.from_cli()
|
cfg = Config.from_cli()
|
||||||
|
|
||||||
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
|
# OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
|
||||||
try:
|
try:
|
||||||
cfg.dataset = Dataset(**cfg.dataset)
|
cfg.format()
|
||||||
cfg.models = Models(**cfg.models)
|
|
||||||
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
|
|
||||||
cfg.evaluation = Evaluation(**cfg.evaluation)
|
|
||||||
cfg.trainer = Trainer(**cfg.trainer)
|
|
||||||
cfg.inference = Inference(**cfg.inference)
|
|
||||||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
|
||||||
|
|
||||||
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
|
|
||||||
|
|
||||||
# cached_property stopped working...
|
# cached_property stopped working...
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -8,10 +8,21 @@ from .emb import g2p, qnt
|
||||||
from .utils import to_device
|
from .utils import to_device
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .export import load_models
|
|
||||||
from .models import get_models
|
from .models import get_models
|
||||||
|
from .train import load_engines
|
||||||
from .data import get_phone_symmap
|
from .data import get_phone_symmap
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
def trim( qnt, trim_length ):
|
||||||
|
length = qnt.shape[0]
|
||||||
|
start = int(length * random.random())
|
||||||
|
end = start + trim_length
|
||||||
|
if end >= length:
|
||||||
|
start = length - trim_length
|
||||||
|
end = length
|
||||||
|
return qnt[start:end]
|
||||||
|
|
||||||
class TTS():
|
class TTS():
|
||||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
|
||||||
self.loading = True
|
self.loading = True
|
||||||
|
@ -22,6 +33,8 @@ class TTS():
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
|
||||||
|
cfg.format()
|
||||||
|
|
||||||
if ar_ckpt and nar_ckpt:
|
if ar_ckpt and nar_ckpt:
|
||||||
self.ar_ckpt = ar_ckpt
|
self.ar_ckpt = ar_ckpt
|
||||||
|
@ -45,18 +58,12 @@ class TTS():
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
|
||||||
def load_models( self ):
|
def load_models( self ):
|
||||||
print("Loading models...")
|
engines = load_engines()
|
||||||
models = load_models()
|
for name, engine in engines.items():
|
||||||
print("Loaded models")
|
|
||||||
for name in models:
|
|
||||||
model = models[name]
|
|
||||||
if name[:2] == "ar":
|
if name[:2] == "ar":
|
||||||
self.ar = model.to(self.device, dtype=torch.float32)
|
self.ar = engine.module.to(self.device)
|
||||||
self.symmap = self.ar.phone_symmap
|
|
||||||
elif name[:3] == "nar":
|
elif name[:3] == "nar":
|
||||||
self.nar = model.to(self.device, dtype=torch.float32)
|
self.nar = engine.module.to(self.device)
|
||||||
else:
|
|
||||||
print("Unknown:", name)
|
|
||||||
|
|
||||||
def encode_text( self, text, lang_marker="en" ):
|
def encode_text( self, text, lang_marker="en" ):
|
||||||
text = g2p.encode(text)
|
text = g2p.encode(text)
|
||||||
|
@ -66,7 +73,10 @@ class TTS():
|
||||||
|
|
||||||
def encode_audio( self, path ):
|
def encode_audio( self, path ):
|
||||||
enc = qnt.encode_from_file( path )
|
enc = qnt.encode_from_file( path )
|
||||||
return enc[0].t().to(torch.int16)
|
res = enc[0].t().to(torch.int16)
|
||||||
|
if trim:
|
||||||
|
res = trim( res, int( 75 * cfg.dataset.duration_range[1] ) )
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
|
def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user