diff --git a/data/noise.dac b/data/noise.dac deleted file mode 100644 index 1a4d02e..0000000 Binary files a/data/noise.dac and /dev/null differ diff --git a/data/qnt.dac b/data/qnt.dac deleted file mode 100644 index de7ff68..0000000 Binary files a/data/qnt.dac and /dev/null differ diff --git a/data/qnt.enc b/data/qnt.enc index 8da3c31..eede679 100644 Binary files a/data/qnt.enc and b/data/qnt.enc differ diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 0a3a81f..96d7508 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -52,6 +52,10 @@ def main(): parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1) parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1) parser.add_argument("--refine-on-stop", action="store_true") + + # experimental settings + parser.add_argument("--load-from-artifact", type=Path, default=None) + parser.add_argument("--denoise-start", type=float, default=0.0) parser.add_argument("--seed", type=int, default=None) @@ -92,6 +96,10 @@ def main(): layer_skip_entropy_threshold=args.layer_skip_entropy_threshold, layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold, refine_on_stop=args.refine_on_stop, + + load_from_artifact=args.load_from_artifact, + denoise_start=args.denoise_start, + seed=args.seed, ) diff --git a/vall_e/inference.py b/vall_e/inference.py index a24d304..8effdb0 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -3,6 +3,7 @@ import torchaudio import soundfile import time import logging +import numpy as np _logger = logging.getLogger(__name__) @@ -12,7 +13,7 @@ from pathlib import Path from .emb import g2p, qnt from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio -from .utils import to_device, set_seed, wrapper as ml +from .utils import to_device, set_seed, clamp, wrapper as ml from .config import cfg, Config from .models import get_models @@ -229,6 +230,9 @@ class TTS(): refine_on_stop=False, # seed = None, + # + load_from_artifact = None, + denoise_start = 0.0, out_path=None, @@ -355,17 +359,34 @@ class TTS(): use_lora=use_lora, ) elif model_len is not None: - len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that - len_list = [ min(l, max_ar_steps) for l in len_list ] + len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=5, disable_tqdm=not tqdm ) # don't need more than that + len_list = [ clamp(1, max_ar_steps, l) for l in len_list ] + + kwargs = {} + + # nasty hardcode to load a reference file and have that as the input target + if load_from_artifact and load_from_artifact.exists(): + artifact = np.load(load_from_artifact, allow_pickle=True)[()] + + phns = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=self.device) + resp = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=self.device) + prom = resp[:75*3, :] + len_list = [ resp.shape[0] ] + + kwargs["resps_list"] = [ resp[:, :1] ] + resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list, + max_steps=max_ar_steps, max_levels=max_nar_levels, sampling_temperature=nar_temp, sampling_min_temperature=min_nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p, sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay, + denoise_start=denoise_start, disable_tqdm=not tqdm, use_lora=use_lora, + **kwargs, ) else: raise Exception("!") diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index d25df2e..3e7be5f 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -19,13 +19,10 @@ import logging _logger = logging.getLogger(__name__) +from ..utils import clamp from ..emb.qnt import trim, encode_as_embedding - from .lora import enable_lora -def clamp(n, lo, hi): - return max(lo, min(n, hi)) - class AR(Base): def forward( self, diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 4b9be09..8d1cc2e 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -24,13 +24,10 @@ import logging _logger = logging.getLogger(__name__) from ..emb.qnt import trim, encode_as_embedding, get_silence -from ..utils import get_devices, setup_logging, timer +from ..utils import get_devices, setup_logging, timer, clamp from .lora import enable_lora -def clamp(n, lo, hi): - return max(lo, min(n, hi)) - class AR_NAR(Base): def forward( self, @@ -490,32 +487,19 @@ def example_usage(): """ # cfg.model.loss_factors = {} - def tokenize(content): - return torch.tensor( cfg.tokenizer.encode(content) ) + def load_artifact( path ): + artifact = np.load(path, allow_pickle=True)[()] - def _load_quants(path) -> Tensor: - qnt = np.load(path, allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16) + text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) + audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) - qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") - noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") + return text, audio - text_list = [ - tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), - #tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device), - ] - proms_list = [ - qnt[:cfg.dataset.frames_per_second, :].to(device), - #qnt[:cfg.dataset.frames_per_second, :].to(device), - ] - resps_list = [ - qnt[:, :].to(device), - #qnt[:cfg.dataset.frames_per_second, :].to(device), - ] + text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") - text_list = text_list[:1] - proms_list = proms_list[:1] - resps_list = resps_list[:1] + text_list = [ text ] + proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] + resps_list = [ audio ] batch_size = len(text_list) @@ -721,7 +705,7 @@ def example_usage(): resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 ) for i, o in enumerate(resps): - _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device) + _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device) unload_model() diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 337c7dd..d37c22e 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -30,9 +30,8 @@ from torch.utils.checkpoint import checkpoint from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision from .arch import * -from ..utils import wrapper as ml +from ..utils import wrapper as ml, clamp from ..samplers import * - from ..emb.qnt import encode_as_embedding # yuck, kind of needed @@ -57,9 +56,6 @@ def _dropout_mask( input, p=None ): mask = torch.tensor( seq, dtype=torch.bool, device=input.device ) return mask -def clamp(n, lo, hi): - return max(lo, min(n, hi)) - def _create_mask(l, device): """1 is valid region and 0 is invalid.""" seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) @@ -1853,19 +1849,19 @@ class Base(nn.Module): else: # argmax instead if temperature <= 0.0: - res = [ logit.argmax(dim=1) for logit in logits ] + res = [ logit.argmax(dim=-1) for logit in logits ] else: res = [ Categorical(logits=logit).sample() for logit in logits ] # calculate token probabilities if "len" in self.capabilities: scores = [ - [ F.softmax(logit[i, :], dim=0)[token].item() for i, token in enumerate(tokens) ] + [ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ] for logit, tokens in zip(logits, res) ] else: scores = [ - [ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ] + [ F.softmax(logit[-1, :], dim=-1)[token].item() for token in tokens ] for logit, tokens in zip(logits, res) ] diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 8b014a9..a327c26 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -21,9 +21,7 @@ from tqdm import trange from .base import Base, list_to_tensor, Categorical, _dropout_mask from ..config import cfg from ..emb.qnt import trim, repeat_extend_audio - -def clamp(n, lo, hi): - return max(lo, min(n, hi)) +from ..utils import clamp _logger = logging.getLogger(__name__) @@ -46,6 +44,7 @@ class NAR(Base): input_prompt_prefix: bool = False, prefix_silence: float = 1.0, + denoise_start: float = 0.0, sampling_temperature: float = 1.0, sampling_min_temperature: float = -1.0, @@ -245,7 +244,11 @@ class NAR(Base): sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer # initial condition - len_list = [ clamp(1, 75*3, l) for l in len_list ] + """ + print( len_list ) + len_list = [ clamp(1, max_steps, l) for l in len_list ] + print( len_list ) + """ metrics = [] mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device) @@ -262,50 +265,39 @@ class NAR(Base): def gumbel_sample(x, temperature = 1., dim = -1): return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim) - test_artifact = None - - # nasty hardcode to load a reference file and have that as the input target - # to-do: expose a way to provide the initial sequence instead through CLI - """ - if False: - path = "./data/00_part2_success-1.enc" - test_artifact = np.load(path, allow_pickle=True)[()] - text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ] - resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ] - proms_list = [ resps[:75*3, :] for resps in resps_list ] - #proms_list = [ resps for resps in resps_list ] - len_list = [ resps.shape[0] for resps in resps_list ] - """ - _super = super() - def demask_sampling( seq_len, max_steps=5, temperature=1.0 ): - starting_temperature = temperature + def demask_sampling( batch_index, seq_len ): + # overrides + max_steps = 10 + temperature = 0.3 + cfg_strength = 1.0 + sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied...... + sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9 - input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token + # if we're denoising from an existing sequence + if denoise_start > 0.0 and resps_list is not None: + start_noise = denoise_start + noise_p = math.cos( start_noise * math.pi * 0.5 ) + mask = torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device ) + input_ids = torch.where( mask, self.stop_token, resps_list[batch_index][:, 0] ) + else: + input_ids = torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token + scores = torch.zeros((seq_len,), dtype=torch.float32, device=device) quant_levels = [ level for _ in range(batch_size) ] prev_list = [ input_ids ] + start_temperature = temperature start_noise = 0.0 end_noise = 1.0 - sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied...... - - # use hardcoded reference file to test inference capabilities - if test_artifact is not None: - # because we "set" it later on, it's not implicitly captured - nonlocal resps_list - start_noise = 0.5 - noise_p = math.cos( start_noise * math.pi * 0.5 ) - input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device ) null_text = torch.tensor([1, 2], device=device, dtype=torch.int16) null_prom = None - cfg_strength = 1.0 for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))): # anneal temperature - temperature = starting_temperature * (steps_until_x0 / max_steps) + temperature = start_temperature * (steps_until_x0 / max_steps) # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) # number of tokens to mask off to "noise" the input sequence @@ -317,12 +309,11 @@ class NAR(Base): # boolean mask is_masked = input_ids == self.stop_token # setup inputs - resps_list = [ input_ids ] inputs = _super.inputs( text_list=text_list, proms_list=proms_list, - resps_list=resps_list, + resps_list=[ input_ids ], lang_list=lang_list, tone_list=tone_list, time_list=[ timestep ], @@ -333,11 +324,14 @@ class NAR(Base): quant_levels=quant_levels, layer_skip_variables=sampling_layer_skip_variables, ) + + logits = output.logits + if cfg_strength > 0: null_inputs = _super.inputs( text_list=[ null_text ], proms_list=[ null_prom ], - resps_list=resps_list, + resps_list=[ input_ids ], lang_list=lang_list, tone_list=tone_list, time_list=[ timestep ], @@ -348,12 +342,10 @@ class NAR(Base): quant_levels=quant_levels, layer_skip_variables=sampling_layer_skip_variables, ) - logits = [ logits + ( logits - null_logits ) * cfg_strength for logits, null_logits in zip(output.logits, null_output.logits) ] - else: - logits = output.logits + for logit, null_logits in zip(output.logits, null_output.logits): + logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength # sample with sampler settings - sampling_top_p = 0.9 filtered_sampled = _super.sample( logits=logits, prev_list=prev_list, @@ -401,12 +393,13 @@ class NAR(Base): # update scores (conjugated to put the worst scores at the top) scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device) + if cfg.experimental: print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores ) return input_ids # perform demasked sampling (mock diffusion) - prev_list = [ demask_sampling( seq_len=l ) for l in len_list ] + prev_list = [ demask_sampling( batch_index=i, seq_len=l ) for i, l in enumerate( len_list ) ] # expand if given a raw 1D tensor for i, resp in enumerate(prev_list): @@ -530,39 +523,20 @@ def example_usage(): import re device = "cuda" - - # mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it) - """ - if "mamba" in cfg.model.arch_type: - cfg.model.resp_levels = 1 - """ - # cfg.model.loss_factors = {} - def tokenize(content): - return torch.tensor( cfg.tokenizer.encode(content) ) + def load_artifact( path ): + artifact = np.load(path, allow_pickle=True)[()] - def _load_quants(path) -> Tensor: - qnt = np.load(path, allow_pickle=True)[()] - return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16) + text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) + audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) - qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") + return text, audio - text_list = [ - tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device), - #tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device), - ] - proms_list = [ - qnt[:cfg.dataset.frames_per_second, :].to(device), - #qnt[:cfg.dataset.frames_per_second, :].to(device), - ] - resps_list = [ - qnt[:, :].to(device), - #qnt[:cfg.dataset.frames_per_second, :].to(device), - ] + text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}") - text_list = text_list[:1] - proms_list = proms_list[:1] - resps_list = resps_list[:1] + text_list = [ text ] + proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] + resps_list = [ audio ] # rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise kwargs = { diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 6927f0c..baff3ab 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -8,8 +8,7 @@ from torch import Tensor, einsum, nn from einops import rearrange from dataclasses import asdict, dataclass, field -def clamp(n, lo, hi): - return max(lo, min(n, hi)) +from .utils import clamp # Simple filter to modify a token's probability if it shows up in the past # `one_time` will only apply the penalty once diff --git a/vall_e/train.py b/vall_e/train.py index 89a7683..57cdf42 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -147,9 +147,13 @@ def run_eval(engines, eval_name, dl, args=None): elif "len" in engine.hyper_config.capabilities: kwargs = base_kwargs | cfg.evaluation.ar_kwargs max_steps = kwargs.pop("max_steps", 500) - kwargs["max_steps"] = 10 - len_list = engine( **kwargs ) # don't need more than that + len_list = engine( max_steps=5, **kwargs ) len_list = [ min( l, max_steps ) for l in len_list ] + + if True: + len_list = [ resp.shape[0] for resp in batch["resps"] ] + kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ] + kwargs["denoise_start"] = 0.5 kwargs = base_kwargs | cfg.evaluation.nar_kwargs resps_list = engine( **kwargs, len_list=len_list ) diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index bcc8c44..4c1273b 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -12,5 +12,6 @@ from .utils import ( get_devices, truncate_json, timer, - prune_missing + prune_missing, + clamp ) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 4831a91..db81285 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -49,6 +49,9 @@ def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, retu missing += m return (keep, missing) if return_missing else keep +def clamp(n, lo, hi): + return max(lo, min(n, hi)) + class timer: def __init__(self, msg="Elapsed time:", callback=None): self.msg = msg