ugh
This commit is contained in:
parent
ad7cfffc00
commit
6b76419123
|
@ -1352,7 +1352,7 @@ class Base(nn.Module):
|
|||
target = []
|
||||
task_type = "tts"
|
||||
|
||||
causal = False
|
||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||
dropout_mask = None
|
||||
for name, input in batch:
|
||||
if name == "dropout_mask":
|
||||
|
@ -1362,11 +1362,12 @@ class Base(nn.Module):
|
|||
if name == "task":
|
||||
task_type = input
|
||||
task_list.append( input )
|
||||
if task_type in ["len", "stt"]:
|
||||
causal = True
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
||||
elif name == "resp":
|
||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_type in ["len", "stt"])
|
||||
# mask found, apply it
|
||||
if dropout_mask is not None:
|
||||
# if mask use original token, else ignore
|
||||
|
@ -1438,6 +1439,8 @@ class Base(nn.Module):
|
|||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||
"""
|
||||
|
||||
# to-do: use NAR-len training and better causal-awareness
|
||||
info = {}
|
||||
batch_size = len( inputs )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user