From 811b15d2806ae0bca7cdf235663b0cfd96ad2123 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 8 Nov 2024 22:05:41 -0600 Subject: [PATCH] I suppose I just have a shit training method since the sampler is as solid as I can get it............... --- vall_e/models/base.py | 2 +- vall_e/models/nar.py | 165 +++++++++++++++++++++++++++--------------- vall_e/samplers.py | 68 +---------------- 3 files changed, 107 insertions(+), 128 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 86b7dda..fbf6426 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1711,7 +1711,7 @@ class Base(nn.Module): """ # perform repetition penalizing - if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0: + if prev_list is not None and repetition_penalty != 1.0: # to-do: figure out a faster way to handle tolist() logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index f65c5c1..30a3702 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -21,7 +21,6 @@ from tqdm import trange from .base import Base, list_to_tensor, Categorical, _dropout_mask from ..config import cfg from ..emb.qnt import trim, repeat_extend_audio -from ..samplers import SampleScheduler def clamp(n, lo, hi): return max(lo, min(n, hi)) @@ -237,77 +236,123 @@ class NAR(Base): if cfg.lora is not None: enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora ) + def log(x, eps = 1e-20): + return torch.log(x.clamp(min = eps)) + + def gumbel_sample(x, temperature = 1., dim = -1): + return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim) + + test_artifact = None + + """ + if False: + path = "./data/237_134500_000036_000004.enc" + test_artifact = np.load(path, allow_pickle=True)[()] + text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ] + resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ] + proms_list = [ resps for resps in resps_list ] + len_list = [ resps.shape[0] for resps in resps_list ] + """ + _super = super() - def forward_lambda( ids, step, temperature ): + def demask_sampling( seq_len, max_steps=10, temperature=0.3 ): + starting_temperature = temperature + + input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token + scores = torch.zeros((seq_len,), dtype=torch.float32, device=device) + quant_levels = [ level for _ in range(batch_size) ] - prev_list = [ ids ] - seq_len = ids.shape[-1] + prev_list = [ input_ids ] - sampling_top_k = math.floor( seq_len * 0.9 ) + noise_scale = 1.0 - inputs = _super.inputs( - text_list=text_list, - proms_list=proms_list, - resps_list=prev_list, - lang_list=lang_list, - tone_list=tone_list, - quant_levels=quant_levels, - ) + """ + if test_artifact is not None: + nonlocal resps_list + input = resps_list[0][:, 0] + noise_scale = 1.0 + input_ids = torch.tensor( [ self.stop_token if random.random() < noise_scale else token for _, token in enumerate( input ) ], dtype=torch.int16, device=device ) + print( input ) + print( input_ids ) + """ - output = _super.forward( - inputs=inputs, - quant_levels=quant_levels, + for timestep, steps_until_x0 in zip(torch.linspace(0, 1, max_steps), reversed(range(max_steps))): + # anneal temperature + temperature = starting_temperature * (steps_until_x0 / max_steps) + # get noise level, per cosine scheduling + noise_p = math.cos( timestep * math.pi * 0.5 ) * noise_scale + # number of tokens to mask off to "noise" the input sequence + masked_tokens_n = max(int( noise_p * seq_len ), 1) + # pick the worst scoring tokens to mask off + masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices + # mask off inputs + input_ids = input_ids.scatter(0, masked_indices, self.stop_token) + # boolean mask + is_masked = input_ids == self.stop_token + # sample + sampling_top_k = math.floor( seq_len * 0.9 ) + resps_list = [ input_ids ] + inputs = _super.inputs( + text_list=text_list, + proms_list=proms_list, + resps_list=resps_list, + lang_list=lang_list, + tone_list=tone_list, + quant_levels=quant_levels, + ) + output = _super.forward( + inputs=inputs, + quant_levels=quant_levels, + layer_skip_variables=sampling_layer_skip_variables, + ) - layer_skip_variables=sampling_layer_skip_variables, - ) - logits = output.logits + # sample with sampler settings + filtered_sampled = _super.sample( + logits=output.logits, + prev_list=prev_list, + quant_levels=quant_levels, - # sample with sampler settings - sampled = _super.sample( - logits=logits, - prev_list=prev_list, - quant_levels=quant_levels, + temperature=temperature, + min_temperature=sampling_min_temperature, + top_p=sampling_top_p, + top_k=sampling_top_k, + min_p=sampling_min_p, + repetition_penalty=sampling_repetition_penalty, + repetition_penalty_decay=sampling_repetition_penalty_decay, + length_penalty=sampling_length_penalty, + ) - temperature=temperature, - min_temperature=sampling_min_temperature, - top_p=sampling_top_p, - top_k=sampling_top_k, - min_p=sampling_min_p, - repetition_penalty=sampling_repetition_penalty, - repetition_penalty_decay=sampling_repetition_penalty_decay, - length_penalty=sampling_length_penalty, - #beam_width=sampling_beam_width, - #mirostat=mirostat, - ) + # retrieves unfiltered logits + unfiltered_sampled = _super.sample( + logits=output.logits, + prev_list=prev_list, + quant_levels=quant_levels, + temperature=0.0, + ) + # update previous list of tokens + prev_list = [ input_ids ] - # greedy sample - greedy_sampled = _super.sample( - logits=logits, - prev_list=prev_list, - quant_levels=quant_levels, + # extract logits + filtered_logits = filtered_sampled.logits[0] + unfiltered_logits = unfiltered_sampled.logits[0] - temperature=0.0, - #min_temperature=sampling_min_temperature, - #top_p=sampling_top_p, - #top_k=sampling_top_k, - #min_p=sampling_min_p, - #repetition_penalty=sampling_repetition_penalty, - #repetition_penalty_decay=sampling_repetition_penalty_decay, - #length_penalty=sampling_length_penalty, - #beam_width=sampling_beam_width, - #mirostat=mirostat, - ) + # extract scores + filtered_scores = filtered_sampled.scores[0] + unfiltered_scores = unfiltered_sampled.scores[0] - return sampled, greedy_sampled + # sample with gumbelnoise + sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 ) + # keep unmasked tokens + input_ids = torch.where( is_masked, sampled_ids, input_ids ) + # update scores (conjugated to put the worst scores at the top) + scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device) - scheduler = SampleScheduler( - device=device, - mask_token=self.stop_token, - max_steps=5, - forward_lambda=forward_lambda, - sampling_temperature=sampling_temperature, - ) - prev_list = [ scheduler.sample( seq_len=len_list[0] ) ] + # print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores ) + + return input_ids + + # perform demasked sampling (mock diffusion) + prev_list = [ demask_sampling( seq_len=l ) for l in len_list ] # expand if given a raw 1D tensor for i, resp in enumerate(prev_list): diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 16ca87e..be3539f 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -520,70 +520,4 @@ def sample_entropix( metrics["min_p"] = min_p """ - return res, metrics - -""" -def add_gumbel_noise(t, temperature, device): - return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device)) -""" - -def log(t, eps = 1e-20): - return torch.log(t.clamp(min = eps)) - -def gumbel_noise(t): - noise = torch.zeros_like(t).uniform_(0, 1) - return -log(-log(noise)) - -def gumbel_sample(t, temperature = 1., dim = -1): - return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) - -# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion""" -class SampleScheduler: - def __init__( - self, - forward_lambda = None, - mask_token = -1, - max_steps = 25, - device = "cuda", - sampling_temperature=1.0, - ): - self.forward_lambda = forward_lambda - self.max_steps = max_steps - self.mask_token = mask_token - self.device = device - - def sample( self, seq_len ): - starting_temperature = 0.2 - - input_ids = torch.ones((seq_len,), dtype=torch.long, device=self.device) * self.mask_token - scores = torch.zeros((seq_len,), dtype=torch.float32, device=self.device) - - for timestep, steps_until_x0 in zip(torch.linspace(0, 1, self.max_steps), reversed(range(self.max_steps))): - # anneal temperature - temperature = starting_temperature * (steps_until_x0 / self.max_steps) - # get noise level, per cosine scheduling - noise_p = math.cos( timestep * math.pi * 0.5 ) - # number of tokens to mask off to "noise" the input sequence - masked_tokens_n = max(int( noise_p * seq_len ), 1) - # pick the worst scoring tokens to mask off - masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices - # mask off inputs - input_ids = input_ids.scatter(0, masked_indices, self.mask_token) - # boolean mask - is_masked = input_ids == self.mask_token - # sample - sampled, greedy_sampled = self.forward_lambda( input_ids, step=timestep, temperature=temperature ) - # extract logits - logits = greedy_sampled.logits[0] - filtered_logits = sampled.logits[0] - - # sample with gumbelnoise - sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 ) - # keep unmasked tokens - input_ids = torch.where( is_masked, sampled_ids, input_ids ) - # update scores (conjugated to put the worst scores at the top) - scores = 1.0 - torch.concat([ F.softmax(logits[i, :], dim=0)[token, None] for i, token in enumerate(input_ids) ]) - - # print( timestep, steps_until_x0, noise_p, masked_tokens_n, temperature, input_ids, scores ) - - return input_ids + return res, metrics \ No newline at end of file