diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 2b65e33..46d0661 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -121,12 +121,14 @@ class AR(Base): stopped = torch.zeros(batch_size, device=device).bool() state = {} if cfg.inference.recurrent_forward else None + sampling_beam_width_use_logs = True + scores = [ 1.0 ] * sampling_beam_width if self.interleave: max_steps *= self.n_prom_levels + # get next in sequence for n in trange(max_steps // max(1, self.recurrent_chunk_size)): - # get next in sequence logits = super().forward( text_list=text_list, @@ -149,12 +151,23 @@ class AR(Base): beam_width=sampling_beam_width, ) - # first step, expand batch # we do it here because the sampler will already expand our logits list - if sampling_beam_width > 0 and batch_size == 1: - text_list = text_list * sampling_beam_width - proms_list = proms_list * sampling_beam_width - resps_list = resps_list * sampling_beam_width + if sampling_beam_width > 0: + # expand tuple + r, s = r + # first step, expand batch + if batch_size == 1: + batch_size *= sampling_beam_width + text_list = text_list * sampling_beam_width + proms_list = proms_list * sampling_beam_width + sequence_list = sequence_list * sampling_beam_width + stopped = torch.zeros(batch_size, device=device).bool() + + # update scores + if sampling_beam_width_use_logs: + scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ] + else: + scores = [ scores[i] * score for i, score in enumerate(s) ] # append tokens for i, ri in enumerate(r): @@ -168,6 +181,16 @@ class AR(Base): if stopped.all().item(): break + # pick the best scoring candidate + # desu this is always going to be candidate 0 + if sampling_beam_width and len(scores) > 0: + best_idx, best_score = (0, 0) + for idx, score in enumerate(scores): + if best_score > score: + best_idx, best_score = idx, score + + sequence_list = [sequence_list[best_idx]] + res = [self._prune(r) for r in resps_list] if self.interleave: res = [self._deinterleave(r) for r in res] diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 730b3d8..8f3eba9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -5,6 +5,7 @@ import torch from torch.nn.utils.rnn import pad_sequence import random +import math from einops import rearrange from torch import Tensor from tqdm import trange @@ -151,12 +152,14 @@ class AR_NAR(Base): state = {} if cfg.inference.recurrent_forward else None + sampling_beam_width_use_logs = True + scores = [ 1.0 ] * sampling_beam_width + if self.interleave: max_steps *= self.n_prom_levels + # get next in sequence for n in trange(max_steps // max(1, self.recurrent_chunk_size)): - # get next in sequence - resps_list = self._unsqueeze_list(sequence_list) logits = super().forward( text_list=text_list, @@ -179,14 +182,23 @@ class AR_NAR(Base): beam_width=sampling_beam_width, ) - # first step, expand batch # we do it here because the sampler will already expand our logits list - if sampling_beam_width > 0 and batch_size == 1: - batch_size *= sampling_beam_width - text_list = text_list * sampling_beam_width - proms_list = proms_list * sampling_beam_width - sequence_list = sequence_list * sampling_beam_width - stopped = torch.zeros(batch_size, device=device).bool() + if sampling_beam_width > 0: + # expand tuple + r, s = r + # first step, expand batch + if batch_size == 1: + batch_size *= sampling_beam_width + text_list = text_list * sampling_beam_width + proms_list = proms_list * sampling_beam_width + sequence_list = sequence_list * sampling_beam_width + stopped = torch.zeros(batch_size, device=device).bool() + + # update scores + if sampling_beam_width_use_logs: + scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ] + else: + scores = [ scores[i] * score for i, score in enumerate(s) ] # append tokens for i, ri in enumerate(r): @@ -199,9 +211,15 @@ class AR_NAR(Base): if stopped.all().item(): break - # pick the first candidate - if sampling_beam_width: - sequence_list = sequence_list[:1] + # pick the best scoring candidate + # desu this is always going to be candidate 0 + if sampling_beam_width and len(scores) > 0: + best_idx, best_score = (0, 0) + for idx, score in enumerate(scores): + if best_score > score: + best_idx, best_score = idx, score + + sequence_list = [sequence_list[best_idx]] return [self._prune(r) for r in sequence_list] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 5f1022f..f8aacf6 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -119,6 +119,22 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf" return logits +# picks the top K tokens amongst a batch of logits +# logits: [Tensor] list of logits +# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from +def top_k_logits_list( logits_list, k ): + # ( batch, tokens ) => ( batch x tokens ) + logits = torch.cat( logits_list ) + candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits + for i, index in enumerate(candidates): + t = [] + N = np.prod(logits.size()) + for n in logits.size(): + N //= n + t.append(index // N) + index %= N + candidates[i] = tuple(t) + return candidates # automagically parses a batch-list and returns it as a list class Embedding(nn.Embedding): @@ -128,7 +144,7 @@ class Embedding(nn.Embedding): return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) -class MultiEmbedding(nn.Embedding): +class MultiEmbedding(nn.Module): """ This embedding sums embeddings on different levels. """ @@ -468,21 +484,13 @@ class Base(nn.Module): # do beam search (naive implementation) # picks the top-k across all batches, and re-batches those resultant tokens - # this doesn't do any other mumbo with previous logits + # returns the logit scores as well to be P-concatted with the previous scores # to-do: not naively implement beam searching if beam_width > 1: - # ( batch, tokens ) => ( batch x tokens ) - flattened = torch.cat( logits ) - candidates = list(torch.topk(flattened.flatten(), beam_width).indices.tolist()) # perform top-k across all logits - for i, index in enumerate(candidates): - t = [] - N = np.prod(flattened.size()) - for n in flattened.size(): - N //= n - t.append(index // N) - index %= N - candidates[i] = tuple(t) - return [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] #, [ logits[batch] for batch, token in candidates ] + candidates = top_k_logits_list( logits, beam_width ) + res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ] + scores = [ logits[batch].flatten()[token] for batch, token in candidates ] + return res, scores # and sample # the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax diff --git a/vall_e/train.py b/vall_e/train.py index f9f503d..289b925 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -152,10 +152,10 @@ def run_eval(engines, disabled_engines, eval_name, dl): stats = {k: sum(v) / len(v) for k, v in stats.items()} - engines_stats.update({ f'{name}.{eval_name}': stats }) - - iteration = engines.global_step - engines_stats['it'] = iteration + engines_stats = { + f'{name}.{eval_name}': stats, + "it": engines.global_step, + } #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")