From 04fef5dad5052bdd9c08066a44723acb7d1b234f Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 12 Feb 2025 00:18:24 -0600 Subject: [PATCH] agony --- docs/emb.md | 2 + scripts/process_emilia.py | 120 ++----- scripts/process_libritts.py | 129 ++----- vall_e/config.py | 3 + vall_e/models/ar_nar.py | 264 +++++++++++++- vall_e/models/arch/llama.py | 4 +- vall_e/models/base.py | 692 ++++++++++++++++++++++++------------ vall_e/samplers.py | 8 +- 8 files changed, 797 insertions(+), 425 deletions(-) diff --git a/docs/emb.md b/docs/emb.md index 39bc4ae..eafb301 100644 --- a/docs/emb.md +++ b/docs/emb.md @@ -90,9 +90,11 @@ However, because this codec relies on FSQ (Finite Scalar Quantization) rather th Proposed architectures may include: * independent NAR-demasking for *all* levels, rather than FSQ level 0. * little additional code is required, as existing NAR-demasking training/inference code can be repurposed for additional levels. + * this also has the best backwards compat with vall_e.cpp, as no extra model code is required. * parallel decoding for *all* levels in one pass, rather than separate passes for each level. * some extra code would be required for orchestrating the additional decoding heads in parallel. * the decoding heads may simply be a single `nn.Linear` classifier, or additional transformer blocks. + * the former yields bad results when overfitting, the latter without an output projection head allows for overfitting. ## `transcribe.py` diff --git a/scripts/process_emilia.py b/scripts/process_emilia.py index ae1eac8..a1b7e54 100644 --- a/scripts/process_emilia.py +++ b/scripts/process_emilia.py @@ -18,55 +18,32 @@ from vall_e.config import cfg from vall_e.emb.g2p import encode as phonemize from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio -def pad(num, zeroes): - return str(num).zfill(zeroes+1) - -def process_items( items, stride=0, stride_offset=0 ): - items = sorted( items ) - return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] - -def load_audio( path, device="cuda" ): - waveform, sample_rate = torchaudio.load(path) - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) - waveform = convert_audio(waveform, sample_rate, cfg.sample_rate, 1) - return waveform.to(device=device), cfg.sample_rate +from vall_e.emb.process import pad, load_audio, process_items, process_jobs def process( - audio_backend="encodec", - input_audio="Emilia", - output_dataset="training", - raise_exceptions=False, - stride=0, - stride_offset=0, - slice="auto", - - device="cuda", - dtype="float16", - amp=False, - ): - # encodec / vocos - - if audio_backend in ["encodec", "vocos"]: - audio_extension = ".enc" - cfg.sample_rate = 24_000 - cfg.model.resp_levels = 8 - elif audio_backend == "dac": - audio_extension = ".dac" - cfg.sample_rate = 44_100 - cfg.model.resp_levels = 9 - elif cfg.audio_backend == "audiodec": - sample_rate = 48_000 - audio_extension = ".dec" - cfg.model.resp_levels = 8 # ? - else: - raise Exception(f"Unknown audio backend: {audio_backend}") + audio_backend="encodec", + input_audio="Emilia", + output_dataset="training", + raise_exceptions=False, + stride=0, + stride_offset=0, + slice="auto", + batch_size=1, + low_memory=False, + device="cuda", + dtype="float16", + amp=False, +): # prepare from args - cfg.audio_backend = audio_backend # "encodec" + cfg.device = device + cfg.set_audio_backend(audio_backend) + audio_extension = cfg.audio_backend_extension + cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False + dtype = cfg.inference.dtype if not amp else None output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" @@ -145,58 +122,11 @@ def process( if waveform is None: waveform, sample_rate = load_audio(inpath) - wavs.append(( - outpath, - text, - language, - waveform, - sample_rate - )) + jobs.append(( outpath, waveform, sample_rate, text, language.lower() )) - if len(wavs) > 0: - for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): - try: - outpath, text, language, waveform, sample_rate = job - - phones = phonemize(text, language=f'{language}'.lower()) - qnt = quantize(waveform, sr=sample_rate, device=device) - - - if cfg.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, - }, - }) - else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": waveform.shape[-1], - "sample_rate": sample_rate, - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, - }, - }) - except Exception as e: - print(f"Failed to quantize: {outpath}:", e) - if raise_exceptions: - raise e - continue + # processes audio files one at a time + process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) + jobs = [] open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) @@ -214,6 +144,8 @@ def main(): parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") + parser.add_argument("--low-memory", action="store_true") + parser.add_argument("--batch-size", type=int, default=0) args = parser.parse_args() @@ -232,6 +164,8 @@ def main(): stride=args.stride, stride_offset=args.stride_offset, slice=args.slice, + batch_size=args.batch_size, + low_memory=args.low_memory, device=args.device, dtype=args.dtype, diff --git a/scripts/process_libritts.py b/scripts/process_libritts.py index 7152ac9..dbb8813 100755 --- a/scripts/process_libritts.py +++ b/scripts/process_libritts.py @@ -15,52 +15,36 @@ from pathlib import Path from vall_e.config import cfg -def pad(num, zeroes): - return str(num).zfill(zeroes+1) +from vall_e.emb.g2p import encode as phonemize +from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio + +from vall_e.emb.process import pad, load_audio, process_items, process_jobs -def process_items( items, stride=0, stride_offset=0 ): - items = sorted( items ) - return items if stride == 0 else [ item for i, item in enumerate( items ) if (i+stride_offset) % stride == 0 ] def process( - audio_backend="encodec", - input_audio="LibriTTS_R", - output_dataset="training", - raise_exceptions=False, - stride=0, - stride_offset=0, - slice="auto", - - device="cuda", - dtype="float16", - amp=False, - ): - # encodec / vocos - - if audio_backend in ["encodec", "vocos"]: - audio_extension = ".enc" - cfg.sample_rate = 24_000 - cfg.model.resp_levels = 8 - elif audio_backend == "dac": - audio_extension = ".dac" - cfg.sample_rate = 44_100 - cfg.model.resp_levels = 9 - elif cfg.audio_backend == "audiodec": - sample_rate = 48_000 - audio_extension = ".dec" - cfg.model.resp_levels = 8 # ? - else: - raise Exception(f"Unknown audio backend: {audio_backend}") + audio_backend="encodec", + input_audio="LibriTTS_R", + output_dataset="training", + raise_exceptions=False, + stride=0, + stride_offset=0, + slice="auto", + batch_size=1, + low_memory=False, + device="cuda", + dtype="float16", + amp=False, +): # prepare from args - cfg.audio_backend = audio_backend # "encodec" + cfg.device = device + cfg.set_audio_backend(audio_backend) + audio_extension = cfg.audio_backend_extension + cfg.inference.weight_dtype = dtype # "bfloat16" cfg.inference.amp = amp # False - # import after because we've overriden the config above - # need to validate if this is even necessary anymore - from vall_e.emb.g2p import encode as phonemize - from vall_e.emb.qnt import encode as quantize, _replace_file_extension + dtype = cfg.inference.dtype if not amp else None output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" @@ -140,62 +124,19 @@ def process( continue if waveform is None: - waveform, sample_rate = torchaudio.load(inpath) - if waveform.shape[0] > 1: - waveform = torch.mean(waveform, dim=0, keepdim=True) + waveform, sample_rate = load_audio(inpath) - wavs.append(( - outpath, - text, - language, - waveform, - sample_rate - )) + jobs.append(( outpath, waveform, sample_rate, text, language )) - if len(wavs) > 0: - for job in tqdm(wavs, desc=f"Quantizing: {speaker_id}"): - try: - outpath, text, language, waveform, sample_rate = job - - phones = phonemize(text, language=language) - qnt = quantize(waveform, sr=sample_rate, device=device) - - - if cfg.audio_backend == "dac": - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.codes.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": qnt.original_length, - "sample_rate": qnt.sample_rate, - - "input_db": qnt.input_db.cpu().numpy().astype(np.float32), - "chunk_length": qnt.chunk_length, - "channels": qnt.channels, - "padding": qnt.padding, - "dac_version": "1.0.0", - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, - }, - }) - else: - np.save(open(_replace_file_extension(outpath, audio_extension), "wb"), { - "codes": qnt.cpu().numpy().astype(np.uint16), - "metadata": { - "original_length": waveform.shape[-1], - "sample_rate": sample_rate, - - "text": text.strip(), - "phonemes": "".join(phones), - "language": language, - }, - }) - except Exception as e: - print(f"Failed to quantize: {outpath}:", e) - if raise_exceptions: - raise e - continue + # processes audio files one at a time + if low_memory: + process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) + jobs = [] + + # processes all audio files for a given speaker + if not low_memory: + process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) + jobs = [] open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) @@ -213,6 +154,8 @@ def main(): parser.add_argument("--stride", type=int, default=0) parser.add_argument("--stride-offset", type=int, default=0) parser.add_argument("--slice", type=str, default="auto") + parser.add_argument("--low-memory", action="store_true") + parser.add_argument("--batch-size", type=int, default=0) args = parser.parse_args() @@ -231,6 +174,8 @@ def main(): stride=args.stride, stride_offset=args.stride_offset, slice=args.slice, + batch_size=args.batch_size, + low_memory=args.low_memory, device=args.device, dtype=args.dtype, diff --git a/vall_e/config.py b/vall_e/config.py index becd445..226db10 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -259,6 +259,9 @@ class ModelExperimentalSettings: # it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token # RetNet's chunked inferencing might be a better place for this + parallel_decoding: bool = False # enables some settings to decode ALL RVQ levels in one pass + # this is a bit of a pain to get working in the test trainer + masking_train_p: float = 0.0 # odds of training with masking masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 960c566..3d34819 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -75,6 +75,11 @@ class AR_NAR(Base): token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels # RVQ levels to apply masking training on masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels + + if self.version >= 7: + masking_train_rvq_levels = [0,self.n_resp_levels] + rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ] + # CFG cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0 cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0 @@ -127,7 +132,10 @@ class AR_NAR(Base): timesteps[i] = (timesteps[i] * 0.6) + 0.2 # trim resps to only contain all levels below the target level - resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] + if self.version < 7: + resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)] + elif not self.parallel_decoding: + resps_list = [r if t in text_task else r[..., l] for r, l, t in zip(resps_list, quant_levels, task_list)] # tensor to cat for RVQ level 0 text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16) @@ -229,6 +237,7 @@ class AR_NAR(Base): tone_list: list[Tensor] | None = None, len_list: list[Tensor] | None = None, raw_text_list: list[Tensor] | None = None, + quant_levels: list[int] | None = None, disable_tqdm=False, use_lora=None, @@ -237,7 +246,11 @@ class AR_NAR(Base): device = text_list[0].device batch_size = len(text_list) - level = 0 + if quant_levels is None: + level = 0 + else: + level = quant_levels[0] # ugh + if cfg.lora is not None: enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) @@ -306,11 +319,12 @@ class AR_NAR(Base): # fill scores scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ] - quant_levels = [ level for _ in range(batch_size) ] + if quant_levels is None: + quant_levels = [ level for _ in range(batch_size) ] null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] null_prom = [ None for _ in range(batch_size) ] - iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm) + iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc=f"NAR Masked Level {level}", disable=disable_tqdm) for timestep in iterator: # update previous list of tokens prev_list = resps_list @@ -430,6 +444,183 @@ class AR_NAR(Base): return resps_list + # handles doing demasking inferencing in parallel to inference all tokens + # it works if the underlying model is trained properly (which is a pain) + def forward_nar_masked_parallel( + self, + + task_list: list[Tensor] | None = None, + + text_list: list[Tensor] | None = None, + proms_list: list[Tensor] | None = None, + resps_list: list[Tensor] | None = None, + + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + len_list: list[Tensor] | None = None, + raw_text_list: list[Tensor] | None = None, + + disable_tqdm=False, + use_lora=None, + **sampling_kwargs, + ): + device = text_list[0].device + batch_size = len(text_list) + + level = 0 + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) + + # convert (N)AR specific args + sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" ) + + min_length = sampling_kwargs.pop("min_duration", 1) + max_length = sampling_kwargs.pop("max_duration", 500) + max_steps = sampling_kwargs.get("max_steps", 25) + refine_on_stop = sampling_kwargs.get("refine_on_stop", False) + entropix_sampling = sampling_kwargs.get("entropix_sampling", False) + annealed_sampling = sampling_kwargs.get("annealed_sampling", True) + + # greedy sampling is very, very much preferred, but using greedy logit scores later helps enough + temperature = sampling_kwargs.pop("temperature", 0.0) + minimum_cfg_strength = sampling_kwargs.get("minimum_cfg_strength", 2.5) + # this really helps keep audio coherent so far + cfg_strength = sampling_kwargs.get("cfg_strength", minimum_cfg_strength) + cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) + start_noise = sampling_kwargs.get("denoise_start", 0.0) + end_noise = sampling_kwargs.get("denoise_end", 1.0) + remasking = sampling_kwargs.get("remasking", True) + max_steps = math.floor(max_steps * (end_noise - start_noise)) + + # to specify the initial mask used + vc_list = sampling_kwargs.pop("vc_list", None) + vc_threshold = sampling_kwargs.pop("vc_threshold", 0.25) + vc_mask_p = sampling_kwargs.pop("vc_mask_p", 0.25) + + len_list = [ clamp(l, min_length, max_length) for l in len_list ] + + # force set CFG because too low / no CFG causes issues + original_cfg_strength = cfg_strength + cfg_strength = max( cfg_strength, minimum_cfg_strength ) + + prefix_context = sampling_kwargs.get("prefix_context", None) + # fill with masked tokens (even though they get masked anyways) + resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ] + # fill scores + scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ] + + quant_levels = [ level for _ in range(batch_size) ] + null_text = [ torch.tensor([1, 2], device=device, dtype=torch.int16) for _ in range(batch_size) ] + null_prom = [ None for _ in range(batch_size) ] + + iterator = tqdm(torch.linspace(start_noise, end_noise, max_steps), desc="NAR Masked", disable=disable_tqdm) + for timestep in iterator: + # update previous list of tokens + prev_list = resps_list + # ramp down over time + annealing = 1.0 - timestep + # get noise level, per cosine scheduling + noise_p = math.cos( timestep * math.pi * 0.5 ) + # proportion of tokens to remask + remask_p = 1.0 / (max_steps * 2) if remasking else 0 + # pick the worst scoring tokens to mask off + masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] + # normal masking + # mask off inputs + resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.stop_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ] + # boolean mask + is_masked = [ resps == self.stop_token for resps in resps_list ] + # timestep inputs + time_list = [ timestep for _ in range(batch_size) ] + + sampling_temperature = temperature * annealing if annealed_sampling else temperature + sampling_cfg = cfg_strength * timestep if annealed_sampling else cfg_strength + + input_resps_list = resps_list + + # setup inputs + inputs = super().inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=input_resps_list, + lang_list=lang_list, + tone_list=tone_list, + time_list=time_list, + quant_levels=quant_levels, + ) + output = super().forward( + inputs=inputs, + quant_levels=quant_levels, + ) + + logits = output.logits + if cfg_strength > 0: + null_inputs = super().inputs( + text_list=null_text, + proms_list=null_prom, + resps_list=input_resps_list, + lang_list=lang_list, + tone_list=tone_list, + time_list=time_list, + quant_levels=quant_levels, + ) + null_output = super().forward( + inputs=null_inputs, + quant_levels=quant_levels, + ) + + logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] ) + + l_scores = [] + l_resps_list = [] + # cringe hack because we're able to sample multiple levels at once + for l in range(self.n_resp_levels): + # sample with sampler settings + filtered_sampled = super().sample( + logits=[ logit[l] for logit in logits ], + prev_list=[ resp[..., l] for resp in prev_list ], + quant_levels=quant_levels, + + temperature=sampling_temperature, + **sampling_kwargs, + ) + + # retrieves unfiltered logits + unfiltered_sampled = super().sample( + logits=[ logit[l] for logit in logits ], + prev_list=[ resp[..., l] for resp in prev_list ], + quant_levels=quant_levels, + + temperature=0.0, + **sampling_kwargs, + ) + + # get sampled tokens + sampled_ids = filtered_sampled.ids + # keep unmasked tokens + l_resps_list.append([ torch.where( masked[..., l], input_ids, resps[..., l] ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]) + # get probability scores + l_scores.append([ + # conjugate to have worse scoring tokens picked for topk + 1.0 - + # only keep scores of tokens we are predicting (and ignore the tokens previously finalized) + torch.where( masked[..., l], torch.tensor([score for index, score in enumerate(scores)], device=device), torch.ones(masked[..., l].shape, device=device) ) + # use unmodified logit scores for this, as it offers better stability + for scores, masked in zip( unfiltered_sampled.scores, is_masked ) + ]) + + resps_list = [] + scores = [] + + for batch_index in range(batch_size): + score = sum([ l_scores[level][batch_index] for level in range(self.n_resp_levels) ]) / self.n_resp_levels + resp = torch.stack([ l_resps_list[level][batch_index] for level in range(self.n_resp_levels) ], dim=-1) + + scores.append( score ) + resps_list.append( resp ) + + return resps_list + def forward_nar( self, task_list: list[Tensor] | None = None, @@ -911,6 +1102,56 @@ class AR_NAR(Base): # is NAR if (len_list is not None or resps_list is not None) and text_list is not None: + if self.version >= 7: + if self.parallel_decoding: + return self.forward_nar_masked_parallel( + task_list=task_list, + + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + raw_text_list=raw_text_list, + + disable_tqdm=disable_tqdm, + use_lora=use_lora, + **sampling_kwargs, + ) + else: + resps_lists = [ None for _ in range(batch_size) ] + for level in range(self.n_resp_levels): + resp_list = self.forward_nar_masked( + task_list=task_list, + + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + + lang_list=lang_list, + tone_list=tone_list, + len_list=len_list, + raw_text_list=raw_text_list, + + disable_tqdm=disable_tqdm, + use_lora=use_lora, + quant_levels=[ level for _ in range(batch_size) ], + **sampling_kwargs, + ) + + for batch_index, resp in enumerate(resp_list): + if resps_lists[batch_index] is None: + resps_lists[batch_index] = [] + + resps_lists[batch_index].append( resp ) + + for batch_index, resps in enumerate(resps_lists): + resps_lists[batch_index] = torch.stack( resps, dim=-1 ) + + return resps_lists + return self.forward_nar( task_list=task_list, @@ -988,8 +1229,8 @@ def example_usage(): resps_list = [ audio[:int(cfg.dataset.frames_per_second * 4), :] ] * batch_size kwargs = { - 'n_text_tokens': 256, - 'n_audio_tokens': 1024, + 'n_text_tokens': cfg.model.text_tokens, + 'n_audio_tokens': cfg.model.audio_tokens, 'd_model': 1024, # 256, # 1024, # 1536 'n_heads': 16, # 4, # 16, # 24 @@ -1004,7 +1245,9 @@ def example_usage(): } bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) - available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else []) + + #available_tasks = [] + (["tts-ar"] if "ar" in cfg.model.capabilities else []) + (["tts-nar"] if "len" in cfg.model.capabilities else []) + available_tasks = ["tts-nar"] model = AR_NAR(**kwargs).to(cfg.device) steps = 500 // batch_size @@ -1156,13 +1399,14 @@ def example_usage(): if task == "tts-nar": len_list = engine( text_list=text_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 ) - len_list = [ resp_list[0].shape[0] for l in len_list ] + len_list = [ r.shape[0] for r in resp_list ] resps_list = engine( text_list=text_list, proms_list=proms_list, len_list=len_list ) else: resps_list = engine( text_list=text_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 ) resps_list = engine( text_list=text_list, proms_list=proms_list, resps_list=resps_list, temperature=0.0 ) for i, o in enumerate(resps_list): + print( o.shape, o ) _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.{task}.wav", device=cfg.device) unload_model() @@ -1185,7 +1429,9 @@ def example_usage(): }, f"./data/{cfg.model.arch_type}.pth" ) """ - #sample("init", 5) + task = available_tasks[0] + #sample("init", task=task) + train() """ diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index 3a3927e..48eef3b 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -435,7 +435,7 @@ class LlamaDecoderLayer_Adapted(LlamaDecoderLayer): class LlamaModel_Adapted(LlamaModel): def __init__(self, config, *args, **kwargs): - self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1) + self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0) self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1) self.early_exit_r = kwargs.pop("early_exit_r", 2) @@ -459,7 +459,7 @@ class LlamaModel_Adapted(LlamaModel): self.post_init() def dropoff_layer( self, l ): - if not self.training: + if not self.training or self.layer_dropout_p <= 0: return False # this could probably a LUT but I'm not fiending for aggressive mal-optimizations diff --git a/vall_e/models/base.py b/vall_e/models/base.py index d224871..818f06a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -37,7 +37,7 @@ from ..samplers import * from ..data import get_task_symmap # these seem more elegant than a dict -Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states', 'exited_layer']) +Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states']) Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy']) LossStats = namedtuple('LossStats', ['loss', 'stats']) @@ -245,6 +245,21 @@ class AudioEmbedding(nn.Module): return x +class AudioEmbedding_Sums(nn.Module): + def __init__( + self, + n_tokens: int, + n_levels: int, + token_dim: int, + ): + super().__init__() + self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)]) + + def forward(self, xi: Tensor ) -> Tensor: + x = sum( [ emb( xi[:, l] ) for l, emb in enumerate(self.embeddings) ] ) + + return x + # time-step embedding # for the NAR-len, since it probably most likely requires encoding the timestep class TimeEmbedding(nn.Module): @@ -272,12 +287,12 @@ class Classifiers(nn.Module): def __init__( self, l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) - token_dim: int, # dimensionality of the embedding - l_embedding_names: list[str] | None = None, # list of names to map to each classifier, + l_embedding_names: list[str], # list of names to map to each classifier, + d_model: int, # dimensionality of the embedding bias: bool = True, ): super().__init__() - self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for n_tokens in l_embedding_tokens]) + self.proj = nn.ModuleList([nn.Linear(d_model, n_tokens, bias=bias) for n_tokens in l_embedding_tokens]) self.names = l_embedding_names def indices( @@ -288,19 +303,29 @@ class Classifiers(nn.Module): return names return [ self.names.index(name) for name in names ] - def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None, stack = False ) -> Tensor: - dtype = xi.dtype - device = xi.device + def forward( + self, + xi: Tensor, + levels: list[int] | None = None, + names: list[str] | None = None, + stack = False, + ) -> Tensor: + dtype = xi[0].dtype + device = xi[0].device if levels and isinstance( levels[-1], str ): names = levels levels = [] # map names to levels + """ if names and not levels: - levels = [ self.names.index(name) for name in names ] + levels = [ None if name =="NAR" else self.names.index(name) for name in names ] + """ + if names and not levels: + levels = [ None if name not in self.names else self.names.index(name) for name in names ] - xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ] + xi = [ x if l == None else self.proj[l]( x ) for x, l in zip(xi, levels) ] if not stack: return xi @@ -316,6 +341,109 @@ class Classifiers(nn.Module): ] return torch.stack( xi ) +# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output +# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment +class ParallelDecoder(nn.Module): + def __init__( + self, + levels, + d_model, + config_kwargs, + ): + super().__init__() + + training = config_kwargs.pop("training", False) + attention_backend = config_kwargs.pop("attention_backend", "default") + gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True) + + hidden_size = config_kwargs.get("hidden_size") + vocab_size = config_kwargs.get("vocab_size") + + #self.d_model = d_model + self.vocab_size = vocab_size + + downs = [] + modules = [] + ups = [] + for level in range(levels): + module = LlamaModel_Adapted(LlamaConfig(**config_kwargs)) + + module = ml.replace_attention( module, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend ) + + if hasattr( module, "embeddings" ): + del module.embeddings + + if gradient_checkpointing and not module.gradient_checkpointing: + module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( + use_reentrant=False + )) + + modules.append(module) + """ + downs.append(nn.Linear(d_model, hidden_size, bias=False)) + ups.append(nn.Linear(hidden_size, vocab_size, bias=False)) + """ + + self.levels = levels + self.decoders = nn.ModuleList(modules) + """ + self.downs = nn.ModuleList(downs) + self.ups = nn.ModuleList(ups) + """ + + def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor: + # split into levels + if level == None: + x = [ self.forward( x, l, **kwargs ) for l in range(self.levels) ] + x = torch.stack( x ) + x = x.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification ) + return x + + # do one level + + # attention + feedforward + x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"] + # this really hates an output head, so just treat the final output as one + x = x[..., :self.vocab_size] + + """ + # downscale to head's dimensionality + x = self.downs[level]( x ) + # attention + feed forward + x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"] + # upscale to vocab logits + x = self.ups[level]( x ) + """ + + return x +""" + +""" +# naively tries to extract multiple codebooks in parallel based on the last hidden state from the model +# this doesn't work well +""" +class ClassifiersParallel(nn.Module): + def __init__( + self, + n_levels: int, # codebook levels + n_tokens: int, # output token count + token_dim: int, # dimensionality of the embedding + bias: bool = False, + ): + super().__init__() + self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens, bias=bias) for _ in range(n_levels)]) + + def forward(self, xi: Tensor, stack: bool = True ) -> Tensor: + dtype = xi.dtype + device = xi.device + + xi = [ proj( xi ) for l, proj in enumerate(self.proj) ] + xi = torch.stack( xi ) + xi = xi.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification ) + + return xi +""" + class Metrics(nn.Module): def __init__( self, @@ -448,6 +576,7 @@ class Base(nn.Module): self.causal = "ar" in self.capabilities or "len" in self.capabilities self.version = self.config.version if self.config is not None else 5 self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0) + self.parallel_decoding = self.config.experimental.parallel_decoding if self.config is not None else False self.arch_type = self.config.arch_type if self.config is not None else "llama" @@ -469,7 +598,7 @@ class Base(nn.Module): tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True - interleave = self.config.experimental.interleave if self.config is not None else False + #interleave = self.config.experimental.interleave if self.config is not None else False noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False classifiers_bias = self.config.experimental.classifiers_bias if self.config is not None else False max_position_embeddings = self.config.experimental.max_position_embeddings if self.config is not None else (75 * 60 * 5) @@ -477,40 +606,56 @@ class Base(nn.Module): masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False - layerskip = self.config.experimental.layerskip if self.config is not None else False - layerskip_r = self.config.experimental.layerskip_r if self.config is not None else 2 - layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1 - layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1 n_tasks = self.config.tasks if self.config is not None else 8 n_langs = self.config.langs if self.config is not None else 2 n_tones = self.config.tones if self.config is not None else 1 # pure AR - if "nar" not in self.capabilities: - n_resp_tokens = n_audio_tokens + 1 - l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels - l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )] - l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels - # NAR-len model - elif "len" in self.capabilities: - # +1 to include the stop or mask token - n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) - if "ar" in self.capabilities: - l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens] - l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1] - l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0'] + if self.version < 7: + if "nar" not in self.capabilities: + n_resp_tokens = n_audio_tokens + 1 + l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels + l_embedding_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )] + l_classifier_tokens = [n_resp_tokens] * self.n_resp_levels + # NAR-len model + elif "len" in self.capabilities: + # +1 to include the stop or mask token + n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) + if "ar" in self.capabilities: + l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens] + l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1] + l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0'] + else: + l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + # AR+NAR model else: + # +1 to include the stop or mask token + n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - l_embedding_names = ['NAR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] - # AR+NAR model else: - # +1 to include the stop or mask token - n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 ) - l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) - l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] - l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + if self.parallel_decoding: + n_resp_tokens = n_audio_tokens + 1 + l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels + l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )] + l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels + else: + """ + n_resp_tokens = n_audio_tokens + 1 + l_embedding_tokens = [n_resp_tokens * self.n_resp_levels] + l_embedding_names = ["NAR"] + l_classifier_tokens = [n_audio_tokens * self.n_resp_levels] + """ + n_resp_tokens = n_audio_tokens + 1 + l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels + l_classifier_tokens = [n_audio_tokens] * self.n_resp_levels + l_embedding_names = [ f'NAR:{i}:{i}' for i in range( self.n_resp_levels ) ] + + n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1 l_classifier_names = l_embedding_names @@ -528,23 +673,13 @@ class Base(nn.Module): l_classifier_tokens += [ n_raw_text_tokens ] l_classifier_names = l_embedding_names + [ "raw_text" ] - n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1 - self.n_vocab = n_vocab self.unified_position_ids = unified_position_ids - self.interleave = interleave - self.layerskip = layerskip self.inject_timestep_embedding = False # results in bad output self.masking_ratio = masking_ratio self.ignore_inputs_for_loss = ignore_inputs_for_loss self.noncausal_masks = noncausal_masks - # use internal attention mechanism for now because I dont have a better way to handle mixed causal/noncausal masks for other attention backends - """ - if noncausal_masks: - attention_backend = "default" - """ - self.text_emb = Embedding(n_text_tokens, d_model) self.raw_text_emb = None self.langs_emb = None @@ -572,7 +707,7 @@ class Base(nn.Module): l_embedding_tokens, d_model, levels=self.n_resp_levels if self.version > 3 else None, ) - else: + elif not self.parallel_decoding: self.proms_emb = AudioEmbedding( [n_audio_tokens] * self.n_resp_levels, d_model, sums=audio_embedding_sums == "prom" or audio_embedding_sums == True, @@ -582,6 +717,17 @@ class Base(nn.Module): sums=audio_embedding_sums == "resp" or audio_embedding_sums == True, l_embedding_names=l_embedding_names, ) + else: + self.proms_emb = AudioEmbedding_Sums( + n_tokens=n_audio_tokens, + n_levels=self.n_resp_levels, + token_dim=d_model, + ) + self.resps_emb = AudioEmbedding_Sums( + n_tokens=n_audio_tokens + 1, + n_levels=self.n_resp_levels, + token_dim=d_model, + ) if self.version >= 3: self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None @@ -597,12 +743,11 @@ class Base(nn.Module): # this ***might*** let me also unify the proms_emb and resps_embedding if self.version >= 5: # "len" RVQ level-0 gets an additional token - self.rvq_l_emb = Embedding(self.n_resp_levels, d_model) + if self.version < 7 or not self.parallel_decoding: + self.rvq_l_emb = Embedding(self.n_resp_levels, d_model) # experimental NAR-only mode self.len_emb = Embedding(11, d_model) - self.time_emb = None # TimeEmbedding(d_model) # if not masking_ratio else None - if self.version >= 6: self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model) @@ -640,10 +785,8 @@ class Base(nn.Module): n_levels=self.n_resp_levels, ) for _ in range(n_layers) ]) elif self.arch_type in ["llama", "mistral", "mixtral"]: - LlamaClass = LlamaModel_Adapted # if (self.layerskip or "len" in self.capabilities) else LlamaModel - if n_experts <= 1: - self.model = LlamaClass(LlamaConfig( + self.model = LlamaModel_Adapted(LlamaConfig( vocab_size=n_vocab, hidden_size=d_model, max_position_embeddings=max_position_embeddings, @@ -692,11 +835,6 @@ class Base(nn.Module): self.model = ml.replace_attention( self.model, klass=MixtralAttention_Adapted, target=MixtralAttention, mode=attention_backend ) """ - if self.layerskip: - self.model.layer_dropout_p = layerskip_p_max - self.model.early_exit_scale = layerskip_e_scale - self.model.early_exit_r = layerskip_r - if self.gradient_checkpointing and not self.model.gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( use_reentrant=False @@ -766,19 +904,38 @@ class Base(nn.Module): if not split_classifiers: self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias) self.classifiers = None - self.metrics = None else: self.classifier = None - self.classifiers = Classifiers( l_classifier_tokens, d_model, l_embedding_names=l_classifier_names, bias=classifiers_bias ) + self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias ) self.metrics = Metrics( l_classifier_tokens ) - """ - if tie_classifier_to_embedding: - for i, proj in enumerate( self.classifiers.proj ): - self.classifiers.proj[i].weight = self.resps_emb.embeddings[i].weight - """ + self.parallel_decoder = None + if self.parallel_decoding: + pd_model = d_model # // 2 + pd_ffn = pd_model * 2 + pd_heads = n_heads // 2 + pd_layers = 1 + config = dict( + vocab_size=n_audio_tokens, + hidden_size=pd_model, + max_position_embeddings=max_position_embeddings, + intermediate_size=pd_ffn, + num_hidden_layers=pd_layers, + num_attention_heads=pd_heads, + attention_dropout=p_dropout if training else 0.0, + num_key_value_heads=pd_heads, + hidden_act="gelu", + is_encoder_decoder=False, + is_decoder=True, + attn_implementation="eager", + + training=self.training, + attention_backend=attention_backend, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.parallel_decoder = ParallelDecoder( self.n_resp_levels, d_model, config ) def _forward( self, @@ -789,8 +946,6 @@ class Base(nn.Module): state = None, - layer_skip_lambda = None, - output_attentions = False, output_hidden_states = False, ): @@ -818,9 +973,6 @@ class Base(nn.Module): if self.n_experts > 1 and self.training: kwargs["output_router_logits"] = True - if self.layerskip and layer_skip_lambda is not None: - kwargs["layer_skip_lambda"] = layer_skip_lambda - output = self.model(**kwargs) x = output["last_hidden_state"] @@ -885,7 +1037,7 @@ class Base(nn.Module): # but skip the last state, as it already is normalized hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ] - return Logits(x, state, inputs, aux_loss, attentions, hidden_states, None) + return Logits(x, state, inputs, aux_loss, attentions, hidden_states) # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation def inputs( @@ -943,7 +1095,7 @@ class Base(nn.Module): if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) # insert RVQ level guidance token if the model is versioned for it - if self.rvq_l_emb is not None and not self.interleave: + if self.rvq_l_emb is not None: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) classifier_level = "AR:0:0" if quant_level == 0 else f'NAR:{quant_level-1}:{quant_level}' @@ -976,7 +1128,10 @@ class Base(nn.Module): # insert the current output response if resps_list is not None and resps_list[i] is not None: inputs[i].append( ( "resp", resps_list[i] ) ) - + + if self.version >= 7: + classifier_level = f"NAR:{quant_level}:{quant_level}" if not self.parallel_decoding else "NAR" + inputs[i].append( ("classifier_level", classifier_level) ) # Audio length prediction task # Sequence: @@ -1022,7 +1177,7 @@ class Base(nn.Module): if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) # insert RVQ level guidance token if the model is versioned for it - if self.rvq_l_emb is not None and not self.interleave: + if self.rvq_l_emb is not None: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert the output text prompt if text_list is not None and text_list[i] is not None: @@ -1038,7 +1193,7 @@ class Base(nn.Module): # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) - if self.rvq_l_emb is not None and not self.interleave: + if self.rvq_l_emb is not None: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert the text prompt if text_list is not None and text_list[i] is not None: @@ -1054,7 +1209,7 @@ class Base(nn.Module): # insert lang token if we're trained for it if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None: inputs[i].append( ( "lang", lang_list[i] ) ) - if self.rvq_l_emb is not None and not self.interleave: + if self.rvq_l_emb is not None: inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) ) # insert the text prompt if raw_text_list is not None and raw_text_list[i] is not None: @@ -1117,12 +1272,23 @@ class Base(nn.Module): return self.proms_emb( input if quant_level == 0 else input[:, :quant_level] ) - - return self.proms_emb( - input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], - quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level - offset = 0, - ) + + if self.version < 7 or not self.parallel_decoding: + return self.proms_emb( + input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], + quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level + offset = 0, + ) + """ + if not self.parallel_decoding: + return self.proms_emb( + input if input.dim() == 1 else input[:, :quant_level+1], + quant_level = quant_level, + offset = 0, + ) + """ + + return self.proms_emb( input ) # yuck token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0 @@ -1188,28 +1354,23 @@ class Base(nn.Module): elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": - if self.interleave: - embeddings = [ self.resps_emb( - input[:, :l+1], - #offset = 0, - #quant_level = l, - name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}', - ) for l in range( input.shape[-1] ) ] - - embedding = _interleave_sequence_reshape( embeddings ) - + if self.parallel_decoding: + if dropout_mask is not None: + embedding = self.resps_emb( torch.where( dropout_mask, self.stop_token, input.t() ).t() ) + else: + embedding = self.resps_emb( input ) # if training NAR-len RVQ level 0 elif dropout_mask is not None: embedding = self.resps_emb( # if masked use masked token, else original token - torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ), + torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ), #quant_level = 0, name = classifier_level, ) # NAR-len - elif classifier_level == "NAR:0:0": + elif classifier_level == f"NAR:{quant_level}:{quant_level}": embedding = self.resps_emb( - input if input.dim() == 1 else input[:, 0], + input if input.dim() == 1 else input[:, quant_level], #quant_level = 0, name = classifier_level, ) @@ -1323,10 +1484,6 @@ class Base(nn.Module): if not isinstance(input, torch.Tensor): return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) - # interleaved model - if self.interleave and name == "resp": - return input.shape[0] * input.shape[1] - # ending input will not have a separator later return input.shape[0] @@ -1352,6 +1509,166 @@ class Base(nn.Module): return ids.to(device=device, dtype=torch.int32) + def calc_loss_parallel( + self, + inputs: list, + logits, + + compute_hard_loss = True, + compute_acc = True, + ): + loss = {} + stats = {} + + device = logits[0].device + batch_size = len(logits) + classifier_levels = self.get_input( inputs, "classifier_level" ) + + # handles tasks where the prompt has task tokens injected in the middle + def prompt_input_to_token( input ): + if isinstance(input, str): + return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) + return input + + for batch_index, batch in enumerate(inputs): + target = [] + causal = False + task_type = "tts" + dropout_mask = None + classifier_level = None + output_len = 0 + + for name, input in batch: + if name == "task": + task_type = input + elif name == "dropout_mask": + dropout_mask = input + elif name == "classifier_level": + classifier_level = input + + it = 0 + for name, input in batch: + token = None + ignored = False + + # non-tokened tasks + if name in non_tokened_names: + continue + # prom can either be a tensor itself or a list of tensors and strings + if name == "prom": + # expand to list if not a list + proms = [ input ] if isinstance(input, torch.Tensor) else input + # iterate over the list to inject their tokens + token = torch.cat( [ prompt_input_to_token( input ) for input in proms if input is not None ] ) + elif name == "resp": + # mask found, apply it + if dropout_mask is not None: + token = torch.where( dropout_mask, input.t(), self.ignore_index ).t() + else: + token = input + # not a special input, inject as-is + else: + token = input + + if not isinstance(token, torch.Tensor): + continue + + if token.is_floating_point(): + ignored = True + + # grab range of our logits for later + seq_len = token.shape[0] + start, end = it, it+seq_len + it += seq_len + 1 # +1 to incorporate the separator + + # deduce if a name for a task is an input or output + if name != task_outputs.get(task_type, name): + if self.ignore_inputs_for_loss: + ignored = True + else: + output_len = seq_len + + if ignored: + # pruned + if self.config.loss_factors: + continue + # fill with ignored out tensor + token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16) + + # perform loss calculation on the individual piece + target.append( token ) + + if classifier_level != "NAR": + seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) + logit = logits[batch_index] + + # shift if causal + if causal: + l = self.causal_size + logit = logit[..., :-l, :] # shift the target so that token n... + seq = seq[..., l:] # ...predicts token n + 1 + + if compute_hard_loss: + nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index ) + if 'nll' not in loss: + loss['nll'] = [] + loss["nll"].append( nll ) + + if compute_acc and False: + if self.metrics is not None: + metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) ) + else: + accuracy_metric = MulticlassAccuracy( + logit.shape[-1], + top_k = 10, + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(logit.device) + metrics = accuracy_metric( logit, seq ) + + if 'acc' not in stats: + stats['acc'] = [] + stats["acc"].append( metrics ) + else: + for level, logit in enumerate( logits[batch_index] ): + seq = _join( [ t if t.dim() <= 1 else t[:, level] for t in target ], torch.tensor(self.ignore_index, device=target[-1].device) ) + + # shift if causal + if causal: + l = self.causal_size + logit = logit[..., :-l, :] # shift the target so that token n... + seq = seq[..., l:] # ...predicts token n + 1 + + if compute_hard_loss: + nll = F.cross_entropy( logit, seq, ignore_index=self.ignore_index ) + if 'nll' not in loss: + loss['nll'] = [] + loss["nll"].append( nll ) + + if compute_acc and False: + if self.metrics is not None: + metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) ) + else: + accuracy_metric = MulticlassAccuracy( + logit.shape[-1], + top_k = 10, + average="micro", + multidim_average="global", + ignore_index = -100 + ).to(logit.device) + metrics = accuracy_metric( logit, seq ) + + if 'acc' not in stats: + stats['acc'] = [] + stats["acc"].append( metrics ) + + # average + loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() } + stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() } + + return LossStats(loss, stats) + def calc_loss( self, inputs: list, @@ -1377,7 +1694,13 @@ class Base(nn.Module): if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums): return torch.full_like(input[..., 0], self.ignore_index) - return input if input.dim() == 1 else input[:, quant_level] + if self.version < 7: + return input if input.dim() == 1 else input[:, quant_level] + + if not self.parallel_decoding: + return input if input.dim() == 1 else input[:, quant_level] + + return input for batch_index, batch in enumerate(inputs): quant_level = quant_levels[batch_index] @@ -1402,6 +1725,8 @@ class Base(nn.Module): # nonautoregressive, parallel elif classifier_level.startswith("NAR:"): causal = False + elif classifier_level == "NAR": + causal = False it = 0 for name, input in batch: @@ -1422,9 +1747,6 @@ class Base(nn.Module): if dropout_mask is not None: # if mask use original token, else ignore token = torch.where( dropout_mask, input if input.dim() == 1 else input[:, 0], self.ignore_index ) - # flatten - elif self.interleave: - token = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) # use resps as-is else: token = input if input.dim() == 1 else input[:, quant_level] @@ -1435,28 +1757,6 @@ class Base(nn.Module): if not isinstance(token, torch.Tensor): continue - # offset to flattened vocab ranges - """ - if self.classifier is not None: - offsets = _get_offsets() - - k = name - if name == "stt": - k = "text" - if name == "prom": - k = f'prom|{quant_level}' - elif name == "resp": - k = f'resps|{classifier_level}' - - if k in offsets: - start, end = offsets[k] - - for i, t in enumerate( token ): - if t == self.ignore_index: - continue - token[i] += start - """ - if token.is_floating_point(): ignored = True @@ -1566,51 +1866,9 @@ class Base(nn.Module): quant_levels: list[int] | None = None, state: dict | list | None = None, - layer_skip_variables: dict | None = None, - output_attentions: bool = False, output_hidden_states: bool = False, - ): - # return early if it's "good" enough" - # lambda because we need to capture the classifier_levels and mask - exited_layer = self.n_layers - def layer_skip_lambda( layer, logits ): - nonlocal exited_layer - kwargs = { - "entropy_threshold": 0.05, - "varentropy_threshold": 0.05, - "min_layer": self.n_layers // 2, - "max_layer": self.n_layers, - } - - kwargs.update( layer_skip_variables ) - - # don't bother on early layers - if layer < kwargs["min_layer"]: - return False - # bail if we want to force early layers - if kwargs["max_layer"] < layer: - return True - - # hidden states aren't normalized - x = self.model.norm( logits ) - - # output projection layer with masking - if self.classifier is not None: - x = self.classifier(x) # * m - elif self.classifiers is not None: - logits = self.classifiers(logits, levels = classifier_levels) # * m - - # calculate metrics - metrics = calculate_entropix_metrics( logits ) - # exit early if "good enough"" - early = metrics["logits_entropy"] <= kwargs["entropy_threshold"] and metrics["logits_varentropy"] <= kwargs["varentropy_threshold"] - - if early: - exited_layer = layer - - return early - + ): # derive quant levels from inputs if not provided if quant_levels is None: quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ] @@ -1628,10 +1886,6 @@ class Base(nn.Module): device = x.device batch_size = len(x_list) - # we only need hidden states if we're training with layerskip - if self.layerskip and training: - output_hidden_states = True - # pad our input and mask, but retain the original length by doing it after if self.l_padding and x.shape[1] % self.l_padding != 0: # pad input @@ -1663,55 +1917,55 @@ class Base(nn.Module): is_causal=is_causal, position_ids=position_ids, output_attentions = output_attentions, - output_hidden_states = output_hidden_states, - layer_skip_lambda = layer_skip_lambda if self.layerskip and layer_skip_variables else None, ) logits = output.logits hidden_states = output.hidden_states + + logits = [ logit for logit in logits ] + + if self.version >= 7 and self.parallel_decoding: + p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ] + if p_indices: + p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0) + p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0) + p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0) + p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ] + p_logits = self.parallel_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal ) + + for i, logit in enumerate(p_logits): + logits[p_indices[i]] = logit + + """ + logits = [ self.parallel_decoder( logit.unsqueeze(0), attention_mask=mask, + position_ids=position_ids, + use_cache=False, + return_dict=True, + is_causal=is_causal )[0] if level == "NAR" else logit for logit, level in zip(logits, classifier_levels) ] + """ + # output projection layer # the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways if self.classifier is not None: logits = self.classifier(logits) # * m - - if output.hidden_states: - for i, state in enumerate( hidden_states ): - hidden_states[i] = self.classifier(hidden_states[i]) # * m # to-do: piece-wise classification, now that there's a head for text # although again, one single monolithic head would be preferable instead...... elif self.classifiers is not None: - logits = self.classifiers(logits, levels = classifier_levels) # * m + logits = self.classifiers(logits, levels = classifier_levels ) - if hidden_states is not None: - for i, state in enumerate( hidden_states ): - hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_levels) # * m + # Reshape + """ + if self.version >= 7 and not self.parallel_decoding: + for batch_index, logit in enumerate( logits ): + if classifier_levels[batch_index] != "NAR": + continue + logits[batch_index] = logit.reshape( logit.shape[0], 8, 1000 ).permute( 1, 0, 2 ) + """ # Remove padding - logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ] - - if hidden_states is not None: - for i, state in enumerate( hidden_states ): - hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] + logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ] - # de-offset if needed - if self.classifier is not None: - offsets = _get_offsets() - for batch_index, classifier_level in enumerate( classifier_levels ): - if classifier_level == "stt": - k = "text" - elif classifier_level == "len": - k = "len" - else: - k = f'resps|{classifier_level}' - - if k not in offsets: - continue - - start, end = offsets[k] - - logits[batch_index] = logits[batch_index][:, start:end] - if not training: loss = None stats = None @@ -1719,30 +1973,18 @@ class Base(nn.Module): self.loss = None self.stats = None # compute loss if the target is given + elif self.version >= 7 and self.parallel_decoding: + loss, stats = self.calc_loss_parallel( inputs=inputs, logits=logits ) + + # include any additional losses (for example: MoE router) + if output.loss is not None: + loss["aux_loss"] = output.loss + + self.loss = loss + self.stats = stats else: loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) - # compute it as an aux-loss - if self.layerskip: - early_exit_loss = {} - if not hasattr( self, "training_steps" ): - self.training_steps = 0 - - for i, state in enumerate( hidden_states ): - loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels ) - - for k, v in loss.items(): - K = f'early_exit.{k}' - if K not in early_exit_loss: - early_exit_loss[K] = [] - early_exit_loss[K].append( v ) - - for k, v in early_exit_loss.items(): - loss[k] = self.model.early_exit_loss( losses=v, t=self.training_steps ) - - # to-do: instead make the cirriculum rely on samples processed instead of steps - self.training_steps += 1 # batch_size - # include any additional losses (for example: MoE router) if output.loss is not None: loss["aux_loss"] = output.loss @@ -1751,7 +1993,7 @@ class Base(nn.Module): self.stats = stats # rewrap, because we're modifying the logits here - return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states, exited_layer) + return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states) def sample( self, diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 567dead..3757612 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -176,17 +176,17 @@ def top_no_logits_processing( logits, n = 1.0 ): # (and because the null logits have a shorter input sequence compared to the positive logits) def cfg_logits( logits, null, strength, lens, rescale=0.0 ): for i, seq_len in enumerate( lens ): - pos = logits[i][-seq_len:] - neg = null[i][-seq_len:] + pos = logits[i][..., -seq_len:, :] + neg = null[i][..., -seq_len:, :] summed = neg + (pos - neg) * strength if rescale <= 0: - logits[i][-seq_len:] = summed + logits[i][..., -seq_len:, :] = summed else: dims = tuple(range(1, summed.ndim - 1)) factor = rescale * (pos.std(dims, keepdim=True) / summed.std(dims, keepdim=True)) + (1 - rescale) - logits[i][-seq_len:] = summed * factor + logits[i][..., -seq_len:, :] = summed * factor return logits