removed the need to supply targ_list + different AudioEmbedding + other things
This commit is contained in:
parent
fcac9503e2
commit
ee25d2e62e
|
@ -150,20 +150,8 @@ class AR_NAR(Base):
|
||||||
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
|
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
|
||||||
else:
|
else:
|
||||||
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
"""
|
|
||||||
if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None:
|
|
||||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
|
||||||
else:
|
|
||||||
quant_levels = torch.Tensor([ 0 if random.random() < cfg.model.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
|
|
||||||
"""
|
|
||||||
|
|
||||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
|
||||||
resps_list = [r[..., 0] if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
|
|
||||||
|
|
||||||
"""
|
|
||||||
if cfg.experimental:
|
|
||||||
proms_list = [ r if l == 0 else trim(r, cfg.dataset.frames_per_second * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
|
||||||
"""
|
|
||||||
|
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
@ -171,13 +159,11 @@ class AR_NAR(Base):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||||
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=resps_list,
|
resps_list=resps_list,
|
||||||
targ_list=targ_list,
|
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
|
|
||||||
|
|
|
@ -100,11 +100,12 @@ class MultiEmbedding(nn.Module):
|
||||||
return x_list
|
return x_list
|
||||||
|
|
||||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||||
class AudioEmbedding(nn.Module):
|
class AudioEmbedding_Old(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||||
token_dim: int, # dimensionality of the embedding
|
token_dim: int, # dimensionality of the embedding
|
||||||
|
mode: "old", # old | prom | resp
|
||||||
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
|
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
|
||||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||||
):
|
):
|
||||||
|
@ -114,7 +115,9 @@ class AudioEmbedding(nn.Module):
|
||||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None and mode == "old" else None
|
||||||
|
#
|
||||||
|
self.mode = mode
|
||||||
#
|
#
|
||||||
self.sums = sums
|
self.sums = sums
|
||||||
|
|
||||||
|
@ -139,6 +142,42 @@ class AudioEmbedding(nn.Module):
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class AudioEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||||
|
token_dim: int, # dimensionality of the embedding
|
||||||
|
mode: str, # prom | resp
|
||||||
|
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# array of embeddings
|
||||||
|
# proms are [0, prom_levels]
|
||||||
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||||
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
|
#
|
||||||
|
self.mode = mode
|
||||||
|
#
|
||||||
|
self.sums = sums
|
||||||
|
|
||||||
|
# maintaining compat is hard
|
||||||
|
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
|
||||||
|
if quant_level is None:
|
||||||
|
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||||
|
|
||||||
|
# embeddings for AR/NAR cannot be shared
|
||||||
|
offset = 0 if self.mode == "prom" or quant_level == 0 else 1
|
||||||
|
|
||||||
|
if xi.dim() == 1:
|
||||||
|
x = self.embeddings[quant_level]( xi )
|
||||||
|
elif self.sums and quant_level > 0:
|
||||||
|
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
||||||
|
else:
|
||||||
|
k = quant_level
|
||||||
|
x = self.embeddings[k + offset]( xi[:, k] )
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
class Base(nn.Module):
|
class Base(nn.Module):
|
||||||
@property
|
@property
|
||||||
def causal(self) -> bool:
|
def causal(self) -> bool:
|
||||||
|
@ -258,17 +297,30 @@ class Base(nn.Module):
|
||||||
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||||
else:
|
elif self.version < 5:
|
||||||
# [1024] * 8
|
# [1024] * 8
|
||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding_Old(
|
||||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||||
levels=self.n_prom_levels if self.version > 3 else None,
|
levels=self.n_prom_levels if self.version > 3 else None,
|
||||||
|
mode="prom" if self.version >= 5 else "old",
|
||||||
sums=self.config.audio_embedding_sums if self.config is not None else True,
|
sums=self.config.audio_embedding_sums if self.config is not None else True,
|
||||||
)
|
)
|
||||||
# [1024 + STOP] + [1024] * 8
|
# [1024 + STOP] + [1024] * 8
|
||||||
self.resps_emb = AudioEmbedding(
|
self.resps_emb = AudioEmbedding_Old(
|
||||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||||
levels=self.n_resp_levels if self.version > 3 else None,
|
levels=self.n_resp_levels if self.version > 3 else None,
|
||||||
|
mode="resp" if self.version >= 5 else "old",
|
||||||
|
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.proms_emb = AudioEmbedding(
|
||||||
|
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||||
|
"prom",
|
||||||
|
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||||
|
)
|
||||||
|
self.resps_emb = AudioEmbedding(
|
||||||
|
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||||
|
"resp",
|
||||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -523,38 +575,6 @@ class Base(nn.Module):
|
||||||
m = mask.squeeze(-1).int()
|
m = mask.squeeze(-1).int()
|
||||||
aux_loss = None
|
aux_loss = None
|
||||||
|
|
||||||
"""
|
|
||||||
# Broken
|
|
||||||
if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"):
|
|
||||||
# prefill
|
|
||||||
if len(state) == 0:
|
|
||||||
prefill_size = x.shape[1]
|
|
||||||
# run the initial prompt to fill the KV cache
|
|
||||||
if self.arch_type == "retnet":
|
|
||||||
for n in range(prefill_size):
|
|
||||||
xi = x[:, n, :].unsqueeze(1)
|
|
||||||
self.model(xi, incremental_state=state, token_embeddings=xi, features_only=True)
|
|
||||||
elif self.arch_type == "retnet-hf":
|
|
||||||
state = None
|
|
||||||
for n in range(prefill_size):
|
|
||||||
xi = x[:, n, :].unsqueeze(1)
|
|
||||||
|
|
||||||
kwargs = dict(
|
|
||||||
attention_mask=m,
|
|
||||||
inputs_embeds=xi,
|
|
||||||
past_key_values=state,
|
|
||||||
use_cache=True,
|
|
||||||
forward_impl='recurrent',
|
|
||||||
# return_dict=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.model(**kwargs)
|
|
||||||
state = out.past_key_values
|
|
||||||
|
|
||||||
# grab last token(s)
|
|
||||||
x = x[:, -1, :].unsqueeze(1)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# HF transformer derived model
|
# HF transformer derived model
|
||||||
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
if self.arch_type in ["llama", "mistral", "mixtral"]:
|
||||||
kwargs = dict(
|
kwargs = dict(
|
||||||
|
@ -564,7 +584,7 @@ class Base(nn.Module):
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
# return_dict=True,
|
# return_dict=True,
|
||||||
)
|
)
|
||||||
if self.n_experts > 1 and targ_list is not None:
|
if self.n_experts > 1 and self.training:
|
||||||
kwargs["output_router_logits"] = True
|
kwargs["output_router_logits"] = True
|
||||||
|
|
||||||
t = self.model(**kwargs)
|
t = self.model(**kwargs)
|
||||||
|
@ -574,7 +594,7 @@ class Base(nn.Module):
|
||||||
if state is not None:
|
if state is not None:
|
||||||
state = t[1]
|
state = t[1]
|
||||||
|
|
||||||
if self.n_experts > 1 and targ_list is not None:
|
if self.n_experts > 1 and self.training:
|
||||||
router_logits = t[-1]
|
router_logits = t[-1]
|
||||||
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok )
|
||||||
elif self.arch_type == "transformer":
|
elif self.arch_type == "transformer":
|
||||||
|
@ -622,7 +642,6 @@ class Base(nn.Module):
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor],
|
||||||
proms_list: list[Tensor],
|
proms_list: list[Tensor],
|
||||||
resps_list: list[Tensor],
|
resps_list: list[Tensor],
|
||||||
targ_list: list[Tensor] | None = None,
|
|
||||||
|
|
||||||
lang_list: list[Tensor] | None = None,
|
lang_list: list[Tensor] | None = None,
|
||||||
tone_list: list[Tensor] | None = None,
|
tone_list: list[Tensor] | None = None,
|
||||||
|
@ -646,8 +665,6 @@ class Base(nn.Module):
|
||||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||||
if resps_list is not None:
|
if resps_list is not None:
|
||||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||||
if targ_list is not None:
|
|
||||||
inputs[i].append( ( "targ", targ_list[i] ) )
|
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
@ -669,11 +686,11 @@ class Base(nn.Module):
|
||||||
elif name == "lang" and self.langs_emb is not None:
|
elif name == "lang" and self.langs_emb is not None:
|
||||||
embedding = self.langs_emb( input )
|
embedding = self.langs_emb( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
embedding = self.proms_emb( input )
|
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level] )
|
||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
embedding = self.resps_emb( input, quant_level )
|
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level )
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -698,7 +715,9 @@ class Base(nn.Module):
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
if name == "prom":
|
if name == "prom":
|
||||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
elif name in ["text", "quant_level", "lang", "tone", "targ"]:
|
elif name == "resp":
|
||||||
|
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
||||||
|
elif name in ["text", "quant_level", "lang", "tone"]:
|
||||||
target.append( input )
|
target.append( input )
|
||||||
|
|
||||||
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
||||||
|
@ -755,10 +774,7 @@ class Base(nn.Module):
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
# do not use resp
|
# do not use resp
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
continue
|
input = input if input.dim() == 1 else input[:, quant_level]
|
||||||
# rename to resp
|
|
||||||
if name == "targ":
|
|
||||||
name = "resp"
|
|
||||||
# select prom level
|
# select prom level
|
||||||
elif name == "prom" and quant_level is not None:
|
elif name == "prom" and quant_level is not None:
|
||||||
input = input[:, quant_level]
|
input = input[:, quant_level]
|
||||||
|
@ -825,13 +841,15 @@ class Base(nn.Module):
|
||||||
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
x_list = self.inputs_to_embeddings( inputs, quant_levels )
|
||||||
x, m = list_to_tensor(x_list)
|
x, m = list_to_tensor(x_list)
|
||||||
|
|
||||||
|
training = self.training
|
||||||
# yes, there's a better way.
|
# yes, there's a better way.
|
||||||
|
"""
|
||||||
training = False
|
training = False
|
||||||
for batch_index, batch in enumerate(inputs):
|
for batch_index, batch in enumerate(inputs):
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
if name == "targ":
|
if name == "targ":
|
||||||
training = True
|
training = True
|
||||||
|
"""
|
||||||
|
|
||||||
device = x.device
|
device = x.device
|
||||||
batch_size = len(x_list)
|
batch_size = len(x_list)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user