From 13b54953bd9768c3ded7ff94fa6786cefb023c22 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 8 Nov 2024 13:34:39 -0600 Subject: [PATCH] agony --- vall_e/models/base.py | 6 +-- vall_e/models/nar.py | 29 +++++++++++--- vall_e/samplers.py | 92 ++++++++++++++----------------------------- 3 files changed, 56 insertions(+), 71 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 88e6784..86b7dda 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -40,7 +40,7 @@ from ..data import get_task_symmap # these seem more elegant than a dict Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer']) -Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) +Sampled = namedtuple('Sampled', ['out', 'logits', 'scores', 'entropy']) LossStats = namedtuple('LossStats', ['loss', 'stats']) """ @@ -1681,7 +1681,7 @@ class Base(nn.Module): ) for batch, logit in enumerate(logits) ] if res: - return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ]) + return Sampled([ r[0] for r in res ], logits, scores, [ r[1] for r in res ]) """ elif quant_levels is None: seq_lens = [ logit.shape[0] for logit in logits ] @@ -1772,4 +1772,4 @@ class Base(nn.Module): for logit, tokens in zip(logits, res) ] - return Sampled(res, scores, entropy) \ No newline at end of file + return Sampled(res, logits, scores, entropy) \ No newline at end of file diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index ef26e3d..f65c5c1 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -226,7 +226,7 @@ class NAR(Base): sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer # initial condition - len_list = [ min(l, 500) for l in len_list ] + len_list = [ min(l, 75*3) for l in len_list ] metrics = [] mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device) @@ -240,9 +240,11 @@ class NAR(Base): _super = super() def forward_lambda( ids, step, temperature ): quant_levels = [ level for _ in range(batch_size) ] - prev_list = [ ids[0] ] + prev_list = [ ids ] seq_len = ids.shape[-1] + sampling_top_k = math.floor( seq_len * 0.9 ) + inputs = _super.inputs( text_list=text_list, proms_list=proms_list, @@ -260,6 +262,7 @@ class NAR(Base): ) logits = output.logits + # sample with sampler settings sampled = _super.sample( logits=logits, prev_list=prev_list, @@ -277,14 +280,30 @@ class NAR(Base): #mirostat=mirostat, ) - ids = sampled[0] + # greedy sample + greedy_sampled = _super.sample( + logits=logits, + prev_list=prev_list, + quant_levels=quant_levels, - return logits[0][-seq_len:].unsqueeze(0), ids[0].unsqueeze(0) + temperature=0.0, + #min_temperature=sampling_min_temperature, + #top_p=sampling_top_p, + #top_k=sampling_top_k, + #min_p=sampling_min_p, + #repetition_penalty=sampling_repetition_penalty, + #repetition_penalty_decay=sampling_repetition_penalty_decay, + #length_penalty=sampling_length_penalty, + #beam_width=sampling_beam_width, + #mirostat=mirostat, + ) + + return sampled, greedy_sampled scheduler = SampleScheduler( device=device, mask_token=self.stop_token, - max_steps=30, + max_steps=5, forward_lambda=forward_lambda, sampling_temperature=sampling_temperature, ) diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 697425b..16ca87e 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -537,13 +537,6 @@ def gumbel_noise(t): def gumbel_sample(t, temperature = 1., dim = -1): return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) -def top_k(logits, thres = 0.9): - k = math.ceil((1 - thres) * logits.shape[-1]) - val, ind = logits.topk(k, dim = -1) - probs = torch.full_like(logits, float('-inf')) - probs.scatter_(2, ind, val) - return probs - # this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion""" class SampleScheduler: def __init__( @@ -558,66 +551,39 @@ class SampleScheduler: self.max_steps = max_steps self.mask_token = mask_token self.device = device - - """ - self.ratios = (np.cos(np.linspace(0, math.pi / 2, self.max_steps + 1)))[1:-1] - self.annealed_temperatures = (1 - np.linspace(0, 1, self.max_steps + 1))[:-2] - self.sampling_temperatures = [sampling_temperature for _ in range(self.max_steps)] - """ - # lifted from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L493 def sample( self, seq_len ): - ids = torch.full((1, seq_len), self.mask_token, dtype = torch.long, device = self.device) - scores = torch.zeros((1, seq_len), dtype = torch.float32, device = self.device) + starting_temperature = 0.2 - for step in range( self.max_steps ): - t = step / self.max_steps - mask_ratio = math.cos(t * math.pi * 0.5) - sampling_temperature = 1.0 - annealed_temperature = sampling_temperature * (1.0 - t) + input_ids = torch.ones((seq_len,), dtype=torch.long, device=self.device) * self.mask_token + scores = torch.zeros((seq_len,), dtype=torch.float32, device=self.device) - num_token_masked = max(int(mask_ratio * seq_len), 1) - masked_indices = scores.topk(num_token_masked, dim = -1).indices + for timestep, steps_until_x0 in zip(torch.linspace(0, 1, self.max_steps), reversed(range(self.max_steps))): + # anneal temperature + temperature = starting_temperature * (steps_until_x0 / self.max_steps) + # get noise level, per cosine scheduling + noise_p = math.cos( timestep * math.pi * 0.5 ) + # number of tokens to mask off to "noise" the input sequence + masked_tokens_n = max(int( noise_p * seq_len ), 1) + # pick the worst scoring tokens to mask off + masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices + # mask off inputs + input_ids = input_ids.scatter(0, masked_indices, self.mask_token) + # boolean mask + is_masked = input_ids == self.mask_token + # sample + sampled, greedy_sampled = self.forward_lambda( input_ids, step=timestep, temperature=temperature ) + # extract logits + logits = greedy_sampled.logits[0] + filtered_logits = sampled.logits[0] - ids = ids.scatter(1, masked_indices, self.mask_token) + # sample with gumbelnoise + sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 ) + # keep unmasked tokens + input_ids = torch.where( is_masked, sampled_ids, input_ids ) + # update scores (conjugated to put the worst scores at the top) + scores = 1.0 - torch.concat([ F.softmax(logits[i, :], dim=0)[token, None] for i, token in enumerate(input_ids) ]) - logits, _ = self.forward_lambda( ids, step=step, temperature=annealed_temperature ) - filtered_logits = top_k( logits ) - sampled_ids = gumbel_sample( filtered_logits, temperature=annealed_temperature, dim=-1 ) + # print( timestep, steps_until_x0, noise_p, masked_tokens_n, temperature, input_ids, scores ) - is_masked = ids == self.mask_token - ids = torch.where( is_masked, sampled_ids, ids ) - - probs_without_temperature = logits.softmax(dim = -1) - - scores = 1 - probs_without_temperature.gather(2, sampled_ids[..., None]) - scores = rearrange(scores, '... 1 -> ...') - #scores = scores.to(dtype=torch.float64).masked_fill(~is_masked, -1e5) - - """ - if step + 1 == self.max_steps: - break - - # lifted from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39 - # create next input sequence - mask = (ids == self.mask_token) - mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(self.device) - mask_len = torch.maximum( - torch.Tensor([1]).to(self.device), - torch.minimum( torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len ) - )[0].squeeze() - - logits = torch.log_softmax(logits, dim=-1) - sampled_logits = torch.squeeze(torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) - sampled_ids = torch.where(mask, sampled_ids, ids) - sampled_logits = torch.where(mask, sampled_logits, +np.inf).float() - - confidence = add_gumbel_noise(sampled_logits, annealed_temperature, self.device) - sorted_confidence, _ = torch.sort(confidence, axis=-1) - cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] - masking = (confidence <= cut_off) - - ids = torch.where(masking, self.mask_token, sampled_ids) - """ - - return sampled_ids[0] \ No newline at end of file + return input_ids