diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 0205199..5be47a6 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1352,7 +1352,7 @@ class Base(nn.Module): target = [] task_type = "tts" - causal = False + causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) dropout_mask = None for name, input in batch: if name == "dropout_mask": @@ -1362,11 +1362,12 @@ class Base(nn.Module): if name == "task": task_type = input task_list.append( input ) + if task_type in ["len", "stt"]: + causal = True elif name == "prom": proms = [ input ] if isinstance(input, torch.Tensor) else input target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) ) elif name == "resp": - causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_type in ["len", "stt"]) # mask found, apply it if dropout_mask is not None: # if mask use original token, else ignore @@ -1438,6 +1439,8 @@ class Base(nn.Module): # + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly # + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes) """ + + # to-do: use NAR-len training and better causal-awareness info = {} batch_size = len( inputs )