diff --git a/vall_e/config.py b/vall_e/config.py index b24a028..33148bd 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -206,8 +206,8 @@ class Model: #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good loss_factors: dict = field(default_factory=lambda: {}) capabilities: list = field(default_factory=lambda: ["ar", "nar"]) - experimental: bool = False # for now it sets things to be HF compatible - kv_heads: int = 0 + experimental: str | None = None # for now it sets things to be HF compatible + kv_heads: int = 0 # MHA or GQA (for supported backends) def get(self, name=None): return [ self ] if not name or self.name == name else [] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 81f1ecd..a7e3720 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -14,10 +14,14 @@ from ..emb.qnt import trim class AR_NAR(Base): @property - def causal(self): + def capabilities(self) -> list[str]: if hasattr(self, "config") and self.config: - return "ar" in self.config.capabilities - return True + return self.config.capabilities + return cfg.model.capabilities + + @property + def causal(self): + return "ar" in self.capabilities or "len" in self.capabilities @property def norm_type(self): @@ -86,8 +90,10 @@ class AR_NAR(Base): return self.config.version return cfg.model.version - def _prune(self, l: Tensor): - indices = (l == self.stop_token).nonzero() + def _prune(self, l: Tensor, stop = None): + if stop is None: + stop = self.stop_token + indices = (l == stop).nonzero() if len(indices) == 0: return l return l[: indices.min().item()] @@ -104,6 +110,7 @@ class AR_NAR(Base): lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, max_steps: int = 1000, max_levels: int = 0, @@ -130,7 +137,16 @@ class AR_NAR(Base): # is training if n_levels == self.n_resp_levels: - # might be better to have this decided on the dataloader level + # to-do: make this YAML configurable + def sample_task(): + p_len_task = 0.25 if "len" in self.capabilities else 0 + return "len" if random.random() < p_len_task else "tts" + + # generate task list to train against + task_list = [ sample_task() for _ in range(batch_size) ] + + # determines which RVQ level to target per batch + quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ] if cfg.experimental: # makes higher levels less likely @@ -142,20 +158,21 @@ class AR_NAR(Base): index = i return int(index) - #quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16) - quant_levels = [ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ] + quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for _ in range(batch_size) ] else: - #quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - quant_levels = [ random.randint(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ] # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) + quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for _ in range(batch_size) ] - resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM + resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # append stop tokens for AR - for i in range(batch_size): - if quant_levels[i] > 0: - continue - - resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ]) + # could technically do it in the .inputs call + if "len" not in self.capabilities: + for i in range(batch_size): + # only apply stop token for RVQ level 0 + if quant_levels[i] > 0: + continue + resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ]) inputs = self.inputs( text_list=text_list, @@ -163,6 +180,7 @@ class AR_NAR(Base): resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, + task_list=task_list, quant_levels=quant_levels, ) @@ -171,6 +189,7 @@ class AR_NAR(Base): inputs=inputs, quant_levels=quant_levels, ) + # is NAR if max_levels == 0: max_levels = self.n_resp_levels - 1 @@ -187,7 +206,7 @@ class AR_NAR(Base): if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break - quant_levels = torch.full((len(text_list),), level) + quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) inputs = self.inputs( text_list=text_list, @@ -223,9 +242,88 @@ class AR_NAR(Base): return prev_list + # other NAR + if len_list is not None: + # is NAR + if max_levels == 0: + max_levels = self.n_resp_levels + + # fill with mock tokens + prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ] + + start = True + for n in trange( max_levels, desc="NAR" ): + level = 0 if n == 0 else prev_list[0].shape[-1] + if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels + break + + quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) + + inputs = self.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=prev_list, + lang_list=lang_list, + tone_list=tone_list, + quant_levels=quant_levels, + ) + + logits = super().forward( + inputs=inputs, + quant_levels=quant_levels, + ) + + resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ] + + if n == 0: + prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ] + else: + prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] + + return prev_list + # is AR - sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] + sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() + + stop_token = 10 if "len" in self.capabilities else self.stop_token + task_list = [ "len" if "len" in self.capabilities else "tts" for _ in range(batch_size) ] + + if "len" in self.capabilities: + for n in trange(10, desc="AR"): + len_list = sequence_list + + inputs = self.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + task_list=task_list, + quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] + ) + + logits = super().forward( + inputs=inputs, + ) + + r = [ logit[-1:].argmax(dim=1) for logit in logits ] + + # append tokens + for i, ri in enumerate(r): + if stop_token in ri: + stopped[i] = True + sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) + + # stop token found + stopped |= r == stop_token + if stopped.all().item(): + break + + # convert tokens into int + return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ] + recurrent_state = [] if cfg.inference.recurrent_forward else None mirostat = [ @@ -252,7 +350,8 @@ class AR_NAR(Base): resps_list=resps_list, lang_list=lang_list, tone_list=tone_list, - + len_list=len_list, + task_list=task_list, quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ) @@ -304,12 +403,12 @@ class AR_NAR(Base): # append tokens for i, ri in enumerate(r): - if self.stop_token in ri: + if stop_token in ri: stopped[i] = True sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) # stop token found - stopped |= r == self.stop_token + stopped |= r == stop_token if stopped.all().item(): break @@ -318,7 +417,8 @@ class AR_NAR(Base): if sampling_beam_width: sequence_list = [ sequence_list[0] ] - return [self._prune(r) for r in sequence_list] + sequence_list = [self._prune(r, stop_token) for r in sequence_list] + return sequence_list def example_usage(): @@ -474,13 +574,17 @@ def example_usage(): return engine.eval() - if "ar" in cfg.model.capabilities: - resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) + if "len" in cfg.model.capabilities: + len_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) + resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.2 ) else: - resps_list = [ qnt[:, 0].to( device ) ] + if "ar" in cfg.model.capabilities: + resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) + else: + resps_list = [ qnt[:, 0].to( device ) ] - if "nar" in cfg.model.capabilities: - resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) + if "nar" in cfg.model.capabilities: + resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 ) for i, o in enumerate(resps_list): _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index cc9e48d..de9aad1 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -115,13 +115,13 @@ class AudioEmbedding_Old(nn.Module): # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None - def forward(self, xi: Tensor, quant_level: int | Tensor | None = None ) -> Tensor: + def forward(self, xi: Tensor, offset: int | None = 0 ) -> Tensor: # prom - if quant_level is None and xi.shape[-1] > 1: + if offset == 0 and xi.shape[-1] > 1: x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) # AR resp - elif quant_level is None or quant_level == 0: - x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] ) + elif quant_level == 0: + x = self.embeddings[0]( xi if xi.dim() == 1 else xi[:, 0] ) # NAR resp else: x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) @@ -147,19 +147,13 @@ class AudioEmbedding(nn.Module): self.sums = sums # maintaining compat is hard - def forward(self, xi: Tensor, quant_level: int | Tensor | None = None ) -> Tensor: - if quant_level is None: - quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 - - # jank but needed - if xi.dim() == 1: - return self.embeddings[quant_level]( xi ) + def forward(self, xi: Tensor, quant_level: int | Tensor | None = None, offset: int = 0 ) -> Tensor: + quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 - offset = 0 if self.mode == "prom" else 1 if self.sums and quant_level > 0: x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] ) else: - k = quant_level - 1 + k = quant_level x = self.embeddings[k + offset]( xi if xi.dim() == 1 else xi[:, k] ) return x @@ -217,6 +211,10 @@ class Base(nn.Module): def version(self) -> int: return 1 + @property + def capabilities(self) -> list[str]: + raise NotImplementedError + @property def stop_token(self): if not self.causal: @@ -273,7 +271,8 @@ class Base(nn.Module): self.langs_emb = None self.tones_emb = None self.tasks_emb = None - self.rvq_level_emb = None + self.rvq_l_emb = None + self.len_emb = None if self.version == 1: # legacy n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom @@ -298,7 +297,7 @@ class Base(nn.Module): ) self.resps_emb = AudioEmbedding( [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model, - "resp", + "resp:len" if "len" in self.capabilities else "resp", sums=self.config.audio_embedding_sums if self.config is not None else True ) @@ -314,7 +313,10 @@ class Base(nn.Module): # this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings # this ***might*** let me also unify the proms_emb and resps_embedding if self.version >= 5: - self.rvq_level_emb = Embedding(self.n_resp_levels, d_model) + self.rvq_l_emb = Embedding(self.n_resp_levels, d_model) + + # experimental NAR-only mode + self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None # this would be nicer to be a stop token or live inside an embedding self.sep = nn.Parameter(torch.randn(d_model)) @@ -623,6 +625,8 @@ class Base(nn.Module): lang_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + task_list: list[str] | None = None, quant_levels: int | list[int] | Tensor | None = None ): @@ -632,17 +636,41 @@ class Base(nn.Module): inputs = [ [] for _ in range(batch_size) ] for i in range(batch_size): quant_level = quant_levels[i] if quant_levels is not None else 0 + task_type = task_list[i] if task_list is not None else "tts" - if text_list is not None: - inputs[i].append( ( "text", text_list[i] ) ) + inputs[i].append( ( "task", task_type ) ) - if self.rvq_level_emb is not None: - inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) ) + # + if task_type == "tts": + if text_list is not None: + inputs[i].append( ( "text", text_list[i] ) ) + if self.rvq_l_emb is not None: + inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) ) + if proms_list is not None: + inputs[i].append( ( "prom", proms_list[i] ) ) + if resps_list is not None: + inputs[i].append( ( "resp", resps_list[i] ) ) + # + elif task_type == "len": + # throw an error so we don't silently train without this + if self.len_emb is None: + raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.") + if text_list is not None: + inputs[i].append( ( "text", text_list[i] ) ) + # technically will always be level 0 but for the sake of keeing the input formatting coherent... + if self.rvq_l_emb is not None: + # override to 0 (I don't know if this change propagates, I'm not familiar with when python passes by (copied) value or reference) + quant_levels[i] = 0 + # inputs[i].append( ( "quant_level", torch.Tensor([ 0 ]).to(device=device, dtype=torch.int16) ) ) + if proms_list is not None: + inputs[i].append( ( "prom", proms_list[i] ) ) - if proms_list is not None: - inputs[i].append( ( "prom", proms_list[i] ) ) - if resps_list is not None: - inputs[i].append( ( "resp", resps_list[i] ) ) + if len_list is not None: + inputs[i].append( ( "len", len_list[i] ) ) + # "encode" length to tokens for 0-9 + stop + elif resps_list is not None: + # yes this could be encoded better + inputs[i].append( ( "len", torch.Tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ]).to(device=device, dtype=torch.int16) ) ) return inputs @@ -656,26 +684,40 @@ class Base(nn.Module): batch = [] quant_level = quant_levels[batch_index] if quant_levels is not None else 0 for name, input in batch_input: + # technically can provide a map for input_name => embedding, but some embedding requires additional processing embedding = None - if name == "text": + + if name == "task": + # noop + # *maybe* inject a token for specifying task type + ... + continue + elif name == "text": embedding = self.text_emb( input ) - elif name == "quant_level" and self.rvq_level_emb is not None: - embedding = self.rvq_level_emb( input ) + elif name == "quant_level" and self.rvq_l_emb is not None: + embedding = self.rvq_l_emb( input ) elif name == "lang" and self.langs_emb is not None: embedding = self.langs_emb( input ) elif name == "prom": # get RVQ level 0, or up to targetted RVQ level inference - embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None ) + embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 ) elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": - # get RVQ level 0, or up to targetted RVQ level inference - embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level ) + if "len" in self.capabilities and quant_level == 0: + # fill with "stop" tokens for NAR-only model + embedding = self.resps_emb( torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), offset = 0 ) + else: + # get RVQ level 0, or up to targetted RVQ level inference + embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 else 1 ) + elif name == "len" and self.len_emb is not None: + embedding = self.len_emb( input ) else: + # should probably raise an exception so things aren't processed silently continue batch.append(embedding) - + x_list.append( _join( batch, self.sep ) ) return x_list @@ -690,22 +732,24 @@ class Base(nn.Module): # old, "naive" way, no loss factoring if not self.config.loss_factors: target_list = [] + task_list = [] for batch_index, batch in enumerate(inputs): quant_level = quant_levels[batch_index] - prev_quant_level = 0 if quant_level == 0 else quant_level - 1 target = [] for name, input in batch: - if name == "prom": + if name == "task": + task_list.append( input ) + elif name == "prom": # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens if self.version < 4 or (self.version >= 5 and self.config.audio_embedding_sums): target.append( torch.full_like(input[..., 0], self.ignore_index) ) # we *CAN* directly map to proms else: - target.append( input if input.dim() == 1 else input[:, prev_quant_level] ) + target.append( input if input.dim() == 1 else input[:, quant_level] ) elif name == "resp": target.append( input if input.dim() == 1 else input[:, quant_level] ) - elif name in ["text", "quant_level", "lang", "tone"]: + elif name in ["text", "quant_level", "lang", "tone", "len"]: target.append( input ) target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) ) @@ -713,8 +757,12 @@ class Base(nn.Module): batch_size = len(target_list) # modify only for the AR so it can properly behave like a transformer for i in range(batch_size): - if quant_levels is not None and quant_levels[i] > 0: - continue + if "len" in self.capabilities: + if task_list[i] != "len": + continue + else: + if quant_levels is not None and quant_levels[i] > 0: + continue l = self.causal_size logits[i] = logits[i][..., :-l, :] # shift the target so that token n... @@ -758,7 +806,6 @@ class Base(nn.Module): for i, batch in enumerate( inputs ): quant_level = quant_levels[i] - prev_quant_level = 0 if quant_level == 0 else quant_level - 1 it = 0 for name, input in batch: @@ -767,7 +814,10 @@ class Base(nn.Module): input = input if input.dim() == 1 else input[:, quant_level] # select prom level elif name == "prom": - input = input[:, prev_quant_level] + input = input[:, quant_level] + # meta-input, no corresponding token at the moment + elif name == "task": + continue seq_len = input.shape[0] @@ -776,7 +826,7 @@ class Base(nn.Module): # for the AR, shift sequence so that it predicts the next token # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) - if quant_level == 0: + if quant_level == 0 and seq_len > 1: l = self.causal_size logit = logit[..., :-l, :] input = input[..., l:] # shift sequence to the right by one (or causal chunk size) @@ -793,7 +843,7 @@ class Base(nn.Module): for name, batch in info.items(): loss_factor = self.loss_factor(name) - if name not in ["text", "prom", "resp"]: + if name not in ["text", "prom", "resp", "len"]: continue if loss_factor == 0.0: