seems that my PromEmbedding/RespEmbedding doesn't actually work all that well, naively using dedicated MultiEmbeddings for AR/NAR in the monolithic model is the best way to go
This commit is contained in:
parent
67617d7d69
commit
b2907ae7e0
|
@ -42,7 +42,17 @@ class TTS():
|
||||||
|
|
||||||
models = get_models(cfg.models.get())
|
models = get_models(cfg.models.get())
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
if name.startswith("ar"):
|
if name.startswith("ar+nar"):
|
||||||
|
self.ar = model
|
||||||
|
state = torch.load(self.ar_ckpt)
|
||||||
|
if "symmap" in state:
|
||||||
|
self.symmap = state['symmap']
|
||||||
|
if "module" in state:
|
||||||
|
state = state['module']
|
||||||
|
self.ar.load_state_dict(state)
|
||||||
|
self.ar = self.ar.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||||
|
self.nar = self.ar
|
||||||
|
elif name.startswith("ar"):
|
||||||
self.ar = model
|
self.ar = model
|
||||||
state = torch.load(self.ar_ckpt)
|
state = torch.load(self.ar_ckpt)
|
||||||
if "symmap" in state:
|
if "symmap" in state:
|
||||||
|
@ -74,7 +84,10 @@ class TTS():
|
||||||
def load_models( self ):
|
def load_models( self ):
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
if name[:2] == "ar":
|
if name[:6] == "ar+nar":
|
||||||
|
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||||
|
self.nar = self.ar
|
||||||
|
elif name[:2] == "ar":
|
||||||
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
self.ar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||||
elif name[:3] == "nar":
|
elif name[:3] == "nar":
|
||||||
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
self.nar = engine.module.to(self.device, dtype=cfg.inference.dtype if not cfg.inference.amp else torch.float32)
|
||||||
|
|
|
@ -82,7 +82,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# is training
|
# is training
|
||||||
if n_levels == self.n_resp_levels:
|
if n_levels == self.n_resp_levels:
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.25:
|
||||||
quant_levels = None
|
quant_levels = None
|
||||||
|
|
||||||
targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
targ_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
|
||||||
|
|
|
@ -62,11 +62,15 @@ class MultiEmbedding(nn.Embedding):
|
||||||
This embedding sums embeddings on different levels.
|
This embedding sums embeddings on different levels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, max_n_levels, n_tokens, token_dim):
|
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
|
||||||
super().__init__(max_n_levels, token_dim)
|
super().__init__(max_n_levels, token_dim)
|
||||||
self.max_n_levels = max_n_levels
|
self.max_n_levels = max_n_levels
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
self.monolithic = monolithic
|
||||||
|
if self.monolithic:
|
||||||
|
self.weights = nn.Parameter(torch.randn(2, max_n_levels, n_tokens, token_dim))
|
||||||
|
else:
|
||||||
|
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||||
|
|
||||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
||||||
# I imagine this is an oversight in the NAR.
|
# I imagine this is an oversight in the NAR.
|
||||||
|
@ -74,7 +78,10 @@ class MultiEmbedding(nn.Embedding):
|
||||||
if len(x_list) == 0:
|
if len(x_list) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
w = self.weight
|
if self.monolithic:
|
||||||
|
w = self.weights[0 if quant_levels is None else 1]
|
||||||
|
else:
|
||||||
|
w = self.weight
|
||||||
|
|
||||||
padded_x_list = []
|
padded_x_list = []
|
||||||
|
|
||||||
|
@ -91,6 +98,7 @@ class MultiEmbedding(nn.Embedding):
|
||||||
|
|
||||||
return x_list
|
return x_list
|
||||||
|
|
||||||
|
"""
|
||||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||||
class PromEmbedding(nn.Module):
|
class PromEmbedding(nn.Module):
|
||||||
def __init__(self, n_levels, n_tokens, token_dim):
|
def __init__(self, n_levels, n_tokens, token_dim):
|
||||||
|
@ -110,6 +118,7 @@ class RespEmbedding(nn.Module):
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||||
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
|
return [ self.embeddings[min(self.n_levels, quant_levels[i]) if quant_levels is not None else 0](xi)[:, 0, :] for i, xi in enumerate(x_list) ]
|
||||||
|
"""
|
||||||
|
|
||||||
class Base(nn.Module):
|
class Base(nn.Module):
|
||||||
@property
|
@property
|
||||||
|
@ -154,7 +163,7 @@ class Base(nn.Module):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_embeddings(self) -> int:
|
def n_embeddings(self) -> int:
|
||||||
return self.n_resp_levels if self.monolithic else 1
|
return 2 if self.monolithic else 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_token(self):
|
def stop_token(self):
|
||||||
|
@ -183,15 +192,11 @@ class Base(nn.Module):
|
||||||
p_dropout: float = 0.1,
|
p_dropout: float = 0.1,
|
||||||
|
|
||||||
config = None,
|
config = None,
|
||||||
use_multiembedding = False,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
|
||||||
|
|
||||||
if self.config is not None and hasattr(self.config, "use_multiembedding"):
|
|
||||||
use_multiembedding = self.config.use_multiembedding
|
|
||||||
|
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
@ -203,19 +208,8 @@ class Base(nn.Module):
|
||||||
|
|
||||||
self.text_emb = Embedding(n_tokens, d_model)
|
self.text_emb = Embedding(n_tokens, d_model)
|
||||||
|
|
||||||
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) #, monolithic=self.monolithic)
|
||||||
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||||
if self.n_embeddings == self.n_prom_levels or not use_multiembedding:
|
|
||||||
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
|
||||||
else:
|
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
|
||||||
|
|
||||||
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
|
||||||
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
|
||||||
if self.n_embeddings > 1 or not use_multiembedding:
|
|
||||||
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
|
||||||
else:
|
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
|
|
@ -52,10 +52,13 @@ def run_eval(engines, disabled_engines, eval_name, dl):
|
||||||
|
|
||||||
AR = None
|
AR = None
|
||||||
NAR = None
|
NAR = None
|
||||||
|
AR_NAR = None
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
if name[:2] == "ar":
|
if name[:6] == "ar+nar":
|
||||||
|
AR_NAR = engine
|
||||||
|
elif name[:2] == "ar":
|
||||||
AR = engine
|
AR = engine
|
||||||
elif name[:3] == "nar":
|
elif name[:3] == "nar":
|
||||||
NAR = engine
|
NAR = engine
|
||||||
|
@ -127,7 +130,11 @@ def run_eval(engines, disabled_engines, eval_name, dl):
|
||||||
for name in engines:
|
for name in engines:
|
||||||
model = engines[name]
|
model = engines[name]
|
||||||
|
|
||||||
if name.startswith("ar"):
|
if name.startswith("ar+nar"):
|
||||||
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], max_steps=cfg.evaluation.steps, sampling_temperature=cfg.evaluation.ar_temperature)
|
||||||
|
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||||
|
resps_list = AR_NAR(text_list=batch["text"], proms_list=batch["proms"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||||
|
elif name.startswith("ar"):
|
||||||
resps_list = model(
|
resps_list = model(
|
||||||
text_list=batch["text"],
|
text_list=batch["text"],
|
||||||
proms_list=batch["proms"],
|
proms_list=batch["proms"],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user