reticulating splines

This commit is contained in:
mrq 2024-06-08 20:30:15 -05:00
parent ead3e2f0cb
commit 8d068fa3f9
4 changed files with 543 additions and 116 deletions

View File

@ -16,6 +16,23 @@ def get_model(config, training=True):
l_padding = config.input_alignment,
training = training,
config = config,
)
elif "len" in config.capabilities:
from .nar import NAR
model = NAR(
n_text_tokens=config.text_tokens,
n_audio_tokens=config.audio_tokens,
d_model=config.dim,
n_heads=config.heads,
n_layers=config.layers,
n_experts=config.experts,
p_dropout=config.dropout,
l_padding = config.input_alignment,
training = training,
config = config,
)

View File

@ -1,3 +1,11 @@
"""
# an AR + NAR model that handles:
* inferencing the primary RVQ level in an autoregressive manner (AR)
* inferencing the remaining RVQ levels in parallel (NAR)
This model can fully handle being trained as a unified model (AR + NAR) or separate models (AR | NAR).
It's recommended to train as a unified model, then "distill" knowledge of each tasks separately, just in case.
"""
from .base import Base, list_to_tensor, Categorical
from ..config import cfg
@ -21,7 +29,7 @@ class AR_NAR(Base):
@property
def causal(self):
return "ar" in self.capabilities or "len" in self.capabilities
return "ar" in self.capabilities
@property
def norm_type(self):
@ -139,8 +147,7 @@ class AR_NAR(Base):
if n_levels == self.n_resp_levels:
# to-do: make this YAML configurable
def sample_task():
p_len_task = 0.25 if "len" in self.capabilities else 0
return "len" if random.random() < p_len_task else "tts"
return "tts"
# generate task list to train against
task_list = [ sample_task() for _ in range(batch_size) ]
@ -158,21 +165,20 @@ class AR_NAR(Base):
index = i
return int(index)
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
else:
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
# append stop tokens for AR
# could technically do it in the .inputs call
if "len" not in self.capabilities:
for i in range(batch_size):
# only apply stop token for RVQ level 0
if quant_levels[i] > 0:
continue
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
for i in range(batch_size):
# only apply stop token for RVQ level 0
if quant_levels[i] > 0:
continue
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
inputs = self.inputs(
text_list=text_list,
@ -241,94 +247,13 @@ class AR_NAR(Base):
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list
# other NAR
if len_list is not None:
# is NAR
if max_levels == 0:
max_levels = self.n_resp_levels
# fill with mock tokens
prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ]
start = True
for n in trange( max_levels, desc="NAR" ):
level = 0 if n == 0 else prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
logits = super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
if n == 0:
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
else:
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list
# is AR
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
stop_token = 10 if "len" in self.capabilities else self.stop_token
task_list = [ "len" if "len" in self.capabilities else "tts" for _ in range(batch_size) ]
if "len" in self.capabilities:
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
for n in trange(10, desc="AR"):
len_list = sequence_list
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
task_list=task_list,
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
)
logits = super().forward(
inputs=inputs,
)
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
# sanitize
for i, token in enumerate(r):
if token > 10:
r[i] = 0
# append tokens
for i, ri in enumerate(r):
if stop_token in ri:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
# stop token found
stopped |= r == stop_token
if stopped.all().item():
break
# convert tokens into int
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
stop_token = self.stop_token
task_list = [ "tts" for _ in range(batch_size) ]
recurrent_state = [] if cfg.inference.recurrent_forward else None
mirostat = [
@ -579,17 +504,13 @@ def example_usage():
return
engine.eval()
if "len" in cfg.model.capabilities:
len_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.2 )
if "ar" in cfg.model.capabilities:
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
else:
if "ar" in cfg.model.capabilities:
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
else:
resps_list = [ qnt[:, 0].to( device ) ]
resps_list = [ qnt[:, 0].to( device ) ]
if "nar" in cfg.model.capabilities:
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
if "nar" in cfg.model.capabilities:
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
for i, o in enumerate(resps_list):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)

View File

