diff --git a/vall_e/config.py b/vall_e/config.py index bc7b3c9..02185c8 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -960,6 +960,19 @@ class NaiveTokenizer: # tokenize return [*map(symmap.get, phones)] + def decode( self, t ): + s = "" + symmap = self.get_vocab() + reverse_symmap = {} + for k, v in symmap.items(): + reverse_symmap[v] = k + + for i, token in enumerate( t ): + s += reverse_symmap[token] + + return s + + _logger = logging.getLogger(__name__) cfg = Config.from_cli() diff --git a/vall_e/data.py b/vall_e/data.py index 8f0d3a4..db12759 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -428,6 +428,7 @@ def get_task_symmap(): "": 5, "": 6, "": 7, + "": 8, "": 6, # fake "": 6, # fake @@ -1052,6 +1053,12 @@ class Dataset(_Dataset): task, ] + # Base TTS ( => ) + elif task == "stt": + # easier to just keep it instead of wrangling around trying to remove it + # it might also help to provide a guidance prompt but who knows right now + proms = self.sample_prompts(spkr_name, ignore=path) + # noise suppression (? => ) # speech removal (? => ) elif task == "ns" or task == "sr": diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index c600015..383cba9 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -200,7 +200,7 @@ class AR(Base): r = super().sample( logits=logits, - resps_list=resps_list, + prev_list=resps_list, temperature=sampling_temperature, min_temperature=sampling_min_temperature, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index d1661f7..09ff810 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -61,15 +61,23 @@ class AR_NAR(Base): disable_tqdm=False, ): - device = text_list[0].device - batch_size = len(text_list) - + text_task = [ "stt" ] + + if text_list is not None: + default_task = "tts" + device = text_list[0].device + batch_size = len(text_list) + else: + default_task = "stt" + device = resps_list[0].device + batch_size = len(resps_list) + # generate task list if not provided if task_list is None: - task_list = [ "tts" for _ in range(batch_size) ] + task_list = [ default_task for _ in range(batch_size) ] # is training or NAR - if resps_list is not None: + if resps_list is not None and text_list is not None: n_levels_set = {r.shape[-1] for r in resps_list} n_levels = next(iter(n_levels_set)) @@ -102,12 +110,18 @@ class AR_NAR(Base): # input RVQ levels quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ] + for i, task in enumerate( task_list ): + if task in text_task: + quant_levels[i] = 0 # self.n_resp_levels - 1 + # trim resps to only contain all levels below the target level - resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] + resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] + # tensor to cat for RVQ level 0 - stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) + text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16) + audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) # I hate python's value/reference semantics so much - for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): + for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list): # cap quant_level if it exceeds its corresponding resp/prom if quant_level >= resps.shape[-1]: quant_levels[i] = resps.shape[-1] - 1 @@ -139,7 +153,11 @@ class AR_NAR(Base): # only apply stop token for RVQ level 0 if quant_level <= 0: # append stop tokens for AR - resps_list[i] = torch.cat([ resps, stop_sequence ]) + if task in text_task: + #text_list[i] = torch.cat([ resps, text_stop_sequence ]) + ... + else: + resps_list[i] = torch.cat([ resps, audio_stop_sequence ]) inputs = self.inputs( @@ -195,7 +213,7 @@ class AR_NAR(Base): resps_list = super().sample( logits=logits, - resps_list=prev_list, + prev_list=prev_list, quant_levels=quant_levels, temperature=sampling_temperature, @@ -220,11 +238,11 @@ class AR_NAR(Base): if cfg.lora is not None: enable_lora( self, cfg.lora.active_level( 0 ) ) + # STT sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() - stop_token = self.stop_token - + stop_token = self.stop_token if task_list[0] != "stt" else 2 # to-do: derive from tokenizer state = None mirostat = [ @@ -233,9 +251,17 @@ class AR_NAR(Base): scores = [ 1.0 ] * sampling_beam_width + # add to text for STT + for i, sequence in enumerate( sequence_list ): + if task_list[i] in text_task: + sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)]) + # get next in sequence for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): - resps_list = [x.unsqueeze(dim=-1) for x in sequence_list] + if task_list[0] in text_task: + text_list = [x for x in sequence_list] + else: + resps_list = [x.unsqueeze(dim=-1) for x in sequence_list] inputs = self.inputs( text_list=text_list, @@ -261,7 +287,7 @@ class AR_NAR(Base): r = super().sample( logits=logits, - resps_list=resps_list, + prev_list=resps_list, temperature=sampling_temperature, min_temperature=sampling_min_temperature, @@ -398,10 +424,10 @@ def example_usage(): """ bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) - tasks = cfg.dataset.tasks_list + available_tasks = ["tts", "stt"] # cfg.dataset.tasks_list model = AR_NAR(**kwargs).to(device) - steps = 150 * len(tasks) # * cfg.model.experimental.causal_size + steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" @@ -486,14 +512,14 @@ def example_usage(): _logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.no_grad() - def sample_data(task=None): + def sample_data(t=None): texts = [] proms = [] resps = [] + tasks = [] for i in range(batch_size): - if task is None: - task = random.choice(tasks) + task = random.choice(available_tasks) if t is None else t text = text_list[i] prom = proms_list[i] @@ -502,6 +528,8 @@ def example_usage(): # do nothing if task == "tts": ... + elif task == "stt": + ... elif task == "tts-c": trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) @@ -523,25 +551,35 @@ def example_usage(): texts.append( text.to(device) ) proms.append( prom.to(device) ) resps.append( resp.to(device) ) + tasks.append( task ) - return texts, proms, resps + return texts, proms, resps, tasks @torch.inference_mode() - def sample( name, steps=1000, task=None ): + def sample( name, steps=500, task=None ): engine.eval() - texts, proms, resps = sample_data( task ) + texts, proms, resps, tasks = sample_data( task ) - if "ar" in cfg.model.capabilities: - resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 ) + if tasks[0] == "stt": + text = engine( None, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 ) + """ + # to-do: STT for NAR + text = engine( text, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 ) + """ + text = [ cfg.tokenizer.decode( t ) for t in text ] + + print( text ) else: - resps = [ resp[:, 0] for resp in resps ] + if "ar" in cfg.model.capabilities: + resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 ) + else: + resps = [ resp[:, 0] for resp in resps ] - if "nar" in cfg.model.capabilities: - resps = engine( texts, proms, resps, sampling_temperature=0.2 ) - - for i, o in enumerate(resps): - _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device) + if "nar" in cfg.model.capabilities: + resps = engine( texts, proms, resps, sampling_temperature=0.2 ) + for i, o in enumerate(resps): + _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device) unload_model() @@ -549,10 +587,10 @@ def example_usage(): engine.train() t = trange(steps) for i in t: - texts, proms, resps = sample_data() + texts, proms, resps, tasks = sample_data() stats = {"step": i} - stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps) + stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks) stats |= {"grad_norm": engine.get_global_grad_norm()} tqdm.write(f"{stats}") @@ -571,7 +609,7 @@ def example_usage(): model = ml.compile_model(model, backend=cfg.optimizations.compile) """ - for task in tasks: + for task in available_tasks: sample("final", task=task) engines.quit() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 1ee8026..05625d6 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -259,7 +259,8 @@ class AudioEmbedding(nn.Module): return x # per-level classification -class AudioClassifier(nn.Module): +# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl +class Classifiers(nn.Module): def __init__( self, l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) @@ -783,10 +784,10 @@ class Base(nn.Module): self.metrics = None else: self.classifier = None - self.classifiers = AudioClassifier( l_tokens, d_model ) + self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model ) self.accuracy_metric = None self.precision_metric = None - self.metrics = Metrics( l_tokens ) + self.metrics = Metrics( l_tokens + [ n_text_tokens ] ) """ if tie_classifier_to_embedding: @@ -907,6 +908,8 @@ class Base(nn.Module): device = text_list[0].device batch_size = len(text_list) + special_tasks = ["stt", "len"] + inputs = [ [] for _ in range(batch_size) ] for i in range(batch_size): quant_level = quant_levels[i] if quant_levels is not None else 0 @@ -921,7 +924,7 @@ class Base(nn.Module): # Base-line TTS task # Sequence: # prom /may/ include tokens inside to help guide things, per SpeechX - if f'<{task_type}>' in get_task_symmap(): + if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks: # insert the text prompt if text_list is not None and text_list[i] is not None: inputs[i].append( ( "text", text_list[i] ) ) @@ -973,6 +976,21 @@ class Base(nn.Module): elif resps_list is not None and resps_list[i] is not None: # yes this could be encoded better inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) ) + # Speech-to-Text prediction task + # Sequence: + elif task_type == "stt": + # insert the input response + if resps_list is not None and resps_list[i] is not None: + inputs[i].append( ( "resp", resps_list[i] ) ) + # insert lang token if we're trained for it + if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: + inputs[i].append( ( "lang", lang_list[i] ) ) + # insert RVQ level guidance token if the model is versioned for it + if self.rvq_l_emb is not None and not self.interleave: + inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) + # insert the output text prompt + if text_list is not None and text_list[i] is not None: + inputs[i].append( ( "text", text_list[i] ) ) else: raise Exception(f'Unrecognized task: {task_type}') @@ -1010,6 +1028,8 @@ class Base(nn.Module): if not token_dropout_rvq_levels: token_dropout_rvq_levels = [1, self.resp_levels] + summed_embeddings_task = [ "stt" ] + x_list = [] for batch_index, batch_input in enumerate(inputs): batch = [] @@ -1071,7 +1091,13 @@ class Base(nn.Module): offset = 0, quant_level = 0, ) - + # cheat-y way to handle performing STT across all levels + elif task_type in summed_embeddings_task: + embedding = sum([ self.resps_emb( + input[:, :l+1], + offset = 0 if l == 0 else 1, # or maybe set to 1 + quant_level = l + ) for l in range( input.shape[-1] - 1 ) ]) else: # get RVQ level 0, or up to targetted RVQ level inference if self.version <= 4: @@ -1171,7 +1197,9 @@ class Base(nn.Module): quant_levels: int | list[int] | Tensor | None = None, ): device = logits[0].device - classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ] + special_tasks = [ "len", "stt" ] + summed_embeddings_task = [ "stt" ] + classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] # handles tasks where the prompt has task tokens injected in the middle def prompt_input_to_token( input, quant_level ): @@ -1192,8 +1220,10 @@ class Base(nn.Module): for batch_index, batch in enumerate(inputs): quant_level = quant_levels[batch_index] target = [] + task_type = "tts" for name, input in batch: if name == "task": + task_type = input task_list.append( input ) elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input @@ -1201,6 +1231,8 @@ class Base(nn.Module): elif name == "resp": if self.interleave: target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) ) + elif task_type in summed_embeddings_task: + target.append( torch.full_like(input[..., 0], self.ignore_index) ) else: target.append( input if input.dim() == 1 else input[:, quant_level] ) elif name in ["text", "quant_level", "lang", "tone", "len"]: @@ -1273,7 +1305,12 @@ class Base(nn.Module): for name, input in batch: # do not use resp if name == "resp": - input = input if input.dim() == 1 else input[:, quant_level] + if self.interleave: + input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) + elif task_type in summed_embeddings_task: + input = torch.full_like(input[..., 0], self.ignore_index) + else: + input = input if input.dim() == 1 else input[:, quant_level] # select prom level elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input @@ -1383,7 +1420,8 @@ class Base(nn.Module): ) if self.classifiers is not None: - classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ] + special_tasks = [ "len", "stt" ] + classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] x = self.classifiers(x, levels = classifier_quant_levels) * m # Remove padding @@ -1402,7 +1440,7 @@ class Base(nn.Module): def sample( self, logits: list[Tensor], # logit scores - resps_list: list[Tensor], # previous tokens + prev_list: list[Tensor], # previous tokens quant_levels: int | list[int] | Tensor | None = None, # base sampling parameters temperature: float = 1.0, @@ -1429,7 +1467,7 @@ class Base(nn.Module): # (NAR) return the entire generated response # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely - logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] + logits = [ logit[-l:] for logit, l in zip(logits, map(len, prev_list)) ] # (AR chunkwise) return the last chunkwise piece elif self.causal: logits = [ logit[-self.causal_size:] for logit in logits ] @@ -1439,22 +1477,22 @@ class Base(nn.Module): # (NAR) disable stop token if quant_levels is not None and "ar" in self.capabilities: - logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ] + logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ] # (AR-len) disable extraneous tokens if quant_levels is None and "len" in self.capabilities: - logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, resps_list) ) ] + logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ] # argmax instead if temperature <= 0.0: return [ logit.argmax(dim=1) for logit in logits ] # perform repetition penalizing - if "len" not in self.capabilities: - logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] + if "len" not in self.capabilities and repetition_penalty != 1.0: + logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, prev_list ) ] # (AR) perform length penalizing if quant_levels is None and self.causal: - logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] + logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] # perform top_k/top_p filtering of our logits if top_k > 0 or top_p < 1.0: @@ -1469,7 +1507,7 @@ class Base(nn.Module): # do DRY sampling if dry_multiplier > 0.0: - logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ] + logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ] # do mirostat sampling # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 87f0fb4..5ba87ce 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -183,7 +183,7 @@ class NAR(Base): resps_list = super().sample( logits=logits, - resps_list=prev_list, + prev_list=prev_list, quant_levels=quant_levels, temperature=1.0 if n == 0 else sampling_temperature,