reticulating splines
This commit is contained in:
parent
ead3e2f0cb
commit
8d068fa3f9
|
@ -16,6 +16,23 @@ def get_model(config, training=True):
|
|||
|
||||
l_padding = config.input_alignment,
|
||||
|
||||
training = training,
|
||||
config = config,
|
||||
)
|
||||
elif "len" in config.capabilities:
|
||||
from .nar import NAR
|
||||
model = NAR(
|
||||
n_text_tokens=config.text_tokens,
|
||||
n_audio_tokens=config.audio_tokens,
|
||||
d_model=config.dim,
|
||||
n_heads=config.heads,
|
||||
n_layers=config.layers,
|
||||
n_experts=config.experts,
|
||||
|
||||
p_dropout=config.dropout,
|
||||
|
||||
l_padding = config.input_alignment,
|
||||
|
||||
training = training,
|
||||
config = config,
|
||||
)
|
||||
|
|
|
@ -1,3 +1,11 @@
|
|||
"""
|
||||
# an AR + NAR model that handles:
|
||||
* inferencing the primary RVQ level in an autoregressive manner (AR)
|
||||
* inferencing the remaining RVQ levels in parallel (NAR)
|
||||
|
||||
This model can fully handle being trained as a unified model (AR + NAR) or separate models (AR | NAR).
|
||||
It's recommended to train as a unified model, then "distill" knowledge of each tasks separately, just in case.
|
||||
"""
|
||||
from .base import Base, list_to_tensor, Categorical
|
||||
from ..config import cfg
|
||||
|
||||
|
@ -21,7 +29,7 @@ class AR_NAR(Base):
|
|||
|
||||
@property
|
||||
def causal(self):
|
||||
return "ar" in self.capabilities or "len" in self.capabilities
|
||||
return "ar" in self.capabilities
|
||||
|
||||
@property
|
||||
def norm_type(self):
|
||||
|
@ -139,8 +147,7 @@ class AR_NAR(Base):
|
|||
if n_levels == self.n_resp_levels:
|
||||
# to-do: make this YAML configurable
|
||||
def sample_task():
|
||||
p_len_task = 0.25 if "len" in self.capabilities else 0
|
||||
return "len" if random.random() < p_len_task else "tts"
|
||||
return "tts"
|
||||
|
||||
# generate task list to train against
|
||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
@ -158,21 +165,20 @@ class AR_NAR(Base):
|
|||
index = i
|
||||
return int(index)
|
||||
|
||||
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
quant_levels = [ generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
else:
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
quant_levels = [ random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
# append stop tokens for AR
|
||||
# could technically do it in the .inputs call
|
||||
if "len" not in self.capabilities:
|
||||
for i in range(batch_size):
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_levels[i] > 0:
|
||||
continue
|
||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||
for i in range(batch_size):
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_levels[i] > 0:
|
||||
continue
|
||||
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
@ -241,94 +247,13 @@ class AR_NAR(Base):
|
|||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
||||
# other NAR
|
||||
if len_list is not None:
|
||||
# is NAR
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels
|
||||
|
||||
# fill with mock tokens
|
||||
prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||
|
||||
start = True
|
||||
for n in trange( max_levels, desc="NAR" ):
|
||||
level = 0 if n == 0 else prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
logits = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
|
||||
|
||||
if n == 0:
|
||||
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
|
||||
else:
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
||||
# is AR
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
stop_token = 10 if "len" in self.capabilities else self.stop_token
|
||||
task_list = [ "len" if "len" in self.capabilities else "tts" for _ in range(batch_size) ]
|
||||
|
||||
if "len" in self.capabilities:
|
||||
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||
for n in trange(10, desc="AR"):
|
||||
len_list = sequence_list
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
task_list=task_list,
|
||||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||
)
|
||||
|
||||
logits = super().forward(
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
||||
# sanitize
|
||||
for i, token in enumerate(r):
|
||||
if token > 10:
|
||||
r[i] = 0
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if stop_token in ri:
|
||||
stopped[i] = True
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
# convert tokens into int
|
||||
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
|
||||
|
||||
stop_token = self.stop_token
|
||||
task_list = [ "tts" for _ in range(batch_size) ]
|
||||
|
||||
recurrent_state = [] if cfg.inference.recurrent_forward else None
|
||||
mirostat = [
|
||||
|
@ -579,17 +504,13 @@ def example_usage():
|
|||
return
|
||||
|
||||
engine.eval()
|
||||
if "len" in cfg.model.capabilities:
|
||||
len_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.2 )
|
||||
if "ar" in cfg.model.capabilities:
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
else:
|
||||
if "ar" in cfg.model.capabilities:
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
else:
|
||||
resps_list = [ qnt[:, 0].to( device ) ]
|
||||
resps_list = [ qnt[:, 0].to( device ) ]
|
||||
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
|
|
@ -1,3 +1,14 @@
|
|||
"""
|
||||
Core model for handling all VALL-E tasks.
|
||||
This should handle all the "low" level things such as:
|
||||
* parsing inputs to sequences
|
||||
* converting sequences to embeddings
|
||||
* forward pass
|
||||
* processing loss and returning logits
|
||||
|
||||
Additional functionality (preparing inputs, generating full audio) should be delegated to classes that inheret the base model
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -100,6 +111,7 @@ class MultiEmbedding(nn.Module):
|
|||
return x_list
|
||||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
# _Old, to preserve compat with previous models.
|
||||
class AudioEmbedding_Old(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -115,12 +127,12 @@ class AudioEmbedding_Old(nn.Module):
|
|||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
|
||||
def forward(self, xi: Tensor, offset: int | None = 0 ) -> Tensor:
|
||||
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
|
||||
# prom
|
||||
if offset == 0 and xi.shape[-1] > 1:
|
||||
if quant_level is None and xi.shape[-1] > 1:
|
||||
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
# AR resp
|
||||
elif quant_level == 0:
|
||||
# prom / AR resp
|
||||
elif quant_level is None or quant_level == 0:
|
||||
x = self.embeddings[0]( xi if xi.dim() == 1 else xi[:, 0] )
|
||||
# NAR resp
|
||||
else:
|
||||
|
@ -128,26 +140,26 @@ class AudioEmbedding_Old(nn.Module):
|
|||
|
||||
return x
|
||||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
# Mostly to handle some oversights and errors during testing
|
||||
class AudioEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
l_tokens: int, # list of number of tokens (needed because AR resps includes stop token)
|
||||
token_dim: int, # dimensionality of the embedding
|
||||
mode: str, # prom | resp
|
||||
sums: bool = True # whether to sum all previous layers of embeddings to factor in other RVQ bin levels (I do not know which way is better)
|
||||
):
|
||||
super().__init__()
|
||||
# array of embeddings
|
||||
# proms are [0, prom_levels]
|
||||
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR
|
||||
# + resps cannot share the AR and NAR embeddings, since they do encode whether to predict the same level but in the next token or predict in place but the next level
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
#
|
||||
self.mode = mode
|
||||
#
|
||||
self.sums = sums
|
||||
|
||||
# maintaining compat is hard
|
||||
def forward(self, xi: Tensor, quant_level: int | Tensor | None = None, offset: int = 0 ) -> Tensor:
|
||||
def forward(self, xi: Tensor, offset: int = 0 ) -> Tensor:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
if self.sums and quant_level > 0:
|
||||
|
@ -159,6 +171,8 @@ class AudioEmbedding(nn.Module):
|
|||
return x
|
||||
|
||||
class Base(nn.Module):
|
||||
# to-do: clean up this property mess
|
||||
|
||||
@property
|
||||
def causal(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
@ -292,12 +306,10 @@ class Base(nn.Module):
|
|||
else:
|
||||
self.proms_emb = AudioEmbedding(
|
||||
[n_prom_tokens] * self.n_prom_levels, d_model,
|
||||
"prom",
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||
)
|
||||
self.resps_emb = AudioEmbedding(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
"resp:len" if "len" in self.capabilities else "resp",
|
||||
sums=self.config.audio_embedding_sums if self.config is not None else True
|
||||
)
|
||||
|
||||
|
@ -700,16 +712,31 @@ class Base(nn.Module):
|
|||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 )
|
||||
if self.version <= 4:
|
||||
embedding = self.proms_emb( input if quant_level == 0 else input[:, :quant_level] )
|
||||
else:
|
||||
if quant_level == 0:
|
||||
embedding = self.proms_emb( input if input.dim() == 1 else input[:, :1], offset = 0 )
|
||||
else:
|
||||
embedding = self.proms_emb( input if input.dim() == 1 else input[:, :quant_level], offset = 0 )
|
||||
elif name == "tone" and self.tones_emb is not None:
|
||||
embedding = self.tones_emb( input )
|
||||
elif name == "resp":
|
||||
if "len" in self.capabilities and quant_level == 0:
|
||||
# fill with "stop" tokens for NAR-only model
|
||||
embedding = self.resps_emb( torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token), offset = 0 )
|
||||
embedding = self.resps_emb(
|
||||
torch.full_like(input if input.dim() == 1 else input[..., 0], self.stop_token),
|
||||
offset = 0
|
||||
)
|
||||
else:
|
||||
# get RVQ level 0, or up to targetted RVQ level inference
|
||||
embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], offset = 0 if quant_level == 0 or "len" in self.capabilities else 1 )
|
||||
if self.version <= 4:
|
||||
embedding = self.resps_emb( input if quant_level == 0 else input[:, :quant_level], quant_level )
|
||||
else:
|
||||
embedding = self.resps_emb(
|
||||
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
|
||||
offset = 0 if quant_level == 0 or "len" in self.capabilities else 1
|
||||
)
|
||||
elif name == "len" and self.len_emb is not None:
|
||||
embedding = self.len_emb( input )
|
||||
else:
|
||||
|
|
462
vall_e/models/nar.py
Normal file
462
vall_e/models/nar.py
Normal file
|
@ -0,0 +1,462 @@
|
|||
"""
|
||||
A (mostly) NAR model that handles inferencing all RVQ levels in parallel (NAR).
|
||||
I believe Meta's Voicebox does this too (predict the utterance length, then decode in parallel)
|
||||
It *does* have to inference the initial length in an autoregresssive-ish manner (it can technically also be done in parallel)
|
||||
|
||||
Initial experiments show this only really "works" for the a few brief seconds before going to silence. I imagine I need to read more papers or just need to train longer.
|
||||
"""
|
||||
|
||||
from .base import Base, list_to_tensor, Categorical
|
||||
from ..config import cfg
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
import random
|
||||
import math
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
|
||||
from ..emb.qnt import trim
|
||||
|
||||
class NAR(Base):
|
||||
@property
|
||||
def capabilities(self) -> list[str]:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.capabilities
|
||||
return cfg.model.capabilities
|
||||
|
||||
@property
|
||||
def causal(self):
|
||||
return "len" in self.capabilities
|
||||
|
||||
@property
|
||||
def norm_type(self):
|
||||
return "ln" # if self.n_resp_levels == 1 else "adaln"
|
||||
|
||||
@property
|
||||
def arch_type(self) -> str:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.arch_type
|
||||
return cfg.model.arch_type
|
||||
|
||||
@property
|
||||
def n_prom_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.prom_levels
|
||||
return cfg.model.prom_levels
|
||||
|
||||
@property
|
||||
def n_resp_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.resp_levels
|
||||
return cfg.model.resp_levels
|
||||
|
||||
@property
|
||||
def n_max_levels(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.max_levels
|
||||
return cfg.model.max_levels
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.tasks
|
||||
return cfg.model.tasks
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.langs
|
||||
return cfg.model.langs
|
||||
|
||||
@property
|
||||
def n_tones(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.tones
|
||||
return cfg.model.tones
|
||||
|
||||
@property
|
||||
def causal_size(self) -> int:
|
||||
# 1 for the stop token
|
||||
# governs how much to shift the logits by
|
||||
# could *technically* make it work to where it can also predict *ALL* RVQ levels in one step, but experimental.py is the better way to go about it
|
||||
return 1 # if self.causal else 0
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.version
|
||||
return cfg.model.version
|
||||
|
||||
def _prune(self, l: Tensor, stop = None):
|
||||
if stop is None:
|
||||
stop = self.stop_token
|
||||
indices = (l == stop).nonzero()
|
||||
if len(indices) == 0:
|
||||
return l
|
||||
return l[: indices.min().item()]
|
||||
|
||||
@staticmethod
|
||||
def _unsqueeze_list(x_list, axis=-1):
|
||||
return [x.unsqueeze(dim=axis) for x in x_list]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_list: list[Tensor],
|
||||
proms_list: list[Tensor],
|
||||
resps_list: list[Tensor] | None = None,
|
||||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
len_list: list[Tensor] | None = None,
|
||||
|
||||
max_steps: int = 1000,
|
||||
max_levels: int = 0,
|
||||
max_resp_context: int = -1,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
sampling_top_k: int = -100,
|
||||
sampling_top_p: float = 1.0,
|
||||
sampling_repetition_penalty: float = 1.0,
|
||||
sampling_repetition_penalty_decay: float = 0.0,
|
||||
sampling_length_penalty: float = 0.0,
|
||||
sampling_beam_width: int = 0,
|
||||
sampling_mirostat_tau: float = 0.0,
|
||||
sampling_mirostat_eta: float = 0.1,
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
# is training
|
||||
if resps_list is not None:
|
||||
p_len_task = 0.25
|
||||
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
# is training
|
||||
assert n_levels == self.n_resp_levels
|
||||
# to-do: make this YAML configurable
|
||||
def sample_task():
|
||||
return "len" if random.random() < p_len_task else "tts"
|
||||
|
||||
# generate task list to train against
|
||||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
if cfg.experimental:
|
||||
# makes higher levels less likely
|
||||
def generate( lo=0, hi=8 ):
|
||||
index = lo
|
||||
p = random.random()
|
||||
for i in range(lo, hi):
|
||||
if p < 1.0 / (2 ** i):
|
||||
index = i
|
||||
return int(index)
|
||||
|
||||
quant_levels = [ 0 if task_list[i] == "len" else generate(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
else:
|
||||
# randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ 0 if task_list[i] == "len" else random.randint(quant_level_range[0], quant_level_range[1]) for i in range(batch_size) ]
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)]
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
task_list=task_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
# NAR
|
||||
if len_list is not None:
|
||||
# is NAR
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels
|
||||
|
||||
# fill with mock tokens
|
||||
prev_list = [ torch.Tensor([ self.stop_token for _ in range(resp_len) ]).to(device=device, dtype=torch.int16) for resp_len in len_list ]
|
||||
|
||||
start = True
|
||||
for n in trange( max_levels, desc="NAR" ):
|
||||
level = 0 if n == 0 else prev_list[0].shape[-1]
|
||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
logits = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
resps_list = [ logit[-l:].argmax(dim=1) for logit, l in zip(logits, len_list) ]
|
||||
|
||||
if n == 0:
|
||||
prev_list = [ r.unsqueeze(-1).to(device) for r in resps_list ]
|
||||
else:
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
return prev_list
|
||||
|
||||
# is AR
|
||||
sequence_list = [ torch.Tensor([0]).to(device=device,dtype=torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
stop_token = 10
|
||||
task_list = [ "len" for _ in range(batch_size) ]
|
||||
|
||||
for n in trange(10, desc="AR"):
|
||||
len_list = sequence_list
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
len_list=len_list,
|
||||
task_list=task_list,
|
||||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||
)
|
||||
|
||||
logits = super().forward(
|
||||
inputs=inputs,
|
||||
)
|
||||
|
||||
r = [ logit[-1:].argmax(dim=1) for logit in logits ]
|
||||
# sanitize
|
||||
for i, token in enumerate(r):
|
||||
if token > 10:
|
||||
r[i] = 0
|
||||
|
||||
# append tokens
|
||||
for i, ri in enumerate(r):
|
||||
if stop_token in ri:
|
||||
stopped[i] = True
|
||||
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
|
||||
|
||||
# stop token found
|
||||
stopped |= r == stop_token
|
||||
if stopped.all().item():
|
||||
break
|
||||
|
||||
# convert tokens into int
|
||||
return [ int("".join([ str(token.item()) for token in r if token != stop_token ])) for r in sequence_list ]
|
||||
|
||||
|
||||
def example_usage():
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||
if cfg.audio_backend == "dac":
|
||||
cfg.sample_rate = 44_000
|
||||
|
||||
from functools import partial
|
||||
from einops import repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..emb.qnt import decode_to_file, unload_model
|
||||
from ..engines import Engine
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
device = "cuda"
|
||||
|
||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||
"""
|
||||
if "mamba" in cfg.model.arch_type:
|
||||
cfg.model.prom_levels = 1
|
||||
cfg.model.resp_levels = 1
|
||||
"""
|
||||
# cfg.model.loss_factors = {}
|
||||
|
||||
def tokenize(content):
|
||||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
qnt = np.load(path, allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.prom_levels, :].t().to(torch.int16)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
|
||||
text_list = [
|
||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt[:, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
|
||||
text_list = text_list[:1]
|
||||
proms_list = proms_list[:1]
|
||||
resps_list = resps_list[:1]
|
||||
|
||||
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
||||
kwargs = {
|
||||
'n_text_tokens': 256,
|
||||
'n_audio_tokens': 1024,
|
||||
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
'n_experts': 1,
|
||||
|
||||
'p_dropout': 0.1,
|
||||
|
||||
'l_padding': 8 if cfg.optimizations.fp8 else 0,
|
||||
|
||||
'config': cfg.model
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
kwargs['config'] = cfg.model
|
||||
except Exception as e:
|
||||
pass
|
||||
"""
|
||||
|
||||
model = NAR(**kwargs).to(device)
|
||||
steps = 200
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||
learning_rate = cfg.hyperparameters.learning_rate if cfg.cfg_path is not None else None
|
||||
|
||||
if cfg.optimizations.dadaptation:
|
||||
# do not combine the two
|
||||
if scheduler == "schedulefree":
|
||||
scheduler = ""
|
||||
|
||||
learning_rate = 1.0
|
||||
|
||||
if optimizer == "prodigy":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0
|
||||
|
||||
optimizer = ml.Prodigy
|
||||
elif optimizer == "adagrad":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-2
|
||||
|
||||
optimizer = ml.Adagrad
|
||||
elif optimizer == "adamw":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-4
|
||||
|
||||
optimizer = ml.AdamW
|
||||
elif optimizer == "sdg":
|
||||
if learning_rate is None:
|
||||
learning_rate = 1.0e-4
|
||||
|
||||
optimizer = ml.SGD
|
||||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
|
||||
if scheduler == "schedulefree":
|
||||
if isinstance(optimizer, ml.AdamW):
|
||||
scheduler = ml.schedulefree.AdamWScheduleFree
|
||||
elif isinstance(optimizer, ml.SGD):
|
||||
scheduler = ml.schedulefree.SGDScheduleFree
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
if scheduler is not None:
|
||||
print("Scheduler:", scheduler)
|
||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
model = ml.replace_linear( model )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.embedding:
|
||||
model = ml.replace_embedding( model )
|
||||
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=1000 ):
|
||||
if cfg.audio_backend == "dac" and name == "init":
|
||||
return
|
||||
|
||||
engine.eval()
|
||||
|
||||
len_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
resps_list = engine( text_list, proms_list, len_list=len_list, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
def train():
|
||||
engine.train()
|
||||
t = trange(steps)
|
||||
for i in t:
|
||||
stats = {"step": i}
|
||||
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
|
||||
stats |= {"grad_norm": engine.get_global_grad_norm()}
|
||||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
"""
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
#sample("init", 5)
|
||||
train()
|
||||
sample("final")
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
Loading…
Reference in New Issue
Block a user