m
This commit is contained in:
parent
ee25d2e62e
commit
516b0894d7
|
@ -105,9 +105,7 @@ class AudioEmbedding_Old(nn.Module):
|
||||||
self,
|
self,
|
||||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||||
token_dim: int, # dimensionality of the embedding
|
token_dim: int, # dimensionality of the embedding
|
||||||
mode: "old", # old | prom | resp
|
|
||||||
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
|
levels: int | None = None, # number of RVQ-bins (I don't remember the specifics)
|
||||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# array of embeddings
|
# array of embeddings
|
||||||
|
@ -115,30 +113,18 @@ class AudioEmbedding_Old(nn.Module):
|
||||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None and mode == "old" else None
|
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||||
#
|
|
||||||
self.mode = mode
|
|
||||||
#
|
|
||||||
self.sums = sums
|
|
||||||
|
|
||||||
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
|
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
|
||||||
# prom
|
# prom
|
||||||
if quant_levels is None and xi.shape[-1] > 1:
|
if quant_levels is None and xi.shape[-1] > 1:
|
||||||
if self.sums:
|
|
||||||
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||||
else:
|
|
||||||
k = 0 # only use the most significant RVQ bin level for the input prom
|
|
||||||
x = self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1)
|
|
||||||
# AR resp
|
# AR resp
|
||||||
elif quant_levels is None or quant_levels == 0:
|
elif quant_levels is None or quant_levels == 0:
|
||||||
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
|
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
|
||||||
# NAR resp
|
# NAR resp
|
||||||
else:
|
else:
|
||||||
if self.sums:
|
|
||||||
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
x = sum( [ self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||||
else:
|
|
||||||
k = xi.shape[-1] - 1 # only use the previous RVQ bin level for the current resp embedding
|
|
||||||
x = self.embeddings[k+1]( xi[:, k] ) * (self.weight[k+1] if self.weight is not None else 1)
|
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -165,16 +151,16 @@ class AudioEmbedding(nn.Module):
|
||||||
if quant_level is None:
|
if quant_level is None:
|
||||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||||
|
|
||||||
# embeddings for AR/NAR cannot be shared
|
# jank but needed
|
||||||
offset = 0 if self.mode == "prom" or quant_level == 0 else 1
|
|
||||||
|
|
||||||
if xi.dim() == 1:
|
if xi.dim() == 1:
|
||||||
x = self.embeddings[quant_level]( xi )
|
return self.embeddings[quant_level]( xi )
|
||||||
elif self.sums and quant_level > 0:
|
|
||||||
|
offset = 0 if self.mode == "prom" else 1
|
||||||
|
if self.sums and quant_level > 0:
|
||||||
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
x = sum( [ self.embeddings[k + offset]( xi[:, k] ) for k in range( quant_level ) ] )
|
||||||
else:
|
else:
|
||||||
k = quant_level
|
k = quant_level - 1
|
||||||
x = self.embeddings[k + offset]( xi[:, k] )
|
x = self.embeddings[k + offset]( xi if xi.dim() == 1 else xi[:, k] )
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -302,15 +288,11 @@ class Base(nn.Module):
|
||||||
self.proms_emb = AudioEmbedding_Old(
|
self.proms_emb = AudioEmbedding_Old(
|
||||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||||
levels=self.n_prom_levels if self.version > 3 else None,
|
levels=self.n_prom_levels if self.version > 3 else None,
|
||||||
mode="prom" if self.version >= 5 else "old",
|
|
||||||
sums=self.config.audio_embedding_sums if self.config is not None else True,
|
|
||||||
)
|
)
|
||||||
# [1024 + STOP] + [1024] * 8
|
# [1024 + STOP] + [1024] * 8
|
||||||
self.resps_emb = AudioEmbedding_Old(
|
self.resps_emb = AudioEmbedding_Old(
|
||||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||||
levels=self.n_resp_levels if self.version > 3 else None,
|
levels=self.n_resp_levels if self.version > 3 else None,
|
||||||
mode="resp" if self.version >= 5 else "old",
|
|
||||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.proms_emb = AudioEmbedding(
|
self.proms_emb = AudioEmbedding(
|
||||||
|
@ -686,7 +668,7 @@ class Base(nn.Module):
|
||||||
elif name == "lang" and self.langs_emb is not None:
|
elif name == "lang" and self.langs_emb is not None:
|
||||||
embedding = self.langs_emb( input )
|
embedding = self.langs_emb( input )
|
||||||
elif name == "prom":
|
elif name == "prom":
|
||||||
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level] )
|
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None )
|
||||||
elif name == "tone" and self.tones_emb is not None:
|
elif name == "tone" and self.tones_emb is not None:
|
||||||
embedding = self.tones_emb( input )
|
embedding = self.tones_emb( input )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user