diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 50035cf..70aaa38 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -276,7 +276,7 @@ class AR_NAR(Base): # get next in sequence for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): - # + # it would technically be faster to just append the new token's embedding to the inputs, but there's a VERY small performance gain from doing it, so it's not worth it text_list = [ sequence_list[i] if task in text_task else text_list[i] for i, task in enumerate(task_list) ] resps_list = [ sequence_list[i] if task not in text_task else resps_list[i] for i, task in enumerate(task_list) ] @@ -291,6 +291,7 @@ class AR_NAR(Base): quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ) + # to-do: find an elegant way to write this if state is not None: logits, state = super().forward( inputs=inputs, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 6149800..ee4a0c8 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -15,6 +15,7 @@ import torch.nn.functional as F import random import numpy as np import re +from time import perf_counter from typing import Literal, overload, Optional, Tuple from functools import partial @@ -1479,8 +1480,8 @@ class Base(nn.Module): elif self.causal: logits = [ logit[-self.causal_size:] for logit in logits ] - devices = [ logit.device for logit in logits ] - logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] + # this might actually slow things down a bit slightly-er? + #logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ] # (NAR) disable stop token if quant_levels is not None and "ar" in self.capabilities: @@ -1494,12 +1495,12 @@ class Base(nn.Module): return [ logit.argmax(dim=1) for logit in logits ] # perform repetition penalizing - if "len" not in self.capabilities and prev_list is not None: + if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0: # to-do: figure out a faster way to handle tolist() logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] # (AR) perform length penalizing - if quant_levels is None and self.causal and prev_list is not None: + if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] # perform top_k/top_p filtering of our logits