added experimental training setting to perform token dropout to MAYBE compensate for errors from the preceding RVQ level (two types: token error offset, token dropout embedding replace)
This commit is contained in:
parent
611a1c4bdc
commit
1acb0e9c84
|
@ -215,6 +215,11 @@ class ModelExperimentalSettings:
|
|||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||
|
||||
# performs token dropout to compensate for errors
|
||||
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
||||
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
||||
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
||||
|
||||
# I really need to clean this up
|
||||
@dataclass()
|
||||
class Model:
|
||||
|
|
|
@ -22,6 +22,9 @@ from ..emb.qnt import trim, encode_as_embedding
|
|||
|
||||
from .lora import enable_lora
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
class AR_NAR(Base):
|
||||
@property
|
||||
def capabilities(self) -> list[str]:
|
||||
|
@ -139,6 +142,11 @@ class AR_NAR(Base):
|
|||
# determines which RVQ level to target per batch
|
||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
token_dropout_error = self.config.experimental.token_dropout_error
|
||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [0, self.resp_levels]
|
||||
|
||||
if p_rvq_levels == "equal":
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||
|
@ -165,39 +173,49 @@ class AR_NAR(Base):
|
|||
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
||||
|
||||
# these two are techinically equivalent if the audio embeddings handle things properly
|
||||
"""
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16)
|
||||
|
||||
"""
|
||||
|
||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
|
||||
"""
|
||||
|
||||
for i in range(batch_size):
|
||||
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||
# cap quant_level if it exceeds its corresponding resp/prom
|
||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
||||
if quant_level >= resps.shape[-1]:
|
||||
quant_levels[i] = resps.shape[-1] - 1
|
||||
|
||||
# proms_list[i] could be a Tensor, list[Tensor], or None
|
||||
if isinstance( proms_list[i], torch.Tensor ):
|
||||
if quant_levels[i] >= proms_list[i].shape[-1]:
|
||||
quant_levels[i] = proms_list[i].shape[-1] - 1
|
||||
# proms could be a Tensor, list[Tensor], or None
|
||||
if isinstance( proms, torch.Tensor ):
|
||||
if quant_level >= proms.shape[-1]:
|
||||
quant_levels[i] = proms.shape[-1] - 1
|
||||
|
||||
elif isinstance( proms_list[i], list ):
|
||||
for j, prom in enumerate( proms_list[i] ):
|
||||
elif isinstance( proms, list ):
|
||||
for j, prom in enumerate( proms ):
|
||||
if not isinstance( prom, torch.Tensor ):
|
||||
continue
|
||||
|
||||
if quant_levels[i] >= prom.shape[-1]:
|
||||
if quant_level >= prom.shape[-1]:
|
||||
quant_levels[i] = prom.shape[-1] - 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_levels[i] > 0:
|
||||
continue
|
||||
# apply token dropout error compensation
|
||||
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
steps = resps.shape[0]
|
||||
for l in range( quant_level ):
|
||||
for t in range( steps ):
|
||||
token = resps[t, l].item()
|
||||
|
||||
if random.random() < token_dropout_error:
|
||||
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0:
|
||||
# append stop tokens for AR
|
||||
# could technically do it in the .inputs call
|
||||
resps_list[i] = torch.cat([ resps_list[i], stop_sequence ])
|
||||
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
||||
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
|
|
@ -12,7 +12,7 @@ Additional functionality (preparing inputs, generating full audio) should be del
|
|||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import traceback
|
||||
import random
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
|
@ -440,6 +440,10 @@ class Base(nn.Module):
|
|||
self.rvq_l_emb = None
|
||||
self.len_emb = None
|
||||
|
||||
# it would be nicer for these to be a token or live inside an embedding
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
self.dropout_token = nn.Parameter(torch.zeros(d_model)) # zeros sounds nicer than randn for a special value
|
||||
|
||||
if self.version == 1: # legacy
|
||||
n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
||||
|
@ -484,9 +488,6 @@ class Base(nn.Module):
|
|||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
||||
|
||||
# this would be nicer to be a stop token or live inside an embedding
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
# ick, there has to be a better way
|
||||
hf_attention = self.config.attention if self.config is not None else None
|
||||
|
||||
|
@ -970,6 +971,16 @@ class Base(nn.Module):
|
|||
|
||||
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
|
||||
|
||||
# yuck
|
||||
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
|
||||
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels if self.config else None
|
||||
|
||||
if self.dropout_token is None or not self.training:
|
||||
token_dropout_rate = 0.0
|
||||
|
||||
if not token_dropout_rvq_levels:
|
||||
token_dropout_rvq_levels = [1, self.resp_levels]
|
||||
|
||||
x_list = []
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = []
|
||||
|
@ -1018,6 +1029,16 @@ class Base(nn.Module):
|
|||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
|
||||
)
|
||||
|
||||
# apply token dropout
|
||||
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||
steps = embedding.shape[0] - (1 if quant_level == 0 else 0) # do not mess with stop token
|
||||
for i in range( steps ):
|
||||
if random.random() > token_dropout_rate:
|
||||
continue
|
||||
|
||||
embedding[i] = self.dropout_token
|
||||
|
||||
elif name == "len" and self.len_emb is not None:
|
||||
embedding = self.len_emb( input )
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue
Block a user