This commit is contained in:
mrq 2024-11-13 09:54:20 -06:00
parent ad7cfffc00
commit 6b76419123

View File

@ -1352,7 +1352,7 @@ class Base(nn.Module):
target = []
task_type = "tts"
causal = False
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
dropout_mask = None
for name, input in batch:
if name == "dropout_mask":
@ -1362,11 +1362,12 @@ class Base(nn.Module):
if name == "task":
task_type = input
task_list.append( input )
if task_type in ["len", "stt"]:
causal = True
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
elif name == "resp":
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_type in ["len", "stt"])
# mask found, apply it
if dropout_mask is not None:
# if mask use original token, else ignore
@ -1438,6 +1439,8 @@ class Base(nn.Module):
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
"""
# to-do: use NAR-len training and better causal-awareness
info = {}
batch_size = len( inputs )