a gorillionth time's the charm (aka: the encoder/decoder pill is a tough pill to swallow)

This commit is contained in:
mrq 2025-02-28 17:56:50 -06:00
parent 09d82a26fe
commit a174c33db6
8 changed files with 224 additions and 122 deletions

View File

@ -1,5 +1,8 @@
# Model Notes
> [!NOTE]
> Most of this information is outdated due to slightly wrong assumptions
The underlying model is a robust transformer, where:
* inputs are passed through an embedding
* the embedded inputs are then passed through each layer of the transformer (or other model type)
@ -12,9 +15,6 @@ The inputs are automatically sequenced in a way that a given task requires, and
While the original paper called for a separate AR model and a NAR model, by treating the AR and the NAR as unique tasks, you can actually train a unified model (`AR+NAR`) for effectively free, as the internal states of the two should overlap quite a lot.
* Additionally, you can even train a `NAR-len` model on top of an existing model.
Later papers for discrete TTS solutions work around the multiple codebook problem by introducing exotic interleaving patterns to work around existing problems. For all intents and purposes, these aren't necessary, as the current sequencing of prioritizng the first codebook (RVQ level 0). The remaining RVQ levels can be easily deduced from the prior level in parallel.
* Exotic solutions aren't necessary at all, as the summed embeddings can be good enough to represent the original waveform. Output codes can be inferenced in parallel with a wider head, neglecting the need to train separate levels.
## The AR (Autoregressive) Model
The AR is responsible for generating the first RVQ level of the audio codes for a given output. References to "outputs from the AR" refers to this level, as it contibutes to the final waveform the most.
@ -28,7 +28,7 @@ One way to work around the time cost is to instead decode more than one token at
* In theory, for a unified AR+NAR model, this *should* be an easy task, as the model can already decode tokens in parallel.
* In reality, this isn't the case. Specifying a `cfg.model.experimental.causal_size > 1` with adequate training will have the output sound *fine* every Nth timestep and every other timestep not so fine, as the following tokens aren't predictable enough.
+ *However*, this may simply be a sampling problem, as this experiment was done with outdated ideas on how to sample the AR, and should be worth revisiting.
* VALL-E 2's paper proposes merging code sequences together into one embedded token for a speedup, but their solution seems rather complex to warrant a fundamental retrain.
* VALL-E 2's paper proposes merging code sequences together into one embedded token for a speedup.
Sampling the AR does not necessarily require a specific sampling temperature, as:
* lower temperatures follow the prompt better, at the cost of variety in the outputs, and the need to either use classifier-free guidance or repetition penalty to wrangle the output.
@ -38,10 +38,6 @@ Traditional samplers for text-gen models can apply to the AR (especially rep/len
Compared to non-autoregressive decoding, I personally feel that autoregressive encoding offers a specific-yet-hard-to-quantify expressive quality that the NAR (and pure NAR solutions) does not offer.
### Pure AR
Technically, with `cfg.model.version >= 7`, a model can be purely AR, as that version of the model encodes and decodes all codebooks of audio in a single pass.
## The NAR (Non-autoregressive) Model
The NAR is responsible for generating the remaining RVQ levels of the audio codes for a given output. References to the "outputs from the NAR" refers to the underlying "levels" for a given waveform, as each further levels contributes to the final waveform less significantly than the previous.
@ -119,16 +115,10 @@ It is not required to train a model from scratch to use this modality, as traini
The "magic" of subjugating a transformer for audio use lies within the ensemble of the embeddings. This is necessary as each piece of a sequence is fundamentally different, but a HF-compatible model can get away with treating each sequence as separate ranges within a total token sequence.
While embeddings *can* be tied to the output head, testing showed that the model ***really*** does not like to do this, although my implementation could very well be flawed.
With attention-based transformers, most embeddings can serve as a token itself and have the attention mechanism attend to it. Theoretically, there should be little to no functional differences between "tokenizing" an embedding, and summing a modifying embedding, but experimentation is needed for this assertion.
* EnCodec seems to function perfectly fine with summing and without, but other codecs such as Descript-Audio-Codec might absolutely require summing.
With attention-based transformers, most embeddings can serve as a token itself and have the attention mechanism attend to it.
Other solutions such as TorToiSe makes use of additional embeddings/classifiers for each portion of the sequence as well.
Other solutions will rely on conditioning latents or extracted features as the input. This *technically* isn't necessary since portions of the model seem to be allocated as an encoder anyways from the embeddings to some arbitrary depth, and as a decoder from some arbitrary depth to the output heads.
* This might also mean it makes more sense to increase the model's size in-post by injecting new layers in the middle where it's outside these pseudo-encoder/decoder layers where it won't make any difference.
### Classifiers
Classifiers are the final output head / projection layer that processes the last hidden states of a model into a probability distribution for each token.
@ -333,6 +323,12 @@ This script aims to implement everything as required per VALL-E agnostically, to
A very naive implementation of using the model can be found under the `__main__` invocation.
### `models/base_v2.py`
This script implements a newer model aimed to sample *all* codebooks for a given step.
Due to major enough differences, this code is segregated from the original `models/base.py` to not break things further.
## `models/ar_nar.py`
This script implements VALL-E as a unified autoregressive and non-autoregressive model, where RVQ-level 0 is inferenced autoregressively, the remaining levels are infereneced non-autoregressively, if requested.

View File

@ -102,6 +102,12 @@ setup(
# other audio backend that doesn't prove fruitful
"descript-audio-codec",
# nemo (to-do: cut this down)
"nemo-toolkit",
"hydra-core",
"lightning",
"sentencepiece"
]
},
url="https://git.ecker.tech/mrq/vall-e",

