removed the need to supply targ_list + different AudioEmbedding + other things
This commit is contained in:
parent
fcac9503e2
commit
ee25d2e62e
|
@ -150,20 +150,8 @@ class AR_NAR(Base):
|
|||
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
|
||||
else:
|
||||
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
"""
|
||||
if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None:
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
else:
|
||||
quant_levels = torch.Tensor([ 0 if random.random() < cfg.model.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
|
||||
"""
|
||||
|
||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
|
||||
|
||||
"""
|
||||
if cfg.experimental:
|
||||
proms_list = [ r if l == 0 else trim(r, cfg.dataset.frames_per_second * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
||||
"""
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
|
||||
|
||||
# append stop tokens for AR
|
||||
for i in range(batch_size):
|
||||
|
@ -171,13 +159,11 @@ class AR_NAR(Base):
|
|||
continue
|
||||
|
||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
targ_list=targ_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
|
||||
|
|
|
@ -100,11 +100,12 @@ class MultiEmbedding(nn.Module):
|
|||
return x_list
|
||||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class AudioEmbedding(nn.Module):
|
||||
class AudioEmbedding_Old(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
mode: "old", # old | prom | resp
|
||||
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
|
||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
):
|
||||
|
@ -114,10 +115,12 @@ class AudioEmbedding(nn.Module):
|
|||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None and mode == "old" else None
|
||||
#
|
||||
self.mode = mode
|
||||
#
|
||||
self.sums = sums
|
||||
|
||||
|
||||
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
|
||||
# prom
|
||||
if quant_levels is None and xi.shape[-1] > 1:
|
||||
|
@ -139,6 +142,42 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
class AudioEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
mode: str, # prom | resp
|
||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
# proms are [0, prom_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
#
|
||||
self.mode = mode
|
||||
#
|
||||
self.sums = sums
|
||||
|
||||
# maintaining compat is hard
|
||||
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
|
||||
if quant_level is None:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
# embeddings for AR/NAR cannot be shared
|
||||
offset = 0 if self.mode == "prom" or quant_level == 0 else 1
|
||||
|
||||
if xi.dim() == 1:
|
||||
x = self.embeddings[quant_level]( xi )
|
||||
elif self.sums and quant_level > 0:
|
||||
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
||||
else:
|
||||
k = quant_level
|
||||
x = self.embeddings[k + offset]( xi[:, k] )
|
||||
|
||||
return x
|
||||
|
||||
class Base(nn.Module):
|
||||
@property
|
||||
def causal(self) -> bool:
|
||||
|
@ -258,17 +297,30 @@ class Base(nn.Module):
|
|||
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
else:
|
||||
elif self.version < 5:
|
||||
# [1024] * 8
|
||||
self.proms_emb = AudioEmbedding(
|
||||
self.proms_emb = AudioEmbedding_Old(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
levels=self.n_prom_levels if self.version > 3 else None,
|
||||
mode="prom" if self.version >= 5 else "old",
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True,
|
||||
)
|
||||
# [1024 + STOP] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding(
|
||||
self.resps_emb = AudioEmbedding_Old(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
mode="resp" if self.version >= 5 else "old",
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||
)
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
"prom",
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
"resp",
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||
)
|
||||
|
||||
|
@ -522,38 +574,6 @@ class Base(nn.Module):
|
|||
x = inputs
|
||||
m = mask.squeeze(-1).int()
|
||||
aux_loss = None
|
||||
|
||||
"""
|
||||
# Broken
|
||||
if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"):
|
||||
# prefill
|
||||
if len(state) == 0:
|
||||
prefill_size = x.shape[1]
|
||||
# run the initial prompt to fill the KV cache
|
||||
if self.arch_type == "retnet":
|
||||
for n in range(prefill_size):
|
||||
xi = x[:, n, :].unsqueeze(1)
|
||||
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
||||
elif self.arch_type == "retnet-hf":
|
||||
state = None
|
||||
for n in range(prefill_size):
|
||||
xi = x[:, n, :].unsqueeze(1)
|
||||
|
||||
kwargs = dict(
|
||||
attention_mask=m,
|
||||
inputs_embeds=xi,
|
||||
past_key_values=state,
|
||||
use_cache=True,
|
||||
forward_impl='recurrent',
|
||||
# return_dict=True,
|
||||
)
|
||||
|
||||
out = self.model(**kwargs)
|
||||
state = out.past_key_values
|
||||
|
||||
# grab last token(s)
|
||||
x = x[:, -1, :].unsqueeze(1)
|
||||
"""
|
||||
|
||||
# HF transformer derived model
|
||||
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||
|
@ -564,7 +584,7 @@ class Base(nn.Module):
|
|||
use_cache=True,
|
||||
# return_dict=True,
|
||||
)
|
||||
if self.n_experts > 1 and targ_list is not None:
|
||||
if self.n_experts > 1 and self.training:
|
||||
kwargs["output_router_logits"] = True
|
||||
|
||||
t = self.model(**kwargs)
|
||||
|
@ -574,7 +594,7 @@ class Base(nn.Module):
|
|||
if state is not None:
|
||||
state = t[1]
|
||||
|
||||
if self.n_experts > 1 and targ_list is not None:
|
||||
if self.n_experts > 1 and self.training:
|
||||
router_logits = t[-1]
|
||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
||||
elif self.arch_type == "transformer":
|
||||
|
@ -622,7 +642,6 @@ class Base(nn.Module):
|
|||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
|
@ -646,8 +665,6 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
if resps_list is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
if targ_list is not None:
|
||||
inputs[i].append( ( "targ", targ_list[i] ) )
|
||||
|
||||
return inputs
|
||||
|
||||
|
@ -669,11 +686,11 @@ class Base(nn.Module):
|
|||
elif name == "lang" and self.langs_emb is not None:
|
||||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
embedding = self.proms_emb( input )
|
||||
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level] )
|
||||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
embedding = self.resps_emb( input, quant_level )
|
||||
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level )
|
||||
else:
|
||||
continue
|
||||
|
||||
|
@ -698,7 +715,9 @@ class Base(nn.Module):
|
|||
for name, input in batch:
|
||||
if name == "prom":
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
elif name in ["text", "quant_level", "lang", "tone", "targ"]:
|
||||
elif name == "resp":
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
||||
elif name in ["text", "quant_level", "lang", "tone"]:
|
||||
target.append( input )
|
||||
|
||||
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
||||
|
@ -755,10 +774,7 @@ class Base(nn.Module):
|
|||
for name, input in batch:
|
||||
# do not use resp
|
||||
if name == "resp":
|
||||
continue
|
||||
# rename to resp
|
||||
if name == "targ":
|
||||
name = "resp"
|
||||
input = input if input.dim() == 1 else input[:, quant_level]
|
||||
# select prom level
|
||||
elif name == "prom" and quant_level is not None:
|
||||
input = input[:, quant_level]
|
||||
|
@ -825,13 +841,15 @@ class Base(nn.Module):
|
|||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||
x, m = list_to_tensor(x_list)
|
||||
|
||||
training = self.training
|
||||
# yes, there's a better way.
|
||||
"""
|
||||
training = False
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
for name, input in batch:
|
||||
if name == "targ":
|
||||
training = True
|
||||
|
||||
"""
|
||||
|
||||
device = x.device
|
||||
batch_size = len(x_list)
|
||||
|
|
Loading…
Reference in New Issue
Block a user