uh
This commit is contained in:
parent
197d517181
commit
c0ac84c795
|
@ -905,9 +905,9 @@ class Base(nn.Module):
|
|||
|
||||
quant_level = quant_levels[bi] if quant_levels is not None else None
|
||||
|
||||
if name in ["text" ]:
|
||||
if name == "text":
|
||||
text_batch.append( input )
|
||||
elif name == "prom": # and (quant_level is None or quant_level == 0) and not self.config.audio_embedding_sums:
|
||||
elif name == "prom":
|
||||
prom_batch.append( input[:, quant_level] if quant_level is not None else input )
|
||||
elif name == "targ":
|
||||
resp_batch.append( input )
|
||||
|
@ -1001,23 +1001,15 @@ class Base(nn.Module):
|
|||
logits_prom = []
|
||||
logits_resp = []
|
||||
|
||||
# trim logits to each section
|
||||
for i, logit in enumerate(logits):
|
||||
text_len = target_text_list[i].shape[0] if target_text_list else 0
|
||||
prom_len = target_prom_list[i].shape[0] if target_prom_list else 0
|
||||
resp_len = target_resp_list[i].shape[0] if target_resp_list else 0
|
||||
text_len = target_text_list[i].shape[0]
|
||||
prom_len = target_prom_list[i].shape[0]
|
||||
resp_len = target_resp_list[i].shape[0]
|
||||
|
||||
if target_text_list:
|
||||
logit_text = logit[:text_len]
|
||||
logits_text.append( logit_text )
|
||||
|
||||
# + 1 to include separator
|
||||
if target_prom_list:
|
||||
logit_prom = logit[text_len+1:text_len+1+prom_len]
|
||||
logits_prom.append( logit_prom )
|
||||
|
||||
if target_resp_list:
|
||||
logit_resp = logit[-resp_len:]
|
||||
logits_resp.append( logit_resp )
|
||||
logits_text.append( logit[:text_len] )
|
||||
logits_prom.append( logit[text_len+1:text_len+1+prom_len] ) # + 1 to include separator
|
||||
logits_resp.append( logit[-resp_len:] )
|
||||
|
||||
|
||||
# modify only for the AR so it can properly behave like a transformer
|
||||
|
|
Loading…
Reference in New Issue
Block a user