@ -1,3 +1,14 @@
"""
Core model for handling all VALL-E tasks.
This should handle all the "low" level things such as:
* parsing inputs to sequences
* converting sequences to embeddings
* forward pass
* processing loss and returning logits
Additional functionality (preparing inputs, generating full audio) should be delegated to classes that inheret the base model
"""
import math
import torch
import torch.nn.functional as F
@ -100,6 +111,7 @@ class MultiEmbedding(nn.Module):
return x_list
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
# _Old, to preserve compat with previous models.
class AudioEmbedding_Old(nn.Module):
def __init__(
self,
@ -115,12 +127,12 @@ class AudioEmbedding_Old(nn.Module):
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, xi: Tensor, offset: int | None = 0 ) -> Tensor:
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
# prom
if offset == 0 and xi.shape[-1] > 1:
if quant_level is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_level == 0:
# prom / AR resp
elif quant_level is None or quant_level == 0:
x = self.embeddings[0]( xi if xi.dim() == 1 else xi[:, 0] )
# NAR resp
else:
@ -128,26 +140,26 @@ class AudioEmbedding_Old(nn.Module):
return x
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
# Mostly to handle some oversights and errors during testing
class AudioEmbedding(nn.Module):
def __init__(
self,
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
mode: str, # prom | resp
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
):
super().__init__()
# array of embeddings
# proms are [0, prom_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
#
self.mode = mode
#
self.sums = sums
# maintaining compat is hard
def forward(self, xi: Tensor, quant_level: int | Tensor | None = None, offset: int = 0 ) -> Tensor:
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
if self.sums and quant_level > 0:
@ -159,6 +171,8 @@ class AudioEmbedding(nn.Module):
return x
class Base(nn.Module):
# to-do: clean up this property mess
@property
def causal(self) -> bool:
raise NotImplementedError
@ -292,12 +306,10 @@ class Base(nn.Module):
else:
self.proms_emb = AudioEmbedding(
[n_prom_tokens] * self.n_prom_levels, d_model,
"prom",
sums=self.config.audio_embedding_sums if self.config is not None else True
)
self.resps_emb = AudioEmbedding(
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
"resp:len" if "len" in self.capabilities else "resp",
sums=self.config.audio_embedding_sums if self.config is not None else True
)
@ -700,16 +712,31 @@ class Base(nn.Module):
embedding = self.langs_emb( input )
elif name == "prom":
# get RVQ level 0, or up to targetted RVQ level inference
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 )
if self.version <= 4:
embedding = self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
else:
if quant_level == 0:
embedding = self.proms_emb( input if input.dim() == 1 else input[:, :1], offset = 0 )
else:
embedding = self.proms_emb( input if input.dim() == 1 else input[:, :quant_level], offset = 0 )
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
if "len" in self.capabilities and quant_level == 0:
# fill with "stop" tokens for NAR-only model
embedding = self.resps_emb( torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), offset = 0 )
embedding = self.resps_emb(
torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token),
offset = 0
)
else:
# get RVQ level 0, or up to targetted RVQ level inference
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 or "len" in self.capabilities else 1 )
if self.version <= 4:
embedding = self.resps_emb( input if quant_level == 0 else input[:, :quant_level], quant_level )
else:
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
)
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )
else:

462
vall_e/models/nar.py Normal file
View File

