seems that my PromEmbedding/RespEmbedding doesn't actually work all that well, naively using dedicated MultiEmbeddings for AR/NAR in the monolithic model is the best way to go
This commit is contained in:
parent
67617d7d69
commit
b2907ae7e0
|
@ -42,7 +42,17 @@ class TTS():
|
|||
|
||||
models = get_models(cfg.models.get())
|
||||
for name, model in models.items():
|
||||
if name.startswith("ar"):
|
||||
if name.startswith("ar+nar"):
|
||||
self.ar = model
|
||||
state = torch.load(self.ar_ckpt)
|
||||
if "symmap" in state:
|
||||
self.symmap = state['symmap']
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
self.nar = self.ar
|
||||
elif name.startswith("ar"):
|
||||
self.ar = model
|
||||
state = torch.load(self.ar_ckpt)
|
||||
if "symmap" in state:
|
||||
|
@ -74,7 +84,10 @@ class TTS():
|
|||
def load_models( self ):
|
||||
engines = load_engines()
|
||||
for name, engine in engines.items():
|
||||
if name[:2] == "ar":
|
||||
if name[:6] == "ar+nar":
|
||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
self.nar = self.ar
|
||||
elif name[:2] == "ar":
|
||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
elif name[:3] == "nar":
|
||||
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||
|
|
|
@ -82,7 +82,7 @@ class AR_NAR(Base):
|
|||
|
||||
# is training
|
||||
if n_levels == self.n_resp_levels:
|
||||
if random.random() < 0.5:
|
||||
if random.random() < 0.25:
|
||||
quant_levels = None
|
||||
|
||||
targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||
|
|
|
@ -62,11 +62,15 @@ class MultiEmbedding(nn.Embedding):
|
|||
This embedding sums embeddings on different levels.
|
||||
"""
|
||||
|
||||
def __init__(self, max_n_levels, n_tokens, token_dim):
|
||||
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
||||
super().__init__(max_n_levels, token_dim)
|
||||
self.max_n_levels = max_n_levels
|
||||
self.n_tokens = n_tokens
|
||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||
self.monolithic = monolithic
|
||||
if self.monolithic:
|
||||
self.weights = nn.Parameter(torch.randn(2, max_n_levels, n_tokens, token_dim))
|
||||
else:
|
||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||
|
||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
||||
# I imagine this is an oversight in the NAR.
|
||||
|
@ -74,7 +78,10 @@ class MultiEmbedding(nn.Embedding):
|
|||
if len(x_list) == 0:
|
||||
return []
|
||||
|
||||
w = self.weight
|
||||
if self.monolithic:
|
||||
w = self.weights[0 if quant_levels is None else 1]
|
||||
else:
|
||||
w = self.weight
|
||||
|
||||
padded_x_list = []
|
||||
|
||||
|
@ -91,6 +98,7 @@ class MultiEmbedding(nn.Embedding):
|
|||
|
||||
return x_list
|
||||
|
||||
"""
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class PromEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
|
@ -110,6 +118,7 @@ class RespEmbedding(nn.Module):
|
|||
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
|
||||
"""
|
||||
|
||||
class Base(nn.Module):
|
||||
@property
|
||||
|
@ -154,7 +163,7 @@ class Base(nn.Module):
|
|||
|
||||
@property
|
||||
def n_embeddings(self) -> int:
|
||||
return self.n_resp_levels if self.monolithic else 1
|
||||
return 2 if self.monolithic else 1
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
|
@ -183,15 +192,11 @@ class Base(nn.Module):
|
|||
p_dropout: float = 0.1,
|
||||
|
||||
config = None,
|
||||
use_multiembedding = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
||||
|
||||
if self.config is not None and hasattr(self.config, "use_multiembedding"):
|
||||
use_multiembedding = self.config.use_multiembedding
|
||||
|
||||
self.n_tokens = n_tokens
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
|
@ -203,19 +208,8 @@ class Base(nn.Module):
|
|||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
|
||||
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
||||
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
||||
if self.n_embeddings == self.n_prom_levels or not use_multiembedding:
|
||||
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
else:
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
|
||||
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
||||
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
||||
if self.n_embeddings > 1 or not use_multiembedding:
|
||||
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
||||
else:
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
|
|
@ -52,10 +52,13 @@ def run_eval(engines, disabled_engines, eval_name, dl):
|
|||
|
||||
AR = None
|
||||
NAR = None
|
||||
AR_NAR = None
|
||||
|
||||
names = []
|
||||
for name, engine in engines.items():
|
||||
if name[:2] == "ar":
|
||||
if name[:6] == "ar+nar":
|
||||
AR_NAR = engine
|
||||
elif name[:2] == "ar":
|
||||
AR = engine
|
||||
elif name[:3] == "nar":
|
||||
NAR = engine
|
||||
|
@ -127,7 +130,11 @@ def run_eval(engines, disabled_engines, eval_name, dl):
|
|||
for name in engines:
|
||||
model = engines[name]
|
||||
|
||||
if name.startswith("ar"):
|
||||
if name.startswith("ar+nar"):
|
||||
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
elif name.startswith("ar"):
|
||||
resps_list = model(
|
||||
text_list=batch["text"],
|
||||
proms_list=batch["proms"],
|
||||
|
|
Loading…
Reference in New Issue
Block a user