tweaks and fixes

This commit is contained in:
mrq 2023-09-07 17:08:38 -05:00
parent b2c2dec291
commit ab5134f385
6 changed files with 48 additions and 31 deletions

View File

@ -33,6 +33,7 @@ class TTS():
pass pass
cfg.mode = "inferencing" cfg.mode = "inferencing"
cfg.trainer.load_module_only = True
self.symmap = None self.symmap = None
if ar_ckpt and nar_ckpt: if ar_ckpt and nar_ckpt:

View File

@ -150,6 +150,7 @@ def example_usage():
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..engines import Engine from ..engines import Engine
from tqdm import tqdm from tqdm import tqdm
from ..utils import wrapper as ml
device = "cuda" device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
@ -184,15 +185,19 @@ def example_usage():
'n_layers': 24, 'n_layers': 24,
} }
"""
try: try:
kwargs['config'] = cfg.models.ar kwargs['config'] = cfg.models.ar
except Exception as e: except Exception as e:
pass pass
"""
model = AR(**kwargs).to(device) model = AR(**kwargs).to(device)
engine = Engine(model=model, optimizer=torch.optim.SGD(model.parameters(), lr=0.1)) optimizer = ml.Prodigy(model.parameters(), lr=1.0)
engine = Engine(model=model, optimizer=optimizer)
steps = 500
def sample( name, steps=400 ): def sample( name, steps=600 ):
engine.eval() engine.eval()
out = engine(text_list, proms_list, max_steps=steps) out = engine(text_list, proms_list, max_steps=steps)
for i, o in enumerate(out): for i, o in enumerate(out):
@ -200,7 +205,7 @@ def example_usage():
def train(): def train():
engine.train() engine.train()
t = trange(60) t = trange(steps)
for i in t: for i in t:
stats = {"step": i} stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)

View File

@ -1,5 +1,4 @@
from .base import Base, list_to_tensor, Categorical from .base import Base, list_to_tensor, Categorical
from ..utils import wrapper as ml
from ..config import cfg from ..config import cfg
import torch import torch
@ -173,6 +172,7 @@ def example_usage():
from ..emb.qnt import decode_to_file, unload_model from ..emb.qnt import decode_to_file, unload_model
from ..engines import Engine from ..engines import Engine
from tqdm import tqdm from tqdm import tqdm
from ..utils import wrapper as ml
device = "cuda" device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
@ -215,8 +215,8 @@ def example_usage():
""" """
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0) optimizer = ml.Prodigy(model.parameters(), lr=1.0)
#optimizer = ml.AdamW(model.parameters(), lr=0.0001)
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@ -238,7 +238,7 @@ def example_usage():
def train(): def train():
engine.train() engine.train()
t = trange(500) t = trange(steps)
for i in t: for i in t:
stats = {"step": i} stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)

View File

@ -68,7 +68,9 @@ class MultiEmbedding(nn.Embedding):
self.n_tokens = n_tokens self.n_tokens = n_tokens
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim)) self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
def forward(self, x_list: list[Tensor]) -> list[Tensor]: # to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
if len(x_list) == 0: if len(x_list) == 0:
return [] return []
@ -151,9 +153,13 @@ class Base(nn.Module):
return False return False
@property @property
def n_embeddings(self): def n_embeddings(self) -> int:
return self.n_resp_levels if self.monolithic else 1 return self.n_resp_levels if self.monolithic else 1
@property
def use_old_embeddings(self) -> bool:
return True
@property @property
def stop_token(self): def stop_token(self):
if not self.causal: if not self.causal:
@ -199,14 +205,14 @@ class Base(nn.Module):
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested # use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt # n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
if self.n_embeddings == self.n_prom_levels: if self.n_embeddings == self.n_prom_levels or not self.use_old_embeddings:
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model) self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
else: else:
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested # use dedicated embeddings for each RVQ-bin level in the output response / target if requested
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs. # n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
if self.n_embeddings > 1: if self.n_embeddings > 1 or not self.use_old_embeddings:
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model) self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
else: else:
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
@ -409,6 +415,7 @@ def example_usage():
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..engines import Engine, Engines from ..engines import Engine, Engines
from tqdm import tqdm, trange from tqdm import tqdm, trange
from ..utils import wrapper as ml
from .ar import AR from .ar import AR
from .nar import NAR from .nar import NAR
@ -432,7 +439,7 @@ def example_usage():
for name, model in models.items(): for name, model in models.items():
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() }) engines = Engines({ name: Engine(model=model, optimizer=ml.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
train = True train = True
@ -449,7 +456,7 @@ def example_usage():
qnt.to(device), qnt.to(device),
] ]
def sample( name, steps=400 ): def sample( name, steps=600 ):
AR = None AR = None
NAR = None NAR = None
@ -471,7 +478,7 @@ def example_usage():
sample("init", 15) sample("init", 15)
engines.train() engines.train()
t = trange(60) t = trange(500)
for i in t: for i in t:
stats = {"step": i} stats = {"step": i}
""" """

View File

@ -130,6 +130,7 @@ def example_usage():
from ..emb.qnt import decode_to_file from ..emb.qnt import decode_to_file
from ..engines import Engine from ..engines import Engine
from tqdm import tqdm from tqdm import tqdm
from ..utils import wrapper as ml
device = "cuda" device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
@ -162,7 +163,9 @@ def example_usage():
'n_layers': 12, 'n_layers': 12,
} }
model = NAR(**kwargs).to(device) model = NAR(**kwargs).to(device)
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
engine = Engine(model=model, optimizer=optimizer)
def sample( name ): def sample( name ):
engine.eval() engine.eval()
@ -171,7 +174,7 @@ def example_usage():
def train(): def train():
engine.train() engine.train()
t = trange(60) t = trange(steps)
for i in t: for i in t:
stats = {"step": i} stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)

View File

@ -43,6 +43,7 @@ def load_engines(invert=False):
engines = dict() engines = dict()
for name, model in models.items(): for name, model in models.items():
if cfg.mode != "inferencing":
# load only the models for training initially # load only the models for training initially
# loads disabled models at evaluation time (to load updated weights if training separately) # loads disabled models at evaluation time (to load updated weights if training separately)
# I'm sure there's a more elegant solution to this # I'm sure there's a more elegant solution to this