layer skip training implemented (need to gut the inferencing from the repo, and to actually see if the model can benefit from this)
This commit is contained in:
parent
4049f51ba9
commit
a22534e8f4
|
@ -255,8 +255,10 @@ class ModelExperimentalSettings:
|
|||
# it just seems like a bitch to try and train something worthwhile with it, since there's crackles every other token
|
||||
# RetNet's chunked inferencing might be a better place for this
|
||||
|
||||
p_len_train: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||
len_train_p: float = 0.05 # odds of injecting a "len" task within the model for NAR-len
|
||||
# to-to: just incorporate this as a task instead
|
||||
|
||||
layerskip: bool = False
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
|
@ -870,9 +872,6 @@ class Config(BaseConfig):
|
|||
|
||||
|
||||
def format( self, training=True ):
|
||||
print( self.models )
|
||||
print( self.loras )
|
||||
|
||||
if isinstance(self.dataset, type):
|
||||
self.dataset = dict()
|
||||
|
||||
|
|
|
@ -43,7 +43,6 @@ def main():
|
|||
|
||||
parser.add_argument("--yaml", type=Path, default=None)
|
||||
parser.add_argument("--model", type=Path, default=None)
|
||||
parser.add_argument("--lora", type=Path, default=None)
|
||||
|
||||
parser.add_argument("--demo-dir", type=Path, default=None)
|
||||
parser.add_argument("--skip-existing", action="store_true")
|
||||
|
@ -245,8 +244,8 @@ def main():
|
|||
|
||||
metadata = batch["metadata"]
|
||||
|
||||
#text = get_random_prompt() if args.random_prompts else metadata["text"]
|
||||
text = get_random_prompt() if i >= (num // 2) else metadata["text"]
|
||||
text = get_random_prompt() if args.random_prompts else metadata["text"]
|
||||
#text = get_random_prompt() if i >= (num // 2) else metadata["text"]
|
||||
language = metadata["language"].lower()
|
||||
|
||||
prompt = dir / "prompt.wav"
|
||||
|
|
|
@ -496,7 +496,7 @@ def example_usage():
|
|||
|
||||
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
|
||||
#available_tasks = cfg.dataset.tasks_list
|
||||
available_tasks = ["tts", "stt"]
|
||||
available_tasks = ["tts"] # , "stt"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 500 # 150 * len(available_tasks) # * cfg.model.experimental.causal_size
|
||||
|
|
|
@ -30,7 +30,7 @@ except Exception as e:
|
|||
pass
|
||||
|
||||
try:
|
||||
from .llama import LlamaModel, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM
|
||||
from .llama import LlamaModel, LlamaModel_Adapted, LlamaConfig, AVAILABLE_ATTENTIONS, LlamaAttention, LlamaAttention_Adapted, LlamaForCausalLM
|
||||
AVAILABLE_ARCHES.append("llama")
|
||||
except Exception as e:
|
||||
ERROR_ARCHES["llama"] = e
|
||||
|
|
|
@ -3,19 +3,24 @@
|
|||
import math
|
||||
import torch
|
||||
import logging
|
||||
import random
|
||||
|
||||
from typing import Literal, overload, Optional, Tuple
|
||||
from typing import Literal, overload, Optional, Tuple, Union, List, Unpack
|
||||
|
||||
from torch import Tensor, nn
|
||||
from transformers.cache_utils import Cache
|
||||
|
||||
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.cache_utils import Cache, DynamicCache, StaticCache
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
AVAILABLE_ATTENTIONS = []
|
||||
|
||||
LN_2 = 0.69314718056
|
||||
|
||||
try:
|
||||
from transformers.utils import is_flash_attn_2_available
|
||||
|
||||
|
@ -131,11 +136,7 @@ if AVAILABLE_ATTENTIONS:
|
|||
|
||||
class LlamaAttention_Adapted(LlamaAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if 'mode' in kwargs:
|
||||
self.mode = kwargs['mode']
|
||||
kwargs.pop("mode")
|
||||
else:
|
||||
self.mode = "sdpa"
|
||||
self.mode = kwargs.pop("mode", "sdpa")
|
||||
|
||||
if self.mode == "math":
|
||||
self.mode = torch.nn.attention.SDPBackend.MATH
|
||||
|
@ -301,4 +302,160 @@ class LlamaAttention_Adapted(LlamaAttention):
|
|||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_scores, past_key_value
|
||||
return attn_output, attn_scores, past_key_value
|
||||
|
||||
class LlamaModel_Adapted(LlamaModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.layer_dropout_p = kwargs.pop("layer_dropout_p", 0.1)
|
||||
self.early_exit_scale = kwargs.pop("early_exit_scale", 0.1)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.layers_n = len(self.layers)
|
||||
def dropoff_layer( self, l ):
|
||||
if not self.training:
|
||||
return False
|
||||
|
||||
# this could probably a LUT but I'm not fiending for aggressive mal-optimizations
|
||||
D = math.exp((l * LN_2) / (self.layers_n - 1)) - 1
|
||||
P = D * self.layer_dropout_p
|
||||
return random.random() < P
|
||||
|
||||
# to-do: properly implement this per the paper
|
||||
# this probably is a function of layer number and training step to decide what layer to apply layerskip to for training
|
||||
def cirriculum( self, l, t=0 ):
|
||||
return 1 # self.layers_n - 1
|
||||
|
||||
def early_exit_loss( self, losses, t=0 ):
|
||||
return sum([ self.normalized_per_layer_loss_scale( l, t ) * losses[l] for l in range(0, self.layers_n) ])
|
||||
|
||||
def normalized_per_layer_loss_scale( self, l, t=0 ):
|
||||
return (self.cirriculum(l, t) * self.early_exit_factor( l )) / (sum([ self.cirriculum(l, t) * self.early_exit_factor( i ) for i in range(0, self.layers_n) ]))
|
||||
|
||||
def early_exit_factor( self, l ):
|
||||
if 0 <= l and l < self.layers_n:
|
||||
return self.early_exit_scale * sum([ i for i in range(0, l) ])
|
||||
return self.layers_n - 1 + self.early_exit_scale * sum([ i for i in range(0, self.layers_n - 1) ])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
)
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
)
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for l, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
if not self.dropoff_layer( l ):
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
|
@ -38,8 +38,9 @@ from ..emb.qnt import encode_as_embedding
|
|||
# yuck, kind of needed
|
||||
from ..data import get_task_symmap
|
||||
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions'])
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states'])
|
||||
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy']) # these seem more elegant than a dict
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
"""
|
||||
from ..utils.pattern import DelayedPatternProvider, VALLEPattern
|
||||
|
@ -442,6 +443,7 @@ class Base(nn.Module):
|
|||
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
|
||||
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
|
||||
interleave = self.config.experimental.interleave if self.config is not None else False
|
||||
layerskip = self.config.experimental.layerskip if self.config is not None else False
|
||||
|
||||
n_tasks = self.config.tasks if self.config is not None else 8
|
||||
n_langs = self.config.langs if self.config is not None else 2
|
||||
|
@ -469,6 +471,7 @@ class Base(nn.Module):
|
|||
|
||||
self.unified_position_ids = unified_position_ids
|
||||
self.interleave = interleave
|
||||
self.layerskip = layerskip
|
||||
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
|
@ -601,8 +604,9 @@ class Base(nn.Module):
|
|||
use_reentrant=False
|
||||
))
|
||||
elif self.arch_type == "llama":
|
||||
LlamaClass = LlamaModel_Adapted if self.layerskip else LlamaModel
|
||||
if n_experts <= 1:
|
||||
self.model = LlamaModel(LlamaConfig(
|
||||
self.model = LlamaClass(LlamaConfig(
|
||||
vocab_size=n_resp_tokens,
|
||||
hidden_size=d_model,
|
||||
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
|
||||
|
@ -814,12 +818,14 @@ class Base(nn.Module):
|
|||
|
||||
state = None,
|
||||
output_attentions = False,
|
||||
output_hidden_states = False,
|
||||
):
|
||||
x = inputs
|
||||
m = mask.squeeze(-1).int()
|
||||
|
||||
aux_loss = None
|
||||
attentions = None
|
||||
hidden_states = None
|
||||
|
||||
# HF transformer derived model
|
||||
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||
|
@ -830,6 +836,7 @@ class Base(nn.Module):
|
|||
position_ids=position_ids,
|
||||
use_cache=not self.training,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
)
|
||||
if self.n_experts > 1 and self.training:
|
||||
|
@ -846,6 +853,9 @@ class Base(nn.Module):
|
|||
if output_attentions:
|
||||
attentions = output["attentions"]
|
||||
|
||||
if output_hidden_states:
|
||||
hidden_states = output["hidden_states"]
|
||||
|
||||
if self.n_experts > 1 and self.training:
|
||||
router_logits = output["aux_loss"]
|
||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
||||
|
@ -904,11 +914,19 @@ class Base(nn.Module):
|
|||
|
||||
x = x[0]
|
||||
|
||||
# process it into a format that I like
|
||||
if output_hidden_states:
|
||||
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i in range( self.n_layers ) ]
|
||||
|
||||
# output projection layer with masking
|
||||
if self.classifier is not None:
|
||||
x = self.classifier(x) * mask
|
||||
|
||||
if output.hidden_states:
|
||||
for i in range( self.n_layers ):
|
||||
hidden_states[i] = self.classifier(hidden_states[i]) * m
|
||||
|
||||
return Logits(x, state, aux_loss, attentions)
|
||||
return Logits(x, state, aux_loss, attentions, hidden_states)
|
||||
|
||||
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
|
||||
def inputs(
|
||||
|
@ -1217,6 +1235,9 @@ class Base(nn.Module):
|
|||
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
):
|
||||
loss = dict(ce = dict())
|
||||
stats = dict(acc = dict())
|
||||
|
||||
device = logits[0].device
|
||||
special_tasks = [ "len", "stt" ]
|
||||
summed_embeddings_task = [ "stt" ]
|
||||
|
@ -1285,23 +1306,22 @@ class Base(nn.Module):
|
|||
if False:
|
||||
target = torch.cat( target_list )
|
||||
inputs = torch.cat( logits )
|
||||
self.loss = dict(
|
||||
# "nll" was in the original implementation and should actually just be called something else
|
||||
loss = dict(
|
||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||
)
|
||||
self.stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
|
||||
stats = self.metrics( inputs, targets, classifier_quant_levels ) if self.metrics is not None else dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
else:
|
||||
self.loss = dict(
|
||||
loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size
|
||||
)
|
||||
self.stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
|
||||
stats = self.metrics( logits, target_list, classifier_quant_levels ) if self.metrics is not None else dict(
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size
|
||||
)
|
||||
|
||||
return
|
||||
return LossStats(loss, stats)
|
||||
|
||||
"""
|
||||
# considerations:
|
||||
|
@ -1311,9 +1331,6 @@ class Base(nn.Module):
|
|||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||
"""
|
||||
self.loss = dict()
|
||||
self.stats = dict(acc = dict())
|
||||
|
||||
info = {}
|
||||
batch_size = len( inputs )
|
||||
|
||||
|
@ -1385,17 +1402,19 @@ class Base(nn.Module):
|
|||
if False:
|
||||
targets = torch.cat( batch["targets"] ).long()
|
||||
inputs = torch.cat( batch["logits"] )
|
||||
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
||||
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
||||
loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
||||
stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
||||
# probably consumes less memory due to not having to allocate memory
|
||||
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
||||
else:
|
||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||
loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / batch_size
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics( batch["logits"], batch["targets"], classifier_quant_levels )
|
||||
self.stats["acc"][name] = metrics["acc"]
|
||||
stats["acc"][name] = metrics["acc"]
|
||||
else:
|
||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||
stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / batch_size
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -1404,6 +1423,7 @@ class Base(nn.Module):
|
|||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
state: dict | list | None = None,
|
||||
output_attentions = False,
|
||||
output_hidden_states = False,
|
||||
):
|
||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
@ -1412,10 +1432,12 @@ class Base(nn.Module):
|
|||
device = x.device
|
||||
batch_size = len(x_list)
|
||||
|
||||
|
||||
# pure AR
|
||||
if quant_levels is None:
|
||||
quant_levels = [ 0 for _ in range(batch_size) ]
|
||||
|
||||
if self.layerskip:
|
||||
output_hidden_states = True
|
||||
|
||||
# pad our input and mask, but retain the original length by doing it after
|
||||
if self.l_padding and x.shape[1] % self.l_padding != 0:
|
||||
|
@ -1440,9 +1462,11 @@ class Base(nn.Module):
|
|||
state=state,
|
||||
position_ids=position_ids,
|
||||
output_attentions = output_attentions,
|
||||
output_hidden_states = output_hidden_states,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
# to-do: piece-wise classification, now that there's a head for text
|
||||
# although again, one single monolithic head would be preferable instead......
|
||||
|
@ -1451,19 +1475,52 @@ class Base(nn.Module):
|
|||
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
|
||||
logits = self.classifiers(logits, levels = classifier_quant_levels) * m
|
||||
|
||||
if hidden_states is not None:
|
||||
for i in range( self.n_layers ):
|
||||
hidden_states[i] = self.classifiers(hidden_states[i], levels = classifier_quant_levels) * m
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[:li] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
||||
# compute loss if the target is given
|
||||
if training:
|
||||
self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
if not self.layerskip:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
else:
|
||||
self.loss = {}
|
||||
self.stats = {}
|
||||
|
||||
for i in range( self.n_layers ):
|
||||
# remove padding
|
||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=hidden_states[i], quant_levels=quant_levels )
|
||||
|
||||
for k, v in loss.items():
|
||||
if k not in self.loss:
|
||||
self.loss[k] = []
|
||||
self.loss[k].append( v )
|
||||
|
||||
for k, v in stats.items():
|
||||
if k not in self.stats:
|
||||
self.stats[k] = []
|
||||
self.stats[k].append( v )
|
||||
|
||||
for k, v in self.loss.items():
|
||||
self.loss[k] = self.model.early_exit_loss( losses=v )
|
||||
|
||||
for k, v in self.stats.items():
|
||||
self.stats[k] = sum( v ) / len( v )
|
||||
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.aux_loss is not None:
|
||||
self.loss["aux_loss"] = output.aux_loss
|
||||
loss["aux_loss"] = output.aux_loss
|
||||
|
||||
self.loss = loss
|
||||
self.stats = stats
|
||||
|
||||
# rewrap, because we're modifying the logits here
|
||||
return Logits(logits, output.state, output.aux_loss, output.attentions)
|
||||
return Logits(logits, output.state, output.aux_loss, output.attentions, hidden_states)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
|
|
|
@ -60,7 +60,7 @@ class NAR(Base):
|
|||
|
||||
# is training
|
||||
if resps_list is not None:
|
||||
p_len_task = self.config.experimental.p_len_train if self.config is not None else 0.05
|
||||
len_train_p = self.config.experimental.len_train_p if self.config is not None else 0.05
|
||||
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
@ -69,7 +69,7 @@ class NAR(Base):
|
|||
|
||||
# to-do: make this YAML configurable
|
||||
def sample_task():
|
||||
return "len" if random.random() < p_len_task else "tts"
|
||||
return "len" if random.random() < len_train_p else "tts"
|
||||
|
||||
# generate task list to train against
|
||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
|
Loading…
Reference in New Issue
Block a user