diff --git a/data/config.yaml b/data/config.yaml index 3c70cd0..bf08028 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,7 +1,14 @@ dataset: - training: [] + training: [ + # "./training/valle/data/LibriTTS/994/", + ] - validation: [] + validation: [ + # "./training/valle/data/Validation/1188/", + ] + noise: [ + # "./training/valle/data/Other/noise/", + ] speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" @@ -12,7 +19,7 @@ dataset: workers: 4 cache: True - phones_range: [4, 256] + phones_range: [4, 512] duration_range: [1.0, 16.0] random_utterance: 1.0 @@ -20,9 +27,11 @@ dataset: prompt_duration: 3.0 sample_type: speaker - tasks_list: ["tts"] # do NOT change this until you're ready to train for SpeechX tasks # ["tts", "tts", "ns", "sr", "tse", "tts", "tts"] + + tasks_list: ["tts"] # ["tts", "ns", "sr", "tse", "cse", "nse"] models: + _max_levels: 8 _models: - name: "ar" size: "full" @@ -38,13 +47,14 @@ models: tasks: 8 arch_type: "retnet" + hyperparameters: batch_size: 16 - gradient_accumulation_steps: 2 + gradient_accumulation_steps: 4 gradient_clipping: 100 - optimizer: Adamw - learning_rate: 1.0e-5 + optimizer: AdamW + learning_rate: 1.0e-4 scheduler_type: "" #scheduler_type: OneCycle @@ -66,13 +76,13 @@ hyperparameters: # decay_mom_rate: 0.0 evaluation: - batch_size: 32 + batch_size: 16 frequency: 500 - size: 32 + size: 16 steps: 300 - ar_temperature: 1.0 - nar_temperature: 0.2 + ar_temperature: 0.95 + nar_temperature: 0.25 trainer: iterations: 1_000_000 @@ -106,4 +116,7 @@ inference: normalize: False # do NOT change this unless you know exactly what you are doing. bitsandbytes: - enabled: false + enabled: False + injects: True + linear: True + embedding: True diff --git a/vall_e/config.py b/vall_e/config.py index b3039e3..f2cf356 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -196,6 +196,8 @@ class Model: @dataclass() class Models: + _max_levels: int = 0 + _models: list[Model] = field(default_factory=lambda: [ Model(name="ar", resp_levels=1, prom_levels=8, tasks=1), Model(name="nar", resp_levels=7, prom_levels=8, tasks=1), @@ -232,6 +234,10 @@ class Models: for model in self._models: tasks = max(tasks, model.tasks) return tasks + + @property + def max_levels(self): + return self._max_levels if self._max_levels > 0 else self.prom_levels @dataclass() class Hyperparameters: @@ -261,7 +267,8 @@ class DeepSpeed: use_compression_training: bool = False compression_bits: int = 8 - def get_ds_cfg(self, model): + @cached_property + def ds_cfg(self): scheduler_params = {} for k in cfg.hyperparameters.scheduler_params: scheduler_params[k] = cfg.hyperparameters.scheduler_params[k] @@ -277,7 +284,7 @@ class DeepSpeed: "params": { "lr": cfg.hyperparameters.learning_rate, } - }, + } if not cfg.hyperparameters.optimizer.endswith("-torch") else None, "scheduler": { "type": cfg.hyperparameters.scheduler_type, "params": scheduler_params, @@ -351,8 +358,8 @@ class DeepSpeed: for k in null_keys: del ds_cfg[k] - if os.path.exists("./config/ds_config.json"): - ds_cfg.update(json.load(open("./config/ds_config.json", "r", encoding="utf-8"))) + if os.path.exists("./data/ds_config.json"): + ds_cfg.update(json.load(open("./data/ds_config.json", "r", encoding="utf-8"))) return ds_cfg @@ -404,8 +411,8 @@ class BitsAndBytes: enabled: bool = False injects: bool = False - linear: bool = False - embedding: bool = False + linear: bool = True + embedding: bool = True @dataclass() class Config(_Config): diff --git a/vall_e/data.py b/vall_e/data.py index 9d97207..be81d6d 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -63,11 +63,8 @@ def _get_quant_path(path): def _get_phone_path(path): return _replace_file_extension(path, ".phn.txt") - def _load_quants(path) -> Tensor: - path = _get_quant_path(path) - return torch.load(path)[0][:cfg.models.prom_levels, :].t().to(torch.int16) - + return torch.load(path)[0][:, :].t().to(torch.int16) @cache def _get_phones(path, lang_marker="en"): @@ -215,12 +212,12 @@ class Dataset(_Dataset): def _get_task_symmap(self): return get_task_symmap() - def get_task_token( self, token ): + def get_task_token( self, token, levels=cfg.models.max_levels ): if not hasattr(self, "task_symmap"): self.task_symmap = self._get_task_symmap() - return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(cfg.models.prom_levels) ]]).to(dtype=torch.int16) + return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16) - def sample_noise(self): + def sample_noise(self): paths = [] for data_dir in cfg.dataset.noise: paths.extend(data_dir.rglob("*.qnt.pt")) @@ -228,7 +225,7 @@ class Dataset(_Dataset): if False and cfg.dataset.use_hdf5: key = f'/noise/{_get_hdf5_path(path)}' - qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) + qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: qnt = _load_quants(path) return qnt @@ -260,7 +257,7 @@ class Dataset(_Dataset): path = random.choice(choices) if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) - qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) + qnt = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: qnt = _load_quants(path) @@ -293,7 +290,7 @@ class Dataset(_Dataset): if cfg.dataset.use_hdf5: key = _get_hdf5_path(path) text = torch.from_numpy(cfg.hdf5[key]["text"][:]).to(self.text_dtype) - resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) + resps = torch.from_numpy(cfg.hdf5[key]["audio"][:, :]).to(torch.int16) else: text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) resps = _load_quants(path) @@ -316,7 +313,7 @@ class Dataset(_Dataset): # extend the noise to fill the target audio noise = repeat_extend_audio(noise, resps.shape[0]) # create the input prompt by merging the target audio with the noise - proms = merge_audio(resps, noise, scale=[1, noise_scale], device="cpu") + proms = merge_audio( resps, noise, scale=[1, noise_scale], device="cpu" ) # set the target to just be the noise if if task == "sr": resps = noise @@ -358,7 +355,7 @@ class Dataset(_Dataset): if cfg.dataset.use_hdf5: texts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["text"][:]).to(self.text_dtype) for path in sampled ] - qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :cfg.models.prom_levels]).to(torch.int16) for path in sampled ] + qnts = [ torch.from_numpy(cfg.hdf5[_get_hdf5_path(path)]["audio"][:, :]).to(torch.int16) for path in sampled ] else: texts = [ torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype) for path in sampled ] qnts = [ _load_quants(path) for path in sampled ] @@ -394,15 +391,15 @@ class Dataset(_Dataset): # it might be better to extend the noise to the sum of the pre+mid+post or pre+edit+post to keep the noise truly coherent # but it's noise, it's supposed to be random - def noise_proms( proms ): + def noise_proms( p ): # ignore if we turned it off - if proms is None: + if p is None: return None # extend the noise to fill the target audio - n = repeat_extend_audio(noise, proms.shape[0]) + n = repeat_extend_audio(noise, p.shape[0]) # merge the noise over the utterance - return merge_audio(proms, n, scale=[1, noise_scale], device="cpu") + return merge_audio(p, n, scale=[1, noise_scale], device="cpu") # apply noise to all pieces pre_prom = noise_proms( pre_prom ) @@ -426,6 +423,8 @@ class Dataset(_Dataset): [ edit_prom ] + ([ post_prom ] if post_prom is not None else []) ) + else: + raise f'Undefined task: {task}' """ # emulate SVC @@ -450,6 +449,10 @@ class Dataset(_Dataset): text = torch.tensor([1, 2]).to(self.text_dtype) """ + # trim to fit to requested prom/resps levels + proms = proms[:, :cfg.models.prom_levels] + resps = resps[:, :cfg.models.prom_levels] + return dict( index=index, diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 40e115b..90164ab 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -21,16 +21,16 @@ except Exception as e: cfg.inference.use_vocos = False @cache -def _load_encodec_model(device="cuda"): +def _load_encodec_model(device="cuda", levels=cfg.models.max_levels): # Instantiate a pretrained EnCodec model assert cfg.sample_rate == 24_000 # too lazy to un-if ladder this shit - if cfg.models.prom_levels == 2: + if levels == 2: bandwidth_id = 1.5 - elif cfg.models.prom_levels == 4: + elif levels == 4: bandwidth_id = 3.0 - elif cfg.models.prom_levels == 8: + elif levels == 8: bandwidth_id = 6.0 model = EncodecModel.encodec_model_24khz().to(device) @@ -43,18 +43,18 @@ def _load_encodec_model(device="cuda"): return model @cache -def _load_vocos_model(device="cuda"): +def _load_vocos_model(device="cuda", levels=cfg.models.max_levels): assert cfg.sample_rate == 24_000 model = Vocos.from_pretrained("charactr/vocos-encodec-24khz") model = model.to(device) # too lazy to un-if ladder this shit - if cfg.models.prom_levels == 2: + if levels == 2: bandwidth_id = 0 - elif cfg.models.prom_levels == 4: + elif levels == 4: bandwidth_id = 1 - elif cfg.models.prom_levels == 8: + elif levels == 8: bandwidth_id = 2 model.bandwidth_id = torch.tensor([bandwidth_id], device=device) @@ -64,11 +64,11 @@ def _load_vocos_model(device="cuda"): return model @cache -def _load_model(device="cuda", vocos=cfg.inference.use_vocos): +def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.models.max_levels): if vocos: - model = _load_vocos_model(device) + model = _load_vocos_model(device, levels=levels) else: - model = _load_encodec_model(device) + model = _load_encodec_model(device, levels=levels) return model @@ -78,7 +78,7 @@ def unload_model(): @torch.inference_mode() -def decode(codes: Tensor, device="cuda"): +def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels): """ Args: codes: (b q t) @@ -94,7 +94,7 @@ def decode(codes: Tensor, device="cuda"): codes = rearrange(codes, "t q -> 1 q t") assert codes.dim() == 3, f'Requires shape (b q t) but got {codes.shape}' - model = _load_model(device) + model = _load_model(device, levels=levels) # upcast so it won't whine if codes.dtype == torch.int8 or codes.dtype == torch.int16 or codes.dtype == torch.uint8: @@ -115,8 +115,8 @@ def decode(codes: Tensor, device="cuda"): return wav, model.sample_rate # huh -def decode_to_wave(resps: Tensor, device="cuda"): - return decode(resps, device=device) +def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.models.max_levels): + return decode(resps, device=device, levels=levels) def decode_to_file(resps: Tensor, path: Path, device="cuda"): wavs, sr = decode(resps, device=device) @@ -129,14 +129,14 @@ def _replace_file_extension(path, suffix): @torch.inference_mode() -def encode(wav: Tensor, sr: int = 24_000, device="cuda"): +def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.models.max_levels): """ Args: wav: (t) sr: int """ - model = _load_encodec_model(device) + model = _load_encodec_model(device, levels=levels) wav = wav.unsqueeze(0) wav = convert_audio(wav, sr, model.sample_rate, model.channels) wav = wav.to(device) @@ -203,16 +203,16 @@ def repeat_extend_audio( qnt, target ): # merges two quantized audios together # I don't know if this works -def merge_audio( *args, device="cpu", scale=[] ): +def merge_audio( *args, device="cpu", scale=[], levels=cfg.models.max_levels ): qnts = [*args] - decoded = [ decode_to_wave(qnt, device=device)[0] for qnt in qnts ] + decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] if len(scale) == len(decoded): for i in range(len(scale)): decoded[i] = decoded[i] * scale[i] combined = sum(decoded) / len(decoded) - return encode(combined, 24_000, device="cpu")[0].t() + return encode(combined, 24_000, device="cpu", levels=levels)[0].t() def main(): parser = argparse.ArgumentParser() diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 048cc7d..5e5ea42 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -39,6 +39,7 @@ import os from torch import Tensor from torch.distributed import all_reduce from typing import Any, Protocol +from functools import cached_property from .base import TrainFeeder @@ -50,6 +51,10 @@ if not distributed_initialized() and cfg.trainer.backend == "local": # A very naive engine implementation using barebones PyTorch class Engine(): def __init__(self, *args, **kwargs): + if '_cfg' in kwargs: + self._cfg = kwargs['_cfg'] + kwargs.pop("_cfg") + self.module = kwargs['model'].to(cfg.device).to(cfg.trainer.dtype) self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None @@ -137,6 +142,10 @@ class Engine(): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) + @cached_property + def device(self): + return next(self.module.parameters()).device + def forward(self, *args, **kwargs): return self.module.forward(*args, **kwargs) diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 8458807..0ca287c 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -31,7 +31,11 @@ if not distributed_initialized() and cfg.trainer.backend == "deepspeed": class Engine(DeepSpeedEngine): def __init__(self, *args, **kwargs): - kwargs['config'] = cfg.trainer.deepspeed.get_ds_cfg(model=kwargs['model']) + if '_cfg' in kwargs: + self._cfg = kwargs['_cfg'] + kwargs.pop("_cfg") + + kwargs['config'] = cfg.trainer.deepspeed.ds_cfg kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) super().__init__(None, *args, **kwargs) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 696e9c8..46975a1 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -8,10 +8,6 @@ from torch import Tensor from tqdm import trange class AR(Base): - @property - def n_resp_levels(self) -> int: - return cfg.models.ar.resp_levels - @property def causal(self): return True @@ -32,6 +28,14 @@ class AR(Base): def n_prom_levels(self) -> int: return cfg.models.prom_levels + @property + def n_resp_levels(self) -> int: + return cfg.models.ar.resp_levels + + @property + def n_max_levels(self) -> int: + return cfg.models.max_levels + @property def n_tasks(self) -> int: return cfg.models.tasks diff --git a/vall_e/models/base.py b/vall_e/models/base.py index beed6f8..ebb4a73 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -113,6 +113,10 @@ class Base(nn.Module): @property def n_prom_levels(self) -> int: raise NotImplementedError + + @property + def n_max_levels(self) -> int: + raise NotImplementedError @property def n_tasks(self) -> int: diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 757017c..d60a49b 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -7,10 +7,6 @@ from torch import Tensor from tqdm import trange class NAR(Base): - @property - def n_resp_levels(self) -> int: - return cfg.models.nar.resp_levels - @property def causal(self): return False @@ -31,6 +27,14 @@ class NAR(Base): def n_prom_levels(self) -> int: return cfg.models.prom_levels + @property + def n_resp_levels(self) -> int: + return cfg.models.nar.resp_levels + + @property + def n_max_levels(self) -> int: + return cfg.models.max_levels + @property def n_tasks(self) -> int: return cfg.models.tasks diff --git a/vall_e/train.py b/vall_e/train.py index 052986a..b5ff43e 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -23,7 +23,11 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu") _logger = logging.getLogger(__name__) def train_feeder(engine, batch): - engine( text_list=batch["text"], proms_list=batch["proms"], resps_list=batch["resps"] ) + engine( + text_list=batch["text"], + proms_list=[prom[:, :engine._cfg.prom_levels] for prom in batch["proms"]], # reduce the input prompt to the target prom level + resps_list=batch["resps"] + ) losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats") diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index c0c1e75..4292c61 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -88,19 +88,33 @@ def load_engines(): # extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks) if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]: - n_prom_levels, n_prom_tokens, d_model = state['proms_emb.weight'].shape + o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape # copy weights from the dict into the old portion - model.proms_emb.weight.data[:n_prom_levels, :n_prom_tokens, :] = state['proms_emb.weight'].data[:n_prom_levels, :n_prom_tokens, :] + model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :] # copy the full tensors back state['proms_emb.weight'] = model.proms_emb.weight + # extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks) + if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]: + o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape + n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape + + # copy weights from the dict into the old portion + model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :] + # reuse additional levels, probably bad + for n in range(o_resp_tokens, n_resp_tokens): + model.resps_emb.weight.data[n] = model.resps_emb.weight.data[o_resp_tokens-1] + # copy the full tensors back + state['resps_emb.weight'] = model.resps_emb.weight + model.load_state_dict(state, strict=cfg.trainer.strict_loading) engines[name] = Engine( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, + _cfg=model._cfg, ) engines = Engines(engines)