ugh
This commit is contained in:
parent
ad7cfffc00
commit
6b76419123
|
@ -1352,7 +1352,7 @@ class Base(nn.Module):
|
||||||
target = []
|
target = []
|
||||||
task_type = "tts"
|
task_type = "tts"
|
||||||
|
|
||||||
causal = False
|
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities)
|
||||||
dropout_mask = None
|
dropout_mask = None
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
if name == "dropout_mask":
|
if name == "dropout_mask":
|
||||||
|
@ -1362,11 +1362,12 @@ class Base(nn.Module):
|
||||||
if name == "task":
|
if name == "task":
|
||||||
task_type = input
|
task_type = input
|
||||||
task_list.append( input )
|
task_list.append( input )
|
||||||
|
if task_type in ["len", "stt"]:
|
||||||
|
causal = True
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
target.append( torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] ) )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
causal = (quant_level == 0 and "ar" in self.capabilities) or ("nar" not in self.capabilities) or (task_type in ["len", "stt"])
|
|
||||||
# mask found, apply it
|
# mask found, apply it
|
||||||
if dropout_mask is not None:
|
if dropout_mask is not None:
|
||||||
# if mask use original token, else ignore
|
# if mask use original token, else ignore
|
||||||
|
@ -1438,6 +1439,8 @@ class Base(nn.Module):
|
||||||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# to-do: use NAR-len training and better causal-awareness
|
||||||
info = {}
|
info = {}
|
||||||
batch_size = len( inputs )
|
batch_size = len( inputs )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user