diff --git a/vall_e/data.py b/vall_e/data.py index 5e25857..cd16a70 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -615,8 +615,6 @@ class Dataset(_Dataset): prom_length = 0 trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) - print(trim_length / cfg.dataset.frames_per_second) - for _ in range(cfg.dataset.max_prompts): path = random.choice(choices) if cfg.dataset.use_hdf5: diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 0979275..aab5201 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -2,24 +2,7 @@ def get_model(config, training=True): name = config.name - if not config.experimental: - from .ar_nar import AR_NAR - model = AR_NAR( - n_text_tokens=config.text_tokens, - n_audio_tokens=config.audio_tokens, - d_model=config.dim, - n_heads=config.heads, - n_layers=config.layers, - n_experts=config.experts, - - p_dropout=config.dropout, - - l_padding = config.input_alignment, - - training = training, - config = config, - ) - elif "len" in config.capabilities: + if "len" in config.capabilities: from .nar import NAR model = NAR( n_text_tokens=config.text_tokens, @@ -36,7 +19,7 @@ def get_model(config, training=True): training = training, config = config, ) - else: + elif config.experimental: from .experimental import Model as Experimental model = Experimental( n_text_tokens=config.text_tokens, @@ -49,6 +32,23 @@ def get_model(config, training=True): config = config, ) + else: + from .ar_nar import AR_NAR + model = AR_NAR( + n_text_tokens=config.text_tokens, + n_audio_tokens=config.audio_tokens, + d_model=config.dim, + n_heads=config.heads, + n_layers=config.layers, + n_experts=config.experts, + + p_dropout=config.dropout, + + l_padding = config.input_alignment, + + training = training, + config = config, + ) print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c30ad0e..9f381d5 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -279,7 +279,9 @@ class Base(nn.Module): @property def stop_token(self): - if not self.causal and "len" not in self.capabilities: + if "len" in self.capabilities: + return 0 + if not self.causal: raise ValueError("Not using stop token!") return self.n_audio_tokens @@ -325,9 +327,15 @@ class Base(nn.Module): self.l_padding = l_padding - # +1 to include the stop token n_prom_tokens = n_audio_tokens - n_resp_tokens = n_audio_tokens + self.causal_size + + if "len" not in self.capabilities: + # +1 to include the stop token + n_resp_tokens = n_audio_tokens + self.causal_size + l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + else: + n_resp_tokens = n_audio_tokens + l_tokens = [n_resp_tokens] * self.n_resp_levels audio_embedding_sums = self.config.audio_embedding_sums if self.config is not None else True split_classifiers = self.config.split_classifiers if self.config is not None else True @@ -351,7 +359,7 @@ class Base(nn.Module): ) # [1024 + STOP] + [1024] * 8 self.resps_emb = AudioEmbedding_Old( - [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, + l_tokens, d_model, levels=self.n_resp_levels if self.version > 3 else None, ) else: @@ -360,7 +368,7 @@ class Base(nn.Module): sums=audio_embedding_sums, ) self.resps_emb = AudioEmbedding( - [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, + l_tokens, d_model, sums=audio_embedding_sums, ) @@ -634,13 +642,11 @@ class Base(nn.Module): self.metrics = None else: - levels = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - self.classifier = None - self.classifiers = AudioClassifier( levels, d_model ) + self.classifiers = AudioClassifier( l_tokens, d_model ) self.accuracy_metric = None self.precision_metric = None - self.metrics = Metrics( levels ) + self.metrics = Metrics( l_tokens ) def _forward( @@ -905,7 +911,7 @@ class Base(nn.Module): self.loss = dict( nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size ) - self.stats = self.metrics( inputs, targets, quant_levels ) if self.metrics is not None else dict( + self.stats = self.metrics( logits, target_list, quant_levels ) if self.metrics is not None else dict( acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size ) diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 0f20baf..7e0ecc5 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -145,8 +145,8 @@ class NAR(Base): n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) - # is training - assert n_levels == self.n_resp_levels + # assert n_levels == self.n_resp_levels + # to-do: make this YAML configurable def sample_task(): return "len" if random.random() < p_len_task else "tts" @@ -170,7 +170,12 @@ class NAR(Base): quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ] else: # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ] + quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] + + # clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC... + for i, resp in enumerate(resps_list): + if quant_levels[i] >= resp.shape[-1]: + quant_levels[i] = resp.shape[-1] - 1 resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] @@ -355,7 +360,7 @@ def example_usage(): """ model = NAR(**kwargs).to(device) - steps = 200 + steps = 500 optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""