I forgot about the time embedding...

This commit is contained in:
mrq 2024-11-08 22:46:26 -06:00
parent 811b15d280
commit 5a09a5f6e9
2 changed files with 33 additions and 6 deletions

View File

@ -281,6 +281,27 @@ class AudioEmbedding(nn.Module):
return x
# time-step embedding
# for the NAR-len, since it probably most likely requires encoding the timestep
class TimeEmbedding(nn.Module):
def __init__(
self,
d_model
):
super().__init__()
self.emb = SinusoidalEmbedding(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model*4),
nn.SiLU(),
nn.Linear(d_model*4, d_model),
)
def forward( self, t ):
t = self.emb(t)
t = self.mlp(t)
return t
# per-level classification
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
class Classifiers(nn.Module):
@ -545,6 +566,7 @@ class Base(nn.Module):
# experimental NAR-only mode
self.len_emb = Embedding(11, d_model) if "len" in self.capabilities else None
self.time_emb = TimeEmbedding(d_model) if "len" in self.capabilities else None
if attention_backend == "auto":
attention_backend = "sdpa"
@ -967,6 +989,7 @@ class Base(nn.Module):
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
task_list: list[str] | None = None,
time_list: list[Tensor] | None = None,
quant_levels: int | list[int] | Tensor | None = None
):
@ -1012,6 +1035,8 @@ class Base(nn.Module):
t = random.random()
p = math.cos(t * math.pi * 0.5)
dropout_mask = _dropout_mask( resps_list[i], p=p )
inputs[i].append( ("timestep", torch.tensor(t, device=device) ) )
inputs[i].append( ("dropout_mask", dropout_mask ) )
# Audio length prediction task
@ -1108,16 +1133,14 @@ class Base(nn.Module):
task_type = "tts"
input_prom = None
dropout_mask = None
timestep = None
# pre-iterate
for name, input in batch_input:
"""
if name == "prop":
proms = [ input ] if isinstance(input, torch.Tensor) else input
input_prom = torch.cat([ prom for prom in proms if isinstance(prom, torch.Tensor) ])
"""
if name == "dropout_mask":
dropout_mask = input
elif name == "timestep":
timestep = input
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
@ -1160,6 +1183,9 @@ class Base(nn.Module):
offset = 0,
quant_level = 0,
)
t_emb = self.time_emb( timestep )
embedding += t_emb
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task:
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
@ -1236,7 +1262,7 @@ class Base(nn.Module):
return 1
# a mask
if name == "dropout_mask":
if name in ["dropout_mask", "timestep"]:
return 0
# list of tokens

View File

@ -298,6 +298,7 @@ class NAR(Base):
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
time_list=[ timestep ],
quant_levels=quant_levels,
)
output = _super.forward(