this seems to work in testing
This commit is contained in:
parent
e029a8804d
commit
b52c5c5d80
BIN
data/qnt.enc
BIN
data/qnt.enc
Binary file not shown.
|
@ -397,6 +397,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
key_name = cfg.lora.full_name
|
||||
|
||||
kwargs['name'] = 'job'
|
||||
kwargs['resume'] = 'allow'
|
||||
if world_size() > 1:
|
||||
kwargs["group"] = "DDP"
|
||||
kwargs['name'] = f'job-{global_rank()}'
|
||||
|
|
|
@ -134,13 +134,11 @@ class AR_NAR(Base):
|
|||
# trim resps to only contain all levels below the target level
|
||||
if self.version < 7:
|
||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
elif not self.parallel_decoding:
|
||||
resps_list = [r if t in text_task else r[..., l] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||
|
||||
# tensor to cat for RVQ level 0
|
||||
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
|
||||
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
|
||||
audio_stop_sequence = torch.tensor([[self.stop_token] * (1 if self.version < 7 else self.n_resp_levels)], device=device, dtype=torch.int16)
|
||||
|
||||
# final validations and stuff
|
||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||
|
@ -173,7 +171,7 @@ class AR_NAR(Base):
|
|||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0 and timesteps[i] is None and not self.parallel_decoding:
|
||||
if quant_level <= 0 and timesteps[i] is None:
|
||||
# append stop tokens for AR
|
||||
if task not in text_task:
|
||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||
|
@ -1103,54 +1101,56 @@ class AR_NAR(Base):
|
|||
# is NAR
|
||||
if (len_list is not None or resps_list is not None) and text_list is not None:
|
||||
if self.version >= 7:
|
||||
if self.parallel_decoding:
|
||||
return self.forward_nar_masked_parallel(
|
||||
task_list=task_list,
|
||||
return self.forward_nar_masked_parallel(
|
||||
task_list=task_list,
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
else:
|
||||
resps_lists = [ None for _ in range(batch_size) ]
|
||||
for level in range(self.n_resp_levels):
|
||||
resp_list = self.forward_nar_masked(
|
||||
task_list=task_list,
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
# NAR demasking for all levels
|
||||
"""
|
||||
resps_lists = [ None for _ in range(batch_size) ]
|
||||
for level in range(self.n_resp_levels):
|
||||
resp_list = self.forward_nar_masked(
|
||||
task_list=task_list,
|
||||
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
quant_levels=[ level for _ in range(batch_size) ],
|
||||
**sampling_kwargs,
|
||||
)
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
raw_text_list=raw_text_list,
|
||||
|
||||
for batch_index, resp in enumerate(resp_list):
|
||||
if resps_lists[batch_index] is None:
|
||||
resps_lists[batch_index] = []
|
||||
|
||||
resps_lists[batch_index].append( resp )
|
||||
disable_tqdm=disable_tqdm,
|
||||
use_lora=use_lora,
|
||||
quant_levels=[ level for _ in range(batch_size) ],
|
||||
**sampling_kwargs,
|
||||
)
|
||||
|
||||
for batch_index, resps in enumerate(resps_lists):
|
||||
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
|
||||
for batch_index, resp in enumerate(resp_list):
|
||||
if resps_lists[batch_index] is None:
|
||||
resps_lists[batch_index] = []
|
||||
|
||||
resps_lists[batch_index].append( resp )
|
||||
|
||||
return resps_lists
|
||||
for batch_index, resps in enumerate(resps_lists):
|
||||
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
|
||||
|
||||
return resps_lists
|
||||
"""
|
||||
|
||||
return self.forward_nar(
|
||||
task_list=task_list,
|
||||
|
@ -1254,7 +1254,7 @@ def example_usage():
|
|||
available_tasks = ["tts-nar"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(cfg.device)
|
||||
steps = 500 // batch_size
|
||||
steps = 750 // batch_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
|
|
@ -245,21 +245,6 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
class AudioEmbedding_Sums(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
|
||||
def forward(self, xi: Tensor ) -> Tensor:
|
||||
x = sum( [ emb( xi[:, l] ) for l, emb in enumerate(self.embeddings) ] )
|
||||
|
||||
return x
|
||||
|
||||
# time-step embedding
|
||||
# for the NAR-len, since it probably most likely requires encoding the timestep
|
||||
class TimeEmbedding(nn.Module):
|
||||
|
@ -318,10 +303,6 @@ class Classifiers(nn.Module):
|
|||
levels = []
|
||||
|
||||
# map names to levels
|
||||
"""
|
||||
if names and not levels:
|
||||
levels = [ None if name =="NAR" else self.names.index(name) for name in names ]
|
||||
"""
|
||||
if names and not levels:
|
||||
levels = [ None if name not in self.names else self.names.index(name) for name in names ]
|
||||
|
||||
|
@ -341,9 +322,36 @@ class Classifiers(nn.Module):
|
|||
]
|
||||
return torch.stack( xi )
|
||||
|
||||
# naively embeds each level of a codebook, then merges the embeddings with a Linear
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
if dropout_mask is not None:
|
||||
xi = xi.clone().detach().t()
|
||||
for l, t in enumerate( xi ):
|
||||
xi[l] = torch.where( dropout_mask, dropout_token, xi[l] )
|
||||
xi = xi.t()
|
||||
|
||||
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
||||
x = self.proj(x)
|
||||
"""
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
"""
|
||||
|
||||
return x
|
||||
|
||||
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
|
||||
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
|
||||
class ParallelDecoder(nn.Module):
|
||||
class AudioDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
levels,
|
||||
|
@ -356,60 +364,39 @@ class ParallelDecoder(nn.Module):
|
|||
attention_backend = config_kwargs.pop("attention_backend", "default")
|
||||
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
|
||||
|
||||
config_kwargs["hidden_size"] *= levels
|
||||
config_kwargs["vocab_size"] *= levels
|
||||
|
||||
hidden_size = config_kwargs.get("hidden_size")
|
||||
vocab_size = config_kwargs.get("vocab_size")
|
||||
|
||||
#self.d_model = d_model
|
||||
self.vocab_size = vocab_size
|
||||
self.up = nn.Linear( d_model, hidden_size )
|
||||
self.down = nn.Linear( hidden_size, vocab_size )
|
||||
self.transformer = None
|
||||
"""
|
||||
self.transformer = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
|
||||
self.transformer = ml.replace_attention( self.transformer, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
|
||||
if hasattr( self.transformer, "embeddings" ):
|
||||
del self.transformer.embeddings
|
||||
|
||||
downs = []
|
||||
modules = []
|
||||
ups = []
|
||||
for level in range(levels):
|
||||
module = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
|
||||
|
||||
module = ml.replace_attention( module, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||
|
||||
if hasattr( module, "embeddings" ):
|
||||
del module.embeddings
|
||||
|
||||
if gradient_checkpointing and not module.gradient_checkpointing:
|
||||
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
modules.append(module)
|
||||
downs.append(nn.Linear(d_model, hidden_size, bias=False))
|
||||
ups.append(nn.Linear(hidden_size, vocab_size, bias=False))
|
||||
|
||||
self.levels = levels
|
||||
self.decoders = nn.ModuleList(modules)
|
||||
self.downs = nn.ModuleList(downs)
|
||||
self.ups = nn.ModuleList(ups)
|
||||
if gradient_checkpointing and not self.transformer.gradient_checkpointing:
|
||||
self.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
"""
|
||||
|
||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||
# split into levels
|
||||
if level == None:
|
||||
x = [ self.forward( x, l, **kwargs ) for l in range(self.levels) ]
|
||||
x = torch.stack( x )
|
||||
x = x.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
|
||||
return x
|
||||
x = self.up( x )
|
||||
if self.transformer is not None:
|
||||
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
|
||||
x = self.down( x )
|
||||
|
||||
# do one level
|
||||
|
||||
# attention + feedforward
|
||||
"""
|
||||
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
|
||||
# this really hates an output head, so just treat the final output as one
|
||||
x = x[..., :self.vocab_size]
|
||||
|
||||
"""
|
||||
# downscale to head's dimensionality
|
||||
x = self.downs[level]( x )
|
||||
# attention + feed forward
|
||||
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
|
||||
# upscale to vocab logits
|
||||
x = self.ups[level]( x )
|
||||
batch_size, seq_len, dim = x.shape
|
||||
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
|
||||
x = x.permute( 0, 2, 1, 3 )
|
||||
|
||||
return x
|
||||
"""
|
||||
|
@ -572,7 +559,6 @@ class Base(nn.Module):
|
|||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||
self.version = self.config.version if self.config is not None else 5
|
||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||
self.parallel_decoding = self.config.experimental.parallel_decoding if self.config is not None else False
|
||||
|
||||
self.arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||
|
||||
|
@ -634,22 +620,10 @@ class Base(nn.Module):
|
|||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
else:
|
||||
if self.parallel_decoding:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
||||
else:
|
||||
"""
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens * self.n_resp_levels]
|
||||
l_embedding_names = ["NAR"]
|
||||
l_classifier_tokens = [n_audio_tokens * self.n_resp_levels]
|
||||
"""
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_classifier_tokens = [n_audio_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [ f'NAR:{i}:{i}' for i in range( self.n_resp_levels ) ]
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
||||
|
||||
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||
|
||||
|
@ -692,6 +666,7 @@ class Base(nn.Module):
|
|||
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom
|
||||
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
self.audio_emb = None
|
||||
elif self.version < 5:
|
||||
# [1024] * 8
|
||||
self.proms_emb = AudioEmbedding_Old(
|
||||
|
@ -703,7 +678,8 @@ class Base(nn.Module):
|
|||
l_embedding_tokens, d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
)
|
||||
elif not self.parallel_decoding:
|
||||
self.audio_emb = None
|
||||
elif self.version < 7:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
|
||||
|
@ -713,17 +689,11 @@ class Base(nn.Module):
|
|||
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
|
||||
l_embedding_names=l_embedding_names,
|
||||
)
|
||||
self.audio_emb = None
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding_Sums(
|
||||
n_tokens=n_audio_tokens,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
self.resps_emb = AudioEmbedding_Sums(
|
||||
n_tokens=n_audio_tokens + 1,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
self.proms_emb = None
|
||||
self.resps_emb = None
|
||||
self.audio_emb = None
|
||||
|
||||
if self.version >= 3:
|
||||
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||
|
@ -739,7 +709,7 @@ class Base(nn.Module):
|
|||
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||
if self.version >= 5:
|
||||
# "len" RVQ level-0 gets an additional token
|
||||
if self.version < 7 or not self.parallel_decoding:
|
||||
if self.version < 7:
|
||||
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
||||
|
||||
# experimental NAR-only mode
|
||||
|
@ -747,6 +717,53 @@ class Base(nn.Module):
|
|||
if self.version >= 6:
|
||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||
|
||||
if self.version >= 7:
|
||||
pd_model = d_model // 4
|
||||
pd_ffn = pd_model * 4
|
||||
pd_heads = n_heads // 4
|
||||
pd_layers = 1
|
||||
|
||||
if False:
|
||||
self.audio_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
else:
|
||||
self.proms_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
self.resps_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
|
||||
self.audio_decoder = AudioDecoder(
|
||||
self.n_resp_levels,
|
||||
d_model,
|
||||
dict(
|
||||
vocab_size=n_audio_tokens,
|
||||
hidden_size=pd_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=pd_ffn,
|
||||
num_hidden_layers=pd_layers,
|
||||
num_attention_heads=pd_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=pd_heads,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation="eager",
|
||||
|
||||
training=self.training,
|
||||
attention_backend=attention_backend,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
)
|
||||
|
||||
if attention_backend == "auto":
|
||||
attention_backend = "sdpa"
|
||||
"""
|
||||
|
@ -906,33 +923,6 @@ class Base(nn.Module):
|
|||
self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias )
|
||||
self.metrics = Metrics( l_classifier_tokens )
|
||||
|
||||
self.parallel_decoder = None
|
||||
if self.parallel_decoding:
|
||||
pd_model = d_model # // 2
|
||||
pd_ffn = pd_model * 2
|
||||
pd_heads = n_heads // 2
|
||||
pd_layers = 1
|
||||
|
||||
config = dict(
|
||||
vocab_size=n_audio_tokens,
|
||||
hidden_size=pd_model,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
intermediate_size=pd_ffn,
|
||||
num_hidden_layers=pd_layers,
|
||||
num_attention_heads=pd_heads,
|
||||
attention_dropout=p_dropout if training else 0.0,
|
||||
num_key_value_heads=pd_heads,
|
||||
hidden_act="gelu",
|
||||
is_encoder_decoder=False,
|
||||
is_decoder=True,
|
||||
attn_implementation="eager",
|
||||
|
||||
training=self.training,
|
||||
attention_backend=attention_backend,
|
||||
gradient_checkpointing=self.gradient_checkpointing,
|
||||
)
|
||||
self.parallel_decoder = ParallelDecoder( self.n_resp_levels, d_model, config )
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
inputs,
|
||||
|
@ -1126,8 +1116,8 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
if self.version >= 7:
|
||||
classifier_level = f"NAR:{quant_level}:{quant_level}" if not self.parallel_decoding else "NAR"
|
||||
|
||||
classifier_level = f"NAR:{quant_level}:{quant_level}"
|
||||
|
||||
inputs[i].append( ("classifier_level", classifier_level) )
|
||||
# Audio length prediction task
|
||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||
|
@ -1269,29 +1259,16 @@ class Base(nn.Module):
|
|||
input if quant_level == 0 else input[:, :quant_level]
|
||||
)
|
||||
|
||||
if self.version < 7: # or not self.parallel_decoding:
|
||||
if self.version < 7:
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||
offset = 0,
|
||||
)
|
||||
|
||||
if not self.parallel_decoding:
|
||||
"""
|
||||
# provides only one
|
||||
return self.proms_emb(
|
||||
input if input.dim() == 1 else input[:, quant_level],
|
||||
quant_level = 0, # if input.dim() == 1 else input.shape[-1],
|
||||
offset = 0,
|
||||
)
|
||||
"""
|
||||
# sums all
|
||||
return self.proms_emb(
|
||||
input,
|
||||
quant_level = quant_level if input.dim() == 1 else input.shape[-1],
|
||||
offset = 0,
|
||||
)
|
||||
|
||||
if self.audio_emb is not None:
|
||||
return self.audio_emb( input )
|
||||
|
||||
return self.proms_emb( input )
|
||||
|
||||
# yuck
|
||||
|
@ -1358,11 +1335,11 @@ class Base(nn.Module):
|
|||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
if self.parallel_decoding:
|
||||
if dropout_mask is not None:
|
||||
embedding = self.resps_emb( torch.where( dropout_mask, self.stop_token, input.t() ).t() )
|
||||
if self.version >= 7:
|
||||
if self.audio_emb is not None:
|
||||
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||
else:
|
||||
embedding = self.resps_emb( input )
|
||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||
# if training NAR-len RVQ level 0
|
||||
elif dropout_mask is not None:
|
||||
embedding = self.resps_emb(
|
||||
|
@ -1513,7 +1490,7 @@ class Base(nn.Module):
|
|||
|
||||
return ids.to(device=device, dtype=torch.int32)
|
||||
|
||||
def calc_loss_parallel(
|
||||
def calc_loss_new(
|
||||
self,
|
||||
inputs: list,
|
||||
logits,
|
||||
|
@ -1589,6 +1566,9 @@ class Base(nn.Module):
|
|||
if name != task_outputs.get(task_type, name):
|
||||
if self.ignore_inputs_for_loss:
|
||||
ignored = True
|
||||
# cringe
|
||||
if task_type != "tts":
|
||||
ignored = True
|
||||
else:
|
||||
output_len = seq_len
|
||||
|
||||
|
@ -1602,7 +1582,7 @@ class Base(nn.Module):
|
|||
# perform loss calculation on the individual piece
|
||||
target.append( token )
|
||||
|
||||
if classifier_level != "NAR":
|
||||
if logits[batch_index].dim() != 3:
|
||||
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
logit = logits[batch_index]
|
||||
|
||||
|
@ -1620,7 +1600,7 @@ class Base(nn.Module):
|
|||
|
||||
if compute_acc and False:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
|
@ -1652,7 +1632,7 @@ class Base(nn.Module):
|
|||
|
||||
if compute_acc and False:
|
||||
if self.metrics is not None:
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
|
||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||
else:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
logit.shape[-1],
|
||||
|
@ -1701,9 +1681,6 @@ class Base(nn.Module):
|
|||
if self.version < 7:
|
||||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
if not self.parallel_decoding:
|
||||
return input if input.dim() == 1 else input[:, quant_level]
|
||||
|
||||
return input
|
||||
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
|
@ -1729,8 +1706,6 @@ class Base(nn.Module):
|
|||
# nonautoregressive, parallel
|
||||
elif classifier_level.startswith("NAR:"):
|
||||
causal = False
|
||||
elif classifier_level == "NAR":
|
||||
causal = False
|
||||
|
||||
it = 0
|
||||
for name, input in batch:
|
||||
|
@ -1773,6 +1748,9 @@ class Base(nn.Module):
|
|||
if name != task_outputs.get(task_type, name):
|
||||
if self.ignore_inputs_for_loss:
|
||||
ignored = True
|
||||
# cringe
|
||||
if task_type != "tts":
|
||||
ignored = True
|
||||
else:
|
||||
output_len = seq_len
|
||||
|
||||
|
@ -1909,10 +1887,10 @@ class Base(nn.Module):
|
|||
# needs to be done here as we still have our raw inputs
|
||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||
casual_levels = [ "AR:0:0", "stt", "len", "phn" ]
|
||||
causal_levels = [ "AR:0:0", "stt", "len", "phn" ]
|
||||
|
||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||
|
||||
output = self._forward(
|
||||
inputs=x,
|
||||
|
@ -1928,26 +1906,19 @@ class Base(nn.Module):
|
|||
|
||||
logits = [ logit for logit in logits ]
|
||||
|
||||
if self.version >= 7 and self.parallel_decoding:
|
||||
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
|
||||
if self.version >= 7:
|
||||
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] not in causal_levels ]
|
||||
if p_indices:
|
||||
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
||||
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
||||
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
||||
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
|
||||
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||
|
||||
p_logits = self.parallel_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
|
||||
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ]
|
||||
|
||||
p_logits = self.audio_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
|
||||
|
||||
for i, logit in enumerate(p_logits):
|
||||
logits[p_indices[i]] = logit
|
||||
|
||||
"""
|
||||
logits = [ self.parallel_decoder( logit.unsqueeze(0), attention_mask=mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False,
|
||||
return_dict=True,
|
||||
is_causal=is_causal )[0] if level == "NAR" else logit for logit, level in zip(logits, classifier_levels) ]
|
||||
"""
|
||||
|
||||
# output projection layer
|
||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||
|
@ -1958,15 +1929,6 @@ class Base(nn.Module):
|
|||
elif self.classifiers is not None:
|
||||
logits = self.classifiers(logits, levels = classifier_levels )
|
||||
|
||||
# Reshape
|
||||
"""
|
||||
if self.version >= 7 and not self.parallel_decoding:
|
||||
for batch_index, logit in enumerate( logits ):
|
||||
if classifier_levels[batch_index] != "NAR":
|
||||
continue
|
||||
logits[batch_index] = logit.reshape( logit.shape[0], 8, 1000 ).permute( 1, 0, 2 )
|
||||
"""
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
|
||||
|
@ -1977,8 +1939,8 @@ class Base(nn.Module):
|
|||
self.loss = None
|
||||
self.stats = None
|
||||
# compute loss if the target is given
|
||||
elif self.version >= 7 and self.parallel_decoding:
|
||||
loss, stats = self.calc_loss_parallel( inputs=inputs, logits=logits )
|
||||
elif self.version >= 7:
|
||||
loss, stats = self.calc_loss_new( inputs=inputs, logits=logits )
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user