added experimental training setting to perform token dropout to MAYBE compensate for errors from the preceding RVQ level (two types: token error offset, token dropout embedding replace)

This commit is contained in:
mrq 2024-07-24 19:35:17 -05:00
parent 611a1c4bdc
commit 1acb0e9c84
3 changed files with 67 additions and 23 deletions

View File

@ -214,6 +214,11 @@ class ModelExperimentalSettings:
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range for LoRAs, isn't necesary
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
# performs token dropout to compensate for errors
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
# I really need to clean this up
@dataclass()

View File

@ -22,6 +22,9 @@ from ..emb.qnt import trim, encode_as_embedding
from .lora import enable_lora
def clamp(n, lo, hi):
return max(lo, min(n, hi))
class AR_NAR(Base):
@property
def capabilities(self) -> list[str]:
@ -139,6 +142,11 @@ class AR_NAR(Base):
# determines which RVQ level to target per batch
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
token_dropout_error = self.config.experimental.token_dropout_error
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [0, self.resp_levels]
if p_rvq_levels == "equal":
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
@ -165,39 +173,49 @@ class AR_NAR(Base):
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
# these two are techinically equivalent if the audio embeddings handle things properly
"""
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16)
"""
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
"""
for i in range(batch_size):
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
# cap quant_level if it exceeds its corresponding resp/prom
if quant_levels[i] >= resps_list[i].shape[-1]:
quant_levels[i] = resps_list[i].shape[-1] - 1
if quant_level >= resps.shape[-1]:
quant_levels[i] = resps.shape[-1] - 1
# proms_list[i] could be a Tensor, list[Tensor], or None
if isinstance( proms_list[i], torch.Tensor ):
if quant_levels[i] >= proms_list[i].shape[-1]:
quant_levels[i] = proms_list[i].shape[-1] - 1
# proms could be a Tensor, list[Tensor], or None
if isinstance( proms, torch.Tensor ):
if quant_level >= proms.shape[-1]:
quant_levels[i] = proms.shape[-1] - 1
elif isinstance( proms_list[i], list ):
for j, prom in enumerate( proms_list[i] ):
elif isinstance( proms, list ):
for j, prom in enumerate( proms ):
if not isinstance( prom, torch.Tensor ):
continue
if quant_levels[i] >= prom.shape[-1]:
if quant_level >= prom.shape[-1]:
quant_levels[i] = prom.shape[-1] - 1
# only apply stop token for RVQ level 0
if quant_levels[i] > 0:
continue
# apply token dropout error compensation
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
steps = resps.shape[0]
for l in range( quant_level ):
for t in range( steps ):
token = resps[t, l].item()
# append stop tokens for AR
# could technically do it in the .inputs call
resps_list[i] = torch.cat([ resps_list[i], stop_sequence ])
if random.random() < token_dropout_error:
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0:
# append stop tokens for AR
# could technically do it in the .inputs call
resps_list[i] = torch.cat([ resps, stop_sequence ])
inputs = self.inputs(
text_list=text_list,

View File

@ -12,7 +12,7 @@ Additional functionality (preparing inputs, generating full audio) should be del
import math
import torch
import torch.nn.functional as F
import traceback
import random
import numpy as np
import re
@ -439,6 +439,10 @@ class Base(nn.Module):
self.tasks_emb = None
self.rvq_l_emb = None
self.len_emb = None
# it would be nicer for these to be a token or live inside an embedding
self.sep = nn.Parameter(torch.randn(d_model))
self.dropout_token = nn.Parameter(torch.zeros(d_model)) # zeros sounds nicer than randn for a special value
if self.version == 1: # legacy
n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
@ -484,9 +488,6 @@ class Base(nn.Module):
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
# this would be nicer to be a stop token or live inside an embedding
self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
hf_attention = self.config.attention if self.config is not None else None
@ -970,6 +971,16 @@ class Base(nn.Module):
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
# yuck
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels if self.config else None
if self.dropout_token is None or not self.training:
token_dropout_rate = 0.0
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [1, self.resp_levels]
x_list = []
for batch_index, batch_input in enumerate(inputs):
batch = []
@ -1018,6 +1029,16 @@ class Base(nn.Module):
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
)
# apply token dropout
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
steps = embedding.shape[0] - (1 if quant_level == 0 else 0) # do not mess with stop token
for i in range( steps ):
if random.random() > token_dropout_rate:
continue
embedding[i] = self.dropout_token
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )
else: