added experimental training setting to perform token dropout to MAYBE compensate for errors from the preceding RVQ level (two types: token error offset, token dropout embedding replace)
This commit is contained in:
parent
611a1c4bdc
commit
1acb0e9c84
|
@ -215,6 +215,11 @@ class ModelExperimentalSettings:
|
||||||
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
unified_position_ids: bool = True # False will generate position IDs partitioned for each section
|
||||||
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
tie_classifier_to_embedding: bool = False # Ties the classifier output to their respective embeddings, this does not seem to do anything good in testing
|
||||||
|
|
||||||
|
# performs token dropout to compensate for errors
|
||||||
|
token_dropout_error: float = 0.0 # probability to nudge a token by ±1
|
||||||
|
token_dropout_rate: float = 0.0 # probability to randomly set a token to a special dropout value
|
||||||
|
token_dropout_rvq_levels: list = field(default_factory=lambda: [1,8]) # determines which levels to do dropout, by default do not do dropout on RVQ level 0
|
||||||
|
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
|
|
|
@ -22,6 +22,9 @@ from ..emb.qnt import trim, encode_as_embedding
|
||||||
|
|
||||||
from .lora import enable_lora
|
from .lora import enable_lora
|
||||||
|
|
||||||
|
def clamp(n, lo, hi):
|
||||||
|
return max(lo, min(n, hi))
|
||||||
|
|
||||||
class AR_NAR(Base):
|
class AR_NAR(Base):
|
||||||
@property
|
@property
|
||||||
def capabilities(self) -> list[str]:
|
def capabilities(self) -> list[str]:
|
||||||
|
@ -139,6 +142,11 @@ class AR_NAR(Base):
|
||||||
# determines which RVQ level to target per batch
|
# determines which RVQ level to target per batch
|
||||||
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
quant_level_range = self.config.experimental.rvq_level_range if self.config is not None and self.config.experimental.rvq_level_range else [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||||
|
|
||||||
|
token_dropout_error = self.config.experimental.token_dropout_error
|
||||||
|
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels
|
||||||
|
if not token_dropout_rvq_levels:
|
||||||
|
token_dropout_rvq_levels = [0, self.resp_levels]
|
||||||
|
|
||||||
if p_rvq_levels == "equal":
|
if p_rvq_levels == "equal":
|
||||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1] - 1) for i in range(batch_size) ]
|
||||||
|
@ -165,39 +173,49 @@ class AR_NAR(Base):
|
||||||
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
quant_levels = [ random.choice( pool ) for i in range(batch_size) ]
|
||||||
|
|
||||||
# these two are techinically equivalent if the audio embeddings handle things properly
|
# these two are techinically equivalent if the audio embeddings handle things properly
|
||||||
|
"""
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16)
|
stop_sequence = torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||||
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
|
stop_sequence = torch.Tensor([[self.stop_token] * 1]).to(device=device, dtype=torch.int16)
|
||||||
"""
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list):
|
||||||
# cap quant_level if it exceeds its corresponding resp/prom
|
# cap quant_level if it exceeds its corresponding resp/prom
|
||||||
if quant_levels[i] >= resps_list[i].shape[-1]:
|
if quant_level >= resps.shape[-1]:
|
||||||
quant_levels[i] = resps_list[i].shape[-1] - 1
|
quant_levels[i] = resps.shape[-1] - 1
|
||||||
|
|
||||||
# proms_list[i] could be a Tensor, list[Tensor], or None
|
# proms could be a Tensor, list[Tensor], or None
|
||||||
if isinstance( proms_list[i], torch.Tensor ):
|
if isinstance( proms, torch.Tensor ):
|
||||||
if quant_levels[i] >= proms_list[i].shape[-1]:
|
if quant_level >= proms.shape[-1]:
|
||||||
quant_levels[i] = proms_list[i].shape[-1] - 1
|
quant_levels[i] = proms.shape[-1] - 1
|
||||||
|
|
||||||
elif isinstance( proms_list[i], list ):
|
elif isinstance( proms, list ):
|
||||||
for j, prom in enumerate( proms_list[i] ):
|
for j, prom in enumerate( proms ):
|
||||||
if not isinstance( prom, torch.Tensor ):
|
if not isinstance( prom, torch.Tensor ):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if quant_levels[i] >= prom.shape[-1]:
|
if quant_level >= prom.shape[-1]:
|
||||||
quant_levels[i] = prom.shape[-1] - 1
|
quant_levels[i] = prom.shape[-1] - 1
|
||||||
|
|
||||||
# only apply stop token for RVQ level 0
|
# apply token dropout error compensation
|
||||||
if quant_levels[i] > 0:
|
if token_dropout_error > 0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||||
continue
|
steps = resps.shape[0]
|
||||||
|
for l in range( quant_level ):
|
||||||
|
for t in range( steps ):
|
||||||
|
token = resps[t, l].item()
|
||||||
|
|
||||||
|
if random.random() < token_dropout_error:
|
||||||
|
offset = 1 * ( 1 if random.random() < 0.5 else -1 )
|
||||||
|
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||||
|
|
||||||
|
# only apply stop token for RVQ level 0
|
||||||
|
if quant_level <= 0:
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
# could technically do it in the .inputs call
|
# could technically do it in the .inputs call
|
||||||
resps_list[i] = torch.cat([ resps_list[i], stop_sequence ])
|
resps_list[i] = torch.cat([ resps, stop_sequence ])
|
||||||
|
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
|
|
@ -12,7 +12,7 @@ Additional functionality (preparing inputs, generating full audio) should be del
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import traceback
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
@ -440,6 +440,10 @@ class Base(nn.Module):
|
||||||
self.rvq_l_emb = None
|
self.rvq_l_emb = None
|
||||||
self.len_emb = None
|
self.len_emb = None
|
||||||
|
|
||||||
|
# it would be nicer for these to be a token or live inside an embedding
|
||||||
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
self.dropout_token = nn.Parameter(torch.zeros(d_model)) # zeros sounds nicer than randn for a special value
|
||||||
|
|
||||||
if self.version == 1: # legacy
|
if self.version == 1: # legacy
|
||||||
n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
n_audio_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||||
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
||||||
|
@ -484,9 +488,6 @@ class Base(nn.Module):
|
||||||
# experimental NAR-only mode
|
# experimental NAR-only mode
|
||||||
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
||||||
|
|
||||||
# this would be nicer to be a stop token or live inside an embedding
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
|
||||||
|
|
||||||
# ick, there has to be a better way
|
# ick, there has to be a better way
|
||||||
hf_attention = self.config.attention if self.config is not None else None
|
hf_attention = self.config.attention if self.config is not None else None
|
||||||
|
|
||||||
|
@ -970,6 +971,16 @@ class Base(nn.Module):
|
||||||
|
|
||||||
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
|
return self.proms_emb( input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level], offset = 0 )
|
||||||
|
|
||||||
|
# yuck
|
||||||
|
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
|
||||||
|
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels if self.config else None
|
||||||
|
|
||||||
|
if self.dropout_token is None or not self.training:
|
||||||
|
token_dropout_rate = 0.0
|
||||||
|
|
||||||
|
if not token_dropout_rvq_levels:
|
||||||
|
token_dropout_rvq_levels = [1, self.resp_levels]
|
||||||
|
|
||||||
x_list = []
|
x_list = []
|
||||||
for batch_index, batch_input in enumerate(inputs):
|
for batch_index, batch_input in enumerate(inputs):
|
||||||
batch = []
|
batch = []
|
||||||
|
@ -1018,6 +1029,16 @@ class Base(nn.Module):
|
||||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||||
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
|
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# apply token dropout
|
||||||
|
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
|
||||||
|
steps = embedding.shape[0] - (1 if quant_level == 0 else 0) # do not mess with stop token
|
||||||
|
for i in range( steps ):
|
||||||
|
if random.random() > token_dropout_rate:
|
||||||
|
continue
|
||||||
|
|
||||||
|
embedding[i] = self.dropout_token
|
||||||
|
|
||||||
elif name == "len" and self.len_emb is not None:
|
elif name == "len" and self.len_emb is not None:
|
||||||
embedding = self.len_emb( input )
|
embedding = self.len_emb( input )
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user