fixes, again
This commit is contained in:
parent
94cf81d38c
commit
341e19162b
|
@ -1055,8 +1055,6 @@ class Dataset(_Dataset):
|
|||
|
||||
# Base STT (<resp> => <text>)
|
||||
elif task == "stt":
|
||||
# easier to just keep it instead of wrangling around trying to remove it
|
||||
# it might also help to provide a guidance prompt but who knows right now
|
||||
proms = [
|
||||
task
|
||||
]
|
||||
|
|
|
@ -135,9 +135,8 @@ class AR_NAR(Base):
|
|||
for j, prom in enumerate( proms ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
|
||||
# apply token dropout error compensation
|
||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
|
|
|
@ -280,7 +280,7 @@ class Classifiers(nn.Module):
|
|||
xi = [
|
||||
#x if l == 0 else
|
||||
x if x.shape[-1] == max_size else
|
||||
torch.cat( [ x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
|
||||
torch.cat( [x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
|
||||
for x, l in zip(xi, levels)
|
||||
]
|
||||
return torch.stack( xi )
|
||||
|
@ -1057,7 +1057,7 @@ class Base(nn.Module):
|
|||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ])
|
||||
input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ])
|
||||
|
||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||
elif name == "tone" and self.tones_emb is not None:
|
||||
|
@ -1164,7 +1164,7 @@ class Base(nn.Module):
|
|||
|
||||
# list of tokens
|
||||
if not isinstance(input, torch.Tensor):
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1
|
||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||
|
||||
# interleaved model
|
||||
if self.interleave and name == "resp":
|
||||
|
|
Loading…
Reference in New Issue
Block a user