View File

@ -277,7 +277,10 @@ class ModelExperimentalSettings:
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
# * NAR-demask would semi-doubly train for AR
# * the model wouldn't also need to learn when to predict the token in place
audio_encoder_mode: str = "sum" # audio encoder mode for version >= 7, because I cannot make up my damn mind
#
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
# these technically should be as hyperparameters
# performs token dropout to compensate for errors

View File

@ -76,9 +76,6 @@ class AR_NAR(Base):
# RVQ levels to apply masking training on
masking_train_rvq_levels = self.config.experimental.masking_train_rvq_levels
if cfg.audio_backend == "nemo":
rvq_levels_p = [ i for i in range( quant_level_range[0], quant_level_range[1] + 1 ) ]
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
@ -105,8 +102,19 @@ class AR_NAR(Base):
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if rvq_levels_p == "equal":
rvq_levels_p = [ i for i in range( lo, hi ) ]
else:
elif rvq_levels_p == "normal":
# yuck
rvq_levels_p = [
0,
1, 1,
2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4,
5, 5, 5, 5,
6, 6,
7,
]
else:
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
# input RVQ levels

View File

@ -105,8 +105,19 @@ class AR_NAR_V2(Base_V2):
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
if rvq_levels_p == "equal":
rvq_levels_p = [ i for i in range( lo, hi ) ]
else:
elif rvq_levels_p == "normal":
# yuck
rvq_levels_p = [
0,
1, 1,
2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4,
5, 5, 5, 5,
6, 6,
7,
]
else:
rvq_levels_p = sum([[i for _ in range(hi - i)] for i in range( lo, hi ) ], [])
# input RVQ levels

View File

