From ce8bb1e4f7623e491cd7aad271a7f5f7b503585f Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 27 Jul 2024 15:36:05 -0500 Subject: [PATCH] sanity cleanups with weird off-by-one-ness, cleaned up and validated vall_e.models.experimental works again --- vall_e/data.py | 276 ++++++++++++++++++++++++---------- vall_e/models/ar_nar.py | 56 +++---- vall_e/models/experimental.py | 174 +++++++++++++-------- vall_e/models/nar.py | 82 ++++++---- vall_e/train.py | 112 +++----------- vall_e/utils/utils.py | 3 + 6 files changed, 398 insertions(+), 305 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index fb6d1c1..f4b1db4 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -34,6 +34,9 @@ _logger = logging.getLogger(__name__) # fold into a typical LLM sequence (one embedding rather than split embeddings) def fold_inputs( text_list = [], + lang_list = [], + task_list = [], + tone_list = [], prom_list = [], resp_list = [], targ_list = [], @@ -42,12 +45,13 @@ def fold_inputs( sep = 3, stop = 3, + config = None, - text_tokens = 256, - audio_tokens = 1024, - audio_rvq_levels = cfg.model.max_levels, quant_levels = None, ): + if config is None: + config = cfg.model + def _create_mask(l, device): seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) @@ -61,108 +65,176 @@ def fold_inputs( m = m.to(x) return x, m + def process_prom_or_task(i, prom): + if prom is None: + return + + if isinstance(prom, str): + task = get_task_symmap()[f'<{input}>'] + seq = torch.Tensor([task_start + task]).to(device=device, dtype=dtype) + + input_ids[i].append( seq ) + input_ids[i].append( sep ) + return + + # deinterleaved + if quant_levels is not None: + quant_level = quant_levels[i] + if ignore_index is not None: + seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to(device=device, dtype=dtype) + else: + seq = prom[:, quant_level].to(device=device, dtype=dtype).clone() + for idx, token in enumerate( seq ): + token += prom_start + ( config.audio_tokens * quant_level ) + # interleaved + else: + if ignore_index is not None: + seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to(device=device, dtype=dtype) + else: + seq = prom.flatten().to(device=device, dtype=dtype) + for idx, token in enumerate( seq ): + token += prom_start + ( config.audio_tokens * ( idx % config.resp_levels ) ) + + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + """ + if quant_levels is not None: + resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ] + """ + device = text_list[0].device + dtype = torch.int64 + batch_size = len(text_list) input_ids = [ [] for _ in range(batch_size) ] offset = 0 - sep = torch.Tensor([ sep ]) - stop = torch.Tensor([ stop ]) + sep = torch.Tensor([ sep ]).to(device=device, dtype=dtype) + stop = torch.Tensor([ stop ]).to(device=device, dtype=dtype) + text_start = 0 + text_end = text_start + config.text_tokens + + lang_start = text_end + lang_end = lang_start + config.langs + + rvq_start = lang_end + rvq_end = rvq_start + config.resp_levels + + prom_start = rvq_end + prom_end = prom_start + config.audio_tokens * config.resp_levels + + task_start = prom_end + task_end = task_start + config.tasks + + tone_start = task_end + tone_end = tone_start + config.tones + + resp_start = tone_end + resp_end = resp_start + config.audio_tokens * config.resp_levels + + # text tokens for i, text in enumerate(text_list): - seq = text.to("cpu", dtype=torch.int64) + if isinstance(text, torch.Tensor): + seq = text + text_start + else: + seq = torch.Tensor([text_start + text]).to(device=device, dtype=dtype) + input_ids[i].append( seq ) + input_ids[i].append( sep ) + + # lang tokens + for i, lang in enumerate(lang_list): + if isinstance(lang, torch.Tensor): + seq = lang + lang_start + else: + seq = torch.Tensor([lang_start + lang]).to(device=device, dtype=dtype) input_ids[i].append( seq ) input_ids[i].append( sep ) - offset = text_tokens # inject target quant_level if quant_levels is not None: for i, rvq in enumerate( quant_levels ): - seq = torch.Tensor([offset + rvq]).to("cpu", dtype=torch.int64) + if isinstance(rvq, torch.Tensor): + seq = rvq + rvq_start + else: + seq = torch.Tensor([rvq_start + rvq]).to(device=device, dtype=dtype) input_ids[i].append( seq ) input_ids[i].append( sep ) - offset = text_tokens + audio_rvq_levels + # prom / task tokens for i, prom in enumerate(prom_list): - # deinterleaved - if quant_levels is not None: - quant_level = quant_levels[i] - if ignore_index is not None: - seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] ) ] ).to("cpu", dtype=torch.int64) - else: - seq = prom[:, quant_level].to("cpu", dtype=torch.int64) - for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * quant_level ) - # interleaved + # list of proms with a possible task token + if isinstance(prom, list): + for p in prom: + process_prom_or_task(i, p) + # raw tensor else: - if ignore_index is not None: - seq = torch.Tensor( [ ignore_index for _ in range( prom.shape[0] * prom.shape[1] ) ] ).to("cpu", dtype=torch.int64) - else: - seq = prom.flatten().to("cpu", dtype=torch.int64) - for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + process_prom_or_task(i, prom) + # tone tokens + for i, tone in enumerate(tone_list): + if isinstance(tone, torch.Tensor): + seq = tone + tone_start + else: + seq = torch.Tensor([tone_start + tone]).to(device=device, dtype=dtype) input_ids[i].append( seq ) input_ids[i].append( sep ) - - offset = text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) + # resp tokens for i, resp in enumerate(resp_list): # deinterleaved if quant_levels is not None: # grab the previous rvq level quant_level = quant_levels[i] - 1 # way to signal we want to inference for rvq level 0 - # without it, it's a random chance for any level to be selected again - + # without it, it's a random chance for any level to be selected again if quant_level < 0: continue - - seq = sep else: # my shitcode keeps things as lists of tensors for each level, so this handles it because lists can't index by tuples if isinstance(resp, list): - seq = resp[quant_level].to("cpu", dtype=torch.int64) + seq = resp[quant_level].to(device=device, dtype=dtype).clone() else: - seq = resp[:, quant_level].to("cpu", dtype=torch.int64) + seq = resp[:, quant_level].to(device=device, dtype=dtype).clone() for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * quant_level ) - + token += resp_start + ( config.audio_tokens * quant_level ) input_ids[i].append( seq ) input_ids[i].append( stop ) # interleaved else: - seq = resp.flatten().to("cpu", dtype=torch.int64) + seq = resp.flatten().to(device=device, dtype=dtype) for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + token += resp_start + ( config.audio_tokens * ( idx % config.resp_levels ) ) input_ids[i].append( seq ) input_ids[i].append( stop ) + # targ list for i, resp in enumerate(targ_list): # deinterleaved if quant_levels is not None: quant_level = quant_levels[i] - seq = resp[:, quant_level].to("cpu", dtype=torch.int64) + seq = resp[:, quant_level].to(device=device, dtype=dtype) for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * quant_level ) + token += resp_start + ( config.audio_tokens * quant_level ) input_ids[i].append( seq ) input_ids[i].append( stop ) # interleaved else: - seq = resp.flatten().to("cpu", dtype=torch.int64) + seq = resp.flatten().to(device=device, dtype=dtype) for idx, token in enumerate( seq ): - token += offset + ( audio_tokens * ( idx % audio_rvq_levels ) ) + token += resp_start + ( config.audio_tokens * ( idx % config.resp_levels ) ) input_ids[i].append( seq ) input_ids[i].append( stop ) for i, batch in enumerate(input_ids): - input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=torch.int64) + input_ids[i] = torch.concat(input_ids[i], dim=-1).to(device=device, dtype=dtype) return list_to_tensor(input_ids) @@ -174,20 +246,62 @@ def unfold_outputs( sep = 3, stop = 3, - text_tokens = 256, - audio_tokens = 1024, - audio_rvq_levels = cfg.model.max_levels, + config = None, quant_levels = None, ): + def bin_to_rvqs( tokens ): + length = len(tokens) + """ + if length % config.resp_levels == 0: + tokens = torch.Tensor(tokens).reshape( config.resp_levels, length // config.resp_levels ).t() + """ + bins = [ [] for _ in range(config.resp_levels) ] + for pos in range( length ): + rvq = pos % config.resp_levels + bins[rvq].append( tokens[pos] ) + nearest = ( len(bins) // config.resp_levels ) * config.resp_levels + bins = bins[:nearest] + return torch.Tensor(bins).t().to(device=device, dtype=dtype) + + if config is None: + config = cfg.model + device = output_ids.device + dtype = torch.int64 + batch_size = output_ids.shape[0] text_list = [ [] for _ in range(batch_size) ] + rvq_list = [ [] for _ in range(batch_size) ] + lang_list = [ [] for _ in range(batch_size) ] + task_list = [ [] for _ in range(batch_size) ] + tone_list = [ [] for _ in range(batch_size) ] prom_list = [ [] for _ in range(batch_size) ] resp_list = [ [] for _ in range(batch_size) ] + text_start = 0 + text_end = text_start + config.text_tokens + + lang_start = text_end + lang_end = lang_start + config.langs + + rvq_start = lang_end + rvq_end = rvq_start + config.resp_levels + + prom_start = rvq_end + prom_end = prom_start + config.audio_tokens * config.resp_levels + + task_start = prom_end + task_end = task_start + config.tasks + + tone_start = task_end + tone_end = tone_start + config.tones + + resp_start = tone_end + resp_end = resp_start + config.audio_tokens * config.resp_levels + for i, batch in enumerate( output_ids ): - # crigne logic to handle prefix resp for rvq levels > 0 + # cringe logic to handle prefix resp for rvq levels > 0 # a better way is to observe if the rvq level increased should_flush = False flushed = False @@ -201,49 +315,51 @@ def unfold_outputs( continue - if 0 <= id and id < text_tokens: - text_list[i].append( id ) - elif text_tokens + audio_rvq_levels <= id and id < text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels): - prom_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens ) - elif text_tokens + audio_rvq_levels + (audio_tokens * audio_rvq_levels) <= id: - resp_list[i].append( (id - text_tokens - audio_rvq_levels) % audio_tokens ) + # text tokens + if text_start <= id and id < text_end: + text_list[i].append( (id - text_start) % config.text_tokens ) + # lang tokens + elif lang_start <= id and id < lang_end: + lang_list[i].append( (id - lang_start) % config.langs ) + # rvq levels + elif rvq_start <= id and id < rvq_end: + rvq_list[i].append( (id - rvq_start) % config.resp_levels ) + # prom tokens + elif prom_start <= id and id < prom_end: + prom_list[i].append( (id - prom_start) % config.audio_tokens ) + # task tokens + elif task_start <= id and id < task_end: + task_list[i].append( (id - task_start) % config.tasks ) + # lang tokens + elif tone_start <= id and id < tone_end: + tone_list[i].append( (id - tone_start) % config.tones ) + # resp tokens + elif resp_start <= id and id < resp_end: + resp_list[i].append( (id - resp_start) % config.audio_tokens ) + if not flushed: should_flush = True if quant_levels is not None: - prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=torch.int64) - resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=torch.int64) + prom_list[i] = torch.Tensor(prom_list[i]).t().to(device=device, dtype=dtype) + resp_list[i] = torch.Tensor(resp_list[i]).t().to(device=device, dtype=dtype) else: - prom_len = len(prom_list[i]) - if prom_len % audio_rvq_levels == 0 and False: - prom_list[i] = torch.Tensor(prom_list[i]).reshape( audio_rvq_levels, prom_len // audio_rvq_levels ).t() - else: - bins = [ [] for _ in range(audio_rvq_levels) ] - for pos in range( prom_len ): - rvq = pos % audio_rvq_levels - bins[rvq].append( prom_list[i][pos] ) - nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels - bins = bins[:nearest] - prom_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) + prom_list[i] = bin_to_rvqs( prom_list[i] ) + resp_list[i] = bin_to_rvqs( resp_list[i] ) - resp_len = len(resp_list[i]) - if len(resp_list[i]) % audio_rvq_levels == 0 and False: - resp_list[i] = torch.Tensor(resp_list[i]).reshape( audio_rvq_levels, resp_len // audio_rvq_levels ).t() - else: - bins = [ [] for _ in range(audio_rvq_levels) ] - for pos in range( resp_len ): - rvq = pos % audio_rvq_levels - bins[rvq].append( resp_list[i][pos] ) - nearest = ( len(bins) // audio_rvq_levels ) * audio_rvq_levels - bins = bins[:nearest] - resp_list[i] = torch.Tensor(bins).t().to(device=device, dtype=torch.int64) - - text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=torch.int64) + text_list[i] = torch.Tensor( text_list[i] ).to(device=device, dtype=dtype) + task_list[i] = torch.Tensor( task_list[i] ).to(device=device, dtype=dtype) + lang_list[i] = torch.Tensor( lang_list[i] ).to(device=device, dtype=dtype) + tone_list[i] = torch.Tensor( tone_list[i] ).to(device=device, dtype=dtype) return dict( text_list=text_list, prom_list=prom_list, - resp_list=resp_list + resp_list=resp_list, + + task_list=task_list, + lang_list=lang_list, + tone_list=tone_list, ) # to-do: clean up this symmap mess @@ -1072,7 +1188,7 @@ def _create_dataloader(dataset, training): shuffle=False, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, drop_last=training, - sampler=dataset.sampler, + sampler=dataset.sampler if training else None, ) if not isinstance(dataset.sampler, BatchedOrderedSampler) else dict( batch_sampler=dataset.sampler, ) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 68d9282..e3020dd 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -137,50 +137,35 @@ class AR_NAR(Base): # is training if training: + # specifies how to sample probabilities of which RVQ levels to train against p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" - # determines which RVQ level to target per batch - quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ] - + quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] + # rate to perform token dropout errors token_dropout_error = self.config.experimental.token_dropout_error + # RVQ levels to apply token dropout on token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels + # implicitly set it to all levels if not token_dropout_rvq_levels: - token_dropout_rvq_levels = [0, self.resp_levels] - - if p_rvq_levels == "equal": + token_dropout_rvq_levels = [0, self.resp_levels - 1] + # allow passing a specific distribution of RVQ levels + p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else [] + if not p_rvq_levels: + lo, hi = quant_level_range[0], quant_level_range[1] + 1 # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] - else: # if p_rvq_levels == "auto": - # makes higher levels less likely - """ - def generate( lo=0, hi=8 ): - index = lo - p = random.random() - for i in range(lo, hi): - if p < 1.0 / (2 ** i): - index = i - return int(index) - """ - - # allow passing a specific distribution of RVQ levels - pool = p_rvq_levels if isinstance(p_rvq_levels, list) else [] - if not pool: - lo, hi = quant_level_range[0], quant_level_range[1] - for i in range( lo, hi ): - rep = hi - i - pool += [i] * rep - - quant_levels = [ random.choice( pool ) for i in range(batch_size) ] - - # these two are techinically equivalent if the audio embeddings handle things properly - """ - resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] - stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) - """ + if p_rvq_levels == "equal": + p_rvq_levels = [ i for i in range( lo, hi ) ] + else: + # yuck + p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) + # input RVQ levels + quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ] + # trim resps to only contain all levels below the target level resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] + # tensor to cat for RVQ level 0 stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16) - + # I hate python's value/reference semantics so much for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): # cap quant_level if it exceeds its corresponding resp/prom if quant_level >= resps.shape[-1]: @@ -213,7 +198,6 @@ class AR_NAR(Base): # only apply stop token for RVQ level 0 if quant_level <= 0: # append stop tokens for AR - # could technically do it in the .inputs call resps_list[i] = torch.cat([ resps, stop_sequence ]) diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index ada1753..cabd7ca 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -57,7 +57,30 @@ class Model(LlmArchClass): hf_attention = config.attention if config is not None else None gradient_checkpointing = config.gradient_checkpointing if config is not None else True # text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop - vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1 + # vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1 + + text_start = 0 + text_end = text_start + config.text_tokens + + lang_start = text_end + lang_end = lang_start + config.langs + + rvq_start = lang_end + rvq_end = rvq_start + config.resp_levels + + prom_start = rvq_end + prom_end = prom_start + config.audio_tokens * config.resp_levels + + task_start = prom_end + task_end = task_start + config.tasks + + tone_start = task_end + tone_end = tone_start + config.tones + + resp_start = tone_end + resp_end = resp_start + config.audio_tokens * config.resp_levels + + vocab_size = resp_end if cfg.model.arch_type == "llama": super().__init__(config=LlamaConfig( @@ -148,11 +171,94 @@ class Model(LlmArchClass): *args, **kwargs, ): - if cfg.model.arch_type in ["mamba","mamba2"]: + config = self.hyper_config + + if "text_list" in kwargs: + text_list = kwargs.pop("text_list", None) + proms_list = kwargs.pop("proms_list", None) + resps_list = kwargs.pop("resps_list", None) + lang_list = kwargs.pop("lang_list", None) + tone_list = kwargs.pop("tone_list", None) + + training = kwargs.pop("training", False) + steps = kwargs.pop("steps", 500) + + batch_size = len(text_list) + + if training: + quant_levels = None if config.experimental.interleave else [ random.randint( 0 if "ar" in config.capabilities else 1, config.max_levels - 1) for _ in range(batch_size) ] + + input_ids, attention_mask = fold_inputs( + text_list=text_list, + prom_list=proms_list, + resp_list=resps_list, + targ_list=resps_list, + quant_levels=quant_levels, + ) + target_ids, target_attention_mask = fold_inputs( + text_list=text_list, + prom_list=proms_list, + resp_list=resps_list, + targ_list=resps_list, + quant_levels=quant_levels, + ignore_index=-100 + ) + return self.forward( + input_ids=input_ids, + labels=target_ids, + ) + + if config.experimental.interleave: + input_ids, attention_mask = fold_inputs( text_list=text_list, prom_list=proms_list ) + output = self.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps*config.max_levels, eos_token_id=3, do_sample=False) + return unfold_outputs( output )["resp_list"] + + resps_list = [ [] for _ in range(batch_size) ] + for l in range(config.max_levels): + quant_levels = [ l for _ in range(batch_size) ] + + input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=proms_list, resp_list=resps_list, quant_levels=quant_levels) + min_length = 1 + for batch in input_ids: + min_length = max( min_length, batch.shape[0] + 1 ) + + output = self.generate( + input_ids=input_ids, + attention_mask=attention_mask, + min_length=min_length, + max_length=min_length+steps*2, + eos_token_id=3, + do_sample=False + ) + + unfolded = unfold_outputs( output, quant_levels=quant_levels ) + + if l == 0: + steps = 0 + + for batch, resp in enumerate(unfolded["resp_list"]): + length = resp.shape[-1] + + # store length + if l == 0: + steps = max( steps, length ) + # pad + else: + resp = resp[:steps] + if length < steps: + resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ]) + resps_list[batch].append( resp ) + + for i, resp in enumerate( resps_list ): + resps_list[i] = torch.stack( resp ).t() + + return resps_list + + if config.arch_type in ["mamba","mamba2"]: if "attention_mask" in kwargs: kwargs.pop("attention_mask") - labels = kwargs.pop("labels") if "labels" in kwargs else None + labels = kwargs.pop("labels", None) output = super().forward(*args, **kwargs) logits = output.logits @@ -322,53 +428,8 @@ def example_usage(): @torch.inference_mode() def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ): engine.eval() - batch_size = len(text_list) - resp_list = None - if cfg.model.experimental.interleave: - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list) - output = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=steps, eos_token_id=3, do_sample=False) - - unfolded = unfold_outputs( output ) - resp_list = unfolded["resp_list"] - else: - resp_list = [ [] for _ in range(batch_size) ] - for l in range(cfg.model.max_levels): - quant_levels = [ l for _ in range(batch_size) ] - - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, quant_levels=quant_levels) - min_length = 1 - for batch in input_ids: - min_length = max( min_length, batch.shape[0] + 1 ) - - output = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - min_length=min_length, - max_length=min_length+steps*2, - eos_token_id=3, - do_sample=False - ) - - unfolded = unfold_outputs( output, quant_levels=quant_levels ) - - if l == 0: - steps = 0 - - for batch, resp in enumerate(unfolded["resp_list"]): - length = resp.shape[-1] - - # store length - if l == 0: - steps = max( steps, length ) - # pad - else: - resp = resp[:steps] - if length < steps: - resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ]) - resp_list[batch].append( resp ) - - for i, resp in enumerate( resp_list ): - resp_list[i] = torch.stack( resp ).t() + + resp_list = model( text_list=text_list, proms_list=prom_list ) for i, batch in enumerate(resp_list): _ = decode_to_file(batch.to(device=device), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) @@ -380,19 +441,8 @@ def example_usage(): t = trange(steps) for i in t: stats = {"step": i} - - batch_size = len(text_list) - quant_levels = None if cfg.model.experimental.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,)) - if quant_levels is not None: - resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ] - else: - resps_list = [ resp for resp in resp_list ] - - input_ids, attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resps_list, targ_list=resp_list, quant_levels=quant_levels) - target_ids, target_attention_mask = fold_inputs(text_list=text_list, prom_list=prom_list, resp_list=resp_list, targ_list=resp_list, ignore_index=-100, quant_levels=quant_levels) - - stats |= engine.traverse(input_ids=input_ids, labels=target_ids, attention_mask=attention_mask) + stats |= engine.traverse(text_list=text_list, proms_list=prom_list, resps_list=resp_list, training=True) stats |= engine.gather_attribute("stats") stats |= {"grad_norm": engine.get_global_grad_norm()} diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 571ccdd..c231c20 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -133,46 +133,62 @@ class NAR(Base): # generate task list to train against task_list = [ sample_task() for _ in range(batch_size) ] - # determines which RVQ level to target per batch - quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ] - + # specifies how to sample probabilities of which RVQ levels to train against p_rvq_levels = self.config.experimental.p_rvq_levels if self.config is not None else "equal" - - if p_rvq_levels == "equal": + # determines which RVQ level to target per batch + quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels - 1 ] + # rate to perform token dropout errors + token_dropout_error = self.config.experimental.token_dropout_error + # RVQ levels to apply token dropout on + token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels + # implicitly set it to all levels + if not token_dropout_rvq_levels: + token_dropout_rvq_levels = [0, self.resp_levels - 1] + # allow passing a specific distribution of RVQ levels + p_rvq_levels = p_rvq_levels if isinstance(p_rvq_levels, list) else [] + if not p_rvq_levels: + lo, hi = quant_level_range[0], quant_level_range[1] + 1 # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) - quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ] - else: # if p_rvq_levels == "auto": - # makes higher levels less likely - """ - def generate( lo=0, hi=8 ): - index = lo - p = random.random() - for i in range(lo, hi): - if p < 1.0 / (2 ** i): - index = i - return int(index) - """ + if p_rvq_levels == "equal": + p_rvq_levels = [ i for i in range( lo, hi ) ] + else: + # yuck + p_rvq_levels = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], []) - # allow passing a specific distribution of RVQ levels - pool = p_rvq_levels if isinstance(p_rvq_levels, list) else [] - if not pool: - lo, hi = quant_level_range[0], quant_level_range[1] - for i in range( lo, hi ): - rep = hi - i - pool += [i] * rep + # input RVQ levels + quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ] + # trim resps to only contain all levels below the target level + resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] - quant_levels = [ random.choice( pool ) for i in range(batch_size) ] - - # clamp quant_levels because some of my audio was saved for only 8 out of 9 RVQ levels for DAC... - for i in range(batch_size): + # I hate python's value/reference semantics so much + for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): # cap quant_level if it exceeds its corresponding resp/prom - if quant_levels[i] >= resps_list[i].shape[-1]: - quant_levels[i] = resps_list[i].shape[-1] - 1 + if quant_level >= resps.shape[-1]: + quant_levels[i] = resps.shape[-1] - 1 - if quant_levels[i] >= proms_list[i].shape[-1]: - quant_levels[i] = proms_list[i].shape[-1] - 1 + # proms could be a Tensor, list[Tensor], or None + if isinstance( proms, torch.Tensor ): + if quant_level >= proms.shape[-1]: + quant_levels[i] = proms.shape[-1] - 1 - resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] + elif isinstance( proms, list ): + for j, prom in enumerate( proms ): + if not isinstance( prom, torch.Tensor ): + continue + + if quant_level >= prom.shape[-1]: + quant_levels[i] = prom.shape[-1] - 1 + + # apply token dropout error compensation + if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]): + steps = resps.shape[0] + for l in range( quant_level ): + for t in range( steps ): + token = resps[t, l].item() + + if random.random() < token_dropout_error: + offset = 1 * ( 1 if random.random() < 0.5 else -1 ) + resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 inputs = self.inputs( text_list=text_list, diff --git a/vall_e/train.py b/vall_e/train.py index 230b2d7..e075f7b 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -31,44 +31,16 @@ def train_feeder(engine, batch): batch_size = len(batch["text"]) engine.current_batch_size = batch_size - if engine.hyper_config.experimental.hf: - if engine.hyper_config.experimental.interleave: - quant_levels = 0 - resps_list = [ resp for resp in batch["resps"] ] - else: - quant_levels = [ random.randint( 0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels) for _ in range(batch_size) ] - resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ] + engine( + text_list=batch["text"], + proms_list=batch["proms"], + resps_list=batch["resps"], + lang_list=batch["lang"], + tone_list=batch["tone"], + task_list=batch["task"], - input_ids, attention_mask = fold_inputs( - text_list=batch["text"], - prom_list=batch["proms"], - resp_list=resps_list, - targ_list=batch["resps"], - quant_levels=quant_levels, - ) - target_ids, target_attention_mask = fold_inputs( - text_list=batch["text"], - prom_list=batch["proms"], - resp_list=resps_list, - targ_list=batch["resps"], - quant_levels=quant_levels, - ignore_index=-100 - ) - engine( - input_ids=input_ids, - labels=target_ids, - ) - else: - engine( - text_list=batch["text"], - proms_list=batch["proms"], - resps_list=batch["resps"], - lang_list=batch["lang"], - tone_list=batch["tone"], - task_list=batch["task"], - - training=True, - ) + training=True, + ) losses = engine.gather_attribute("loss") stat = engine.gather_attribute("stats") @@ -137,66 +109,18 @@ def run_eval(engines, eval_name, dl): engine = engines[name] if engine.hyper_config.experimental.hf: - if engine.hyper_config.experimental.interleave: - input_ids, attention_mask = fold_inputs( - text_list=batch["text"], - prom_list=batch["proms"], - ) - output = engine.module.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=cfg.evaluation.steps, eos_token_id=3, do_sample=False) - resps_list = unfold_outputs( output )["resp_list"] - else: - steps = cfg.evaluation.steps - resps_list = [ [] for _ in range(len(text_list)) ] - for l in range(cfg.model.max_levels): - quant_levels = [ [ l ] for _ in range(len(text_list)) ] - - input_ids, attention_mask = fold_inputs(text_list=batch["text"], prom_list=batch["proms"], resp_list=resps_list, quant_levels=quant_levels, experimental=True) - min_length = 1 - for batch in input_ids: - min_length = max( min_length, batch.shape[0] + 1 ) - - output = model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - min_length=min_length, - max_length=min_length+steps*(2 if l > 0 else 1), - eos_token_id=3, - do_sample=False - ) - - unfolded = unfold_outputs( output, quant_levels=quant_levels ) - - if l == 0: - steps = 0 - - for batch, resp in enumerate(unfolded["resp_list"]): - length = resp.shape[-1] - - # store length - if l == 0: - steps = max( steps, length ) - # pad - else: - resp = resp[:steps] - if length < steps: - resp = torch.cat([ resp, torch.Tensor([ 0 for _ in range(steps-length) ]).to(resp) ]) - - resps_list[batch].append( resp ) - - for i, resp in enumerate( resps_list ): - resps_list[i] = torch.stack( resp ).t() + resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"] ) + elif "len" in engine.hyper_config.capabilities: + len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that + resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels ) else: - if "len" in engine.hyper_config.capabilities: - len_list = engine(text_list=batch["text"], proms_list=batch["proms"], max_steps=10 ) # don't need more than that - resps_list = engine( text_list=batch["text"], proms_list=batch["proms"], len_list=len_list, max_levels=cfg.evaluation.nar_levels ) + if "ar" in engine.hyper_config.capabilities: + resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) else: - if "ar" in engine.hyper_config.capabilities: - resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature) - else: - resps_list = [ resp[:, 0] for resp in batch["resps"] ] + resps_list = [ resp[:, 0] for resp in batch["resps"] ] - if "nar" in engine.hyper_config.capabilities: - resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels ) + if "nar" in engine.hyper_config.capabilities: + resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature, max_levels=cfg.evaluation.nar_levels ) process( name, batch, resps_list ) diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 37d8ac4..d8199f3 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -178,3 +178,6 @@ def to_device(x: T | None, *args, **kwargs) -> T: return return tree_map(lambda t: t.to(*args, **kwargs), x) + +def coalese( *arg, return_last=True ): + return [ x for x in arg if x is not None ][-1 if return_last else 0] \ No newline at end of file