diff --git a/vall_e/__main__.py b/vall_e/__main__.py index e1cd269..fab6a28 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -17,11 +17,17 @@ def main(): parser.add_argument("--max-ar-steps", type=int, default=6 * 75) parser.add_argument("--ar-temp", type=float, default=1.0) parser.add_argument("--nar-temp", type=float, default=1.0) + + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=0) + parser.add_argument("--repetition-penalty", type=float, default=1.0) + parser.add_argument("--length-penalty", type=float, default=0.0) + parser.add_argument("--device", default="cuda") args = parser.parse_args() tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device ) - tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp ) + tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, length_penalty=args.length_penalty ) if __name__ == "__main__": main() diff --git a/vall_e/inference.py b/vall_e/inference.py index fb10d8c..7d20ea1 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -125,9 +125,9 @@ class TTS(): return res @torch.inference_mode() - def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path=None ): + def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, length_penalty=0.0, out_path=None ): if out_path is None: - out_path = f"./data/{text}.wav" + out_path = f"./data/{cfg.start_time}.wav" prom = self.encode_audio( references ) phns = self.encode_text( text ) @@ -136,9 +136,9 @@ class TTS(): phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp): - resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp) + resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty) resps_list = [r.unsqueeze(-1) for r in resps_list] - resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp) + resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty) wav, sr = qnt.decode_to_file(resps_list[0], out_path) diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 850909c..95adbea 100755 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -86,6 +86,10 @@ class AR(Base): resps_list: list[Tensor] | None = None, max_steps: int = 1000, sampling_temperature: float = 1.0, + sampling_top_k: int = -100, + sampling_top_p: float = 1.0, + sampling_repetition_penalty: float = 1.0, + sampling_length_penalty: float = 0.0, ): if resps_list is not None: if self.interleave: @@ -121,6 +125,10 @@ class AR(Base): resps_list=self._unsqueeze_list(resps_list), quant_levels=None, sampling_temperature=sampling_temperature, + sampling_top_p=sampling_top_p, + sampling_top_k=sampling_top_k, + sampling_repetition_penalty=sampling_repetition_penalty, + sampling_length_penalty=sampling_length_penalty, state=state ) @@ -193,9 +201,9 @@ def example_usage(): """ model = AR(**kwargs).to(device) + steps = 500 optimizer = ml.Prodigy(model.parameters(), lr=1.0) engine = Engine(model=model, optimizer=optimizer) - steps = 500 def sample( name, steps=600 ): engine.eval() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 8371950..3caa115 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -71,6 +71,10 @@ class AR_NAR(Base): resps_list: list[Tensor] | None = None, max_steps: int = 1000, sampling_temperature: float = 0.0, + sampling_top_k: int = -100, + sampling_top_p: float = 1.0, + sampling_repetition_penalty: float = 1.0, + sampling_length_penalty: float = 0.0, ): device = text_list[0].device batch_size = len(text_list) @@ -120,6 +124,10 @@ class AR_NAR(Base): prev_list, quant_levels=quant_levels, sampling_temperature=sampling_temperature, + sampling_top_p=sampling_top_p, + sampling_top_k=sampling_top_k, + sampling_repetition_penalty=sampling_repetition_penalty, + sampling_length_penalty=sampling_length_penalty, ) prev_list = [ @@ -146,6 +154,10 @@ class AR_NAR(Base): proms_list, self._unsqueeze_list(resps_list), sampling_temperature=sampling_temperature, + sampling_top_p=sampling_top_p, + sampling_top_k=sampling_top_k, + sampling_repetition_penalty=sampling_repetition_penalty, + sampling_length_penalty=sampling_length_penalty, state=state ) @@ -221,6 +233,7 @@ def example_usage(): print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + @torch.inference_mode() def sample( name, steps=600 ): engine.eval() resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 ) @@ -245,7 +258,7 @@ def example_usage(): tqdm.write(f"{stats}") - #sample("init", 75) + sample("init", 75) train() sample("final") diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0d7c6df..b40f528 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -49,6 +49,63 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): m = m.to(x) return x, m +# Simple filter to modify a token's probability if it shows up in the past +# To-do: have its effect decay based on distance +def reptition_penalize( logits, previous, factor=1.0 ): + if factor == 1.0: + return logits + + priors = set(previous.tolist()) + + for token in priors: + logits[:, token] /= factor + + return logits + +# Simple "filter" that modifies the logit for the stop token, based on the sequence length +def length_penalize( logits, length, factor=0.0, token=-1 ): + if factor == 0.0: + return logits + + logits[:, token] /= (length ** factor) + return logits + +# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 +def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens per batch example in the output + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens > 1: + # Keep at least min_tokens (set to min_tokens-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + return logits + + +# automagically parses a batch-list and returns it as a list class Embedding(nn.Embedding): def forward(self, x_list: list[Tensor]) -> list[Tensor]: if len(x_list) == 0: @@ -56,7 +113,6 @@ class Embedding(nn.Embedding): return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) - class MultiEmbedding(nn.Embedding): """ This embedding sums embeddings on different levels. @@ -257,30 +313,6 @@ class Base(nn.Module): ignore_index=self.ignore_index, ) - @overload - def forward( - self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor], - targ_list: list[Tensor] | None = None, - quant_levels: Tensor | None = None, - sampling_temperature: float = 1.0, - ) -> Tensor: - ... - - @overload - def forward( - self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor], - targ_list: list[Tensor] | None = None, - quant_levels: Tensor | None = None, - sampling_temperature: float = 1.0, - ) -> list[Tensor]: - ... - def forward( self, text_list: list[Tensor], @@ -290,6 +322,10 @@ class Base(nn.Module): quant_levels: Tensor | None = None, sampling_temperature: float = 1.0, + sampling_top_k: int = -100, + sampling_top_p: float = 1.0, + sampling_repetition_penalty: float = 1.0, + sampling_length_penalty: float = 0.0, state: dict | None = None, ): @@ -330,13 +366,10 @@ class Base(nn.Module): x = self.classifier(x) * m # Remove padding - h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))] + logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] # compute loss if the target is given if targ_list is not None: - if any([l == 0 for l in map(len, targ_list)]): - raise ValueError("Cannot compute loss given empty targ_list.") - ignore_sep = torch.tensor(self.ignore_index, device=device) # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against @@ -365,36 +398,49 @@ class Base(nn.Module): targ_list[i][-1] = self.stop_token # create the new target sequence to compute the loss against - y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) + target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) ) + inputs = torch.cat( logits ) self.loss = dict( # "nll" was in the original implementation and should actually just be called something else - nll=F.cross_entropy( - torch.cat(h_list), # input / predicted logits - torch.cat(y_list), # target / ground truth - ignore_index=self.ignore_index, - ) + nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) ) self.stats = dict( - acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ), - precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ), + acc = self.accuracy_metric( inputs, target ), + precision = self.precision_metric( inputs, target ), ) + + return logits + - # return the entire generated token string - return_all = False - if return_all: - logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))] - # return the entire generated response - elif quant_levels is not None: - logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))] - # return the last chunkwise piece - elif self.causal and self.recurrent_chunk_size > 0: - logits = [hi[-self.recurrent_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))] - # return just the last code - else: - logits = [ hi[-1:] for hi in h_list ] - return [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ] + # (NAR) return the entire generated response + if quant_levels is not None: + logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] + # (AR chunkwise) return the last chunkwise piece + elif self.causal and self.recurrent_chunk_size > 0: + logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ] + # (AR) return just the last code + else: + logits = [ logit[-1:] for logit in logits ] + + # perform repetition penalizing + logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ] + + # perform length penalizing + if quant_levels is None and self.causal: + logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] + + # scale our logits by the temp + logits = [ logit / sampling_temperature for logit in logits ] + + # perform top_k/top_p filtering of our logits + if sampling_top_k > 0: + logits = [ top_k_top_p_filtering(logit, top_k=sampling_top_k, top_p=sampling_top_p) for logit in logits ] + + # and sample + # the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax + return [ Categorical(logits=logit).sample() for logit in logits ] def example_usage(): from ..config import cfg diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index 194d8cd..fde734d 100755 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -57,6 +57,10 @@ class NAR(Base): proms_list: list[Tensor], resps_list: list[Tensor], sampling_temperature: float = 0.2, + sampling_top_k: int = -100, + sampling_top_p: float = 1.0, + sampling_repetition_penalty: float = 1.0, + sampling_length_penalty: float = 0.0, ): """ Args: @@ -112,6 +116,10 @@ class NAR(Base): prev_list, quant_levels=quant_levels, sampling_temperature=sampling_temperature, + sampling_top_p=sampling_top_p, + sampling_top_k=sampling_top_k, + sampling_repetition_penalty=sampling_repetition_penalty, + sampling_length_penalty=sampling_length_penalty, ) prev_list = [