len prediction for new model (and remove logit normalization since it kills inferencing)
This commit is contained in:
parent
5f98543d4d
commit
5c512717a6
|
@ -360,20 +360,16 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
|
|||
|
||||
The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model.
|
||||
|
||||
### `models/arch/bitnet.py`
|
||||
|
||||
This script modifies modules of BitNet to play nicely with my existing code.
|
||||
|
||||
### `models/arch/llama.py`
|
||||
|
||||
This script modifies modules of LLaMA provided through `transformers`.
|
||||
This script contains its own copy of the LLaMA provided through `transformers` with its own modifications and independence from any updates that may break it.
|
||||
|
||||
A bulk of it pertains to modifying `LlamaAttention` and detecting available attention mechanisms, allowing for using different attention mechanisms:
|
||||
* `torch.nn.functional.scaled_dot_product_attention`-based attention:
|
||||
* `math`: torch's SDPA's `math` kernel
|
||||
* `mem_efficient`: torch's SDPA's memory efficient (`xformers` adjacent) kernel
|
||||
* `cudnn`: torch's SDPA's `cudnn` kernel
|
||||
* `flash`: torch's SDPA's flash attention kernel
|
||||
* `flash_(sdpa)`: torch's SDPA's flash attention kernel
|
||||
* internal implementations of external attention backends:
|
||||
* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention
|
||||
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
||||
|
@ -404,6 +400,8 @@ Modifications to `LlamaModel` is also provided to implement LayerSkip-aware trai
|
|||
```
|
||||
* install with `pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-build-isolation`
|
||||
|
||||
Later versions of PyTorch with ROCm natively supports full (both inferencing and training) Flash Attention through SDPA's interface. To use it, just set the model's attention to `flash_(sdpa)`
|
||||
|
||||
### `models/arch/mamba.py`
|
||||
|
||||
This script modifies modules of Mamba, to allow it to play nicely with my existing code.
|
||||
|
|
69
docs/models_v2.md
Normal file
69
docs/models_v2.md
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Model V2 Notes
|
||||
|
||||
This section aims to document the `_v2` class of models.
|
||||
|
||||
Unlike the original, this implementation strives to operate on *all* codebooks at once, rather than requiring the model to operate on one codebook level at a time.
|
||||
|
||||
This model might *not* scale well up, as the `nemo-smaller-44khz-llama-8` brand seems to perform at a similar quality to `nemo-larger-44khz-llama-8`.
|
||||
* However, the latter had speech emerge much quicker than the former, but both seem to have a problem with consistently working on various speakers unlike the previous series of models.
|
||||
|
||||
Documentation here might be all over the place from having to extract four weeks worth of agonizing experiments.
|
||||
|
||||
## Audio Codecs
|
||||
|
||||
*Technically* this implementation should work for *any* codec, as it seems to "work" adequately for `nvidia/audio-codec-44khz` (an FSQ codec with 86 frames per second, 8 codebooks, and 1000 codes per codebook). The previously allusive DAC (an RVQ codec with 87 frames per second, 9 codebooks, and 1024 codes per codebook) should produce decent results, as will the tried and true EnCodec codec (an RVQ codec with 75 frames per second, 8 codebooks, and 1024 codes per codebook). In theory, they should work better, as FSQ codecs will make confidence issues in codebook levels much more apparent as it's not residual.
|
||||
|
||||
## `AudioEncoder` / `AudioDecoder`
|
||||
|
||||
Because this model operates on the full audio sequence at once, extra care is required to ensure the model accurately operates on it, rather than leave it to chance that the model will inherently encode/decode from its latent space. Extra care is also required for how the audio is encoded (residually, or finitely).
|
||||
|
||||
### Residual*
|
||||
|
||||
To-do: document this version.
|
||||
|
||||
Honestly, I haven't gotten to test this, as I grew impatient with using EnCodec as the audio codec, and went all in on `nvidia/audio-codec-44khz`.
|
||||
|
||||
I *feel* this might not work from the cursory experiments I did before giving up on it, but residual codecs might also work with the FSQ-targeted encoder/decoders.
|
||||
|
||||
### Finite*
|
||||
|
||||
The `AudioEncoder` embeds each codebook level (and injects level-position embedding information), stacks it, then passes it through an MLP ( / residual feedforward network ), then weighs each level through learned weights before summing it down to one sequence.
|
||||
* I feel most of this is kind of overkill, since I believe layer 0 could do this better, but it might also allow better tuning of the model's "encoder" with an explicit one over an inherent one.
|
||||
* Attention could also be used in place of the learned weights, as some speakers *could* prioritize different codebooks levels for FSQ sequences.
|
||||
|
||||
The `AudioDecoder` projects the last hidden state through another feed-forward network (non-residual, with its own pre-layer norm). The decoder can be configured to either share the head for all levels, or dedicate a head for each level.
|
||||
* I feel non-shared heads might also be overkill, but allows for the decoder to better-er extract the dedicated codebook level from the last hidden state.
|
||||
|
||||
## Pure NAR
|
||||
|
||||
Like the previous implementation, this model can operate entirely non-autoregressively (and with non-causal attention) as a masked transformer. The demasking inference loop is the same, where each demasking step can mask off an entire timestep on the sum of the logit scores, or independently (where each level has its own mask).
|
||||
|
||||
Unlike the previous implementation, duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab. size of 11, and decoded autoregressively.
|
||||
|
||||
## Pure AR
|
||||
|
||||
Unlike the previous implementation, this model can also operate entirely autoregressively as a causal transformer, where each `forward` samples *all* codebooks at one code-frame.
|
||||
|
||||
More experimentation is needed for this modality, but seeing as the pure NAR approach works, I imagine a model can either be trained purely-autoregressively, or mixed (such as with the confusingly named `ar+nar-len`) model.
|
||||
|
||||
However, this modality was not trained for either models, as there seems to be some weird quirk when inferencing that's caught under CUDA, but not ROCm. This doesn't seem to "go away" with more training, unfortunately.
|
||||
|
||||
## Training Regimen
|
||||
|
||||
The `nemo-smaller-44khz-llama-8` model is a 512-dim, 12 layered, 8 headed attention-based transformer with rotary position embedding. Training was performed on four V100s with AMP+`float16` with a batch size of 8 samples per GPU, and an AdamW optimizer with adequate parameters (`1.0e-4` learning rate, betas of `[0.8, 0.95]`, weight_decay of `0.01`, linear warmup to 5K steps before holding) for 400K steps before introducing training for duration prediction in parallel. The dataloader sorts the dataset by duration, starting from 2 seconds and ending with 8 seconds-ed utterances. Training consists of computing the loss for each codebook level non-parallely (where a level is randomly assigned to a sample per a "normal" distribution) with each loss being weighed "normal"ly, for 70% of the epoch when speech starts to emerge. Then, the model was trained to compute the loss paralelly (where all levels have the loss computed) without weighing the loss per-level. Audio quality was lacking for most speakers, as the model failed to handle all codebook levels adequately. Additional training slowly helps, but by-the-numbers metrics don't show much improvement.
|
||||
|
||||
The `nemo-larger-44khz-llama-8` model is similar to its immediate predecessor, with 1024-dim, 24 layers, and 16 heads. Training is similar where the only difference is with a learning rate of `3.0e-4`. Speech emerged quicker than its predecessor at `?`% of the epoch, but quality remains about the same.
|
||||
|
||||
Training of both models experienced degredation in quality periodically, where the loss will rise, spike, then climb back down. It's reasonable to assume this came from duration sorting being the cause, as the model might somehow "overfit" based on duration, as this problem disappeared when re-initializing the dataloader to instead batch samples by durations, then shuffle the batches. However, training throughput significantly dropped for the larger model.
|
||||
|
||||
The differences between the two models suggests there is no outright immediate benefits from scaling up as it "costs" more to train the larger model. Benefitis may be discovered through manual evaluation, which kind of predicates on the duration predictor (which wasn't added until much later into training out of neglect).
|
||||
|
||||
## Benefits and Caveats
|
||||
|
||||
To be evaluated.
|
||||
|
||||
At a glance, compared to the prior model setup, this implementation allows for the model to better represent speech as it's able to see the entire signal and account for it in its latent space, rather than only specific levels of it.
|
||||
|
||||
Additionally, this implementation paves the way for live decoding of the audio under the autoregressive mode (if trained for it).
|
||||
|
||||
However, I'm not sure if the additional complexity justifies it.
|
|
@ -280,7 +280,6 @@ class ModelExperimentalSettings:
|
|||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||
|
||||
#
|
||||
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
||||
per_level_normalization: bool = True # moves the final norm out from the underlying model into the decoder
|
||||
audio_level_loss_factors: list[float] | str = "auto" # the loss factors per-level when training
|
||||
# "auto" will pick best for codec
|
||||
|
|
|
@ -523,6 +523,8 @@ class TTS():
|
|||
# add an additional X seconds
|
||||
len_list = [ int(l * duration_padding) for l in len_list ]
|
||||
|
||||
print( len_list )
|
||||
|
||||
kwargs = {}
|
||||
if prefix_context is not None:
|
||||
kwargs["prefix_context"] = prefix_context
|
||||
|
|
|
@ -410,7 +410,7 @@ class AR_NAR_V2(Base_V2):
|
|||
|
||||
return resps_list
|
||||
|
||||
def forward_ar_len(
|
||||
def forward_len(
|
||||
self,
|
||||
|
||||
task_list: list[Tensor],
|
||||
|
@ -437,83 +437,35 @@ class AR_NAR_V2(Base_V2):
|
|||
elif proms_list:
|
||||
device = proms_list[0].device
|
||||
batch_size = len(proms_list)
|
||||
elif resps_list:
|
||||
device = resps_list[0].device
|
||||
batch_size = len(resps_list)
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) if use_lora is None else use_lora )
|
||||
|
||||
# convert AR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||
task_list = [ "len" for _ in range( batch_size ) ]
|
||||
quant_levels = [ 0 for _ in range( batch_size ) ]
|
||||
|
||||
temperature = sampling_kwargs.get("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
|
||||
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
|
||||
max_duration = sampling_kwargs.get("max_duration", 500)
|
||||
beam_width = sampling_kwargs.get("beam_width", 0)
|
||||
entropix_sampling = sampling_kwargs.get("entropix_sampling", False)
|
||||
refine_on_stop = sampling_kwargs.get("refine_on_stop", False)
|
||||
input_prompt_prefix = sampling_kwargs.get("input_prompt_prefix", False)
|
||||
layer_skip = sampling_kwargs.get("layer_skip", False)
|
||||
prefix_silence = sampling_kwargs.get("prefix_silence", 0.0)
|
||||
mirostat_tau = sampling_kwargs.get("mirostat_tau", 0.0)
|
||||
mirostat_eta = sampling_kwargs.get("mirostat_eta", 0.0)
|
||||
inputs = self.inputs(
|
||||
task_list=task_list,
|
||||
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=None,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=None,
|
||||
text_list=text_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# inference len
|
||||
sequence_list = [ torch.tensor([0], device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
stop_token = 10
|
||||
task_list = [ "len" for _ in range(batch_size) ]
|
||||
quant_levels = [ 0 for _ in range( max( batch_size, beam_width ) ) ]
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
logits = output.logits
|
||||
|
||||
iterator = trange(10, desc="AR", disable=disable_tqdm)
|
||||
for n in iterator:
|
||||
len_list = sequence_list
|
||||
|
||||
inputs = self.inputs(
|
||||
task_list=task_list,
|
||||
|
||||
phns_list=phns_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
text_list=text_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
logits = output.logits
|
||||
|
||||
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
||||
# sanitize
|
||||
for i, token in enumerate(r):
|
||||
if token > stop_token:
|
||||
r[i][0] = stop_token
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if stop_token in ri:
|
||||
stopped[i] = True
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
iterator.close()
|
||||
break
|
||||
|
||||
# convert tokens into int
|
||||
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
|
||||
return [ int(logit * cfg.dataset.frames_per_second) for logit in logits ]
|
||||
|
||||
def forward_ar(
|
||||
self,
|
||||
|
@ -691,7 +643,6 @@ class AR_NAR_V2(Base_V2):
|
|||
use_lora=None,
|
||||
**sampling_kwargs,
|
||||
):
|
||||
# deduce batch_size
|
||||
# deduce batch_size
|
||||
if phns_list:
|
||||
device = phns_list[0].device
|
||||
|
@ -782,7 +733,7 @@ class AR_NAR_V2(Base_V2):
|
|||
"""
|
||||
|
||||
if task_list is not None and task_list[0] == "len":
|
||||
return self.forward_ar_len(
|
||||
return self.forward_len(
|
||||
task_list=task_list,
|
||||
|
||||
phns_list=phns_list,
|
||||
|
|
|
@ -628,9 +628,12 @@ class Model(LlamaPreTrainedModel):
|
|||
prom_start, prom_end = text_end, text_end + aux_len[1]
|
||||
output_start, output_end = prom_end, prom_end + aux_len[2]
|
||||
|
||||
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
|
||||
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
|
||||
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0
|
||||
if aux_len[0]:
|
||||
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
|
||||
if aux_len[1]:
|
||||
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
|
||||
if aux_len[2]:
|
||||
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0
|
||||
|
||||
# apply the original attention mask
|
||||
expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||
|
|
|
@ -424,7 +424,7 @@ class Base_V2(nn.Module):
|
|||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
self.noncausal_masks = noncausal_masks
|
||||
self.audio_level_loss_factors = audio_level_loss_factors
|
||||
self.logit_normalization = logit_normalization
|
||||
self.logit_normalization = False # this actually kills the model's demasking capabilities
|
||||
self.use_segmented_attention_mask = use_segmented_attention_mask
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
@ -685,14 +685,6 @@ class Base_V2(nn.Module):
|
|||
# insert tone token if we're trained for it
|
||||
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
|
||||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
|
||||
# insert output length tokens (if it exists)
|
||||
if len_list is not None and len_list[i] is not None:
|
||||
inputs[i].append( ( "len", len_list[i] ) )
|
||||
# "encode" length to tokens for 0-9 + stop
|
||||
elif resps_list is not None and resps_list[i] is not None:
|
||||
# yes this could be encoded better
|
||||
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
|
||||
|
||||
inputs[i].append( ("classifier_level", "len") )
|
||||
# Speech-to-Text prediction task
|
||||
|
@ -818,6 +810,10 @@ class Base_V2(nn.Module):
|
|||
|
||||
batch.append(embedding)
|
||||
|
||||
# needed, cringe
|
||||
if task_type == "len":
|
||||
batch[-1] = torch.cat( [ batch[-1], self.sep[None] ] )
|
||||
|
||||
x_list.append( _join( batch, self.sep ) )
|
||||
|
||||
return x_list
|
||||
|
@ -1006,8 +1002,10 @@ class Base_V2(nn.Module):
|
|||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
token = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
"""
|
||||
if self.logit_normalization:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
"""
|
||||
|
||||
loss_targets.append( token.long() )
|
||||
loss_logits.append( logit )
|
||||
|
@ -1026,8 +1024,10 @@ class Base_V2(nn.Module):
|
|||
logit = logit[..., :-l, :] # shift the target so that token n...
|
||||
token = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
"""
|
||||
if self.logit_normalization:
|
||||
logit = logit_normalization( logit, self.logit_normalization )
|
||||
"""
|
||||
|
||||
loss_targets.append( token[:, level].long() )
|
||||
loss_logits.append( logit )
|
||||
|
@ -1108,7 +1108,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
# check if len logits are provided
|
||||
if logits_aux is not None:
|
||||
len_factor = 0.01
|
||||
len_factor = 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses)
|
||||
aux_loss_logit = torch.cat( logits_aux )
|
||||
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
||||
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
@ -1164,7 +1164,7 @@ class Base_V2(nn.Module):
|
|||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||
causal_levels = [ "len", "phn", "text" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
|
||||
causal_levels = [ "phn", "text" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
|
||||
|
||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
@ -1246,11 +1246,14 @@ class Base_V2(nn.Module):
|
|||
self.loss = None
|
||||
self.stats = None
|
||||
|
||||
# grab duration if no resp is provided
|
||||
if aux_lens[0][2] == 0:
|
||||
# this can all technically be grabbed outside of this forward and manually invoke len_decoder on the last hidden states
|
||||
tasks = self.get_input( inputs, name="task" )
|
||||
|
||||
# grab duration if no resp is provided or len task is requested
|
||||
if tasks[0] == "len" or aux_lens[0][2] == 0:
|
||||
# do duration prediction
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
# only keep the input
|
||||
# only keep the designated token (although this should technically be logit[-1, :1])
|
||||
logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
|
||||
logits = logits_aux
|
||||
|
|
Loading…
Reference in New Issue
Block a user