fixes, again
This commit is contained in:
parent
94cf81d38c
commit
341e19162b
|
@ -1055,8 +1055,6 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# Base STT (<resp> => <text>)
|
# Base STT (<resp> => <text>)
|
||||||
elif task == "stt":
|
elif task == "stt":
|
||||||
# easier to just keep it instead of wrangling around trying to remove it
|
|
||||||
# it might also help to provide a guidance prompt but who knows right now
|
|
||||||
proms = [
|
proms = [
|
||||||
task
|
task
|
||||||
]
|
]
|
||||||
|
|
|
@ -135,9 +135,8 @@ class AR_NAR(Base):
|
||||||
for j, prom in enumerate( proms ):
|
for j, prom in enumerate( proms ):
|
||||||
if not isinstance( prom, torch.Tensor ):
|
if not isinstance( prom, torch.Tensor ):
|
||||||
continue
|
continue
|
||||||
|
if quant_level >= prom.shape[-1]:
|
||||||
if quant_level >= prom.shape[-1]:
|
quant_levels[i] = prom.shape[-1] - 1
|
||||||
quant_levels[i] = prom.shape[-1] - 1
|
|
||||||
|
|
||||||
# apply token dropout error compensation
|
# apply token dropout error compensation
|
||||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||||
|
|
|
@ -280,7 +280,7 @@ class Classifiers(nn.Module):
|
||||||
xi = [
|
xi = [
|
||||||
#x if l == 0 else
|
#x if l == 0 else
|
||||||
x if x.shape[-1] == max_size else
|
x if x.shape[-1] == max_size else
|
||||||
torch.cat( [ x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
|
torch.cat( [x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
|
||||||
for x, l in zip(xi, levels)
|
for x, l in zip(xi, levels)
|
||||||
]
|
]
|
||||||
return torch.stack( xi )
|
return torch.stack( xi )
|
||||||
|
@ -1057,7 +1057,7 @@ class Base(nn.Module):
|
||||||
embedding = self.langs_emb( input )
|
embedding = self.langs_emb( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
input_prom = torch.cat([ prom for prom in proms if isinstance(input, torch.Tensor) ])
|
input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ])
|
||||||
|
|
||||||
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
|
||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
|
@ -1164,7 +1164,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# list of tokens
|
# list of tokens
|
||||||
if not isinstance(input, torch.Tensor):
|
if not isinstance(input, torch.Tensor):
|
||||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.tensor) ] ) + 1
|
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||||
|
|
||||||
# interleaved model
|
# interleaved model
|
||||||
if self.interleave and name == "resp":
|
if self.interleave and name == "resp":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user