overhauled how the right resp level / classifier gets picked to avoid cringemath
This commit is contained in:
parent
269648605e
commit
910033343c
|
@ -260,7 +260,7 @@ class ModelExperimentalSettings:
|
||||||
|
|
||||||
masking_train_p: float = 0.0 # odds of training with masking
|
masking_train_p: float = 0.0 # odds of training with masking
|
||||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||||
masking_separate_embeddings: bool = False
|
masking_separate_embeddings: bool = False # to-do: explain
|
||||||
|
|
||||||
# classifier-free guidance shit
|
# classifier-free guidance shit
|
||||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||||
|
|
|
@ -883,6 +883,10 @@ def example_usage():
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
cfg.model.experimental.masking_train_p = 0.5
|
||||||
|
cfg.hyperparameters.batch_size = 1
|
||||||
|
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
|
|
||||||
|
@ -896,7 +900,6 @@ def example_usage():
|
||||||
|
|
||||||
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||||
batch_size = cfg.hyperparameters.batch_size
|
batch_size = cfg.hyperparameters.batch_size
|
||||||
cfg.model.experimental.masking_train_p = 1.0
|
|
||||||
|
|
||||||
text_list = [ text ] * batch_size
|
text_list = [ text ] * batch_size
|
||||||
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
||||||
|
|
|
@ -46,6 +46,9 @@ LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||||
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
summed_embeddings_task = [ "stt" ]
|
||||||
|
special_tasks = [ "len", "stt" ]
|
||||||
|
|
||||||
def _dropout_mask( input, p=None ):
|
def _dropout_mask( input, p=None ):
|
||||||
# cosine scheduling
|
# cosine scheduling
|
||||||
if p is None:
|
if p is None:
|
||||||
|
@ -182,78 +185,26 @@ class AudioEmbedding(nn.Module):
|
||||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||||
token_dim: int, # dimensionality of the embedding
|
token_dim: int, # dimensionality of the embedding
|
||||||
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||||
external_mode: str | None = None, # "exclusive" | "inclusive", whether to include the original audio backend's embeddings
|
l_names: list[str] = [], # names to map to indices
|
||||||
|
|
||||||
capabilities: list[str] | None = None, # helper shit
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# array of embeddings
|
# array of embeddings
|
||||||
# proms are [0, resp_levels]
|
# proms are [0, resp_levels]
|
||||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||||
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
# further experimentation is needed to see if this actually is useful
|
# further experimentation is needed to see if this actually is useful
|
||||||
self.sums = sums
|
self.sums = sums
|
||||||
|
#
|
||||||
|
self.names = l_names
|
||||||
|
|
||||||
self.external_mode = external_mode
|
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, name: str | None = None, sums = None ) -> Tensor:
|
||||||
self.capabilities = capabilities
|
|
||||||
|
|
||||||
# set initial weights to zero
|
|
||||||
if self.external_mode == "inclusive":
|
|
||||||
for i, embedding in enumerate(self.embeddings):
|
|
||||||
embedding.weight = torch.nn.Parameter(torch.zeros( embedding.weight.shape ))
|
|
||||||
|
|
||||||
def external_embeddings(self, input: Tensor, quant_level: int | None = None ) -> Tensor:
|
|
||||||
if quant_level is None:
|
|
||||||
quant_level = 0 if input.dim() == 1 else input.shape[-1] - 1
|
|
||||||
|
|
||||||
# for AR, trim any stop tokens
|
|
||||||
has_stop_token = False
|
|
||||||
|
|
||||||
# this block apparently doesn't work
|
|
||||||
"""
|
|
||||||
if quant_level == 0:
|
|
||||||
stop_token = self.embeddings[0].weight.shape[0] - 1
|
|
||||||
stop_token_indices = (input == stop_token).nonzero()
|
|
||||||
has_stop_token = len(stop_token_indices) > 0
|
|
||||||
|
|
||||||
if has_stop_token:
|
|
||||||
input = input[:stop_token_indices.min().item()]
|
|
||||||
"""
|
|
||||||
has_stop_token = False
|
|
||||||
|
|
||||||
if quant_level == 0:
|
|
||||||
stop_token = self.embeddings[0].weight.shape[0] - 1
|
|
||||||
has_stop_token = input[-1] == stop_token
|
|
||||||
|
|
||||||
if has_stop_token:
|
|
||||||
input = input[:-1]
|
|
||||||
|
|
||||||
# get external embedding
|
|
||||||
embedding = encode_as_embedding( input, quant_level, sums=self.sums ).to(device=input.device, dtype=self.embeddings[quant_level].weight.dtype)
|
|
||||||
# resize if necessary (in case the external embeddings do not match our model dim)
|
|
||||||
embedding = ml.resize_weight( embedding, self.embeddings[quant_level].weight.shape[-1], dim=-1, random=False )
|
|
||||||
|
|
||||||
# reintroduce stop token
|
|
||||||
if has_stop_token:
|
|
||||||
stop_token = self.internal_forward( torch.tensor([stop_token]).to(device=input.device, dtype=torch.int16), 0 )
|
|
||||||
embedding = torch.concat( [ embedding, stop_token ] )
|
|
||||||
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
def internal_forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
|
|
||||||
if offset is None:
|
|
||||||
# prom
|
|
||||||
if self.capabilities is None:
|
|
||||||
offset = 0
|
|
||||||
elif "nar" not in self.capabilities:
|
|
||||||
offset = 0
|
|
||||||
elif quant_level > 0:
|
|
||||||
offset = 1
|
|
||||||
|
|
||||||
if sums is None:
|
if sums is None:
|
||||||
sums = self.sums
|
sums = self.sums
|
||||||
|
|
||||||
|
# handle mapping from name
|
||||||
|
if name in self.names:
|
||||||
|
offset = self.names.index( name )
|
||||||
|
|
||||||
if quant_level is None:
|
if quant_level is None:
|
||||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||||
|
|
||||||
|
@ -265,17 +216,6 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, xi: Tensor, offset: int | None = None, quant_level: int | None = None, sums = None ) -> Tensor:
|
|
||||||
x = self.internal_forward( xi, offset = offset, quant_level = quant_level, sums = sums ) if self.external_mode != "exclusive" or xi.shape[0] == 0 else None
|
|
||||||
|
|
||||||
if self.external_mode and xi.shape[0] > 0:
|
|
||||||
external_embeddings = self.external_embeddings( xi, quant_level = quant_level )
|
|
||||||
if self.external_mode == "exclusive":
|
|
||||||
return external_embeddings
|
|
||||||
x += external_embeddings
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
# time-step embedding
|
# time-step embedding
|
||||||
# for the NAR-len, since it probably most likely requires encoding the timestep
|
# for the NAR-len, since it probably most likely requires encoding the timestep
|
||||||
class TimeEmbedding(nn.Module):
|
class TimeEmbedding(nn.Module):
|
||||||
|
@ -304,14 +244,32 @@ class Classifiers(nn.Module):
|
||||||
self,
|
self,
|
||||||
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
|
||||||
token_dim: int, # dimensionality of the embedding
|
token_dim: int, # dimensionality of the embedding
|
||||||
|
l_names: list[str] | None = None, # list of names to map to each classifier
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
|
self.proj = nn.ModuleList([nn.Linear(token_dim, n_tokens) for n_tokens in l_tokens])
|
||||||
|
self.names = l_names
|
||||||
|
|
||||||
def forward(self, xi: Tensor, levels: list[int] ) -> Tensor:
|
def indices(
|
||||||
|
self,
|
||||||
|
names
|
||||||
|
):
|
||||||
|
if isinstance( names[-1], int ):
|
||||||
|
return names
|
||||||
|
return [ self.names.index(name) for name in names ]
|
||||||
|
|
||||||
|
def forward(self, xi: Tensor, levels: list[int] | None = None, names: list[str] | None = None ) -> Tensor:
|
||||||
dtype = xi.dtype
|
dtype = xi.dtype
|
||||||
device = xi.device
|
device = xi.device
|
||||||
|
|
||||||
|
if levels and isinstance( levels[-1], str ):
|
||||||
|
names = levels
|
||||||
|
levels = []
|
||||||
|
|
||||||
|
# map names to levels
|
||||||
|
if names and not levels:
|
||||||
|
levels = [ self.names.index(name) for name in names ]
|
||||||
|
|
||||||
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
xi = [ self.proj[l]( x ) for x, l in zip(xi, levels) ]
|
||||||
# pad if needed
|
# pad if needed
|
||||||
# to-do: validate that this causes ZERO issues
|
# to-do: validate that this causes ZERO issues
|
||||||
|
@ -349,11 +307,11 @@ class Metrics(nn.Module):
|
||||||
ignore_index=ignore_index,
|
ignore_index=ignore_index,
|
||||||
) for n_tokens in l_tokens ])
|
) for n_tokens in l_tokens ])
|
||||||
|
|
||||||
def calc_accuracy( self, inputs, targets, quant_levels ):
|
def calc_accuracy( self, inputs, targets, classifier_levels ):
|
||||||
return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs )
|
||||||
|
|
||||||
def calc_precision( self, inputs, targets, quant_levels ):
|
def calc_precision( self, inputs, targets, classifier_levels ):
|
||||||
return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, quant_levels ) ] ) / len( inputs )
|
return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs )
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -486,21 +444,29 @@ class Base(nn.Module):
|
||||||
if "nar" not in self.capabilities:
|
if "nar" not in self.capabilities:
|
||||||
n_resp_tokens = n_audio_tokens + 1
|
n_resp_tokens = n_audio_tokens + 1
|
||||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||||
|
resp_l_names = [f'AR:{i}:{i}' for i in range( self.n_resp_levels )]
|
||||||
# NAR-len model
|
# NAR-len model
|
||||||
elif "len" in self.capabilities and masking_separate_embeddings:
|
elif "len" in self.capabilities and masking_separate_embeddings:
|
||||||
# +1 to include the stop or mask token
|
# +1 to include the stop or mask token
|
||||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
|
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||||
|
|
||||||
|
if masking_separate_embeddings:
|
||||||
|
l_tokens += [n_resp_tokens]
|
||||||
|
resp_l_names += ['NAR:0:0']
|
||||||
# AR+NAR model
|
# AR+NAR model
|
||||||
else:
|
else:
|
||||||
# +1 to include the stop or mask token
|
# +1 to include the stop or mask token
|
||||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||||
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
|
resp_l_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||||
|
|
||||||
|
classifier_l_names = resp_l_names + ["stt"]
|
||||||
|
|
||||||
self.unified_position_ids = unified_position_ids
|
self.unified_position_ids = unified_position_ids
|
||||||
self.interleave = interleave
|
self.interleave = interleave
|
||||||
self.layerskip = layerskip
|
self.layerskip = layerskip
|
||||||
self.special_tasks = [ "len", "stt" ]
|
|
||||||
self.inject_timestep_embedding = False # results in bad output
|
self.inject_timestep_embedding = False # results in bad output
|
||||||
self.masking_separate_embeddings = masking_separate_embeddings
|
self.masking_separate_embeddings = masking_separate_embeddings
|
||||||
|
|
||||||
|
@ -534,14 +500,11 @@ class Base(nn.Module):
|
||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding(
|
||||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||||
sums=audio_embedding_sums,
|
sums=audio_embedding_sums,
|
||||||
external_mode=audio_embedding_mode,
|
|
||||||
capabilities=None,
|
|
||||||
)
|
)
|
||||||
self.resps_emb = AudioEmbedding(
|
self.resps_emb = AudioEmbedding(
|
||||||
l_tokens, d_model,
|
l_tokens, d_model,
|
||||||
sums=audio_embedding_sums,
|
sums=audio_embedding_sums,
|
||||||
external_mode=audio_embedding_mode,
|
l_names=resp_l_names,
|
||||||
capabilities=self.capabilities,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.version >= 3:
|
if self.version >= 3:
|
||||||
|
@ -842,7 +805,7 @@ class Base(nn.Module):
|
||||||
self.metrics = None
|
self.metrics = None
|
||||||
else:
|
else:
|
||||||
self.classifier = None
|
self.classifier = None
|
||||||
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model )
|
self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model, l_names=classifier_l_names )
|
||||||
self.accuracy_metric = None
|
self.accuracy_metric = None
|
||||||
self.precision_metric = None
|
self.precision_metric = None
|
||||||
self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
|
self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
|
||||||
|
@ -1002,6 +965,7 @@ class Base(nn.Module):
|
||||||
quant_level = quant_levels[i] if quant_levels is not None else 0
|
quant_level = quant_levels[i] if quant_levels is not None else 0
|
||||||
task_type = task_list[i] if task_list is not None else "tts"
|
task_type = task_list[i] if task_list is not None else "tts"
|
||||||
timestep = time_list[i] if time_list is not None else None
|
timestep = time_list[i] if time_list is not None else None
|
||||||
|
classifier_level = None
|
||||||
|
|
||||||
# insert task type as a string
|
# insert task type as a string
|
||||||
inputs[i].append( ( "task", task_type ) )
|
inputs[i].append( ( "task", task_type ) )
|
||||||
|
@ -1012,7 +976,7 @@ class Base(nn.Module):
|
||||||
# Base-line TTS task
|
# Base-line TTS task
|
||||||
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
|
||||||
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
|
||||||
if f'<{task_type}>' in get_task_symmap() and task_type not in self.special_tasks:
|
if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
|
||||||
# insert the text prompt
|
# insert the text prompt
|
||||||
if text_list is not None and text_list[i] is not None:
|
if text_list is not None and text_list[i] is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
@ -1022,19 +986,22 @@ class Base(nn.Module):
|
||||||
# insert RVQ level guidance token if the model is versioned for it
|
# insert RVQ level guidance token if the model is versioned for it
|
||||||
if self.rvq_l_emb is not None and not self.interleave:
|
if self.rvq_l_emb is not None and not self.interleave:
|
||||||
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
|
||||||
|
|
||||||
|
classifier_level = "AR:0:0" if quant_level == 0 else f'NAR:{quant_level-1}:{quant_level}'
|
||||||
# insert input audio prompt
|
# insert input audio prompt
|
||||||
if proms_list is not None and proms_list[i] is not None:
|
if proms_list is not None and proms_list[i] is not None:
|
||||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
# insert tone token if we're trained for it
|
# insert tone token if we're trained for it
|
||||||
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
||||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||||
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
|
|
||||||
"""
|
|
||||||
# insert timestep token
|
# insert timestep token
|
||||||
if timestep is not None:
|
if timestep is not None:
|
||||||
|
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
|
||||||
|
"""
|
||||||
# store timestep information
|
# store timestep information
|
||||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||||
"""
|
"""
|
||||||
|
classifier_level = "NAR:0:0"
|
||||||
# insert the current output response
|
# insert the current output response
|
||||||
if resps_list is not None and resps_list[i] is not None:
|
if resps_list is not None and resps_list[i] is not None:
|
||||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||||
|
@ -1050,6 +1017,7 @@ class Base(nn.Module):
|
||||||
dropout_mask = _dropout_mask( resps_list[i], p )
|
dropout_mask = _dropout_mask( resps_list[i], p )
|
||||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||||
|
|
||||||
|
inputs[i].append( ("classifier_level", classifier_level) )
|
||||||
# Audio length prediction task
|
# Audio length prediction task
|
||||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||||
elif task_type == "len":
|
elif task_type == "len":
|
||||||
|
@ -1080,6 +1048,8 @@ class Base(nn.Module):
|
||||||
elif resps_list is not None and resps_list[i] is not None:
|
elif resps_list is not None and resps_list[i] is not None:
|
||||||
# yes this could be encoded better
|
# yes this could be encoded better
|
||||||
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
||||||
|
|
||||||
|
inputs[i].append( ("classifier_level", "stt") )
|
||||||
# Speech-to-Text prediction task
|
# Speech-to-Text prediction task
|
||||||
# Sequence: <resp><sep><rvq lvl><sep><text>
|
# Sequence: <resp><sep><rvq lvl><sep><text>
|
||||||
elif task_type == "stt":
|
elif task_type == "stt":
|
||||||
|
@ -1095,6 +1065,8 @@ class Base(nn.Module):
|
||||||
# insert the output text prompt
|
# insert the output text prompt
|
||||||
if text_list is not None and text_list[i] is not None:
|
if text_list is not None and text_list[i] is not None:
|
||||||
inputs[i].append( ( "text", text_list[i] ) )
|
inputs[i].append( ( "text", text_list[i] ) )
|
||||||
|
|
||||||
|
inputs[i].append( ("classifier_level", "stt") )
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Unrecognized task: {task_type}')
|
raise Exception(f'Unrecognized task: {task_type}')
|
||||||
return inputs
|
return inputs
|
||||||
|
@ -1131,7 +1103,6 @@ class Base(nn.Module):
|
||||||
if not token_dropout_rvq_levels:
|
if not token_dropout_rvq_levels:
|
||||||
token_dropout_rvq_levels = [1, self.resp_levels]
|
token_dropout_rvq_levels = [1, self.resp_levels]
|
||||||
|
|
||||||
summed_embeddings_task = [ "stt" ]
|
|
||||||
|
|
||||||
x_list = []
|
x_list = []
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
|
@ -1140,6 +1111,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
task_type = "tts"
|
task_type = "tts"
|
||||||
input_prom = None
|
input_prom = None
|
||||||
|
classifier_level = None
|
||||||
dropout_mask = None
|
dropout_mask = None
|
||||||
timestep = None
|
timestep = None
|
||||||
|
|
||||||
|
@ -1147,6 +1119,8 @@ class Base(nn.Module):
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
if name == "dropout_mask":
|
if name == "dropout_mask":
|
||||||
dropout_mask = input
|
dropout_mask = input
|
||||||
|
elif name == "classifier_level":
|
||||||
|
classifier_level = input
|
||||||
|
|
||||||
for name, input in batch_input:
|
for name, input in batch_input:
|
||||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||||
|
@ -1179,8 +1153,9 @@ class Base(nn.Module):
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
embeddings = [ self.resps_emb(
|
embeddings = [ self.resps_emb(
|
||||||
input[:, :l+1],
|
input[:, :l+1],
|
||||||
offset = 0,
|
#offset = 0,
|
||||||
quant_level = l
|
#quant_level = l,
|
||||||
|
name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}',
|
||||||
) for l in range( input.shape[-1] ) ]
|
) for l in range( input.shape[-1] ) ]
|
||||||
|
|
||||||
embedding = _interleave_sequence_reshape( embeddings )
|
embedding = _interleave_sequence_reshape( embeddings )
|
||||||
|
@ -1190,16 +1165,18 @@ class Base(nn.Module):
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
# if masked use masked token, else original token
|
# if masked use masked token, else original token
|
||||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||||
offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
#offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
||||||
quant_level = 0,
|
#quant_level = 0,
|
||||||
|
name = classifier_level,
|
||||||
)
|
)
|
||||||
# cheat-y way to handle performing STT across all levels
|
# cheat-y way to handle performing STT across all levels
|
||||||
elif task_type in summed_embeddings_task:
|
elif task_type in summed_embeddings_task:
|
||||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||||
embedding = sum([ self.resps_emb(
|
embedding = sum([ self.resps_emb(
|
||||||
input[:, :l+1],
|
input[:, :l+1],
|
||||||
offset = 0 if l == 0 else 1, # or maybe set to 1
|
#offset = 0 if l == 0 else 1, # or maybe set to 1
|
||||||
quant_level = l,
|
#quant_level = l,
|
||||||
|
name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}',
|
||||||
sums = False
|
sums = False
|
||||||
) for l in range( input.shape[-1] - 1 ) ])
|
) for l in range( input.shape[-1] - 1 ) ])
|
||||||
else:
|
else:
|
||||||
|
@ -1210,6 +1187,7 @@ class Base(nn.Module):
|
||||||
quant_level
|
quant_level
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
"""
|
||||||
offset = 0
|
offset = 0
|
||||||
if "nar" not in self.capabilities:
|
if "nar" not in self.capabilities:
|
||||||
offset = 0
|
offset = 0
|
||||||
|
@ -1221,6 +1199,13 @@ class Base(nn.Module):
|
||||||
offset = offset,
|
offset = offset,
|
||||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding = self.resps_emb(
|
||||||
|
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||||
|
offset = 0 if classifier_level.startswith("AR:") else 1,
|
||||||
|
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||||
|
)
|
||||||
|
|
||||||
# apply token dropout
|
# apply token dropout
|
||||||
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||||
|
@ -1283,16 +1268,14 @@ class Base(nn.Module):
|
||||||
# there's a better way
|
# there's a better way
|
||||||
if not self.unified_position_ids:
|
if not self.unified_position_ids:
|
||||||
x_list = []
|
x_list = []
|
||||||
|
non_tokens = ["task", "dropout_mask", "classifier_level"]
|
||||||
|
last_input = ["resp", "len"]
|
||||||
|
|
||||||
def get_input_token_length( name, input ):
|
def get_input_token_length( name, input ):
|
||||||
# task token
|
# task token
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
# a mask
|
|
||||||
if name in ["dropout_mask"]:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# list of tokens
|
# list of tokens
|
||||||
if not isinstance(input, torch.Tensor):
|
if not isinstance(input, torch.Tensor):
|
||||||
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] ) + 1
|
||||||
|
@ -1302,12 +1285,12 @@ class Base(nn.Module):
|
||||||
return input.shape[0] * input.shape[1]
|
return input.shape[0] * input.shape[1]
|
||||||
|
|
||||||
# ending input will not have a separator later
|
# ending input will not have a separator later
|
||||||
return input.shape[0] + (0 if name in ["resp", "len"] else 1)
|
return input.shape[0] + (0 if name in last_input else 1)
|
||||||
|
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = torch.cat( [
|
batch = torch.cat( [
|
||||||
torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
|
torch.tensor([*range(get_input_token_length(name, input))], device=device, dtype=torch.int32)
|
||||||
for name, input in batch_input if name != "task"
|
for name, input in batch_input if name not in non_tokens
|
||||||
] )
|
] )
|
||||||
|
|
||||||
delta = ids[batch_index].shape[0] - batch.shape[0]
|
delta = ids[batch_index].shape[0] - batch.shape[0]
|
||||||
|
@ -1325,17 +1308,14 @@ class Base(nn.Module):
|
||||||
inputs: list,
|
inputs: list,
|
||||||
logits,
|
logits,
|
||||||
|
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: list[int] | None = None,
|
||||||
):
|
):
|
||||||
loss = dict(ce = dict())
|
loss = dict(ce = dict())
|
||||||
stats = dict(acc = dict())
|
stats = dict(acc = dict())
|
||||||
|
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
batch_size = len(logits)
|
batch_size = len(logits)
|
||||||
summed_embeddings_task = [ "stt" ]
|
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
|
||||||
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
|
|
||||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
|
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
@ -1369,7 +1349,7 @@ class Base(nn.Module):
|
||||||
if name == "task":
|
if name == "task":
|
||||||
task_type = input
|
task_type = input
|
||||||
task_list.append( input )
|
task_list.append( input )
|
||||||
if task_type in ["len", "stt"]:
|
if task_type in special_tasks:
|
||||||
causal = True
|
causal = True
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||||
|
@ -1423,7 +1403,7 @@ class Base(nn.Module):
|
||||||
loss = dict(
|
loss = dict(
|
||||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||||
)
|
)
|
||||||
stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
|
stats = self.metrics( inputs, targets, classifier_levels ) if self.metrics is not None else dict(
|
||||||
acc = self.accuracy_metric( inputs, target ),
|
acc = self.accuracy_metric( inputs, target ),
|
||||||
# precision = self.precision_metric( inputs, target ),
|
# precision = self.precision_metric( inputs, target ),
|
||||||
)
|
)
|
||||||
|
@ -1432,7 +1412,7 @@ class Base(nn.Module):
|
||||||
loss = dict(
|
loss = dict(
|
||||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||||
)
|
)
|
||||||
stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
|
stats = self.metrics( logits, target_list, self.classifiers.indices( classifier_levels ) ) if self.metrics is not None else dict(
|
||||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1466,7 +1446,7 @@ class Base(nn.Module):
|
||||||
# meta-input, no corresponding token at the moment
|
# meta-input, no corresponding token at the moment
|
||||||
if name == "task":
|
if name == "task":
|
||||||
task_name = input
|
task_name = input
|
||||||
if task_type in ["len", "stt"]:
|
if task_type in special_tasks:
|
||||||
causal = True
|
causal = True
|
||||||
continue
|
continue
|
||||||
# do not use resp as-is
|
# do not use resp as-is
|
||||||
|
@ -1529,7 +1509,7 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||||
if self.metrics is not None:
|
if self.metrics is not None:
|
||||||
metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels )
|
metrics = self.metrics( batch["logits"], batch["targets"], self.classifiers.indices( classifier_levels ) )
|
||||||
stats["acc"][name] = metrics["acc"]
|
stats["acc"][name] = metrics["acc"]
|
||||||
else:
|
else:
|
||||||
stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||||
|
@ -1540,7 +1520,7 @@ class Base(nn.Module):
|
||||||
self,
|
self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
|
|
||||||
quant_levels: int | list[int] | Tensor | None = None,
|
quant_levels: list[int] | None = None,
|
||||||
state: dict | list | None = None,
|
state: dict | list | None = None,
|
||||||
|
|
||||||
layer_skip_variables: dict | None = None,
|
layer_skip_variables: dict | None = None,
|
||||||
|
@ -1549,7 +1529,7 @@ class Base(nn.Module):
|
||||||
output_hidden_states: bool = False,
|
output_hidden_states: bool = False,
|
||||||
):
|
):
|
||||||
# return early if it's "good" enough"
|
# return early if it's "good" enough"
|
||||||
# lambda because we need to capture the classifier_quant_levels and mask
|
# lambda because we need to capture the classifier_levels and mask
|
||||||
exited_layer = self.n_layers
|
exited_layer = self.n_layers
|
||||||
def layer_skip_lambda( layer, logits ):
|
def layer_skip_lambda( layer, logits ):
|
||||||
nonlocal exited_layer
|
nonlocal exited_layer
|
||||||
|
@ -1576,7 +1556,7 @@ class Base(nn.Module):
|
||||||
if self.classifier is not None:
|
if self.classifier is not None:
|
||||||
x = self.classifier(x) # * m
|
x = self.classifier(x) # * m
|
||||||
elif self.classifiers is not None:
|
elif self.classifiers is not None:
|
||||||
logits = self.classifiers(logits, levels = classifier_quant_levels) # * m
|
logits = self.classifiers(logits, levels = classifier_levels) # * m
|
||||||
|
|
||||||
# calculate metrics
|
# calculate metrics
|
||||||
metrics = calculate_entropix_metrics( logits )
|
metrics = calculate_entropix_metrics( logits )
|
||||||
|
@ -1628,10 +1608,7 @@ class Base(nn.Module):
|
||||||
# needs to be done here as we still have our raw inputs
|
# needs to be done here as we still have our raw inputs
|
||||||
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
#position_ids = self.inputs_to_position_ids( inputs, mask=m.squeeze(-1).int() ) if not self.unified_position_ids else None
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
|
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
|
||||||
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
|
|
||||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
|
|
||||||
|
|
||||||
if self.inject_timestep_embedding:
|
if self.inject_timestep_embedding:
|
||||||
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
||||||
|
@ -1664,11 +1641,11 @@ class Base(nn.Module):
|
||||||
# to-do: piece-wise classification, now that there's a head for text
|
# to-do: piece-wise classification, now that there's a head for text
|
||||||
# although again, one single monolithic head would be preferable instead......
|
# although again, one single monolithic head would be preferable instead......
|
||||||
elif self.classifiers is not None:
|
elif self.classifiers is not None:
|
||||||
logits = self.classifiers(logits, levels = classifier_quant_levels) # * m
|
logits = self.classifiers(logits, levels = classifier_levels) # * m
|
||||||
|
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
for i, state in enumerate( hidden_states ):
|
for i, state in enumerate( hidden_states ):
|
||||||
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) # * m
|
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_levels) # * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
||||||
|
@ -1716,7 +1693,7 @@ class Base(nn.Module):
|
||||||
self,
|
self,
|
||||||
logits: list[Tensor], # logit scores
|
logits: list[Tensor], # logit scores
|
||||||
prev_list: list[Tensor] | None = None, # previous tokens
|
prev_list: list[Tensor] | None = None, # previous tokens
|
||||||
quant_levels: int | list[int] | Tensor | None = None, # to-do: derive this from the prev_list
|
quant_levels: list[int] | None = None, # to-do: derive this from the prev_list
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
):
|
):
|
||||||
# yikes
|
# yikes
|
||||||
|
|
Loading…
Reference in New Issue
Block a user