I forgot about the time embedding...
This commit is contained in:
parent
811b15d280
commit
5a09a5f6e9
|
@ -281,6 +281,27 @@ class AudioEmbedding(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
# time-step embedding
|
||||
# for the NAR-len, since it probably most likely requires encoding the timestep
|
||||
class TimeEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model
|
||||
):
|
||||
super().__init__()
|
||||
self.emb = SinusoidalEmbedding(d_model)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(d_model, d_model*4),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model*4, d_model),
|
||||
)
|
||||
|
||||
def forward( self, t ):
|
||||
t = self.emb(t)
|
||||
t = self.mlp(t)
|
||||
|
||||
return t
|
||||
|
||||
# per-level classification
|
||||
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
|
||||
class Classifiers(nn.Module):
|
||||
|
@ -545,6 +566,7 @@ class Base(nn.Module):
|
|||
|
||||
# experimental NAR-only mode
|
||||
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
|
||||
self.time_emb = TimeEmbedding(d_model) if "len" in self.capabilities else None
|
||||
|
||||
if attention_backend == "auto":
|
||||
attention_backend = "sdpa"
|
||||
|
@ -967,6 +989,7 @@ class Base(nn.Module):
|
|||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
task_list: list[str] | None = None,
|
||||
time_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: int | list[int] | Tensor | None = None
|
||||
):
|
||||
|
@ -1012,6 +1035,8 @@ class Base(nn.Module):
|
|||
t = random.random()
|
||||
p = math.cos(t * math.pi * 0.5)
|
||||
dropout_mask = _dropout_mask( resps_list[i], p=p )
|
||||
|
||||
inputs[i].append( ("timestep", torch.tensor(t, device=device) ) )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
# Audio length prediction task
|
||||
|
@ -1108,16 +1133,14 @@ class Base(nn.Module):
|
|||
task_type = "tts"
|
||||
input_prom = None
|
||||
dropout_mask = None
|
||||
timestep = None
|
||||
|
||||
# pre-iterate
|
||||
for name, input in batch_input:
|
||||
"""
|
||||
if name == "prop":
|
||||
proms = [ input ] if isinstance(input, torch.Tensor) else input
|
||||
input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ])
|
||||
"""
|
||||
if name == "dropout_mask":
|
||||
dropout_mask = input
|
||||
elif name == "timestep":
|
||||
timestep = input
|
||||
|
||||
for name, input in batch_input:
|
||||
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
|
||||
|
@ -1160,6 +1183,9 @@ class Base(nn.Module):
|
|||
offset = 0,
|
||||
quant_level = 0,
|
||||
)
|
||||
|
||||
t_emb = self.time_emb( timestep )
|
||||
embedding += t_emb
|
||||
# cheat-y way to handle performing STT across all levels
|
||||
elif task_type in summed_embeddings_task:
|
||||
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
|
||||
|
@ -1236,7 +1262,7 @@ class Base(nn.Module):
|
|||
return 1
|
||||
|
||||
# a mask
|
||||
if name == "dropout_mask":
|
||||
if name in ["dropout_mask", "timestep"]:
|
||||
return 0
|
||||
|
||||
# list of tokens
|
||||
|
|
|
@ -298,6 +298,7 @@ class NAR(Base):
|
|||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=[ timestep ],
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
output = _super.forward(
|
||||
|
|
Loading…
Reference in New Issue
Block a user