removed the need to supply targ_list + different AudioEmbedding + other things

This commit is contained in:
mrq 2024-06-06 18:52:41 -05:00
parent fcac9503e2
commit ee25d2e62e
2 changed files with 70 additions and 66 deletions

View File

@ -150,20 +150,8 @@ class AR_NAR(Base):
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
else:
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
"""
if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else:
quant_levels = torch.Tensor([ 0 if random.random() < cfg.model.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
"""
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
resps_list = [r[..., 0] if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
"""
if cfg.experimental:
proms_list = [ r if l == 0 else trim(r, cfg.dataset.frames_per_second * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
"""
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
# append stop tokens for AR
for i in range(batch_size):
@ -171,13 +159,11 @@ class AR_NAR(Base):
continue
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
targ_list=targ_list,
lang_list=lang_list,
tone_list=tone_list,

View File

@ -100,11 +100,12 @@ class MultiEmbedding(nn.Module):
return x_list
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module):
class AudioEmbedding_Old(nn.Module):
def __init__(
self,
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
mode: "old", # old | prom | resp
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
):
@ -114,10 +115,12 @@ class AudioEmbedding(nn.Module):
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None and mode == "old" else None
#
self.mode = mode
#
self.sums = sums
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
# prom
if quant_levels is None and xi.shape[-1] > 1:
@ -139,6 +142,42 @@ class AudioEmbedding(nn.Module):
return x
class AudioEmbedding(nn.Module):
def __init__(
self,
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
mode: str, # prom | resp
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
):
super().__init__()
# array of embeddings
# proms are [0, prom_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
#
self.mode = mode
#
self.sums = sums
# maintaining compat is hard
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
# embeddings for AR/NAR cannot be shared
offset = 0 if self.mode == "prom" or quant_level == 0 else 1
if xi.dim() == 1:
x = self.embeddings[quant_level]( xi )
elif self.sums and quant_level > 0:
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
else:
k = quant_level
x = self.embeddings[k + offset]( xi[:, k] )
return x
class Base(nn.Module):
@property
def causal(self) -> bool:
@ -258,17 +297,30 @@ class Base(nn.Module):
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
elif self.version < 5:
# [1024] * 8
self.proms_emb = AudioEmbedding(
self.proms_emb = AudioEmbedding_Old(
[n_prom_tokens] * self.n_prom_levels, d_model,
levels=self.n_prom_levels if self.version > 3 else None,
mode="prom" if self.version >= 5 else "old",
sums=self.config.audio_embedding_sums if self.config is not None else True,
)
# [1024 + STOP] + [1024] * 8
self.resps_emb = AudioEmbedding(
self.resps_emb = AudioEmbedding_Old(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
levels=self.n_resp_levels if self.version > 3 else None,
mode="resp" if self.version >= 5 else "old",
sums=self.config.audio_embedding_sums if self.config is not None else True
)
else:
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
"prom",
sums=self.config.audio_embedding_sums if self.config is not None else True
)
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
"resp",
sums=self.config.audio_embedding_sums if self.config is not None else True
)
@ -522,38 +574,6 @@ class Base(nn.Module):
x = inputs
m = mask.squeeze(-1).int()
aux_loss = None
"""
# Broken
if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"):
# prefill
if len(state) == 0:
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
if self.arch_type == "retnet":
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
elif self.arch_type == "retnet-hf":
state = None
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
kwargs = dict(
attention_mask=m,
inputs_embeds=xi,
past_key_values=state,
use_cache=True,
forward_impl='recurrent',
# return_dict=True,
)
out = self.model(**kwargs)
state = out.past_key_values
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
"""
# HF transformer derived model
if self.arch_type in ["llama", "mistral", "mixtral"]:
@ -564,7 +584,7 @@ class Base(nn.Module):
use_cache=True,
# return_dict=True,
)
if self.n_experts > 1 and targ_list is not None:
if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True
t = self.model(**kwargs)
@ -574,7 +594,7 @@ class Base(nn.Module):
if state is not None:
state = t[1]
if self.n_experts > 1 and targ_list is not None:
if self.n_experts > 1 and self.training:
router_logits = t[-1]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
elif self.arch_type == "transformer":
@ -622,7 +642,6 @@ class Base(nn.Module):
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
@ -646,8 +665,6 @@ class Base(nn.Module):
inputs[i].append( ( "prom", proms_list[i] ) )
if resps_list is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
if targ_list is not None:
inputs[i].append( ( "targ", targ_list[i] ) )
return inputs
@ -669,11 +686,11 @@ class Base(nn.Module):
elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input )
elif name == "prom":
embedding = self.proms_emb( input )
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level] )
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
embedding = self.resps_emb( input, quant_level )
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level )
else:
continue
@ -698,7 +715,9 @@ class Base(nn.Module):
for name, input in batch:
if name == "prom":
target.append( torch.full_like(input[..., 0], self.ignore_index) )
elif name in ["text", "quant_level", "lang", "tone", "targ"]:
elif name == "resp":
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
elif name in ["text", "quant_level", "lang", "tone"]:
target.append( input )
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
@ -755,10 +774,7 @@ class Base(nn.Module):
for name, input in batch:
# do not use resp
if name == "resp":
continue
# rename to resp
if name == "targ":
name = "resp"
input = input if input.dim() == 1 else input[:, quant_level]
# select prom level
elif name == "prom" and quant_level is not None:
input = input[:, quant_level]
@ -825,13 +841,15 @@ class Base(nn.Module):
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, m = list_to_tensor(x_list)
training = self.training
# yes, there's a better way.
"""
training = False
for batch_index, batch in enumerate(inputs):
for name, input in batch:
if name == "targ":
training = True
"""
device = x.device
batch_size = len(x_list)