this seems to work in testing
This commit is contained in:
parent
e029a8804d
commit
b52c5c5d80
BIN
data/qnt.enc
BIN
data/qnt.enc
Binary file not shown.
@ -397,6 +397,7 @@ def load_engines(training=True, **model_kwargs):
|
|||||||
key_name = cfg.lora.full_name
|
key_name = cfg.lora.full_name
|
||||||
|
|
||||||
kwargs['name'] = 'job'
|
kwargs['name'] = 'job'
|
||||||
|
kwargs['resume'] = 'allow'
|
||||||
if world_size() > 1:
|
if world_size() > 1:
|
||||||
kwargs["group"] = "DDP"
|
kwargs["group"] = "DDP"
|
||||||
kwargs['name'] = f'job-{global_rank()}'
|
kwargs['name'] = f'job-{global_rank()}'
|
||||||
|
@ -134,13 +134,11 @@ class AR_NAR(Base):
|
|||||||
# trim resps to only contain all levels below the target level
|
# trim resps to only contain all levels below the target level
|
||||||
if self.version < 7:
|
if self.version < 7:
|
||||||
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
||||||
elif not self.parallel_decoding:
|
|
||||||
resps_list = [r if t in text_task else r[..., l] for r, l, t in zip(resps_list, quant_levels, task_list)]
|
|
||||||
|
|
||||||
# tensor to cat for RVQ level 0
|
# tensor to cat for RVQ level 0
|
||||||
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
|
text_stop_sequence = torch.tensor([2], device=device, dtype=torch.int16)
|
||||||
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
text_start_stop_sequence = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||||
audio_stop_sequence = torch.tensor([[self.stop_token]], device=device, dtype=torch.int16)
|
audio_stop_sequence = torch.tensor([[self.stop_token] * (1 if self.version < 7 else self.n_resp_levels)], device=device, dtype=torch.int16)
|
||||||
|
|
||||||
# final validations and stuff
|
# final validations and stuff
|
||||||
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
|
||||||
@ -173,7 +171,7 @@ class AR_NAR(Base):
|
|||||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||||
|
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if quant_level <= 0 and timesteps[i] is None and not self.parallel_decoding:
|
if quant_level <= 0 and timesteps[i] is None:
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
if task not in text_task:
|
if task not in text_task:
|
||||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||||
@ -1103,54 +1101,56 @@ class AR_NAR(Base):
|
|||||||
# is NAR
|
# is NAR
|
||||||
if (len_list is not None or resps_list is not None) and text_list is not None:
|
if (len_list is not None or resps_list is not None) and text_list is not None:
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
if self.parallel_decoding:
|
return self.forward_nar_masked_parallel(
|
||||||
return self.forward_nar_masked_parallel(
|
task_list=task_list,
|
||||||
task_list=task_list,
|
|
||||||
|
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
|
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
len_list=len_list,
|
len_list=len_list,
|
||||||
raw_text_list=raw_text_list,
|
raw_text_list=raw_text_list,
|
||||||
|
|
||||||
disable_tqdm=disable_tqdm,
|
disable_tqdm=disable_tqdm,
|
||||||
use_lora=use_lora,
|
use_lora=use_lora,
|
||||||
**sampling_kwargs,
|
**sampling_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
resps_lists = [ None for _ in range(batch_size) ]
|
|
||||||
for level in range(self.n_resp_levels):
|
|
||||||
resp_list = self.forward_nar_masked(
|
|
||||||
task_list=task_list,
|
|
||||||
|
|
||||||
text_list=text_list,
|
# NAR demasking for all levels
|
||||||
proms_list=proms_list,
|
"""
|
||||||
resps_list=resps_list,
|
resps_lists = [ None for _ in range(batch_size) ]
|
||||||
|
for level in range(self.n_resp_levels):
|
||||||
lang_list=lang_list,
|
resp_list = self.forward_nar_masked(
|
||||||
tone_list=tone_list,
|
task_list=task_list,
|
||||||
len_list=len_list,
|
|
||||||
raw_text_list=raw_text_list,
|
|
||||||
|
|
||||||
disable_tqdm=disable_tqdm,
|
text_list=text_list,
|
||||||
use_lora=use_lora,
|
proms_list=proms_list,
|
||||||
quant_levels=[ level for _ in range(batch_size) ],
|
resps_list=resps_list,
|
||||||
**sampling_kwargs,
|
|
||||||
)
|
lang_list=lang_list,
|
||||||
|
tone_list=tone_list,
|
||||||
|
len_list=len_list,
|
||||||
|
raw_text_list=raw_text_list,
|
||||||
|
|
||||||
for batch_index, resp in enumerate(resp_list):
|
disable_tqdm=disable_tqdm,
|
||||||
if resps_lists[batch_index] is None:
|
use_lora=use_lora,
|
||||||
resps_lists[batch_index] = []
|
quant_levels=[ level for _ in range(batch_size) ],
|
||||||
|
**sampling_kwargs,
|
||||||
resps_lists[batch_index].append( resp )
|
)
|
||||||
|
|
||||||
for batch_index, resps in enumerate(resps_lists):
|
for batch_index, resp in enumerate(resp_list):
|
||||||
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
|
if resps_lists[batch_index] is None:
|
||||||
|
resps_lists[batch_index] = []
|
||||||
|
|
||||||
|
resps_lists[batch_index].append( resp )
|
||||||
|
|
||||||
return resps_lists
|
for batch_index, resps in enumerate(resps_lists):
|
||||||
|
resps_lists[batch_index] = torch.stack( resps, dim=-1 )
|
||||||
|
|
||||||
|
return resps_lists
|
||||||
|
"""
|
||||||
|
|
||||||
return self.forward_nar(
|
return self.forward_nar(
|
||||||
task_list=task_list,
|
task_list=task_list,
|
||||||
@ -1254,7 +1254,7 @@ def example_usage():
|
|||||||
available_tasks = ["tts-nar"]
|
available_tasks = ["tts-nar"]
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(cfg.device)
|
model = AR_NAR(**kwargs).to(cfg.device)
|
||||||
steps = 500 // batch_size
|
steps = 750 // batch_size
|
||||||
|
|
||||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||||
|
@ -245,21 +245,6 @@ class AudioEmbedding(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class AudioEmbedding_Sums(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_tokens: int,
|
|
||||||
n_levels: int,
|
|
||||||
token_dim: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
|
||||||
|
|
||||||
def forward(self, xi: Tensor ) -> Tensor:
|
|
||||||
x = sum( [ emb( xi[:, l] ) for l, emb in enumerate(self.embeddings) ] )
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
# time-step embedding
|
# time-step embedding
|
||||||
# for the NAR-len, since it probably most likely requires encoding the timestep
|
# for the NAR-len, since it probably most likely requires encoding the timestep
|
||||||
class TimeEmbedding(nn.Module):
|
class TimeEmbedding(nn.Module):
|
||||||
@ -318,10 +303,6 @@ class Classifiers(nn.Module):
|
|||||||
levels = []
|
levels = []
|
||||||
|
|
||||||
# map names to levels
|
# map names to levels
|
||||||
"""
|
|
||||||
if names and not levels:
|
|
||||||
levels = [ None if name =="NAR" else self.names.index(name) for name in names ]
|
|
||||||
"""
|
|
||||||
if names and not levels:
|
if names and not levels:
|
||||||
levels = [ None if name not in self.names else self.names.index(name) for name in names ]
|
levels = [ None if name not in self.names else self.names.index(name) for name in names ]
|
||||||
|
|
||||||
@ -341,9 +322,36 @@ class Classifiers(nn.Module):
|
|||||||
]
|
]
|
||||||
return torch.stack( xi )
|
return torch.stack( xi )
|
||||||
|
|
||||||
|
# naively embeds each level of a codebook, then merges the embeddings with a Linear
|
||||||
|
class AudioEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_tokens: int,
|
||||||
|
n_levels: int,
|
||||||
|
token_dim: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embs = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||||
|
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||||
|
|
||||||
|
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||||
|
if dropout_mask is not None:
|
||||||
|
xi = xi.clone().detach().t()
|
||||||
|
for l, t in enumerate( xi ):
|
||||||
|
xi[l] = torch.where( dropout_mask, dropout_token, xi[l] )
|
||||||
|
xi = xi.t()
|
||||||
|
|
||||||
|
x = torch.cat([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ], dim=-1)
|
||||||
|
x = self.proj(x)
|
||||||
|
"""
|
||||||
|
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||||
|
"""
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
|
# Pseudo-MoE by doing additional decoding from the main transformer's last hidden output
|
||||||
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
|
# ironically, not using a classifier to hidden_dim => audio_tokens causes problems with fitment
|
||||||
class ParallelDecoder(nn.Module):
|
class AudioDecoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
levels,
|
levels,
|
||||||
@ -356,60 +364,39 @@ class ParallelDecoder(nn.Module):
|
|||||||
attention_backend = config_kwargs.pop("attention_backend", "default")
|
attention_backend = config_kwargs.pop("attention_backend", "default")
|
||||||
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
|
gradient_checkpointing = config_kwargs.pop("gradient_checkpointing", True)
|
||||||
|
|
||||||
|
config_kwargs["hidden_size"] *= levels
|
||||||
|
config_kwargs["vocab_size"] *= levels
|
||||||
|
|
||||||
hidden_size = config_kwargs.get("hidden_size")
|
hidden_size = config_kwargs.get("hidden_size")
|
||||||
vocab_size = config_kwargs.get("vocab_size")
|
vocab_size = config_kwargs.get("vocab_size")
|
||||||
|
|
||||||
#self.d_model = d_model
|
#self.d_model = d_model
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
self.up = nn.Linear( d_model, hidden_size )
|
||||||
|
self.down = nn.Linear( hidden_size, vocab_size )
|
||||||
|
self.transformer = None
|
||||||
|
"""
|
||||||
|
self.transformer = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
|
||||||
|
self.transformer = ml.replace_attention( self.transformer, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
||||||
|
|
||||||
|
if hasattr( self.transformer, "embeddings" ):
|
||||||
|
del self.transformer.embeddings
|
||||||
|
|
||||||
downs = []
|
if gradient_checkpointing and not self.transformer.gradient_checkpointing:
|
||||||
modules = []
|
self.transformer.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||||
ups = []
|
use_reentrant=False
|
||||||
for level in range(levels):
|
))
|
||||||
module = LlamaModel_Adapted(LlamaConfig(**config_kwargs))
|
"""
|
||||||
|
|
||||||
module = ml.replace_attention( module, klass=LlamaAttention_Adapted, target=LlamaAttention, mode=attention_backend )
|
|
||||||
|
|
||||||
if hasattr( module, "embeddings" ):
|
|
||||||
del module.embeddings
|
|
||||||
|
|
||||||
if gradient_checkpointing and not module.gradient_checkpointing:
|
|
||||||
module.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
|
||||||
use_reentrant=False
|
|
||||||
))
|
|
||||||
|
|
||||||
modules.append(module)
|
|
||||||
downs.append(nn.Linear(d_model, hidden_size, bias=False))
|
|
||||||
ups.append(nn.Linear(hidden_size, vocab_size, bias=False))
|
|
||||||
|
|
||||||
self.levels = levels
|
|
||||||
self.decoders = nn.ModuleList(modules)
|
|
||||||
self.downs = nn.ModuleList(downs)
|
|
||||||
self.ups = nn.ModuleList(ups)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
def forward(self, x: Tensor, level: int | None = None, stack: bool = True, **kwargs ) -> Tensor:
|
||||||
# split into levels
|
x = self.up( x )
|
||||||
if level == None:
|
if self.transformer is not None:
|
||||||
x = [ self.forward( x, l, **kwargs ) for l in range(self.levels) ]
|
x = self.transformer( inputs_embeds=x, **kwargs )["last_hidden_state"]
|
||||||
x = torch.stack( x )
|
x = self.down( x )
|
||||||
x = x.permute( 1, 0, 2, 3 ) # ( level, batch, token, classification => batch, level, token, classification )
|
|
||||||
return x
|
|
||||||
|
|
||||||
# do one level
|
batch_size, seq_len, dim = x.shape
|
||||||
|
x = x.reshape( batch_size, seq_len, 8, dim // 8 )
|
||||||
# attention + feedforward
|
x = x.permute( 0, 2, 1, 3 )
|
||||||
"""
|
|
||||||
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
|
|
||||||
# this really hates an output head, so just treat the final output as one
|
|
||||||
x = x[..., :self.vocab_size]
|
|
||||||
|
|
||||||
"""
|
|
||||||
# downscale to head's dimensionality
|
|
||||||
x = self.downs[level]( x )
|
|
||||||
# attention + feed forward
|
|
||||||
x = self.decoders[level](inputs_embeds=x, **kwargs)["last_hidden_state"]
|
|
||||||
# upscale to vocab logits
|
|
||||||
x = self.ups[level]( x )
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
"""
|
"""
|
||||||
@ -572,7 +559,6 @@ class Base(nn.Module):
|
|||||||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||||
self.version = self.config.version if self.config is not None else 5
|
self.version = self.config.version if self.config is not None else 5
|
||||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||||
self.parallel_decoding = self.config.experimental.parallel_decoding if self.config is not None else False
|
|
||||||
|
|
||||||
self.arch_type = self.config.arch_type if self.config is not None else "llama"
|
self.arch_type = self.config.arch_type if self.config is not None else "llama"
|
||||||
|
|
||||||
@ -634,22 +620,10 @@ class Base(nn.Module):
|
|||||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
else:
|
else:
|
||||||
if self.parallel_decoding:
|
n_resp_tokens = n_audio_tokens + 1
|
||||||
n_resp_tokens = n_audio_tokens + 1
|
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
||||||
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
||||||
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
|
||||||
else:
|
|
||||||
"""
|
|
||||||
n_resp_tokens = n_audio_tokens + 1
|
|
||||||
l_embedding_tokens = [n_resp_tokens * self.n_resp_levels]
|
|
||||||
l_embedding_names = ["NAR"]
|
|
||||||
l_classifier_tokens = [n_audio_tokens * self.n_resp_levels]
|
|
||||||
"""
|
|
||||||
n_resp_tokens = n_audio_tokens + 1
|
|
||||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
|
||||||
l_classifier_tokens = [n_audio_tokens] * self.n_resp_levels
|
|
||||||
l_embedding_names = [ f'NAR:{i}:{i}' for i in range( self.n_resp_levels ) ]
|
|
||||||
|
|
||||||
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||||
|
|
||||||
@ -692,6 +666,7 @@ class Base(nn.Module):
|
|||||||
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom
|
n_audio_tokens += (n_tasks - 1) # old models have the task tokens in the prom
|
||||||
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
self.proms_emb = MultiEmbedding(self.n_resp_levels, n_audio_tokens, d_model)
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||||
|
self.audio_emb = None
|
||||||
elif self.version < 5:
|
elif self.version < 5:
|
||||||
# [1024] * 8
|
# [1024] * 8
|
||||||
self.proms_emb = AudioEmbedding_Old(
|
self.proms_emb = AudioEmbedding_Old(
|
||||||
@ -703,7 +678,8 @@ class Base(nn.Module):
|
|||||||
l_embedding_tokens, d_model,
|
l_embedding_tokens, d_model,
|
||||||
levels=self.n_resp_levels if self.version > 3 else None,
|
levels=self.n_resp_levels if self.version > 3 else None,
|
||||||
)
|
)
|
||||||
elif not self.parallel_decoding:
|
self.audio_emb = None
|
||||||
|
elif self.version < 7:
|
||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding(
|
||||||
[n_audio_tokens] * self.n_resp_levels, d_model,
|
[n_audio_tokens] * self.n_resp_levels, d_model,
|
||||||
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
|
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
|
||||||
@ -713,17 +689,11 @@ class Base(nn.Module):
|
|||||||
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
|
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
|
||||||
l_embedding_names=l_embedding_names,
|
l_embedding_names=l_embedding_names,
|
||||||
)
|
)
|
||||||
|
self.audio_emb = None
|
||||||
else:
|
else:
|
||||||
self.proms_emb = AudioEmbedding_Sums(
|
self.proms_emb = None
|
||||||
n_tokens=n_audio_tokens,
|
self.resps_emb = None
|
||||||
n_levels=self.n_resp_levels,
|
self.audio_emb = None
|
||||||
token_dim=d_model,
|
|
||||||
)
|
|
||||||
self.resps_emb = AudioEmbedding_Sums(
|
|
||||||
n_tokens=n_audio_tokens + 1,
|
|
||||||
n_levels=self.n_resp_levels,
|
|
||||||
token_dim=d_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.version >= 3:
|
if self.version >= 3:
|
||||||
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
|
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||||
@ -739,7 +709,7 @@ class Base(nn.Module):
|
|||||||
# this ***might*** let me also unify the proms_emb and resps_embedding
|
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||||
if self.version >= 5:
|
if self.version >= 5:
|
||||||
# "len" RVQ level-0 gets an additional token
|
# "len" RVQ level-0 gets an additional token
|
||||||
if self.version < 7 or not self.parallel_decoding:
|
if self.version < 7:
|
||||||
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
|
||||||
|
|
||||||
# experimental NAR-only mode
|
# experimental NAR-only mode
|
||||||
@ -747,6 +717,53 @@ class Base(nn.Module):
|
|||||||
if self.version >= 6:
|
if self.version >= 6:
|
||||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||||
|
|
||||||
|
if self.version >= 7:
|
||||||
|
pd_model = d_model // 4
|
||||||
|
pd_ffn = pd_model * 4
|
||||||
|
pd_heads = n_heads // 4
|
||||||
|
pd_layers = 1
|
||||||
|
|
||||||
|
if False:
|
||||||
|
self.audio_emb = AudioEncoder(
|
||||||
|
n_tokens=n_audio_tokens + 1, # masked token
|
||||||
|
n_levels=self.n_resp_levels,
|
||||||
|
token_dim=d_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.proms_emb = AudioEncoder(
|
||||||
|
n_tokens=n_audio_tokens,
|
||||||
|
n_levels=self.n_resp_levels,
|
||||||
|
token_dim=d_model,
|
||||||
|
)
|
||||||
|
self.resps_emb = AudioEncoder(
|
||||||
|
n_tokens=n_audio_tokens + 1, # masked token
|
||||||
|
n_levels=self.n_resp_levels,
|
||||||
|
token_dim=d_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.audio_decoder = AudioDecoder(
|
||||||
|
self.n_resp_levels,
|
||||||
|
d_model,
|
||||||
|
dict(
|
||||||
|
vocab_size=n_audio_tokens,
|
||||||
|
hidden_size=pd_model,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
intermediate_size=pd_ffn,
|
||||||
|
num_hidden_layers=pd_layers,
|
||||||
|
num_attention_heads=pd_heads,
|
||||||
|
attention_dropout=p_dropout if training else 0.0,
|
||||||
|
num_key_value_heads=pd_heads,
|
||||||
|
hidden_act="gelu",
|
||||||
|
is_encoder_decoder=False,
|
||||||
|
is_decoder=True,
|
||||||
|
attn_implementation="eager",
|
||||||
|
|
||||||
|
training=self.training,
|
||||||
|
attention_backend=attention_backend,
|
||||||
|
gradient_checkpointing=self.gradient_checkpointing,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if attention_backend == "auto":
|
if attention_backend == "auto":
|
||||||
attention_backend = "sdpa"
|
attention_backend = "sdpa"
|
||||||
"""
|
"""
|
||||||
@ -906,33 +923,6 @@ class Base(nn.Module):
|
|||||||
self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias )
|
self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias )
|
||||||
self.metrics = Metrics( l_classifier_tokens )
|
self.metrics = Metrics( l_classifier_tokens )
|
||||||
|
|
||||||
self.parallel_decoder = None
|
|
||||||
if self.parallel_decoding:
|
|
||||||
pd_model = d_model # // 2
|
|
||||||
pd_ffn = pd_model * 2
|
|
||||||
pd_heads = n_heads // 2
|
|
||||||
pd_layers = 1
|
|
||||||
|
|
||||||
config = dict(
|
|
||||||
vocab_size=n_audio_tokens,
|
|
||||||
hidden_size=pd_model,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
intermediate_size=pd_ffn,
|
|
||||||
num_hidden_layers=pd_layers,
|
|
||||||
num_attention_heads=pd_heads,
|
|
||||||
attention_dropout=p_dropout if training else 0.0,
|
|
||||||
num_key_value_heads=pd_heads,
|
|
||||||
hidden_act="gelu",
|
|
||||||
is_encoder_decoder=False,
|
|
||||||
is_decoder=True,
|
|
||||||
attn_implementation="eager",
|
|
||||||
|
|
||||||
training=self.training,
|
|
||||||
attention_backend=attention_backend,
|
|
||||||
gradient_checkpointing=self.gradient_checkpointing,
|
|
||||||
)
|
|
||||||
self.parallel_decoder = ParallelDecoder( self.n_resp_levels, d_model, config )
|
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
inputs,
|
inputs,
|
||||||
@ -1126,8 +1116,8 @@ class Base(nn.Module):
|
|||||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||||
|
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
classifier_level = f"NAR:{quant_level}:{quant_level}" if not self.parallel_decoding else "NAR"
|
classifier_level = f"NAR:{quant_level}:{quant_level}"
|
||||||
|
|
||||||
inputs[i].append( ("classifier_level", classifier_level) )
|
inputs[i].append( ("classifier_level", classifier_level) )
|
||||||
# Audio length prediction task
|
# Audio length prediction task
|
||||||
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
# Sequence: <text><sep><rvq lvl><prom><sep><len>
|
||||||
@ -1269,29 +1259,16 @@ class Base(nn.Module):
|
|||||||
input if quant_level == 0 else input[:, :quant_level]
|
input if quant_level == 0 else input[:, :quant_level]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.version < 7: # or not self.parallel_decoding:
|
if self.version < 7:
|
||||||
return self.proms_emb(
|
return self.proms_emb(
|
||||||
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
|
||||||
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
|
||||||
offset = 0,
|
offset = 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.parallel_decoding:
|
|
||||||
"""
|
|
||||||
# provides only one
|
|
||||||
return self.proms_emb(
|
|
||||||
input if input.dim() == 1 else input[:, quant_level],
|
|
||||||
quant_level = 0, # if input.dim() == 1 else input.shape[-1],
|
|
||||||
offset = 0,
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
# sums all
|
|
||||||
return self.proms_emb(
|
|
||||||
input,
|
|
||||||
quant_level = quant_level if input.dim() == 1 else input.shape[-1],
|
|
||||||
offset = 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
if self.audio_emb is not None:
|
||||||
|
return self.audio_emb( input )
|
||||||
|
|
||||||
return self.proms_emb( input )
|
return self.proms_emb( input )
|
||||||
|
|
||||||
# yuck
|
# yuck
|
||||||
@ -1358,11 +1335,11 @@ class Base(nn.Module):
|
|||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
if self.parallel_decoding:
|
if self.version >= 7:
|
||||||
if dropout_mask is not None:
|
if self.audio_emb is not None:
|
||||||
embedding = self.resps_emb( torch.where( dropout_mask, self.stop_token, input.t() ).t() )
|
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||||
else:
|
else:
|
||||||
embedding = self.resps_emb( input )
|
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||||
# if training NAR-len RVQ level 0
|
# if training NAR-len RVQ level 0
|
||||||
elif dropout_mask is not None:
|
elif dropout_mask is not None:
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
@ -1513,7 +1490,7 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
return ids.to(device=device, dtype=torch.int32)
|
return ids.to(device=device, dtype=torch.int32)
|
||||||
|
|
||||||
def calc_loss_parallel(
|
def calc_loss_new(
|
||||||
self,
|
self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
logits,
|
logits,
|
||||||
@ -1589,6 +1566,9 @@ class Base(nn.Module):
|
|||||||
if name != task_outputs.get(task_type, name):
|
if name != task_outputs.get(task_type, name):
|
||||||
if self.ignore_inputs_for_loss:
|
if self.ignore_inputs_for_loss:
|
||||||
ignored = True
|
ignored = True
|
||||||
|
# cringe
|
||||||
|
if task_type != "tts":
|
||||||
|
ignored = True
|
||||||
else:
|
else:
|
||||||
output_len = seq_len
|
output_len = seq_len
|
||||||
|
|
||||||
@ -1602,7 +1582,7 @@ class Base(nn.Module):
|
|||||||
# perform loss calculation on the individual piece
|
# perform loss calculation on the individual piece
|
||||||
target.append( token )
|
target.append( token )
|
||||||
|
|
||||||
if classifier_level != "NAR":
|
if logits[batch_index].dim() != 3:
|
||||||
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
seq = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||||
logit = logits[batch_index]
|
logit = logits[batch_index]
|
||||||
|
|
||||||
@ -1620,7 +1600,7 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
if compute_acc and False:
|
if compute_acc and False:
|
||||||
if self.metrics is not None:
|
if self.metrics is not None:
|
||||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
|
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||||
else:
|
else:
|
||||||
accuracy_metric = MulticlassAccuracy(
|
accuracy_metric = MulticlassAccuracy(
|
||||||
logit.shape[-1],
|
logit.shape[-1],
|
||||||
@ -1652,7 +1632,7 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
if compute_acc and False:
|
if compute_acc and False:
|
||||||
if self.metrics is not None:
|
if self.metrics is not None:
|
||||||
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ "NAR:0" if classifier_level == "NAR" else classifier_level ]) )
|
metrics = self.metrics.calc_accuracy( [ logit ], [ token ], self.classifiers.indices([ classifier_level ]) )
|
||||||
else:
|
else:
|
||||||
accuracy_metric = MulticlassAccuracy(
|
accuracy_metric = MulticlassAccuracy(
|
||||||
logit.shape[-1],
|
logit.shape[-1],
|
||||||
@ -1701,9 +1681,6 @@ class Base(nn.Module):
|
|||||||
if self.version < 7:
|
if self.version < 7:
|
||||||
return input if input.dim() == 1 else input[:, quant_level]
|
return input if input.dim() == 1 else input[:, quant_level]
|
||||||
|
|
||||||
if not self.parallel_decoding:
|
|
||||||
return input if input.dim() == 1 else input[:, quant_level]
|
|
||||||
|
|
||||||
return input
|
return input
|
||||||
|
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
@ -1729,8 +1706,6 @@ class Base(nn.Module):
|
|||||||
# nonautoregressive, parallel
|
# nonautoregressive, parallel
|
||||||
elif classifier_level.startswith("NAR:"):
|
elif classifier_level.startswith("NAR:"):
|
||||||
causal = False
|
causal = False
|
||||||
elif classifier_level == "NAR":
|
|
||||||
causal = False
|
|
||||||
|
|
||||||
it = 0
|
it = 0
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
@ -1773,6 +1748,9 @@ class Base(nn.Module):
|
|||||||
if name != task_outputs.get(task_type, name):
|
if name != task_outputs.get(task_type, name):
|
||||||
if self.ignore_inputs_for_loss:
|
if self.ignore_inputs_for_loss:
|
||||||
ignored = True
|
ignored = True
|
||||||
|
# cringe
|
||||||
|
if task_type != "tts":
|
||||||
|
ignored = True
|
||||||
else:
|
else:
|
||||||
output_len = seq_len
|
output_len = seq_len
|
||||||
|
|
||||||
@ -1909,10 +1887,10 @@ class Base(nn.Module):
|
|||||||
# needs to be done here as we still have our raw inputs
|
# needs to be done here as we still have our raw inputs
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
classifier_levels = self.get_input( inputs, name="classifier_level" )
|
||||||
casual_levels = [ "AR:0:0", "stt", "len", "phn" ]
|
causal_levels = [ "AR:0:0", "stt", "len", "phn" ]
|
||||||
|
|
||||||
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
# right now limit to new versions because I need to retrain the model for noncausal masks...
|
||||||
is_causal = [ l in casual_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
@ -1928,26 +1906,19 @@ class Base(nn.Module):
|
|||||||
|
|
||||||
logits = [ logit for logit in logits ]
|
logits = [ logit for logit in logits ]
|
||||||
|
|
||||||
if self.version >= 7 and self.parallel_decoding:
|
if self.version >= 7:
|
||||||
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
|
p_indices = [ batch_index for batch_index in range(batch_size) if classifier_levels[batch_index] not in causal_levels ]
|
||||||
if p_indices:
|
if p_indices:
|
||||||
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
p_logits = torch.stack([ logits[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||||
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
|
||||||
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ], dim=0)
|
|
||||||
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if classifier_levels[batch_index] == "NAR" ]
|
|
||||||
|
|
||||||
p_logits = self.parallel_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
|
p_mask = torch.stack([ mask[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||||
|
p_ids = torch.stack([ position_ids[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ], dim=0)
|
||||||
|
p_causal = [ is_causal[batch_index] for batch_index in range(batch_size) if batch_index in p_indices ]
|
||||||
|
|
||||||
|
p_logits = self.audio_decoder( p_logits, attention_mask=p_mask, position_ids=p_ids, use_cache=False, return_dict=True, is_causal=p_causal )
|
||||||
|
|
||||||
for i, logit in enumerate(p_logits):
|
for i, logit in enumerate(p_logits):
|
||||||
logits[p_indices[i]] = logit
|
logits[p_indices[i]] = logit
|
||||||
|
|
||||||
"""
|
|
||||||
logits = [ self.parallel_decoder( logit.unsqueeze(0), attention_mask=mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
use_cache=False,
|
|
||||||
return_dict=True,
|
|
||||||
is_causal=is_causal )[0] if level == "NAR" else logit for logit, level in zip(logits, classifier_levels) ]
|
|
||||||
"""
|
|
||||||
|
|
||||||
# output projection layer
|
# output projection layer
|
||||||
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
|
||||||
@ -1958,15 +1929,6 @@ class Base(nn.Module):
|
|||||||
elif self.classifiers is not None:
|
elif self.classifiers is not None:
|
||||||
logits = self.classifiers(logits, levels = classifier_levels )
|
logits = self.classifiers(logits, levels = classifier_levels )
|
||||||
|
|
||||||
# Reshape
|
|
||||||
"""
|
|
||||||
if self.version >= 7 and not self.parallel_decoding:
|
|
||||||
for batch_index, logit in enumerate( logits ):
|
|
||||||
if classifier_levels[batch_index] != "NAR":
|
|
||||||
continue
|
|
||||||
logits[batch_index] = logit.reshape( logit.shape[0], 8, 1000 ).permute( 1, 0, 2 )
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
||||||
|
|
||||||
@ -1977,8 +1939,8 @@ class Base(nn.Module):
|
|||||||
self.loss = None
|
self.loss = None
|
||||||
self.stats = None
|
self.stats = None
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
elif self.version >= 7 and self.parallel_decoding:
|
elif self.version >= 7:
|
||||||
loss, stats = self.calc_loss_parallel( inputs=inputs, logits=logits )
|
loss, stats = self.calc_loss_new( inputs=inputs, logits=logits )
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if output.loss is not None:
|
if output.loss is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user