@ -0,0 +1,462 @@
"""
A (mostly) NAR model that handles inferencing all RVQ levels in parallel (NAR).
I believe Meta's Voicebox does this too (predict the utterance length, then decode in parallel)
It *does* have to inference the initial length in an autoregresssive-ish manner (it can technically also be done in parallel)
Initial experiments show this only really "works" for the a few brief seconds before going to silence. I imagine I need to read more papers or just need to train longer.
"""
from .base import Base, list_to_tensor, Categorical
from ..config import cfg
import torch
from torch.nn.utils.rnn import pad_sequence
import random
import math
from einops import rearrange
from torch import Tensor
from tqdm import trange
from ..emb.qnt import trim
class NAR(Base):
@property
def capabilities(self) -> list[str]:
if hasattr(self, "config") and self.config:
return self.config.capabilities
return cfg.model.capabilities
@property
def causal(self):
return "len" in self.capabilities
@property
def norm_type(self):
return "ln" # if self.n_resp_levels == 1 else "adaln"
@property
def arch_type(self) -> str:
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.model.arch_type
@property
def n_prom_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.prom_levels
return cfg.model.prom_levels
@property
def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.resp_levels
return cfg.model.resp_levels
@property
def n_max_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.max_levels
return cfg.model.max_levels
@property
def n_tasks(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.tasks
return cfg.model.tasks
@property
def n_langs(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.langs
return cfg.model.langs
@property
def n_tones(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.tones
return cfg.model.tones
@property
def causal_size(self) -> int:
# 1 for the stop token
# governs how much to shift the logits by
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
return 1 # if self.causal else 0
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return True
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.model.version
def _prune(self, l: Tensor, stop = None):
if stop is None:
stop = self.stop_token
indices = (l == stop).nonzero()
if len(indices) == 0:
return l
return l[: indices.min().item()]
@staticmethod
def _unsqueeze_list(x_list, axis=-1):
return [x.unsqueeze(dim=axis) for x in x_list]
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
max_steps: int = 1000,
max_levels: int = 0,
max_resp_context: int = -1,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
):
device = text_list[0].device
batch_size = len(text_list)
# is training
if resps_list is not None:
p_len_task = 0.25
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
# is training
assert n_levels == self.n_resp_levels
# to-do: make this YAML configurable
def sample_task():
return "len" if random.random() < p_len_task else "tts"
# generate task list to train against
task_list = [ sample_task() for _ in range(batch_size) ]
# determines which RVQ level to target per batch
quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ]
if cfg.experimental:
# makes higher levels less likely
def generate( lo=0, hi=8 ):
index = lo
p = random.random()
for i in range(lo, hi):
if p < 1.0 / (2 ** i):
index = i
return int(index)
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
else:
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
task_list=task_list,
quant_levels=quant_levels,
)
return super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
# NAR
if len_list is not None:
# is NAR
if max_levels == 0:
max_levels = self.n_resp_levels
# fill with mock tokens
prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ]
start = True
for n in trange( max_levels, desc="NAR" ):
level = 0 if n == 0 else prev_list[0].shape[-1]
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
logits = super().forward(
inputs=inputs,
quant_levels=quant_levels,
)
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
if n == 0:
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
else:
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list
# is AR
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()
stop_token = 10
task_list = [ "len" for _ in range(batch_size) ]
for n in trange(10, desc="AR"):
len_list = sequence_list
inputs = self.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
len_list=len_list,
task_list=task_list,
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
)
logits = super().forward(
inputs=inputs,
)
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
# sanitize
for i, token in enumerate(r):
if token > 10:
r[i] = 0
# append tokens
for i, ri in enumerate(r):
if stop_token in ri:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
# stop token found
stopped |= r == stop_token
if stopped.all().item():
break
# convert tokens into int
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
def example_usage():
cfg.trainer.backend = "local"
cfg.hyperparameters.gradient_accumulation_steps = 1
if cfg.audio_backend == "dac":
cfg.sample_rate = 44_000
from functools import partial
from einops import repeat
from tqdm import tqdm
from ..emb.qnt import decode_to_file, unload_model
from ..engines import Engine
from ..utils import wrapper as ml
import numpy as np
import re
device = "cuda"
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""
if "mamba" in cfg.model.arch_type:
cfg.model.prom_levels = 1
cfg.model.resp_levels = 1
"""
# cfg.model.loss_factors = {}
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
]
proms_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resps_list = [
qnt[:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text_list = text_list[:1]
proms_list = proms_list[:1]
resps_list = resps_list[:1]
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {
'n_text_tokens': 256,
'n_audio_tokens': 1024,
'd_model': 1024, # 256, # 1024, # 1536
'n_heads': 16, # 4, # 16, # 24
'n_layers': 12, # 32
'n_experts': 1,
'p_dropout': 0.1,
'l_padding': 8 if cfg.optimizations.fp8 else 0,
'config': cfg.model
}
"""
try:
kwargs['config'] = cfg.model
except Exception as e:
pass
"""
model = NAR(**kwargs).to(device)
steps = 200
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
if cfg.optimizations.dadaptation:
# do not combine the two
if scheduler == "schedulefree":
scheduler = ""
learning_rate = 1.0
if optimizer == "prodigy":
if learning_rate is None:
learning_rate = 1.0
optimizer = ml.Prodigy
elif optimizer == "adagrad":
if learning_rate is None:
learning_rate = 1.0e-2
optimizer = ml.Adagrad
elif optimizer == "adamw":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.AdamW
elif optimizer == "sdg":
if learning_rate is None:
learning_rate = 1.0e-4
optimizer = ml.SGD
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
optimizer = optimizer(model.parameters(), lr=learning_rate)
if scheduler == "schedulefree":
if isinstance(optimizer, ml.AdamW):
scheduler = ml.schedulefree.AdamWScheduleFree
elif isinstance(optimizer, ml.SGD):
scheduler = ml.schedulefree.SGDScheduleFree
else:
scheduler = None
if scheduler is not None:
print("Scheduler:", scheduler)
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
model = ml.replace_linear( model )
if cfg.optimizations.replace and cfg.optimizations.embedding:
model = ml.replace_embedding( model )
engine = Engine(model=model, optimizer=optimizer)
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
print(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=1000 ):
if cfg.audio_backend == "dac" and name == "init":
return
engine.eval()
len_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.2 )
for i, o in enumerate(resps_list):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model()
def train():
engine.train()
t = trange(steps)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}")
"""
torch.save( {
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
"""
#sample("init", 5)
train()
sample("final")
if __name__ == "__main__":
example_usage()