this seems to work in testing

This commit is contained in:
mrq 2025-02-12 16:16:04 -06:00
parent e029a8804d
commit b52c5c5d80
4 changed files with 190 additions and 227 deletions

Binary file not shown.

View File

@ -397,6 +397,7 @@ def load_engines(training=True, **model_kwargs):
key_name = cfg.lora.full_name
kwargs['name'] = 'job'
kwargs['resume'] = 'allow'
if world_size() > 1:
kwargs["group"] = "DDP"
kwargs['name'] = f'job-{global_rank()}'

View File

@ -134,13 +134,11 @@ class AR_NAR(Base):
# trim resps to only contain all levels below the target level
if self.version < 7:
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
elif not self.parallel_decoding:
resps_list = [r if t in text_task else r[..., l] for r, l, t in zip(resps_list, quant_levels, task_list)]
# tensor to cat for RVQ level 0
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token] * (1 if self.version < 7 else self.n_resp_levels)], device=device, dtype=torch.int16)
# final validations and stuff
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
@ -173,7 +171,7 @@ class AR_NAR(Base):
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0 and timesteps[i] is None and not self.parallel_decoding:
if quant_level <= 0 and timesteps[i] is None:
# append stop tokens for AR
if task not in text_task:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
@ -1103,54 +1101,56 @@ class AR_NAR(Base):
# is NAR
if (len_list is not None or resps_list is not None) and text_list is not None:
if self.version >= 7:
if self.parallel_decoding:
return self.forward_nar_masked_parallel(
task_list=task_list,
return self.forward_nar_masked_parallel(
task_list=task_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
else:
resps_lists = [ None for _ in range(batch_size) ]
for level in range(self.n_resp_levels):
resp_list = self.forward_nar_masked(
task_list=task_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
**sampling_kwargs,
)
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
# NAR demasking for all levels
"""
resps_lists = [ None for _ in range(batch_size) ]
for level in range(self.n_resp_levels):
resp_list = self.forward_nar_masked(
task_list=task_list,
disable_tqdm=disable_tqdm,
use_lora=use_lora,
quant_levels=[ level for _ in range(batch_size) ],
**sampling_kwargs,
)
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
raw_text_list=raw_text_list,
for batch_index, resp in enumerate(resp_list):
if resps_lists[batch_index] is None:
resps_lists[batch_index] = []
resps_lists[batch_index].append( resp )
disable_tqdm=disable_tqdm,
use_lora=use_lora,
quant_levels=[ level for _ in range(batch_size) ],
**sampling_kwargs,
)
for batch_index, resps in enumerate(resps_lists):
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
for batch_index, resp in enumerate(resp_list):
if resps_lists[batch_index] is None:
resps_lists[batch_index] = []
resps_lists[batch_index].append( resp )
return resps_lists
for batch_index, resps in enumerate(resps_lists):
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
return resps_lists
"""
return self.forward_nar(
task_list=task_list,
@ -1254,7 +1254,7 @@ def example_usage():
available_tasks = ["tts-nar"]
model = AR_NAR(**kwargs).to(cfg.device)
steps = 500 // batch_size
steps = 750 // batch_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""

View File

@ -245,21 +245,6 @@ class AudioEmbedding(nn.Module):
return x
class AudioEmbedding_Sums(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
):
super().__init__()
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
def forward(self, xi: Tensor ) -> Tensor:
x = sum( [ emb( xi[:, l] ) for l, emb in enumerate(self.embeddings) ] )
return x
# time-step embedding
# for the NAR-len, since it probably most likely requires encoding the timestep
class TimeEmbedding(nn.Module):
@ -318,10 +303,6 @@ class Classifiers(nn.Module):
levels = []
# map names to levels
"""
if names and not levels:
levels = [ None if name =="NAR" else self.names.index(name) for name in names ]
"""
if names and not levels:
levels = [ None if name not in self.names else self.names.index(name) for name in names ]
@ -341,9 +322,36 @@ class Classifiers(nn.Module):
]
return torch.stack( xi )
# naively embeds each level of a codebook, then merges the embeddings with a Linear
class AudioEncoder(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
):
super().__init__()
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
if dropout_mask is not None:
xi = xi.clone().detach().t()
for l, t in enumerate( xi ):
xi[l] = torch.where( dropout_mask, dropout_token, xi[l] )
xi = xi.t()
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
x = self.proj(x)
"""
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
"""
return x
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
class ParallelDecoder(nn.Module):
class AudioDecoder(nn.Module):
def __init__(
self,
levels,
@ -356,60 +364,39 @@ class ParallelDecoder(nn.Module):
attention_backend = config_kwargs.pop("attention_backend", "default")
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
config_kwargs["hidden_size"] *= levels
config_kwargs["vocab_size"] *= levels
hidden_size = config_kwargs.get("hidden_size")
vocab_size = config_kwargs.get("vocab_size")
#self.d_model = d_model
self.vocab_size = vocab_size
self.up = nn.Linear( d_model, hidden_size )
self.down = nn.Linear( hidden_size, vocab_size )
self.transformer = None
"""
self.transformer = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
self.transformer = ml.replace_attention( self.transformer, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
if hasattr( self.transformer, "embeddings" ):
del self.transformer.embeddings
downs = []
modules = []
ups = []
for level in range(levels):
module = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
module = ml.replace_attention( module, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
if hasattr( module, "embeddings" ):
del module.embeddings
if gradient_checkpointing and not module.gradient_checkpointing:
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
modules.append(module)
downs.append(nn.Linear(d_model, hidden_size, bias=False))
ups.append(nn.Linear(hidden_size, vocab_size, bias=False))
self.levels = levels
self.decoders = nn.ModuleList(modules)
self.downs = nn.ModuleList(downs)
self.ups = nn.ModuleList(ups)
if gradient_checkpointing and not self.transformer.gradient_checkpointing:
self.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
"""
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
# split into levels
if level == None:
x = [ self.forward( x, l, **kwargs ) for l in range(self.levels) ]
x = torch.stack( x )
x = x.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
return x
x = self.up( x )
if self.transformer is not None:
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
x = self.down( x )
# do one level
# attention + feedforward
"""
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
# this really hates an output head, so just treat the final output as one
x = x[..., :self.vocab_size]
"""
# downscale to head's dimensionality
x = self.downs[level]( x )
# attention + feed forward
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
# upscale to vocab logits
x = self.ups[level]( x )
batch_size, seq_len, dim = x.shape
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
x = x.permute( 0, 2, 1, 3 )
return x
"""
@ -572,7 +559,6 @@ class Base(nn.Module):
self.causal = "ar" in self.capabilities or "len" in self.capabilities
self.version = self.config.version if self.config is not None else 5
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
self.parallel_decoding = self.config.experimental.parallel_decoding if self.config is not None else False
self.arch_type = self.config.arch_type if self.config is not None else "llama"
@ -634,22 +620,10 @@ class Base(nn.Module):
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
else:
if self.parallel_decoding:
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
else:
"""
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens * self.n_resp_levels]
l_embedding_names = ["NAR"]
l_classifier_tokens = [n_audio_tokens * self.n_resp_levels]
"""
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_classifier_tokens = [n_audio_tokens] * self.n_resp_levels
l_embedding_names = [ f'NAR:{i}:{i}' for i in range( self.n_resp_levels ) ]
n_resp_tokens = n_audio_tokens + 1
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
@ -692,6 +666,7 @@ class Base(nn.Module):
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
self.audio_emb = None
elif self.version < 5:
# [1024] * 8
self.proms_emb = AudioEmbedding_Old(
@ -703,7 +678,8 @@ class Base(nn.Module):
l_embedding_tokens, d_model,
levels=self.n_resp_levels if self.version > 3 else None,
)
elif not self.parallel_decoding:
self.audio_emb = None
elif self.version < 7:
self.proms_emb = AudioEmbedding(
[n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
@ -713,17 +689,11 @@ class Base(nn.Module):
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
l_embedding_names=l_embedding_names,
)
self.audio_emb = None
else:
self.proms_emb = AudioEmbedding_Sums(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.resps_emb = AudioEmbedding_Sums(
n_tokens=n_audio_tokens + 1,
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.proms_emb = None
self.resps_emb = None
self.audio_emb = None
if self.version >= 3:
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
@ -739,7 +709,7 @@ class Base(nn.Module):
# this ***might*** let me also unify the proms_emb and resps_embedding
if self.version >= 5:
# "len" RVQ level-0 gets an additional token
if self.version < 7 or not self.parallel_decoding:
if self.version < 7:
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
# experimental NAR-only mode
@ -747,6 +717,53 @@ class Base(nn.Module):
if self.version >= 6:
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
if self.version >= 7:
pd_model = d_model // 4
pd_ffn = pd_model * 4
pd_heads = n_heads // 4
pd_layers = 1
if False:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 1, # masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 1, # masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
)
self.audio_decoder = AudioDecoder(
self.n_resp_levels,
d_model,
dict(
vocab_size=n_audio_tokens,
hidden_size=pd_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=pd_ffn,
num_hidden_layers=pd_layers,
num_attention_heads=pd_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=pd_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation="eager",
training=self.training,
attention_backend=attention_backend,
gradient_checkpointing=self.gradient_checkpointing,
)
)
if attention_backend == "auto":
attention_backend = "sdpa"
"""
@ -906,33 +923,6 @@ class Base(nn.Module):
self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias )
self.metrics = Metrics( l_classifier_tokens )
self.parallel_decoder = None
if self.parallel_decoding:
pd_model = d_model # // 2
pd_ffn = pd_model * 2
pd_heads = n_heads // 2
pd_layers = 1
config = dict(
vocab_size=n_audio_tokens,
hidden_size=pd_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=pd_ffn,
num_hidden_layers=pd_layers,
num_attention_heads=pd_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=pd_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
attn_implementation="eager",
training=self.training,
attention_backend=attention_backend,
gradient_checkpointing=self.gradient_checkpointing,
)
self.parallel_decoder = ParallelDecoder( self.n_resp_levels, d_model, config )
def _forward(
self,
inputs,
@ -1126,8 +1116,8 @@ class Base(nn.Module):
inputs[i].append( ( "resp", resps_list[i] ) )
if self.version >= 7:
classifier_level = f"NAR:{quant_level}:{quant_level}" if not self.parallel_decoding else "NAR"
classifier_level = f"NAR:{quant_level}:{quant_level}"
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
@ -1269,29 +1259,16 @@ class Base(nn.Module):
input if quant_level == 0 else input[:, :quant_level]
)
if self.version < 7: # or not self.parallel_decoding:
if self.version < 7:
return self.proms_emb(
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
offset = 0,
)
if not self.parallel_decoding:
"""
# provides only one
return self.proms_emb(
input if input.dim() == 1 else input[:, quant_level],
quant_level = 0, # if input.dim() == 1 else input.shape[-1],
offset = 0,
)
"""
# sums all
return self.proms_emb(
input,
quant_level = quant_level if input.dim() == 1 else input.shape[-1],
offset = 0,
)
if self.audio_emb is not None:
return self.audio_emb( input )
return self.proms_emb( input )
# yuck
@ -1358,11 +1335,11 @@ class Base(nn.Module):
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
if self.parallel_decoding:
if dropout_mask is not None:
embedding = self.resps_emb( torch.where( dropout_mask, self.stop_token, input.t() ).t() )
if self.version >= 7:
if self.audio_emb is not None:
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
else:
embedding = self.resps_emb( input )
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
# if training NAR-len RVQ level 0
elif dropout_mask is not None:
embedding = self.resps_emb(
@ -1513,7 +1490,7 @@ class Base(nn.Module):
return ids.to(device=device, dtype=torch.int32)
def calc_loss_parallel(
def calc_loss_new(
self,
inputs: list,
logits,
@ -1589,6 +1566,9 @@ class Base(nn.Module):
if name != task_outputs.get(task_type, name):
if self.ignore_inputs_for_loss:
ignored = True
# cringe
if task_type != "tts":
ignored = True
else:
output_len = seq_len
@ -1602,7 +1582,7 @@ class Base(nn.Module):
# perform loss calculation on the individual piece
target.append( token )
if classifier_level != "NAR":
if logits[batch_index].dim() != 3:
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
logit = logits[batch_index]
@ -1620,7 +1600,7 @@ class Base(nn.Module):
if compute_acc and False:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
@ -1652,7 +1632,7 @@ class Base(nn.Module):
if compute_acc and False:
if self.metrics is not None:
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
@ -1701,9 +1681,6 @@ class Base(nn.Module):
if self.version < 7:
return input if input.dim() == 1 else input[:, quant_level]
if not self.parallel_decoding:
return input if input.dim() == 1 else input[:, quant_level]
return input
for batch_index, batch in enumerate(inputs):
@ -1729,8 +1706,6 @@ class Base(nn.Module):
# nonautoregressive, parallel
elif classifier_level.startswith("NAR:"):
causal = False
elif classifier_level == "NAR":
causal = False
it = 0
for name, input in batch:
@ -1773,6 +1748,9 @@ class Base(nn.Module):
if name != task_outputs.get(task_type, name):
if self.ignore_inputs_for_loss:
ignored = True
# cringe
if task_type != "tts":
ignored = True
else:
output_len = seq_len
@ -1909,10 +1887,10 @@ class Base(nn.Module):
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )
casual_levels = [ "AR:0:0", "stt", "len", "phn" ]
causal_levels = [ "AR:0:0", "stt", "len", "phn" ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
output = self._forward(
inputs=x,
@ -1928,26 +1906,19 @@ class Base(nn.Module):
logits = [ logit for logit in logits ]
if self.version >= 7 and self.parallel_decoding:
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
if self.version >= 7:
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] not in causal_levels ]
if p_indices:
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
p_logits = self.parallel_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ]
p_logits = self.audio_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
for i, logit in enumerate(p_logits):
logits[p_indices[i]] = logit
"""
logits = [ self.parallel_decoder( logit.unsqueeze(0), attention_mask=mask,
position_ids=position_ids,
use_cache=False,
return_dict=True,
is_causal=is_causal )[0] if level == "NAR" else logit for logit, level in zip(logits, classifier_levels) ]
"""
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
@ -1958,15 +1929,6 @@ class Base(nn.Module):
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels )
# Reshape
"""
if self.version >= 7 and not self.parallel_decoding:
for batch_index, logit in enumerate( logits ):
if classifier_levels[batch_index] != "NAR":
continue
logits[batch_index] = logit.reshape( logit.shape[0], 8, 1000 ).permute( 1, 0, 2 )
"""
# Remove padding
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
@ -1977,8 +1939,8 @@ class Base(nn.Module):
self.loss = None
self.stats = None
# compute loss if the target is given
elif self.version >= 7 and self.parallel_decoding:
loss, stats = self.calc_loss_parallel( inputs=inputs, logits=logits )
elif self.version >= 7:
loss, stats = self.calc_loss_new( inputs=inputs, logits=logits )
# include any additional losses (for example: MoE router)
if output.loss is not None: