tweaks and fixes
This commit is contained in:
parent
b2c2dec291
commit
ab5134f385
|
@ -33,6 +33,7 @@ class TTS():
|
|||
pass
|
||||
|
||||
cfg.mode = "inferencing"
|
||||
cfg.trainer.load_module_only = True
|
||||
|
||||
self.symmap = None
|
||||
if ar_ckpt and nar_ckpt:
|
||||
|
|
|
@ -150,6 +150,7 @@ def example_usage():
|
|||
from ..emb.qnt import decode_to_file
|
||||
from ..engines import Engine
|
||||
from tqdm import tqdm
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
device = "cuda"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||
|
@ -184,15 +185,19 @@ def example_usage():
|
|||
'n_layers': 24,
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
kwargs['config'] = cfg.models.ar
|
||||
except Exception as e:
|
||||
pass
|
||||
"""
|
||||
|
||||
model = AR(**kwargs).to(device)
|
||||
engine = Engine(model=model, optimizer=torch.optim.SGD(model.parameters(), lr=0.1))
|
||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
steps = 500
|
||||
|
||||
def sample( name, steps=400 ):
|
||||
def sample( name, steps=600 ):
|
||||
engine.eval()
|
||||
out = engine(text_list, proms_list, max_steps=steps)
|
||||
for i, o in enumerate(out):
|
||||
|
@ -200,7 +205,7 @@ def example_usage():
|
|||
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(60)
|
||||
t = trange(steps)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from .base import Base, list_to_tensor, Categorical
|
||||
from ..utils import wrapper as ml
|
||||
from ..config import cfg
|
||||
|
||||
import torch
|
||||
|
@ -173,6 +172,7 @@ def example_usage():
|
|||
from ..emb.qnt import decode_to_file, unload_model
|
||||
from ..engines import Engine
|
||||
from tqdm import tqdm
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
device = "cuda"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||
|
@ -215,8 +215,8 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 500
|
||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
#optimizer = ml.AdamW(model.parameters(), lr=0.0001)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
@ -238,7 +238,7 @@ def example_usage():
|
|||
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(500)
|
||||
t = trange(steps)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
|
|
|
@ -68,7 +68,9 @@ class MultiEmbedding(nn.Embedding):
|
|||
self.n_tokens = n_tokens
|
||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||
|
||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
||||
# I imagine this is an oversight in the NAR.
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
|
||||
|
@ -151,9 +153,13 @@ class Base(nn.Module):
|
|||
return False
|
||||
|
||||
@property
|
||||
def n_embeddings(self):
|
||||
def n_embeddings(self) -> int:
|
||||
return self.n_resp_levels if self.monolithic else 1
|
||||
|
||||
@property
|
||||
def use_old_embeddings(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def stop_token(self):
|
||||
if not self.causal:
|
||||
|
@ -199,14 +205,14 @@ class Base(nn.Module):
|
|||
|
||||
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
||||
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
||||
if self.n_embeddings == self.n_prom_levels:
|
||||
if self.n_embeddings == self.n_prom_levels or not self.use_old_embeddings:
|
||||
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
else:
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
|
||||
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
||||
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
||||
if self.n_embeddings > 1:
|
||||
if self.n_embeddings > 1 or not self.use_old_embeddings:
|
||||
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
||||
else:
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
|
@ -409,6 +415,7 @@ def example_usage():
|
|||
from ..emb.qnt import decode_to_file
|
||||
from ..engines import Engine, Engines
|
||||
from tqdm import tqdm, trange
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
from .ar import AR
|
||||
from .nar import NAR
|
||||
|
@ -432,7 +439,7 @@ def example_usage():
|
|||
for name, model in models.items():
|
||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||
engines = Engines({ name: Engine(model=model, optimizer=ml.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||
|
||||
train = True
|
||||
|
||||
|
@ -449,7 +456,7 @@ def example_usage():
|
|||
qnt.to(device),
|
||||
]
|
||||
|
||||
def sample( name, steps=400 ):
|
||||
def sample( name, steps=600 ):
|
||||
AR = None
|
||||
NAR = None
|
||||
|
||||
|
@ -471,7 +478,7 @@ def example_usage():
|
|||
sample("init", 15)
|
||||
|
||||
engines.train()
|
||||
t = trange(60)
|
||||
t = trange(500)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
"""
|
||||
|
|
|
@ -130,6 +130,7 @@ def example_usage():
|
|||
from ..emb.qnt import decode_to_file
|
||||
from ..engines import Engine
|
||||
from tqdm import tqdm
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
device = "cuda"
|
||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||
|
@ -162,7 +163,9 @@ def example_usage():
|
|||
'n_layers': 12,
|
||||
}
|
||||
model = NAR(**kwargs).to(device)
|
||||
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
||||
steps = 500
|
||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
def sample( name ):
|
||||
engine.eval()
|
||||
|
@ -171,7 +174,7 @@ def example_usage():
|
|||
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(60)
|
||||
t = trange(steps)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
|
|
|
@ -43,6 +43,7 @@ def load_engines(invert=False):
|
|||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
if cfg.mode != "inferencing":
|
||||
# load only the models for training initially
|
||||
# loads disabled models at evaluation time (to load updated weights if training separately)
|
||||
# I'm sure there's a more elegant solution to this
|
||||
|
|
Loading…
Reference in New Issue
Block a user