seems that my PromEmbedding/RespEmbedding doesn't actually work all that well, naively using dedicated MultiEmbeddings for AR/NAR in the monolithic model is the best way to go

This commit is contained in:
mrq 2023-09-08 01:03:24 -05:00
parent 67617d7d69
commit b2907ae7e0
4 changed files with 40 additions and 26 deletions

View File

@ -42,7 +42,17 @@ class TTS():
models = get_models(cfg.models.get())
for name, model in models.items():
if name.startswith("ar"):
if name.startswith("ar+nar"):
self.ar = model
state = torch.load(self.ar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
self.nar = self.ar
elif name.startswith("ar"):
self.ar = model
state = torch.load(self.ar_ckpt)
if "symmap" in state:
@ -74,7 +84,10 @@ class TTS():
def load_models( self ):
engines = load_engines()
for name, engine in engines.items():
if name[:2] == "ar":
if name[:6] == "ar+nar":
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
self.nar = self.ar
elif name[:2] == "ar":
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
elif name[:3] == "nar":
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)

View File

@ -82,7 +82,7 @@ class AR_NAR(Base):
# is training
if n_levels == self.n_resp_levels:
if random.random() < 0.5:
if random.random() < 0.25:
quant_levels = None
targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels

View File

@ -62,11 +62,15 @@ class MultiEmbedding(nn.Embedding):
This embedding sums embeddings on different levels.
"""
def __init__(self, max_n_levels, n_tokens, token_dim):
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
super().__init__(max_n_levels, token_dim)
self.max_n_levels = max_n_levels
self.n_tokens = n_tokens
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
self.monolithic = monolithic
if self.monolithic:
self.weights = nn.Parameter(torch.randn(2, max_n_levels, n_tokens, token_dim))
else:
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
@ -74,7 +78,10 @@ class MultiEmbedding(nn.Embedding):
if len(x_list) == 0:
return []
w = self.weight
if self.monolithic:
w = self.weights[0 if quant_levels is None else 1]
else:
w = self.weight
padded_x_list = []
@ -91,6 +98,7 @@ class MultiEmbedding(nn.Embedding):
return x_list
"""
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class PromEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
@ -110,6 +118,7 @@ class RespEmbedding(nn.Module):
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
"""
class Base(nn.Module):
@property
@ -154,7 +163,7 @@ class Base(nn.Module):
@property
def n_embeddings(self) -> int:
return self.n_resp_levels if self.monolithic else 1
return 2 if self.monolithic else 1
@property
def stop_token(self):
@ -183,15 +192,11 @@ class Base(nn.Module):
p_dropout: float = 0.1,
config = None,
use_multiembedding = False,
):
super().__init__()
self.config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
if self.config is not None and hasattr(self.config, "use_multiembedding"):
use_multiembedding = self.config.use_multiembedding
self.n_tokens = n_tokens
self.d_model = d_model
self.n_heads = n_heads
@ -203,19 +208,8 @@ class Base(nn.Module):
self.text_emb = Embedding(n_tokens, d_model)
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
if self.n_embeddings == self.n_prom_levels or not use_multiembedding:
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
else:
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
if self.n_embeddings > 1 or not use_multiembedding:
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
else:
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
self.sep = nn.Parameter(torch.randn(d_model))

View File

@ -52,10 +52,13 @@ def run_eval(engines, disabled_engines, eval_name, dl):
AR = None
NAR = None
AR_NAR = None
names = []
for name, engine in engines.items():
if name[:2] == "ar":
if name[:6] == "ar+nar":
AR_NAR = engine
elif name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
@ -127,7 +130,11 @@ def run_eval(engines, disabled_engines, eval_name, dl):
for name in engines:
model = engines[name]
if name.startswith("ar"):
if name.startswith("ar+nar"):
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
elif name.startswith("ar"):
resps_list = model(
text_list=batch["text"],
proms_list=batch["proms"],