diff --git a/data/config.yaml b/data/config.yaml old mode 100755 new mode 100644 index c2ea978..a46cec6 --- a/data/config.yaml +++ b/data/config.yaml @@ -1,140 +1,143 @@ -models: -- name: "ar+nar" - size: "full" - resp_levels: 8 - prom_levels: 8 - tasks: 8 - langs: 2 - tones: 1 - arch_type: llama - training: True - version: 4 - attention: flash_attention_2 - dropout: 0.1 - - loss_factors: - text: 0.1 - resp: 1.0 - -hyperparameters: - autotune: False - autotune_params: - start_profile_step: 1 - end_profile_step: 50 - num_tuning_micro_batch_sizes: 8 - - batch_size: 16 - gradient_accumulation_steps: 4 - gradient_clipping: 1.0 - warmup_steps: 100 - - optimizer: Prodigy - learning_rate: 1.0 - torch_optimizer: True - - scheduler: "" # ScheduleFree - torch_scheduler: True - -evaluation: - batch_size: 8 - frequency: 5000 - size: 8 - - steps: 500 - ar_temperature: 0.95 - nar_temperature: 0.25 - load_disabled_engines: True - -trainer: - #no_logger: True - ddp: False - #check_for_oom: False - iterations: 1_000_000 - - save_tag: step - save_on_oom: True - save_on_quit: True - save_frequency: 250 - export_on_save: True - - keep_last_checkpoints: 4 - - aggressive_optimizations: False - load_disabled_engines: False - - #load_state_dict: True - strict_loading: False - #load_tag: "9500" - #load_states: False - #restart_step_count: True - - gc_mode: None # "global_step" - - weight_dtype: float32 # float16 or bfloat16 - amp: False - - backend: deepspeed - deepspeed: - inferencing: True - zero_optimization_level: 0 - use_compression_training: False - - amp: False - - activation_checkpointing: True - - load_webui: False - -inference: - backend: deepspeed - audio_backend: "dac" - normalize: False - - weight_dtype: float32 # float16 or bfloat16 - amp: False - -optimizations: - injects: False - replace: True - - linear: False - embedding: False - optimizers: True - - bitsandbytes: False - dadaptation: False - bitnet: False - fp8: False - -experimental: True # practically required now it seems - -dataset: - speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" - speaker_group_getter: "lambda p: f'{p.parts[-3]}'" - speaker_languages: - ja: [] - - use_hdf5: True - use_metadata: True - hdf5_flag: r - validate: True - - workers: 2 - cache: True - - duration_range: [3.0, 5.0] - - random_utterance: 1.0 - max_prompts: 1 - prompt_duration: 3.0 - - max_resps: 1 - p_resp_append: 0.25 - - sample_type: path # speaker - - tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"] - - training: [] - validation: [] - noise: [] +sample_rate: 24_000 # 44_000 for dac +audio_backend: "vocos" # or dac + +models: +- name: "ar+nar" + size: "full" + resp_levels: 8 + prom_levels: 8 + tasks: 8 + langs: 2 + tones: 1 + arch_type: llama + training: True + version: 4 + attention: flash_attention_2 + dropout: 0.1 + + loss_factors: + text: 0.1 + prom: 0.0 + resp: 1.0 + +hyperparameters: + autotune: False + autotune_params: + start_profile_step: 1 + end_profile_step: 50 + num_tuning_micro_batch_sizes: 8 + + batch_size: 16 + gradient_accumulation_steps: 4 + gradient_clipping: 1.0 + warmup_steps: 100 + + optimizer: Prodigy + learning_rate: 1.0 + torch_optimizer: True + + scheduler: "" # ScheduleFree + torch_scheduler: True + +evaluation: + batch_size: 8 + frequency: 5000 + size: 8 + + steps: 500 + ar_temperature: 0.95 + nar_temperature: 0.25 + load_disabled_engines: True + +trainer: + #no_logger: True + ddp: False + #check_for_oom: False + iterations: 1_000_000 + + save_tag: step + save_on_oom: True + save_on_quit: True + save_frequency: 250 + export_on_save: True + + keep_last_checkpoints: 4 + + aggressive_optimizations: False + load_disabled_engines: False + + #load_state_dict: True + strict_loading: False + #load_tag: "9500" + #load_states: False + #restart_step_count: True + + gc_mode: None # "global_step" + + weight_dtype: float32 # float16 or bfloat16 + amp: False + + backend: deepspeed + deepspeed: + inferencing: True + zero_optimization_level: 0 + use_compression_training: False + + amp: False + + activation_checkpointing: True + + load_webui: False + +inference: + backend: deepspeed + normalize: False + + weight_dtype: float32 # float16 or bfloat16 + amp: False + +optimizations: + injects: False + replace: True + + linear: False + embedding: False + optimizers: True + + bitsandbytes: False + dadaptation: False + bitnet: False + fp8: False + +experimental: True # practically required now it seems + +dataset: + speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" + speaker_group_getter: "lambda p: f'{p.parts[-3]}'" + speaker_languages: + ja: [] + + use_hdf5: True + use_metadata: True + hdf5_flag: r + validate: True + + workers: 2 + cache: True + + duration_range: [3.0, 5.0] + + random_utterance: 1.0 + max_prompts: 1 + prompt_duration: 3.0 + + max_resps: 1 + p_resp_append: 0.25 + + sample_type: path # speaker + + tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"] + + training: [] + validation: [] + noise: [] \ No newline at end of file diff --git a/data/qnt.enc b/data/qnt.enc new file mode 100644 index 0000000..8da3c31 Binary files /dev/null and b/data/qnt.enc differ diff --git a/data/qnt.pt b/data/qnt.pt deleted file mode 100755 index a99b73b..0000000 Binary files a/data/qnt.pt and /dev/null differ diff --git a/scripts/process_libritts.py b/scripts/process_libritts.py index 69b1711..53ce053 100755 --- a/scripts/process_libritts.py +++ b/scripts/process_libritts.py @@ -95,7 +95,7 @@ for paths in tqdm(txts, desc="Processing..."): np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { "codes": qnt.cpu().numpy().astype(np.uint16), "metadata": { - "original_length": qnt.shape[0] / 75.0, + "original_length": qnt.shape[-1] / 75.0, "sample_rate": cfg.sample_rate, "text": text.strip(), diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 770360e..57e5af1 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -340,12 +340,10 @@ def example_usage(): return torch.tensor( cfg.tokenizer.encode(content) ) def _load_quants(path) -> Tensor: - if cfg.audio_backend == "dac": - qnt = np.load(f'{path}.dac', allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) - return torch.load(f'{path}.pt')[0][:, :cfg.model.prom_levels].t().to(torch.int16) + qnt = np.load(path, allow_pickle=True)[()] + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) - qnt = _load_quants("./data/qnt") + qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") text_list = [ diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f0a019e..6a86ca3 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -944,8 +944,8 @@ class Base(nn.Module): "logits": [], } - info[name]["targets"].append( input ) - info[name]["logits"].append( logit ) + info[name]["targets"].append( input.contiguous() ) + info[name]["logits"].append( logit.contiguous() ) for name, batch in info.items(): loss_factor = self.loss_factor(name) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py new file mode 100644 index 0000000..606a839 --- /dev/null +++ b/vall_e/models/experimental.py @@ -0,0 +1,404 @@ +from ..config import cfg + +import torch +from torch.nn.utils.rnn import pad_sequence +from torch import Tensor +from torch.nn import CrossEntropyLoss + +import random +import math + +from einops import rearrange +from tqdm import trange + +AVAILABLE_ARCHES = [] + +try: + from transformers import LlamaForCausalLM, LlamaConfig + AVAILABLE_ARCHES.append("llama") +except Exception as e: + print("Error importing `llama` arch:", e) + pass + +try: + from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig + AVAILABLE_ARCHES.append("mamba") +except Exception as e: + print("Error importing `mamba` arch:", e) + pass + +def _create_mask(l, device): + seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) + stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) + return (seq < stop).float() # (b t) + +def list_to_tensor(x_list: list[Tensor]): + l = list(map(len, x_list)) + x = pad_sequence(x_list).t() + + m = _create_mask(l, x_list[0].device) + m = m.to(x) + return x, m + +# fold into a typical LLM sequence (one embedding rather than split embeddings) +def fold( + text_list = [], + proms_list = [], + resp_list = [], + + ignore_index = None, + + sep = 3, + stop = 3, + + text_tokens = 256, + audio_tokens = 1024, + audio_rvq_levels = cfg.model.prom_levels +): + + device = text_list[0].device + batch_size = len(text_list) + input_ids = [ [] for _ in range(batch_size) ] + + offset = 0 + + sep = torch.Tensor([ sep ]) + stop = torch.Tensor([ stop ]) + + for i, text in enumerate(text_list): + seq = text.to("cpu", dtype=torch.int64) + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + offset = text_tokens + for i, prom in enumerate(proms_list): + if ignore_index is not None: + seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64) + else: + seq = prom.flatten().to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + offset = text_tokens + (audio_tokens * audio_rvq_levels) + for i, resp in enumerate(resp_list): + seq = resp.flatten().to("cpu", dtype=torch.int64) + for idx, token in enumerate( seq ): + token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + input_ids[i].append( seq ) + input_ids[i].append( stop ) + + for i, batch in enumerate(input_ids): + input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64) + + return list_to_tensor(input_ids) + +# unfold from one unified token ID space to separate token spaces +def unfold( + input_ids, + + sep = 3, + stop = 3, + + text_tokens = 256, + audio_tokens = 1024, + audio_rvq_levels = cfg.model.prom_levels +): + device = input_ids.device + batch_size = input_ids.shape[0] + + text_list = [ [] for _ in range(batch_size) ] + prom_list = [ [] for _ in range(batch_size) ] + resp_list = [ [] for _ in range(batch_size) ] + + for i, batch in enumerate( input_ids ): + for idx, token in enumerate( batch ): + id = token.item() + if id == sep or id == stop: + continue + + if 0 <= id and id < text_tokens: + text_list[i].append( id ) + elif text_tokens <= id and id < text_tokens + (audio_tokens * audio_rvq_levels): + prom_list[i].append( (id - text_tokens) % audio_tokens ) + elif text_tokens + (audio_tokens * audio_rvq_levels) <= id: + resp_list[i].append( (id - text_tokens) % audio_tokens ) + + prom_len = len(prom_list[i]) + if prom_len % audio_rvq_levels == 0 and False: + prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t() + else: + bins = [ [] for _ in range(audio_rvq_levels) ] + for pos in range( prom_len ): + rvq = pos % audio_rvq_levels + bins[rvq].append( prom_list[i][pos] ) + nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels + bins = bins[:nearest] + prom_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) + + + resp_len = len(resp_list[i]) + if len(resp_list[i]) % audio_rvq_levels == 0 and False: + resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() + else: + bins = [ [] for _ in range(audio_rvq_levels) ] + for pos in range( resp_len ): + rvq = pos % audio_rvq_levels + bins[rvq].append( resp_list[i][pos] ) + nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels + bins = bins[:nearest] + resp_list[i] = torch.Tensor(bins).t().to(dtype=torch.int64) + + text_list[i] = torch.Tensor( text_list[i] ).to(dtype=torch.int64) + + return dict( + text_list=text_list, + prom_list=prom_list, + resp_list=resp_list + ) + + +SELECTED_ARCH = cfg.model.arch_type +if SELECTED_ARCH not in AVAILABLE_ARCHES: + raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available") + +if SELECTED_ARCH == "mamba": + LlmArchClass = MambaLMHeadModel +elif SELECTED_ARCH == "llama": + LlmArchClass = LlamaForCausalLM +else: + raise ValueError(f"Requesting arch `{SELECTED_ARCH}` but not available") + +class Model(LlmArchClass): + def __init__( + self, + d_model=1024, + n_layers=12, + n_heads=16, + p_dropout=0.1, + + attention_backend=None, + activation_checkpointing=True, + ): + + if SELECTED_ARCH == "llama": + super().__init__(config=LlamaConfig( + vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1, + hidden_size=d_model, + max_position_embeddings=cfg.dataset.frames_per_second * cfg.model.prom_levels * 60, # max-length of 60 seconds + intermediate_size=d_model*4, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + attention_dropout=p_dropout, + num_key_value_heads=n_heads, + sliding_window=cfg.dataset.frames_per_second * cfg.model.prom_levels * 12, + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + attn_implementation=attention_backend, + )) + + if activation_checkpointing: + self.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) + elif SELECTED_ARCH == "mamba": + super().__init__(config=MambaConfig( + vocab_size=256 + (1024 * cfg.model.prom_levels) + (1024 * cfg.model.prom_levels) + 1, + d_model=d_model, + n_layer=n_layers*2, + #ssm_cfg={"layer": "Mamba2"}, + )) + + + def forward( + self, + *args, + **kwargs, + ): + output = super().forward(*args, **kwargs) + + if SELECTED_ARCH == "llama": + if output.loss is not None: + self.loss = dict( + nll = output.loss, + ) + elif SELECTED_ARCH == "mamba": + if "labels" in kwargs: + logits = output.logits + labels = kwargs.pop("labels") + + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, shift_logits.size(-1)) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + self.loss = dict( + nll = loss, + ) + + return output + +def example_usage(): + cfg.trainer.backend = "local" + cfg.hyperparameters.gradient_accumulation_steps = 1 + if cfg.audio_backend == "dac": + cfg.sample_rate = 44_000 + + from functools import partial + from einops import repeat + from tqdm import tqdm + + from ..emb.qnt import decode_to_file, unload_model + from ..engines import Engine + from ..utils import wrapper as ml + + import numpy as np + import re + + device = "cuda" + + def tokenize(content): + return torch.tensor( cfg.tokenizer.encode(content) ) + + def _load_quants(path) -> Tensor: + qnt = np.load(path, allow_pickle=True)[()] + return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16) + + qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") + + + text_list = [ + tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), + #tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), + ] + proms_list = [ + qnt[:cfg.dataset.frames_per_second, :].to(device), + #qnt[:cfg.dataset.frames_per_second, :].to(device), + ] + resps_list = [ + qnt[:, :].to(device), + #qnt[cfg.dataset.frames_per_second:, :].to(device), + ] + + text_list = text_list[:1] + proms_list = proms_list[:1] + resps_list = resps_list[:1] + + input_ids, attention_mask = fold(text_list, proms_list, resps_list) + target_ids, target_attention_mask = fold(text_list, proms_list, resps_list, ignore_index=-100) + prefix_input_ids, prefix_attention_mask = fold(text_list, proms_list) + + kwargs = {} + model = Model(**kwargs).to(device) + steps = 50 + + optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy" + scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else "" + learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None + + if cfg.optimizations.dadaptation: + # do not combine the two + if scheduler == "schedulefree": + scheduler = "" + + learning_rate = 1.0 + + if optimizer == "prodigy": + if learning_rate is None: + learning_rate = 1.0 + + optimizer = ml.Prodigy + elif optimizer == "adagrad": + if learning_rate is None: + learning_rate = 1.0e-2 + + optimizer = ml.Adagrad + elif optimizer == "adamw": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.AdamW + elif optimizer == "sdg": + if learning_rate is None: + learning_rate = 1.0e-4 + + optimizer = ml.SGD + else: + raise ValueError(f"Unrecognized optimizer: {optimizer}") + + print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) + + optimizer = optimizer(model.parameters(), lr=learning_rate) + + if scheduler == "schedulefree": + if isinstance(optimizer, ml.AdamW): + scheduler = ml.schedulefree.AdamWScheduleFree + elif isinstance(optimizer, ml.SGD): + scheduler = ml.schedulefree.SGDScheduleFree + else: + scheduler = None + + if scheduler is not None: + print("Scheduler:", scheduler) + optimizer = scheduler( model.parameters(), lr = learning_rate ) + + if cfg.optimizations.replace and cfg.optimizations.linear: + model = ml.replace_linear( model ) + + if cfg.optimizations.replace and cfg.optimizations.embedding: + model = ml.replace_embedding( model ) + + engine = Engine(model=model, optimizer=optimizer) + + torch.save( { + 'module': model.state_dict() + }, f"./data/{SELECTED_ARCH}.pth" ) + + print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + + @torch.inference_mode() + def sample( name, steps=cfg.model.prom_levels*cfg.dataset.frames_per_second*60 ): + engine.eval() + if SELECTED_ARCH == "mamba": + output = model.generate(input_ids=prefix_input_ids, cg=True, max_length=steps, eos_token_id=3) + else: + output = model.generate(input_ids=prefix_input_ids, attention_mask=prefix_attention_mask, max_length=steps, eos_token_id=3, do_sample=False) + + unfolded = unfold( output ) + for i, batch in enumerate(unfolded["resp_list"]): + _ = decode_to_file(batch.to(device=device), f"data/{SELECTED_ARCH}.{cfg.audio_backend}.{i}.{name}.wav", device=device) + + unload_model() + + def train(): + engine.train() + t = trange(steps) + for i in t: + stats = {"step": i} + if SELECTED_ARCH == "mamba": + stats |= engine.traverse(input_ids=input_ids, labels=target_ids) + else: + stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask) + stats |= {"grad_norm": engine.get_global_grad_norm()} + + tqdm.write(f"{stats}") + + torch.save( { + 'module': model.state_dict() + }, f"./data/{SELECTED_ARCH}.pth" ) + + #sample("init", 5) + train() + sample("final") + +if __name__ == "__main__": + example_usage()