From 458b95d196a00bfb0b3270b2e40be1e0c8ae16a2 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 19 May 2024 11:23:56 -0500 Subject: [PATCH] added option to split between text loss and audio loss (to-do: document this better), because it may or may not be a problem with LLaMA-backed models because my loss hovers around 3.9 / 56% accuracy despite sounding decent at the moment --- data/config.yaml | 96 +++++++++++++++----------------- vall_e/config.py | 9 +++ vall_e/models/ar_nar.py | 5 +- vall_e/models/base.py | 118 +++++++++++++++++++++++++++++++++------- 4 files changed, 157 insertions(+), 71 deletions(-) diff --git a/data/config.yaml b/data/config.yaml index 5f106a3..c2ea978 100755 --- a/data/config.yaml +++ b/data/config.yaml @@ -6,41 +6,38 @@ models: tasks: 8 langs: 2 tones: 1 - arch_type: "retnet" + arch_type: llama training: True - version: 3 + version: 4 + attention: flash_attention_2 + dropout: 0.1 + + loss_factors: + text: 0.1 + resp: 1.0 hyperparameters: - batch_size: 4 + autotune: False + autotune_params: + start_profile_step: 1 + end_profile_step: 50 + num_tuning_micro_batch_sizes: 8 + + batch_size: 16 gradient_accumulation_steps: 4 - gradient_clipping: 10 - - optimizer: Adagrad + gradient_clipping: 1.0 + warmup_steps: 100 + + optimizer: Prodigy + learning_rate: 1.0 torch_optimizer: True - learning_rate: 1.0e-2 - scheduler_type: "" - #scheduler_type: OneCycle - #scheduler_params: - # cycle_first_step_size: 10_000 - # cycle_first_stair_count: 10_000 - - # cycle_second_step_size: 15_000 - # cycle_second_stair_count: 15_000 - - # decay_step_size: 5_000 - - # cycle_min_lr: 2.5e-4 # 1.0e-5 - # cycle_max_lr: 2.5e-4 # 1.0e-4 - # decay_lr_rate: 0.0 - - # cycle_min_mom: 0.90 - # cycle_max_mom: 0.99 - # decay_mom_rate: 0.0 + scheduler: "" # ScheduleFree + torch_scheduler: True evaluation: batch_size: 8 - frequency: 10000 + frequency: 5000 size: 8 steps: 500 @@ -49,8 +46,9 @@ evaluation: load_disabled_engines: True trainer: - no_logger: True - + #no_logger: True + ddp: False + #check_for_oom: False iterations: 1_000_000 save_tag: step @@ -72,7 +70,7 @@ trainer: gc_mode: None # "global_step" - weight_dtype: float32 + weight_dtype: float32 # float16 or bfloat16 amp: False backend: deepspeed @@ -81,34 +79,34 @@ trainer: zero_optimization_level: 0 use_compression_training: False + amp: False + activation_checkpointing: True - load_webui: True + load_webui: False inference: backend: deepspeed audio_backend: "dac" normalize: False - weight_dtype: float32 + weight_dtype: float32 # float16 or bfloat16 amp: False -bitsandbytes: - enabled: False - +optimizations: injects: False - replace: False + replace: True linear: False embedding: False - - bitnet: False + optimizers: True -fp8: - enabled: False - backend: "te" - -experimental: True + bitsandbytes: False + dadaptation: False + bitnet: False + fp8: False + +experimental: True # practically required now it seems dataset: speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" @@ -121,23 +119,19 @@ dataset: hdf5_flag: r validate: True - workers: 8 + workers: 2 cache: True - - #phones_range: [4, 512] - #duration_range: [1.0, 32.0] - phones_range: [0, 512] - duration_range: [0.0, 64.0] + duration_range: [3.0, 5.0] random_utterance: 1.0 - max_prompts: 3 - prompt_duration: 6.0 + max_prompts: 1 + prompt_duration: 3.0 max_resps: 1 p_resp_append: 0.25 - sample_type: speaker + sample_type: path # speaker tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"] diff --git a/vall_e/config.py b/vall_e/config.py index f1263e2..7cd39d8 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -213,9 +213,13 @@ class Model: attention: str = "auto" audio_embedding_sums: bool = True dropout: float = 0.1 # adjustable dropout value + loss_factors: dict = field(default_factory=lambda: {}) def get(self, name=None): return [ self ] if not name or self.name == name else [] + + def loss_factor(self, k): + return self.loss_factors[k] if k in self.loss_factors else 1.0 @property def max_levels(self): @@ -491,6 +495,11 @@ class DeepSpeed: return ds_cfg +@dataclass() +class LossFactor: + text: float = 1.0 + resp: float = 1.0 + @dataclass() class Trainer: iterations: int = 100_000 diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 848100d..02c8bc9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -348,12 +348,15 @@ def example_usage(): text_list = [ tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), + #tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device), ] proms_list = [ qnt[:cfg.dataset.frames_per_second, :].to(device), + #qnt[:cfg.dataset.frames_per_second, :].to(device), ] resps_list = [ - qnt.to(device), + qnt[:, :].to(device), + #qnt[:cfg.dataset.frames_per_second, :].to(device), ] text_list = text_list[:1] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index ae79f61..97051a0 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -426,6 +426,11 @@ class Base(nn.Module): def ignore_index(self): return -100 + def loss_factor(self, k): + if self.config is None: + return 1.0 + return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0 + def __init__( self, n_tokens: int = 1024, @@ -880,6 +885,31 @@ class Base(nn.Module): return x_list + def training_targets_split( + self, + inputs: list, + ): + text_lists = [] + resp_lists = [] + + for bi in range(len(inputs)): + text_batch = [] + resp_batch = [] + + for i in range(len(inputs[bi])): + name, input = inputs[bi][i] + device = input.device + + if name in ["text", "lang" ]: + text_batch.append( input ) + elif name == "targ": + resp_batch.append( input ) + + text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) ) + resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) ) + + return text_lists, resp_lists + def forward( self, inputs: list, @@ -929,30 +959,80 @@ class Base(nn.Module): # compute loss if the target is given if training: - target_list = self.training_targets( inputs ) - - # modify only for the AR so it can properly behave like a transformer - for i in range(len(target_list)): - if quant_levels is not None and quant_levels[i] > 0: - continue + if not self.config.loss_factors: + target_list = self.training_targets( inputs ) + + # modify only for the AR so it can properly behave like a transformer + for i in range(len(target_list)): + if quant_levels is not None and quant_levels[i] > 0: + continue - logits[i] = logits[i][..., :-1, :] # shift the target so that token n... - target_list[i] = target_list[i][..., 1:] # predicts token n + 1 + logits[i] = logits[i][..., :-1, :] # shift the target so that token n... + target_list[i] = target_list[i][..., 1:] # predicts token n + 1 - target = torch.cat( target_list ) - inputs = torch.cat( logits ) + target = torch.cat( target_list ) + inputs = torch.cat( logits ) - self.loss = dict( - # "nll" was in the original implementation and should actually just be called something else - nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) - ) - self.stats = dict( - acc = self.accuracy_metric( inputs, target ), - # precision = self.precision_metric( inputs, target ), - ) + self.loss = dict( + # "nll" was in the original implementation and should actually just be called something else + nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) + ) + self.stats = dict( + acc = self.accuracy_metric( inputs, target ), + # precision = self.precision_metric( inputs, target ), + ) + # split our loss + else: + target_text_list, target_resp_list = self.training_targets_split( inputs ) + # grab respective slice of logits + logits_text = [ hi[:li.shape[0]] for hi, li in zip(logits, target_text_list) ] + logits_resp = [ hi[-li.shape[0]:] for hi, li in zip(logits, target_resp_list) ] + + # modify only for the AR so it can properly behave like a transformer + for i in range(len(target_text_list)): + if quant_levels is not None and quant_levels[i] > 0: + continue + + # shift the target so that token n... + logits_text[i] = logits_text[i][..., :-1, :] + logits_resp[i] = logits_resp[i][..., :-1, :] + + # predicts token n + 1 + target_text_list[i] = target_text_list[i][..., 1:] + target_resp_list[i] = target_resp_list[i][..., 1:] + + target_text = torch.cat( target_text_list ).long() + target_resp = torch.cat( target_resp_list ).long() + + inputs_text = torch.cat( logits_text ) + inputs_resp = torch.cat( logits_resp ) + + self.loss = dict( + text = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ), + resp = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ), + ) + + for k in self.loss: + self.loss[k] *= self.loss_factor(k) + + # to-do: compute loss per individual batch to scale per RVQ level + """ + rvq_loss_factor = self.loss_factor("quant") + if isinstance( rvq_loss_factor, list ): + ... + """ + + self.stats = dict( + acc = dict( + text = self.accuracy_metric( inputs_text, target_text ), + resp = self.accuracy_metric( inputs_resp, target_resp ), + ), + ) + + # include any additional losses (for example: MoE router) if aux_loss is not None: - self.loss["nll"] += aux_loss + self.loss["aux_loss"] = aux_loss return (logits, state) if state is not None else logits