@ -510,7 +510,7 @@ class Model(LlamaPreTrainedModel):
self.layers = nn.ModuleList(
[DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) if config.output_norm else nn.Identity()
self.rotary_emb = RotaryEmbedding(config=config)
self.gradient_checkpointing = False
@ -729,7 +729,6 @@ class Model(LlamaPreTrainedModel):
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer

View File

@ -81,49 +81,20 @@ def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
x[..., level] = torch.where( dropout_mask, lhs, rhs )
return x
# naively embeds each level of a codebook, then merges the embeddings with a Linear
class AudioEncoder(nn.Module):
# aims to properly encode RVQ-encoded token sequence into an embedding
class ResidualAudioEncoder(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
enc_mode: str = "sum",
l_weights: list[float] | None = None,
training: bool = True,
):
super().__init__()
self.enc_mode = enc_mode
d_ffn = 4
if not l_weights:
l_weights = [1 for _ in range(n_levels)]
if enc_mode == "sum":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = None
self.weights = nn.Parameter(torch.tensor(l_weights))
elif enc_mode == "sub_interleave":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
self.proj = None
elif enc_mode == "interleave":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
#self.proj = nn.Linear(n_levels * token_dim, token_dim)
self.proj = nn.Sequential(
nn.Linear(n_levels * token_dim, d_ffn * token_dim),
nn.GELU(),
nn.Linear(d_ffn * token_dim, token_dim)
)
elif enc_mode == "attn":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.cross_attn = nn.MultiheadAttention(embed_dim=token_dim,num_heads=n_levels,dropout=0.1)
self.proj = nn.Sequential(
nn.Linear(n_levels * token_dim, d_ffn * token_dim),
nn.GELU(),
nn.Linear(d_ffn * token_dim, token_dim)
)
for emb in self.embs:
nn.init.normal_(emb.weight, mean=0.0, std=0.02)
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim)) # i still don't understand why this needs to be explicitly added instead of it being deduced in the embedding itself
self.cross_attn = nn.MultiheadAttention( embed_dim=token_dim, num_heads=8, dropout=0.1 if training else 0.0, batch_first=True )
self.proj = nn.Linear(token_dim, token_dim) # i don't understand why this is necessary
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
@ -133,58 +104,130 @@ class AudioEncoder(nn.Module):
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
# old way
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
if self.enc_mode == "sum":
weights = F.softmax( self.weights, dim=0 )
x = sum([ weights[l] * emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
# attention-based crunge
elif self.enc_mode == "attn":
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
attn, _ = self.cross_attn(
x.permute(1, 0, 2),
x.permute(1, 0, 2),
x.permute(1, 0, 2),
)
attn = attn.permute(1, 0, 2)
x = x + attn
x = x.view(x.shape[0], -1)
# x = attn.reshape(x.shape[0], -1)
# encode by interleaving embeddings into one "token"
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
else:
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
x = x.view(x.shape[0], -1)
if self.proj is not None:
x = self.proj(x)
# ( seq_len, dim ) => ( seq_len, levels, dim )
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
x = x + self.pos_embedding
x, _ = self.cross_attn( x, x, x )
x = self.proj( x.mean(dim=1) )
return x
class AudioDecoder(nn.Module):
# aims to properly decode the last hidden states from a model into logits for an RVQ-encoded token sequence
class ResidualAudioDecoder(nn.Module):
def __init__(
self,
d_model,
vocab_size,
resp_levels,
training: bool = True,
use_ln: bool = False,
):
super().__init__()
self.resp_levels = resp_levels
self.head = nn.Linear( d_model, vocab_size * resp_levels )
self.projs = nn.ModuleList([nn.Sequential(
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
nn.Linear(d_model, d_model),
) for _ in range(resp_levels)]) # per-level projs
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
# prior way up-projected then down-projected, but that's silly
self.cross_attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=8, dropout=0.1 if training else 0.0, batch_first=True ) # xattn so each level can attend to others per residual-ness
self.head = nn.Linear(d_model, vocab_size) # final output head, i feel it would be better to have it per-level but i assume the proj handles it
# forward for one sequence
def _forward( self, x: Tensor ) -> Tensor:
seq_len, resp_levels = x.shape[0], len(self.projs)
x = torch.stack([proj(x) for proj in self.projs], dim=1)
x, _ = self.cross_attn( x, x, x )
x = self.head( x )
x = x.view( resp_levels, seq_len, -1 )
return x
# interleave by reshaping / permuting
# at least I hope this does it properly, it checks out against my OCR classifier
batch_size, seq_len, dim = x.shape
x = x.view( batch_size, seq_len, self.resp_levels, -1 )
x = x.permute( 0, 2, 1, 3 )
# required to act on per sequence and not a batch due to headed-ness
def forward( self, x_i: Tensor ) -> Tensor:
return torch.stack([ self._forward(x) for x in x_i ], dim=0)
# the above, but for FSQ codecs, as each level is independent from one another
class FiniteAudioEncoder(nn.Module):
def __init__(
self,
n_tokens: int,
n_levels: int,
token_dim: int,
training: bool = True,
):
super().__init__()
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(n_levels)])
self.pos_embedding = nn.Parameter(torch.randn(1, n_levels, token_dim))
self.proj = nn.Linear(token_dim, token_dim)
self.level_weights = nn.Parameter(torch.ones(n_levels))
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
if xi.shape[0] == 0:
dim = self.embs[0].weight.shape[-1] # self.proj.weight.shape[0]
return torch.zeros((0, dim), device=xi.device, dtype=xi.dtype)
if dropout_mask is not None:
xi = _dropout_codes( xi, dropout_mask, dropout_token )
x = torch.stack([ emb(xi[:, i]) for i, emb in enumerate(self.embs) ], dim=1)
x = x + self.pos_embedding
x = self.proj( x )
weights = F.softmax(self.level_weights, dim=0).view(1, -1, 1)
x = (x * weights).sum(dim=1)
return x
# aims to decode the last hidden state into independent codebooks
# uses an MLP instead of Attn since it's not residual I guess (the problem with caving to consult claude-3-5-sonnet is that it will blindly agree with you if you suggest anything)
# optional per-level LN, might be beneficial
class FiniteAudioDecoder(nn.Module):
def __init__(
self,
d_model: int,
vocab_size: int,
n_levels: int,
d_ffn: int = 4,
use_ln: bool = True,
shared_levels: bool = False,
training: bool = False,
):
super().__init__()
self.n_levels = n_levels
self.shared_levels = shared_levels
if not shared_levels:
self.head = nn.ModuleList([nn.Sequential(
# ln
(nn.LayerNorm(d_model) if use_ln else nn.Identity()),
# ffn
nn.Linear(d_model, d_ffn * d_model),
nn.GELU(),
nn.Linear(d_ffn * d_model, d_model),
# head
nn.Linear(d_model, vocab_size)
) for _ in range(n_levels)])
else:
self.head = nn.Sequential(
# ffn
nn.Linear(d_model, d_ffn * d_model),
nn.GELU(),
nn.Linear(d_ffn * d_model, d_model),
# head
nn.Linear(d_model, vocab_size * n_levels)
)
def forward(self, x: Tensor) -> Tensor:
batch_size, seq_len, _ = x.shape
if not self.shared_levels:
x = torch.stack([head(x) for head in self.head], dim=1)
else:
x = self.head(x)
x = x.view(batch_size, seq_len, self.n_levels, -1)
x = x.transpose(1, 2)
return x
# handles simple output projections into logits for other tasks
class AuxDecoder(nn.Module):
def __init__(
self,
@ -255,8 +298,9 @@ class Base_V2(nn.Module):
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
predict_causally = config.experimental.predict_causally if config is not None else False
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
audio_encoder_mode = config.experimental.audio_encoder_mode if config is not None else "sum"
audio_level_weights = [1.0 / (i + 1) for i in range(n_resp_levels)] # to-do: find the weights for FSQ
logit_normalization = config.experimental.logit_normalization if config is not None else 0
per_level_normalization = config.experimental.per_level_normalization if config is not None else True
n_vocab = 256
n_tasks = config.tasks if config is not None else 8
@ -274,6 +318,12 @@ class Base_V2(nn.Module):
if attention_backend not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
# to-do: deduce nemo better-er
if n_audio_tokens == 1000:
# assume midrage contains important details
center = n_resp_levels // 2
audio_level_weights = [1.0 - abs(i - center) / n_resp_levels for i in range(n_resp_levels)]
self.training = training
self.teaching = False
self.config = config
@ -323,6 +373,7 @@ class Base_V2(nn.Module):
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.audio_level_weights = audio_level_weights
self.logit_normalization = logit_normalization
self.sep = nn.Parameter(torch.randn(d_model))
@ -337,34 +388,41 @@ class Base_V2(nn.Module):
self.proms_emb = None
self.resps_emb = None
# to-do: deduce nemo-ness better-er
if n_audio_tokens == 1000:
AudioEncoder = FiniteAudioEncoder
AudioDecoder = FiniteAudioDecoder
else:
AudioEncoder = ResidualAudioEncoder
AudioDecoder = ResidualAudioDecoder
if monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
enc_mode=audio_encoder_mode,
l_weights=audio_level_weights,
training=training,
)
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
enc_mode=audio_encoder_mode,
l_weights=audio_level_weights,
training=training,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
enc_mode=audio_encoder_mode,
l_weights=audio_level_weights,
training=training,
)
self.audio_decoder = AudioDecoder(
d_model,
(n_audio_tokens + 1),
self.n_resp_levels,
training=training,
use_ln=per_level_normalization,
)
self.len_decoder = AuxDecoder( d_model, 11 )
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
@ -391,6 +449,7 @@ class Base_V2(nn.Module):
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
output_norm=not per_level_normalization, # moves the LN out to the decoder
#gradient_checkpointing=self.gradient_checkpointing,
)
self.model_config.attn_mode = attention_backend
@ -540,7 +599,7 @@ class Base_V2(nn.Module):
p = self.masking_ratio
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if self.training:
if self.training and p > 0:
dropout_mask = _dropout_mask( resps_list[i], p )
inputs[i].append( ("dropout_mask", dropout_mask ) )
# insert the current output response
@ -806,6 +865,10 @@ class Base_V2(nn.Module):
return input
def _logit_normalization( logit ):
norms = torch.norm(logit, p=2, dim=-1, keepdim=True) + 1e-7
return torch.div(logit, norms) / self.logit_normalization
def _calc_loss( logit, sequence, causal = True, level = None ):
# filter tokens that exceed the vocab size
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
@ -818,10 +881,20 @@ class Base_V2(nn.Module):
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
sequence = sequence[..., l:] # ...predicts token n + 1
batched = sequence.dim() > 1
# logit normalization
if self.logit_normalization:
# it would probably be better to unsqueeze then squeeze to avoid code duplication but who cares
if not batched:
logit = _logit_normalization( logit )
else:
for i, l in enumerate( logit ):
logit[i] = _logit_normalization( l )
# flatten batch
parallel = sequence.dim() > 1
if parallel:
if batched:
logit = logit.reshape(-1, logit.shape[-1])
sequence = sequence.reshape(-1)
@ -829,10 +902,12 @@ class Base_V2(nn.Module):
metrics = None
if compute_hard_loss:
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction='mean' if not parallel else 'none' ) * (level_weights[level] if level is not None and not parallel else 1)
reduction = 'mean' if not batched else 'none'
weight = level_weights[level] if level is not None and not batched else 1
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction=reduction ) * weight
# manually weigh each level
if parallel:
if batched:
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_weights, device=device)
if compute_acc:
@ -844,6 +919,7 @@ class Base_V2(nn.Module):
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, sequence )
return nll, metrics
for batch_index, batch in enumerate(inputs):
@ -949,7 +1025,7 @@ class Base_V2(nn.Module):
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
if nll is not None:
nll = nll.sum()
nll = nll.mean()
loss_key = f'{name}.nll'
acc_key = f'{name}.acc'
@ -966,7 +1042,8 @@ class Base_V2(nn.Module):
else:
target.append( token )
# perofrm loss calculation on the entire sequence
# perform loss calculation on the entire sequence
if not self.config.loss_factors:
if logits[batch_index].dim() < 3:
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
@ -984,18 +1061,22 @@ class Base_V2(nn.Module):
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
else:
nlls = []
metrics = []
accs = []
for level, logit in enumerate( logits[batch_index] ):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metric = _calc_loss( logit, sequence, causal, level )
nlls.append( nll )
metrics.append( metric )
nll = sum(nlls)
metrics = mean(metrics)
nll, metrics = _calc_loss( logit, sequence, causal, level )
if nll:
nlls.append( nll )
if metrics:
accs.append( metrics )
if nlls:
nll = sum(nlls) / len(nlls)
if accs:
metrics = sum(accs) / len(accs)
if nll is not None:
if 'nll' not in loss:
@ -1008,8 +1089,8 @@ class Base_V2(nn.Module):
stats["acc"].append( metrics )
# average
loss = { name: mean( loss[name] ) for name in loss.keys() }
stats = { name: mean( stats[name] ) for name in stats.keys() }
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats)
@ -1076,7 +1157,6 @@ class Base_V2(nn.Module):
logits = [ logit for logit in output.logits ]
hidden_states = output.hidden_states
grouped_logits = {}
for batch_index in range( batch_size ):

View File

@ -35,8 +35,7 @@ T = TypeVar("T")
def mean( l ):
if not l:
return 0
_l = [ _ for _ in l if _ is not None ]
return sum(_l) / len(_l)
return sum(l) / len(l)
# removes prefix from key in a dict
# useful for mapping args like ar_temperature => temperature