layer skip training implemented (need to gut the inferencing from the repo, and to actually see if the model can benefit from this)

This commit is contained in:
mrq 2024-10-30 20:05:45 -05:00
parent 4049f51ba9
commit a22534e8f4
7 changed files with 251 additions and 39 deletions

View File

@ -255,9 +255,11 @@ class ModelExperimentalSettings:
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token # it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
# RetNet's chunked inferencing might be a better place for this # RetNet's chunked inferencing might be a better place for this
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
# to-to: just incorporate this as a task instead # to-to: just incorporate this as a task instead
layerskip: bool = False
# I really need to clean this up # I really need to clean this up
@dataclass() @dataclass()
class Model: class Model:
@ -870,9 +872,6 @@ class Config(BaseConfig):
def format( self, training=True ): def format( self, training=True ):
print( self.models )
print( self.loras )
if isinstance(self.dataset, type): if isinstance(self.dataset, type):
self.dataset = dict() self.dataset = dict()

View File

@ -43,7 +43,6 @@ def main():
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--model", type=Path, default=None) parser.add_argument("--model", type=Path, default=None)
parser.add_argument("--lora", type=Path, default=None)
parser.add_argument("--demo-dir", type=Path, default=None) parser.add_argument("--demo-dir", type=Path, default=None)
parser.add_argument("--skip-existing", action="store_true") parser.add_argument("--skip-existing", action="store_true")
@ -245,8 +244,8 @@ def main():
metadata = batch["metadata"] metadata = batch["metadata"]
#text = get_random_prompt() if args.random_prompts else metadata["text"] text = get_random_prompt() if args.random_prompts else metadata["text"]
text = get_random_prompt() if i >= (num // 2) else metadata["text"] #text = get_random_prompt() if i >= (num // 2) else metadata["text"]
language = metadata["language"].lower() language = metadata["language"].lower()
prompt = dir / "prompt.wav" prompt = dir / "prompt.wav"

View File

@ -496,7 +496,7 @@ def example_usage():
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
#available_tasks = cfg.dataset.tasks_list #available_tasks = cfg.dataset.tasks_list
available_tasks = ["tts", "stt"] available_tasks = ["tts"] # , "stt"]
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 500 # 150 * len(available_tasks) # * cfg.model.experimental.causal_size steps = 500 # 150 * len(available_tasks) # * cfg.model.experimental.causal_size

View File

@ -30,7 +30,7 @@ except Exception as e:
pass pass
try: try:
from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM
AVAILABLE_ARCHES.append("llama") AVAILABLE_ARCHES.append("llama")
except Exception as e: except Exception as e:
ERROR_ARCHES["llama"] = e ERROR_ARCHES["llama"] = e

View File

@ -3,19 +3,24 @@
import math import math
import torch import torch
import logging import logging
import random
from typing import Literal, overload, Optional, Tuple from typing import Literal, overload, Optional, Tuple, Union, List, Unpack
from torch import Tensor, nn from torch import Tensor, nn
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
AVAILABLE_ATTENTIONS = [] AVAILABLE_ATTENTIONS = []
LN_2 = 0.69314718056
try: try:
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
@ -131,11 +136,7 @@ if AVAILABLE_ATTENTIONS:
class LlamaAttention_Adapted(LlamaAttention): class LlamaAttention_Adapted(LlamaAttention):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if 'mode' in kwargs: self.mode = kwargs.pop("mode", "sdpa")
self.mode = kwargs['mode']
kwargs.pop("mode")
else:
self.mode = "sdpa"
if self.mode == "math": if self.mode == "math":
self.mode = torch.nn.attention.SDPBackend.MATH self.mode = torch.nn.attention.SDPBackend.MATH
@ -302,3 +303,159 @@ class LlamaAttention_Adapted(LlamaAttention):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_scores, past_key_value return attn_output, attn_scores, past_key_value
class LlamaModel_Adapted(LlamaModel):
def __init__(self, *args, **kwargs):
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
super().__init__(*args, **kwargs)
self.layers_n = len(self.layers)
def dropoff_layer( self, l ):
if not self.training:
return False
# this could probably a LUT but I'm not fiending for aggressive mal-optimizations
D = math.exp((l * LN_2) / (self.layers_n - 1)) - 1
P = D * self.layer_dropout_p
return random.random() < P
# to-do: properly implement this per the paper
# this probably is a function of layer number and training step to decide what layer to apply layerskip to for training
def cirriculum( self, l, t=0 ):
return 1 # self.layers_n - 1
def early_exit_loss( self, losses, t=0 ):
return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ])
def normalized_per_layer_loss_scale( self, l, t=0 ):
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / (sum([ self.cirriculum(l, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ]))
def early_exit_factor( self, l ):
if 0 <= l and l < self.layers_n:
return self.early_exit_scale * sum([ i for i in range(0, l) ])
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for l, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
if not self.dropoff_layer( l ):
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding
# yuck, kind of needed # yuck, kind of needed
from ..data import get_task_symmap from ..data import get_task_symmap
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions']) Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict
LossStats = namedtuple('LossStats', ['loss', 'stats'])
""" """
from ..utils.pattern import DelayedPatternProvider, VALLEPattern from ..utils.pattern import DelayedPatternProvider, VALLEPattern
@ -442,6 +443,7 @@ class Base(nn.Module):
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else "" audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
interleave = self.config.experimental.interleave if self.config is not None else False interleave = self.config.experimental.interleave if self.config is not None else False
layerskip = self.config.experimental.layerskip if self.config is not None else False
n_tasks = self.config.tasks if self.config is not None else 8 n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2 n_langs = self.config.langs if self.config is not None else 2
@ -469,6 +471,7 @@ class Base(nn.Module):
self.unified_position_ids = unified_position_ids self.unified_position_ids = unified_position_ids
self.interleave = interleave self.interleave = interleave
self.layerskip = layerskip
self.text_emb = Embedding(n_text_tokens, d_model) self.text_emb = Embedding(n_text_tokens, d_model)
self.langs_emb = None self.langs_emb = None
@ -601,8 +604,9 @@ class Base(nn.Module):
use_reentrant=False use_reentrant=False
)) ))
elif self.arch_type == "llama": elif self.arch_type == "llama":
LlamaClass = LlamaModel_Adapted if self.layerskip else LlamaModel
if n_experts <= 1: if n_experts <= 1:
self.model = LlamaModel(LlamaConfig( self.model = LlamaClass(LlamaConfig(
vocab_size=n_resp_tokens, vocab_size=n_resp_tokens,
hidden_size=d_model, hidden_size=d_model,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
@ -814,12 +818,14 @@ class Base(nn.Module):
state = None, state = None,
output_attentions = False, output_attentions = False,
output_hidden_states = False,
): ):
x = inputs x = inputs
m = mask.squeeze(-1).int() m = mask.squeeze(-1).int()
aux_loss = None aux_loss = None
attentions = None attentions = None
hidden_states = None
# HF transformer derived model # HF transformer derived model
if self.arch_type in ["llama", "mistral", "mixtral"]: if self.arch_type in ["llama", "mistral", "mixtral"]:
@ -830,6 +836,7 @@ class Base(nn.Module):
position_ids=position_ids, position_ids=position_ids,
use_cache=not self.training, use_cache=not self.training,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True, return_dict=True,
) )
if self.n_experts > 1 and self.training: if self.n_experts > 1 and self.training:
@ -846,6 +853,9 @@ class Base(nn.Module):
if output_attentions: if output_attentions:
attentions = output["attentions"] attentions = output["attentions"]
if output_hidden_states:
hidden_states = output["hidden_states"]
if self.n_experts > 1 and self.training: if self.n_experts > 1 and self.training:
router_logits = output["aux_loss"] router_logits = output["aux_loss"]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok ) aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
@ -904,11 +914,19 @@ class Base(nn.Module):
x = x[0] x = x[0]
# process it into a format that I like
if output_hidden_states:
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i in range( self.n_layers ) ]
# output projection layer with masking # output projection layer with masking
if self.classifier is not None: if self.classifier is not None:
x = self.classifier(x) * mask x = self.classifier(x) * mask
return Logits(x, state, aux_loss, attentions) if output.hidden_states:
for i in range( self.n_layers ):
hidden_states[i] = self.classifier(hidden_states[i]) * m
return Logits(x, state, aux_loss, attentions, hidden_states)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation # takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs( def inputs(
@ -1217,6 +1235,9 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
): ):
loss = dict(ce = dict())
stats = dict(acc = dict())
device = logits[0].device device = logits[0].device
special_tasks = [ "len", "stt" ] special_tasks = [ "len", "stt" ]
summed_embeddings_task = [ "stt" ] summed_embeddings_task = [ "stt" ]
@ -1285,23 +1306,22 @@ class Base(nn.Module):
if False: if False:
target = torch.cat( target_list ) target = torch.cat( target_list )
inputs = torch.cat( logits ) inputs = torch.cat( logits )
self.loss = dict( loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index ) nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
) )
self.stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict( stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
acc = self.accuracy_metric( inputs, target ), acc = self.accuracy_metric( inputs, target ),
# precision = self.precision_metric( inputs, target ), # precision = self.precision_metric( inputs, target ),
) )
else: else:
self.loss = dict( loss = dict(
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
) )
self.stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict( stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
) )
return return LossStats(loss, stats)
""" """
# considerations: # considerations:
@ -1311,9 +1331,6 @@ class Base(nn.Module):
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly # + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes) # + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
""" """
self.loss = dict()
self.stats = dict(acc = dict())
info = {} info = {}
batch_size = len( inputs ) batch_size = len( inputs )
@ -1385,17 +1402,19 @@ class Base(nn.Module):
if False: if False:
targets = torch.cat( batch["targets"] ).long() targets = torch.cat( batch["targets"] ).long()
inputs = torch.cat( batch["logits"] ) inputs = torch.cat( batch["logits"] )
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
self.stats["acc"][name] = self.accuracy_metric( inputs, targets ) stats["acc"][name] = self.accuracy_metric( inputs, targets )
# probably consumes less memory due to not having to allocate memory # probably consumes less memory due to not having to allocate memory
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed) # this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
else: else:
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
if self.metrics is not None: if self.metrics is not None:
metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels ) metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels )
self.stats["acc"][name] = metrics["acc"] stats["acc"][name] = metrics["acc"]
else: else:
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
return LossStats(loss, stats)
def forward( def forward(
self, self,
@ -1404,6 +1423,7 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None, state: dict | list | None = None,
output_attentions = False, output_attentions = False,
output_hidden_states = False,
): ):
x_list = self.inputs_to_embeddings( inputs, quant_levels ) x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
@ -1412,11 +1432,13 @@ class Base(nn.Module):
device = x.device device = x.device
batch_size = len(x_list) batch_size = len(x_list)
# pure AR # pure AR
if quant_levels is None: if quant_levels is None:
quant_levels = [ 0 for _ in range(batch_size) ] quant_levels = [ 0 for _ in range(batch_size) ]
if self.layerskip:
output_hidden_states = True
# pad our input and mask, but retain the original length by doing it after # pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0: if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input # pad input
@ -1440,9 +1462,11 @@ class Base(nn.Module):
state=state, state=state,
position_ids=position_ids, position_ids=position_ids,
output_attentions = output_attentions, output_attentions = output_attentions,
output_hidden_states = output_hidden_states,
) )
logits = output.logits logits = output.logits
hidden_states = output.hidden_states
# to-do: piece-wise classification, now that there's a head for text # to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead...... # although again, one single monolithic head would be preferable instead......
@ -1451,19 +1475,52 @@ class Base(nn.Module):
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ] classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
logits = self.classifiers(logits, levels = classifier_quant_levels) * m logits = self.classifiers(logits, levels = classifier_quant_levels) * m
if hidden_states is not None:
for i in range( self.n_layers ):
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m
# Remove padding # Remove padding
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ] logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
# compute loss if the target is given # compute loss if the target is given
if training: if training:
self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) if not self.layerskip:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
else:
self.loss = {}
self.stats = {}
for i in range( self.n_layers ):
# remove padding
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
for k, v in loss.items():
if k not in self.loss:
self.loss[k] = []
self.loss[k].append( v )
for k, v in stats.items():
if k not in self.stats:
self.stats[k] = []
self.stats[k].append( v )
for k, v in self.loss.items():
self.loss[k] = self.model.early_exit_loss( losses=v )
for k, v in self.stats.items():
self.stats[k] = sum( v ) / len( v )
# include any additional losses (for example: MoE router) # include any additional losses (for example: MoE router)
if output.aux_loss is not None: if output.aux_loss is not None:
self.loss["aux_loss"] = output.aux_loss loss["aux_loss"] = output.aux_loss
self.loss = loss
self.stats = stats
# rewrap, because we're modifying the logits here # rewrap, because we're modifying the logits here
return Logits(logits, output.state, output.aux_loss, output.attentions) return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states)
def sample( def sample(
self, self,

View File

@ -60,7 +60,7 @@ class NAR(Base):
# is training # is training
if resps_list is not None: if resps_list is not None:
p_len_task = self.config.experimental.p_len_train if self.config is not None else 0.05 len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05
n_levels_set = {r.shape[-1] for r in resps_list} n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set)) n_levels = next(iter(n_levels_set))
@ -69,7 +69,7 @@ class NAR(Base):
# to-do: make this YAML configurable # to-do: make this YAML configurable
def sample_task(): def sample_task():
return "len" if random.random() < p_len_task else "tts" return "len" if random.random() < len_train_p else "tts"
# generate task list to train against # generate task list to train against
task_list = [ sample_task() for _ in range(batch_size) ] task_list = [ sample_task() for _ in range(batch_size) ]