tweaks and fixes
This commit is contained in:
parent
b2c2dec291
commit
ab5134f385
|
@ -33,6 +33,7 @@ class TTS():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
cfg.mode = "inferencing"
|
cfg.mode = "inferencing"
|
||||||
|
cfg.trainer.load_module_only = True
|
||||||
|
|
||||||
self.symmap = None
|
self.symmap = None
|
||||||
if ar_ckpt and nar_ckpt:
|
if ar_ckpt and nar_ckpt:
|
||||||
|
|
|
@ -150,6 +150,7 @@ def example_usage():
|
||||||
from ..emb.qnt import decode_to_file
|
from ..emb.qnt import decode_to_file
|
||||||
from ..engines import Engine
|
from ..engines import Engine
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||||
|
@ -184,15 +185,19 @@ def example_usage():
|
||||||
'n_layers': 24,
|
'n_layers': 24,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
kwargs['config'] = cfg.models.ar
|
kwargs['config'] = cfg.models.ar
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
model = AR(**kwargs).to(device)
|
model = AR(**kwargs).to(device)
|
||||||
engine = Engine(model=model, optimizer=torch.optim.SGD(model.parameters(), lr=0.1))
|
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
steps = 500
|
||||||
|
|
||||||
def sample( name, steps=400 ):
|
def sample( name, steps=600 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
out = engine(text_list, proms_list, max_steps=steps)
|
out = engine(text_list, proms_list, max_steps=steps)
|
||||||
for i, o in enumerate(out):
|
for i, o in enumerate(out):
|
||||||
|
@ -200,7 +205,7 @@ def example_usage():
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
engine.train()
|
engine.train()
|
||||||
t = trange(60)
|
t = trange(steps)
|
||||||
for i in t:
|
for i in t:
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from .base import Base, list_to_tensor, Categorical
|
from .base import Base, list_to_tensor, Categorical
|
||||||
from ..utils import wrapper as ml
|
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -173,6 +172,7 @@ def example_usage():
|
||||||
from ..emb.qnt import decode_to_file, unload_model
|
from ..emb.qnt import decode_to_file, unload_model
|
||||||
from ..engines import Engine
|
from ..engines import Engine
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||||
|
@ -215,8 +215,8 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
|
steps = 500
|
||||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
#optimizer = ml.AdamW(model.parameters(), lr=0.0001)
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
@ -238,7 +238,7 @@ def example_usage():
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
engine.train()
|
engine.train()
|
||||||
t = trange(500)
|
t = trange(steps)
|
||||||
for i in t:
|
for i in t:
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||||
|
|
|
@ -68,7 +68,9 @@ class MultiEmbedding(nn.Embedding):
|
||||||
self.n_tokens = n_tokens
|
self.n_tokens = n_tokens
|
||||||
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
||||||
|
# I imagine this is an oversight in the NAR.
|
||||||
|
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||||
if len(x_list) == 0:
|
if len(x_list) == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@ -151,9 +153,13 @@ class Base(nn.Module):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_embeddings(self):
|
def n_embeddings(self) -> int:
|
||||||
return self.n_resp_levels if self.monolithic else 1
|
return self.n_resp_levels if self.monolithic else 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_old_embeddings(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stop_token(self):
|
def stop_token(self):
|
||||||
if not self.causal:
|
if not self.causal:
|
||||||
|
@ -199,14 +205,14 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
# use dedicated embeddings for each RVQ-bin level in the input acoustic prompt if requested
|
||||||
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
# n_embeddings == prom_levels because using the MultiEmbedding is more than fine for the input acoustic prompt
|
||||||
if self.n_embeddings == self.n_prom_levels:
|
if self.n_embeddings == self.n_prom_levels or not self.use_old_embeddings:
|
||||||
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
self.proms_emb = PromEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
else:
|
else:
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
|
|
||||||
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
# use dedicated embeddings for each RVQ-bin level in the output response / target if requested
|
||||||
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
# n_embeddings > 1 because the using the MultiEmbedding "works" fine enough for split AR/NARs.
|
||||||
if self.n_embeddings > 1:
|
if self.n_embeddings > 1 or not self.use_old_embeddings:
|
||||||
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
self.resps_emb = RespEmbedding(self.n_embeddings, n_resp_tokens, d_model)
|
||||||
else:
|
else:
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||||
|
@ -409,6 +415,7 @@ def example_usage():
|
||||||
from ..emb.qnt import decode_to_file
|
from ..emb.qnt import decode_to_file
|
||||||
from ..engines import Engine, Engines
|
from ..engines import Engine, Engines
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
from .ar import AR
|
from .ar import AR
|
||||||
from .nar import NAR
|
from .nar import NAR
|
||||||
|
@ -432,7 +439,7 @@ def example_usage():
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
engines = Engines({ name: Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
engines = Engines({ name: Engine(model=model, optimizer=ml.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
|
||||||
|
|
||||||
train = True
|
train = True
|
||||||
|
|
||||||
|
@ -449,7 +456,7 @@ def example_usage():
|
||||||
qnt.to(device),
|
qnt.to(device),
|
||||||
]
|
]
|
||||||
|
|
||||||
def sample( name, steps=400 ):
|
def sample( name, steps=600 ):
|
||||||
AR = None
|
AR = None
|
||||||
NAR = None
|
NAR = None
|
||||||
|
|
||||||
|
@ -471,7 +478,7 @@ def example_usage():
|
||||||
sample("init", 15)
|
sample("init", 15)
|
||||||
|
|
||||||
engines.train()
|
engines.train()
|
||||||
t = trange(60)
|
t = trange(500)
|
||||||
for i in t:
|
for i in t:
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -130,6 +130,7 @@ def example_usage():
|
||||||
from ..emb.qnt import decode_to_file
|
from ..emb.qnt import decode_to_file
|
||||||
from ..engines import Engine
|
from ..engines import Engine
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from ..utils import wrapper as ml
|
||||||
|
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
|
||||||
|
@ -162,7 +163,9 @@ def example_usage():
|
||||||
'n_layers': 12,
|
'n_layers': 12,
|
||||||
}
|
}
|
||||||
model = NAR(**kwargs).to(device)
|
model = NAR(**kwargs).to(device)
|
||||||
engine = Engine(model=model, optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4))
|
steps = 500
|
||||||
|
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
def sample( name ):
|
def sample( name ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
|
@ -171,7 +174,7 @@ def example_usage():
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
engine.train()
|
engine.train()
|
||||||
t = trange(60)
|
t = trange(steps)
|
||||||
for i in t:
|
for i in t:
|
||||||
stats = {"step": i}
|
stats = {"step": i}
|
||||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||||
|
|
|
@ -43,6 +43,7 @@ def load_engines(invert=False):
|
||||||
engines = dict()
|
engines = dict()
|
||||||
|
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
|
if cfg.mode != "inferencing":
|
||||||
# load only the models for training initially
|
# load only the models for training initially
|
||||||
# loads disabled models at evaluation time (to load updated weights if training separately)
|
# loads disabled models at evaluation time (to load updated weights if training separately)
|
||||||
# I'm sure there's a more elegant solution to this
|
# I'm sure there's a more elegant solution to this
|
||||||
|
|
Loading…
Reference in New Issue
Block a user