vall-e/vall_e/models/base.py

1723 lines
72 KiB
Python

"""
Core model for handling all VALL-E tasks.
This should handle all the "low" level things such as:
* parsing inputs to sequences
* converting sequences to embeddings
* forward pass
* processing loss and returning logits
Additional functionality (preparing inputs, generating full audio) should be delegated to classes that inheret the base model
"""
# to-do: clean this whole mess up
import math
import torch
import torch.nn.functional as F
import random
import numpy as np
import re
from time import perf_counter
from collections import namedtuple
from typing import Literal, overload, Optional, Tuple
from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.nn import Embedding
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import *
from ..utils import ml, clamp
from ..samplers import *
# yuck, kind of needed
from ..data import get_task_symmap
import logging
_logger = logging.getLogger(__name__)
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'inputs', 'loss', 'attentions', 'hidden_states'])
Sampled = namedtuple('Sampled', ['ids', 'logits', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
summed_embeddings_task = [ "stt" ]
special_tasks = [ "len", "stt", "phn", "text", "un-phn" ]
non_tokened_names = ["task", "dropout_mask", "classifier_level"]
task_outputs = {
"tts": "resp",
"ns": "resp",
"sr": "resp",
"stt": "phn",
"len": "len",
"phn": "phn",
"un-phn": "text",
}
# yuck
def _get_offsets():
return {
"phn": (0, 256),
"quant_level": (256, 264),
"lang": (264, 270),
"task": (270, 279),
"len": (279, 290),
"tone": (290, 291),
"sep": (291, 292),
"prom|0": (292, 1316),
"prom|1": (1316, 2340),
"prom|2": (2340, 3364),
"prom|3": (3364, 4388),
"prom|4": (4388, 5412),
"prom|5": (5412, 6436),
"prom|6": (6436, 7460),
"prom|7": (7460, 8484),
"resps|AR:0:0": (8484, 9509),
"resps|NAR:0:1": (9509, 10533),
"resps|NAR:1:2": (10533, 11557),
"resps|NAR:2:3": (11557, 12581),
"resps|NAR:3:4": (12581, 13605),
"resps|NAR:4:5": (13605, 14629),
"resps|NAR:5:6": (14629, 15653),
"resps|NAR:6:7": (15653, 16677),
"resps|NAR:0:0": (16677, 17702),
}
def _dropout_mask( input, p ):
return (torch.rand(input.shape[0], device=input.device) < p)
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def _join(x: tuple[Tensor], sep: Tensor):
"""
Args:
x: (k t d)
sep: (d)
"""
ret = x[0]
for i in range(1, len(x)):
ret = torch.cat((ret, sep[None], x[i]), dim=0)
return ret
def list_to_tensor(x_list: list[Tensor]):
l = list(map(len, x_list))
x = pad_sequence(x_list, batch_first=True)
m = _create_mask(l, x_list[0].device)
m = m.to(x).int()
return x, m
def _interleave_sequence_reshape( input: list[torch.Tensor], dim=-1 ):
shape = (input[0].shape[0] * len(input), input[0].shape[dim] )
return torch.concat( [ i.t() for i in input ] ).t().reshape( shape )
def _interleave_sequence_flatten( input: list[torch.Tensor] ):
return torch.concat( [ i.t() for i in input ] ).t().flatten()
# Embedding that sums each codebook level within a given input acoustic prompt
# Mostly to handle some oversights and errors during testing
class AudioEmbedding(nn.Module):
def __init__(
self,
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
token_dim: int, # dimensionality of the embedding
sums: bool = True, # whether to sum all previous layers of embeddings to factor in other codebook levels (I do not know which way is better)
l_embedding_names: list[str] = [], # names to map to indices
):
super().__init__()
# array of embeddings
# proms are [0, resp_levels]
# resp are split to where [0] is for the AR, and [1:] are reserved for NAR (except [-1] for NAR-len if utilized)
self.embeddings = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for n_tokens in l_embedding_tokens])
# further experimentation is needed to see if this actually is useful
self.sums = sums
# index of name maps to its corresponding embedding in the list
self.names = l_embedding_names
def forward(
self,
xi: Tensor, # input tensor
offset: int | None = None, # explicit offset, interop for the older codebase. use `name` instead
quant_level: int | None = None, # the codebook level of the audio we currently have (our `input_quant_level`)
name: str | None = None, # specifies where in the embeddings list to start from and iterate through
sums = None
) -> Tensor:
# if not explicitly requested, use the default setting at instantiation time
if sums is None:
sums = self.sums
# if not explicitly requested, assume input quant_level based on shape
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
# handle mapping embedding index offset
if name in self.names:
offset = self.names.index( name )
offset -= quant_level # offset by quant_level since it'll iterate up that many levels
# sum all prior codebook levels if requested (as quant_level = 0 does not have any other codebooks to sum through)
if sums and quant_level > 0:
x = sum( [ self.embeddings[input_quant_level + offset]( xi[:, input_quant_level] ) for input_quant_level in range( quant_level + 1 ) ] )
else:
input_quant_level = quant_level
x = self.embeddings[input_quant_level + offset]( xi if xi.dim() == 1 else xi[:, input_quant_level] )
return x
# per-level classification
# it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
class Classifiers(nn.Module):
def __init__(
self,
l_embedding_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
l_embedding_names: list[str], # list of names to map to each classifier,
d_model: int, # dimensionality of the embedding
bias: bool = True,
):
super().__init__()
self.proj = nn.ModuleList([nn.Linear(d_model, n_tokens, bias=bias) for n_tokens in l_embedding_tokens])
self.names = l_embedding_names
def indices(
self,
names
):
if isinstance( names[-1], int ):
return names
return [ self.names.index(name) for name in names ]
def forward(
self,
xi: Tensor,
levels: list[int] | None = None,
names: list[str] | None = None,
stack = False,
) -> Tensor:
dtype = xi[0].dtype
device = xi[0].device
if levels and isinstance( levels[-1], str ):
names = levels
levels = []
# map names to levels
if names and not levels:
levels = [ None if name not in self.names else self.names.index(name) for name in names ]
xi = [ x if l == None else self.proj[l]( x ) for x, l in zip(xi, levels) ]
if not stack:
return xi
# pad if needed
# to-do: validate that this causes ZERO issues
# addendum: this does cause problems
max_size = max([ x.shape[-1] for x in xi ])
xi = [
#x if l == 0 else
x if x.shape[-1] == max_size else
torch.cat( [x, torch.full( (x.shape[0], max_size - x.shape[-1]), -float("inf"), device=device, dtype=dtype) ], dim=-1 )
for x, l in zip(xi, levels)
]
return torch.stack( xi )
def _dropout_codes( x, dropout_mask, dropout_token, swapped=False ):
"""
x = x.clone().detach().t()
for l, t in enumerate( x ):
x[l] = torch.where( dropout_mask, dropout_token, x[l] )
return x.t()
"""
x = x.clone().detach()
levels = x.shape[-1]
for level in range( levels ):
lhs = dropout_token if not swapped else x[..., level]
rhs = x[..., level] if not swapped else dropout_token
x[..., level] = torch.where( dropout_mask, lhs, rhs )
return x
class Metrics(nn.Module):
def __init__(
self,
l_embedding_tokens: int | list[int],
top_k = 10,
average="micro",
multidim_average="global",
ignore_index = -100
):
super().__init__()
self.accuracy = nn.ModuleList([ MulticlassAccuracy(
n_tokens,
top_k=top_k,
average=average,
multidim_average=multidim_average,
ignore_index=ignore_index,
) for n_tokens in l_embedding_tokens ])
self.precision = nn.ModuleList([ MulticlassPrecision(
n_tokens,
top_k=top_k,
average=average,
multidim_average=multidim_average,
ignore_index=ignore_index,
) for n_tokens in l_embedding_tokens ])
def calc_accuracy( self, inputs, targets, classifier_levels ):
return sum( [ self.accuracy[l]( input[:, :self.accuracy[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs )
def calc_precision( self, inputs, targets, classifier_levels ):
return sum( [ self.precision[l]( input[:, :self.precision[l].num_classes], target ) for target, input, l in zip( targets, inputs, classifier_levels ) ] ) / len( inputs )
def __call__(self, *args, **kwargs):
return dict(
acc=self.calc_accuracy(*args, **kwargs),
)
class Base(nn.Module):
def loss_factor(self, k):
if self.config is None:
return 1.0
return self.config.loss_factor(k)
def _prune(self, l: Tensor, stop = None):
if stop is None:
stop = self.stop_token
indices = (l == stop).nonzero()
if len(indices) == 0:
return l
return l[: indices.min().item()]
def __init__(
self,
n_phn_tokens: int = 256,
n_audio_tokens: int = 1024,
n_text_tokens: int = 8575,
d_model: int = 512,
d_ffn: int = 4,
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
n_experts: int = 1,
l_padding: int = 0,
training = True,
attention = None,
config = None,
):
super().__init__()
self.training = training
self.teaching = False
self.config = config
self.n_phn_tokens = n_phn_tokens
self.n_audio_tokens = n_audio_tokens
self.n_text_tokens = n_text_tokens
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.n_experts = n_experts
self.l_padding = l_padding
self.ignore_index = -100
self.n_resp_levels = self.config.resp_levels if self.config else n_resp_levels
self.n_max_levels = self.config.max_levels if self.config else n_resp_levels
self.capabilities = self.config.capabilities if self.config else ["ar", "nar", "len"]
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
self.stop_token = self.n_audio_tokens
self.mask_token = self.stop_token
self.causal = True
self.version = self.config.version if self.config is not None else 6
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
self.arch_type = self.config.arch_type if self.config is not None else "llama"
# check if requested arch is unavailable
if self.arch_type in ERROR_ARCHES:
raise ERROR_ARCHES[self.arch_type]
if not attention:
attention = self.config.attention if self.config is not None else "auto"
# crunge
if self.config is not None and config.teacher:
self.teaching = True
self.training = False
attention_backend = attention
audio_embedding_sums = self.config.experimental.audio_embedding_sums if self.config is not None else False
split_classifiers = self.config.experimental.split_classifiers if self.config is not None else False
tie_classifier_to_embedding = self.config.experimental.tie_classifier_to_embedding if self.config is not None else False
audio_embedding_mode = self.config.experimental.audio_embedding_mode if self.config is not None else ""
unified_position_ids = self.config.experimental.unified_position_ids if self.config is not None else True
#interleave = self.config.experimental.interleave if self.config is not None else False
noncausal_masks = self.config.experimental.noncausal_masks if self.config is not None else False
classifiers_bias = self.config.experimental.classifiers_bias if self.config is not None else False
max_position_embeddings = self.config.experimental.max_position_embeddings if self.config is not None else (75 * 60 * 5)
masking_ratio = self.config.experimental.masking_ratio if self.config is not None else False
ignore_inputs_for_loss = self.config.experimental.ignore_inputs_for_loss if self.config is not None else False
resp_parallel_training = self.config.experimental.resp_parallel_training if self.config is not None else True
predict_causally = self.config.experimental.predict_causally if self.config is not None else False
monolithic_audio_encoder = self.config.experimental.monolithic_audio_encoder if self.config is not None else False
self.resp_parallel_training = resp_parallel_training
self.predict_causally = predict_causally
n_tasks = self.config.tasks if self.config is not None else 8
n_langs = self.config.langs if self.config is not None else 2
n_tones = self.config.tones if self.config is not None else 1
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
l_embedding_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens - 1]
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] + ['NAR:0:0']
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
l_classifier_names = l_embedding_names
# STT
l_classifier_names += [ "stt" ]
l_classifier_tokens += [ n_phn_tokens ]
# LEN
if "len" in self.capabilities:
l_classifier_tokens += [ 11 ]
l_classifier_names += ["len"]
# TEXT => PHN / PHN => TEXT
if self.version >= 6:
l_classifier_tokens += [ n_text_tokens ]
l_classifier_names = l_embedding_names + [ "text" ]
self.n_vocab = n_vocab
self.unified_position_ids = unified_position_ids
self.inject_timestep_embedding = False # results in bad output
self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.text_emb = Embedding(n_phn_tokens, d_model)
self.raw_text_emb = None
self.langs_emb = None
self.tones_emb = None
self.tasks_emb = None
self.rvq_l_emb = None
self.len_emb = None
# it would be nicer for these to be a token or live inside an embedding
self.sep = nn.Parameter(torch.randn(d_model))
self.dropout_token = nn.Parameter(torch.randn(d_model))
self.proms_emb = AudioEmbedding(
[n_audio_tokens] * self.n_resp_levels, d_model,
sums=audio_embedding_sums == "prom" or audio_embedding_sums == True,
)
self.resps_emb = AudioEmbedding(
l_embedding_tokens, d_model,
sums=audio_embedding_sums == "resp" or audio_embedding_sums == True,
l_embedding_names=l_embedding_names,
)
self.langs_emb = Embedding(n_langs, d_model) if n_langs > 0 else None
self.tasks_emb = Embedding(n_tasks, d_model) if n_tasks > 0 else None
self.capabilities += ["lang"]
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
self.tones_emb = Embedding(n_tones, d_model) if n_tones > 0 else None
self.rvq_l_emb = Embedding(self.n_resp_levels, d_model)
self.len_emb = Embedding(11, d_model)
self.raw_text_emb = Embedding(self.n_text_tokens, d_model)
if attention_backend == "auto":
attention_backend = "sdpa"
hf_attention = attention_backend
HF_ATTENTIONS = ["eager", "sdpa", "flash_attention_2"]
if attention_backend not in HF_ATTENTIONS:
hf_attention = None
if attention_backend not in AVAILABLE_ATTENTIONS:
raise ValueError(f"Requesting attention `{attention_backend}` but is not available. Currently available: {AVAILABLE_ATTENTIONS}")
# override any requested padding size
if attention_backend == "flash_attn_v100":
self.l_padding = 32
elif attention_backend == "fused_attn":
self.l_padding = 128
self.model_config = LlamaConfig(
vocab_size=n_vocab,
hidden_size=d_model,
max_position_embeddings=max_position_embeddings,
intermediate_size=d_model*d_ffn,
num_hidden_layers=n_layers,
num_attention_heads=n_heads,
attention_dropout=p_dropout if training else 0.0,
num_key_value_heads=n_heads,
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
#gradient_checkpointing=self.gradient_checkpointing,
)
self.model_config.attn_mode = attention_backend
self.model = LlamaModel(self.model_config)
if not split_classifiers:
self.classifier = nn.Linear(d_model, n_vocab, bias=classifiers_bias)
self.classifiers = None
self.metrics = None
else:
self.classifier = None
self.classifiers = Classifiers( l_classifier_tokens, l_classifier_names, d_model, bias=classifiers_bias )
self.metrics = Metrics( l_classifier_tokens )
def _forward(
self,
inputs,
mask = None,
is_causal = None,
position_ids = None,
state = None,
output_attentions = False,
output_hidden_states = False,
):
x = inputs
m = mask #.squeeze(-1).int()
aux_loss = None
attentions = None
hidden_states = None
# HF transformer derived model
kwargs = dict(
inputs_embeds=x,
attention_mask=m,
past_key_values=state,
position_ids=position_ids,
use_cache=False, # not self.training,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
is_causal=is_causal,
)
if self.n_experts > 1 and self.training:
kwargs["output_router_logits"] = True
output = self.model(**kwargs)
x = output["last_hidden_state"]
# to-do: figure out why KV caching doesn't work
#if not self.training:
if state is not None:
state = output["past_key_values"]
if output_attentions:
attentions = output["attentions"]
if output_hidden_states:
hidden_states = output["hidden_states"]
if self.n_experts > 1 and self.training:
router_logits = output["router_logits"]
aux_loss = self.model.config.router_aux_loss_coef * load_balancing_loss_func( router_logits, self.model.config.num_local_experts, self.model.config.num_experts_per_tok, m )
# process it into a format that I like
if output_hidden_states:
# hidden_states is actually layers + 1, as hidden_states[0] == embedding...........
hidden_states = [ state for state in hidden_states[1:] ]
# apply normalization to these states (to-do: check if this matters)
# but skip the last state, as it already is normalized
hidden_states = [ x if i == self.n_layers - 1 else self.model.norm(output.hidden_states[i]) for i, state in enumerate( hidden_states ) ]
return Logits(x, state, inputs, aux_loss, attentions, hidden_states)
# takes a bunch of separate lists and parses them into an ordered array of tuples to guide input sequence creation
def inputs(
self,
phns_list: list[Tensor] | None = None,
text_list: list[Tensor] | None = None,
proms_list: list[Tensor] | None = None,
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
len_list: list[Tensor] | None = None,
task_list: list[str] | None = None,
time_list: list[Tensor] | None = None,
quant_levels: int | list[int] | Tensor | None = None
):
if phns_list and phns_list[0] is not None:
device = phns_list[0].device
batch_size = len(phns_list)
elif text_list and text_list[0] is not None:
device = text_list[0].device
batch_size = len(text_list)
elif proms_list and proms_list[0] is not None:
device = proms_list[0].device
batch_size = len(proms_list)
elif resps_list and resps_list[0] is not None:
device = resps_list[0].device
batch_size = len(resps_list)
inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size):
quant_level = quant_levels[i] if quant_levels is not None else 0
task_type = task_list[i] if task_list is not None else "tts"
timestep = time_list[i] if time_list is not None else None
classifier_level = None
# insert task type as a string
inputs[i].append( ( "task", task_type ) )
# to-do: maybe not split the below blocks up
# might be beneficial in the event I need to use a difference sequence, such as STT tasks
# Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX
if task_type in get_task_symmap() and task_type not in special_tasks:
# insert the text prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
elif text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert RVQ level guidance token if the model is versioned for it
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
classifier_level = "AR:0:0" if quant_level == 0 else f'NAR:{quant_level-1}:{quant_level}'
# insert input audio prompt
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
# insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# insert timestep token
if timestep is not None:
# force set to use this classifier level
classifier_level = "NAR:0:0"
# store timestep information
if self.masking_ratio in ["random", "rand"]:
# cosine scheduled timestep => masking ratio
p = math.cos(timestep * math.pi * 0.5)
# I don't think is is necessary as the timestep is encoded in the sequence by the number of masked tokens, probably.
if self.inject_timestep_embedding:
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
else:
# a paper said to use a fixed masking ratio of 0.8 for training
# ...but I want to make it user adjustable
p = self.masking_ratio
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if self.training:
dropout_mask = _dropout_mask( resps_list[i], p )
inputs[i].append( ("dropout_mask", dropout_mask ) )
# insert the current output response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
inputs[i].append( ("classifier_level", classifier_level) )
# Audio length prediction task
# Sequence: <text><sep><rvq lvl><prom><sep><len>
elif task_type == "len":
# throw an error so we don't silently train without this
if self.len_emb is None:
raise Exception(f"Requesting task `{task_type}` but corresponding embedding is not defined.")
# insert the text prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
elif text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# technically will always be level 0 but for the sake of keeing the input formatting coherent...
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert input audio prompt
if proms_list is not None and proms_list[i] is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
# insert tone token if we're trained for it
if "tone" in self.capabilities and tone_list is not None and tone_list[i] is not None:
inputs[i].append( ( "tone", tone_list[i] ) )
# insert output length tokens (if it exists)
if len_list is not None and len_list[i] is not None:
inputs[i].append( ( "len", len_list[i] ) )
# "encode" length to tokens for 0-9 + stop
elif resps_list is not None and resps_list[i] is not None:
# yes this could be encoded better
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
inputs[i].append( ("classifier_level", "len") )
# Speech-to-Text prediction task
# Sequence: <resp><sep><rvq lvl><sep><text>
elif task_type == "stt":
# insert the input response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert RVQ level guidance token if the model is versioned for it
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the output text prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
inputs[i].append( ("classifier_level", "phn") )
# Text phonemizing task
# Sequence: <text><sep><lang><sep><phonemes>
elif task_type == "phn":
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
inputs[i].append( ("classifier_level", "phn") )
# Text de-phonemizing task
# Sequence: <text><sep><lang><sep><phonemes>
elif task_type == "un-phn":
# insert the text prompt
if phns_list is not None and phns_list[i] is not None:
inputs[i].append( ( "phn", phns_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
if self.rvq_l_emb is not None:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ("classifier_level", "text") )
else:
raise Exception(f'Unrecognized task: {task_type}')
return inputs
def offset_inputs(
self,
inputs: list,
direction: int = 1, # -1 to de-offset
):
offsets = _get_offsets()
for batch_index, batch_input in enumerate(inputs):
quant_level = None
classifier_level = None
# pre-iterate
for name, input in batch_input:
if name == "quant_level":
quant_level = input
elif name == "classifier_level":
classifier_level = input
for name, input in batch_input:
if not isinstance( input, torch.Tensor ):
continue
k = name
if name == "prom":
k = f'prom|{quant_level}'
elif name == "resp":
k = f'resps|{classifier_level}'
if k not in offsets:
continue
start, end = offsets[k]
for i, t in enumerate( input ):
input[i] += start * direction
return inputs
def inputs_to_embeddings(
self,
inputs: list,
quant_levels: int | list[int] | Tensor | None = None
):
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_embedding( input, quant_level ):
if isinstance(input, str):
return self.tasks_emb( torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16) )
# get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4:
return self.proms_emb(
input if quant_level == 0 else input[:, :quant_level]
)
return self.proms_emb(
input if input.dim() == 1 else input[:, : 1 if quant_level == 0 else quant_level],
quant_level = 0 if quant_level == 0 else quant_level - 1, # input is one below the target quant level
offset = 0,
)
# yuck
token_dropout_rate = self.config.experimental.token_dropout_rate if self.config else 0.0
token_dropout_rvq_levels = self.config.experimental.token_dropout_rvq_levels if self.config else None
if self.dropout_token is None or not self.training:
token_dropout_rate = 0.0
if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [1, self.resp_levels]
x_list = []
for batch_index, batch_input in enumerate(inputs):
batch = []
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
task_type = "tts"
input_prom = None
classifier_level = None
dropout_mask = None
timestep = None
# pre-iterate
for name, input in batch_input:
if name == "classifier_level":
classifier_level = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "timestep":
timestep = input
for name, input in batch_input:
# technically can provide a map for input_name => embedding, but some embedding requires additional processing
embedding = None
# is already an embedding
if name == "task":
# noop
# *maybe* inject a token for specifying task type
task_type = input
continue
elif name == "phn":
embedding = self.text_emb( input )
device = embedding.device
elif name == "text" and self.raw_text_emb is not None:
embedding = self.raw_text_emb( input )
device = embedding.device
elif name == "quant_level" and self.rvq_l_emb is not None:
embedding = self.rvq_l_emb( input )
elif name == "lang" and self.langs_emb is not None:
embedding = self.langs_emb( input )
elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input
"""
if proms is None:
continue
"""
# to-do: probably insert separators if task requires it?
embedding = torch.cat( [ prompt_input_to_embedding( input, quant_level ) for input in proms if input is not None ] )
elif name == "tone" and self.tones_emb is not None:
embedding = self.tones_emb( input )
elif name == "resp":
# if training NAR-len RVQ level 0
if dropout_mask is not None:
embedding = self.resps_emb(
# if masked use masked token, else original token
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, quant_level] ),
#quant_level = 0,
name = classifier_level,
)
# NAR-len
elif classifier_level == f"NAR:{quant_level}:{quant_level}":
embedding = self.resps_emb(
input if input.dim() == 1 else input[:, quant_level],
#quant_level = 0,
name = classifier_level,
)
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task:
# we do a manual sum because I trained it to use the AR embeddings + NAR embeddings for STT......
embedding = sum([ self.resps_emb(
input[:, :l+1],
offset = 0 if l == 0 else 1, # or maybe set to 1
quant_level = l,
#name = 'AR:0:0' if l == 0 else f'NAR:{l-1}:{l}',
sums = False
) for l in range( input.shape[-1] - 1 ) ])
else:
# get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4:
embedding = self.resps_emb(
input if quant_level == 0 else input[:, :quant_level],
quant_level
)
else:
input_quant_level = 0 if quant_level == 0 else quant_level - 1 # input is one below the target quant level
embedding = self.resps_emb(
input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level],
#offset = 0 if classifier_level.startswith("AR:") else 1,
name = classifier_level,
quant_level = input_quant_level,
)
# apply token dropout
"""
if token_dropout_rate > 0.0 and (token_dropout_rvq_levels[0] <= quant_level and quant_level <= token_dropout_rvq_levels[1]):
steps = embedding.shape[0] - (1 if quant_level == 0 else 0) # do not mess with stop token
for i in range( steps ):
if random.random() > token_dropout_rate:
continue
embedding[i] = self.dropout_token
"""
elif name == "timestep" and self.time_emb is not None:
embedding = self.time_emb( input )
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )
else:
# should probably raise an exception so things aren't processed silently
continue
batch.append(embedding)
x_list.append( _join( batch, self.sep ) )
return x_list
# get an attribute from a given input list
def get_input(
self,
inputs,
name,
at=None,
):
find_all = at is None
res = [] if at is None else None
for batch_index, batch_input in enumerate(inputs):
if not find_all and batch_index != at:
continue
for n, input in batch_input:
if n != name:
continue
if not find_all:
return input
res.append( input )
return res
# creates position ids from a given input list
# if not unified_position_ids, then each input segment will have its own sequence
def inputs_to_position_ids(
self,
inputs: list,
mask: Tensor,
):
device = mask.device
# shamelessly grabbed from modeling_llama.py
ids = mask.long().cumsum(-1) - 1
ids.masked_fill_( mask == 0, 1 )
# there's a better way
if not self.unified_position_ids:
x_list = []
def get_input_token_length( name, input, task ):
# task token
if isinstance(input, str):
return 1
# list of tokens
if not isinstance(input, torch.Tensor):
return sum( [ i.shape[0] for i in input if isinstance(i, torch.Tensor) ] )
# ending input will not have a separator later
return input.shape[0]
for batch_index, batch_input in enumerate(inputs):
# pre-iterate
task = "tts"
for name, input in batch_input:
if name == "task":
task = input
batch = torch.cat( [
torch.tensor([*range(get_input_token_length(name, input, task) + (1 if name != task_outputs.get(task, name) else 0))], device=device, dtype=torch.int32)
for name, input in batch_input if name not in non_tokened_names
] )
delta = ids[batch_index].shape[0] - batch.shape[0]
if delta > 0:
batch = torch.cat( [ batch, torch.tensor([1] * delta, device=device, dtype=torch.int32) ] )
x_list.append( batch )
ids = torch.stack( x_list )
return ids.to(device=device, dtype=torch.int32)
def calc_loss(
self,
inputs: list,
logits,
quant_levels: list[int] | None = None,
compute_hard_loss = True,
compute_acc = True,
):
loss = {}
stats = {}
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
if isinstance(input, str):
return torch.tensor( [ get_task_symmap()[input] ], device=device, dtype=torch.int16)
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config and self.config.experimental.audio_embedding_sums):
return torch.full_like(input[..., 0], self.ignore_index)
return input if input.dim() == 1 else input[:, quant_level]
def _calc_loss( logit, sequence, causal = True ):
# filter tokens that exceed the vocab size
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
# drop if all tokens are ignored
if all(sequence == self.ignore_index):
return None, None
# shift if causal
if causal or self.predict_causally:
l = self.causal_size
logit = logit[..., :-l, :] # shift the target so that token n...
sequence = sequence[..., l:] # ...predicts token n + 1
nll = None
metrics = None
if compute_hard_loss:
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
if compute_acc:
if self.metrics is not None and classifier_level in self.classifiers.names:
metrics = self.metrics.calc_accuracy( [ logit ], [ sequence ], self.classifiers.indices([ classifier_level ]) )
else:
accuracy_metric = MulticlassAccuracy(
logit.shape[-1],
top_k = min(logit.shape[0], 10),
average="micro",
multidim_average="global",
ignore_index = -100
).to(logit.device)
metrics = accuracy_metric( logit, sequence )
metrics
return nll, metrics
for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index]
target = []
causal = True
task_type = "tts"
dropout_mask = None
classifier_level = None
output_len = 0
for name, input in batch:
if name == "task":
task_type = input
elif name == "dropout_mask":
dropout_mask = input
elif name == "classifier_level":
classifier_level = input
# autoregressive, causal
if classifier_level.startswith("AR:"):
causal = True
# nonautoregressive, parallel
elif classifier_level.startswith("NAR:"):
causal = False
it = 0
for name, input in batch:
token = None
ignored = False
# non-tokened tasks
if name in non_tokened_names:
continue
# prom can either be a tensor itself or a list of tensors and strings
if name == "prom":
# expand to list if not a list
proms = [ input ] if isinstance(input, torch.Tensor) else input
# iterate over the list to inject their tokens
token = torch.cat( [ prompt_input_to_token( input, quant_level ) for input in proms if input is not None ] )
if logits[batch_index].dim() < 3 and token.dim() >= 2:
token = token[..., 0]
elif name == "resp":
# mask found, apply it
token = input if input.dim() == 1 else input[:, quant_level]
# mask found, apply it
if dropout_mask is not None:
token = torch.where( dropout_mask, token, self.ignore_index )
# not a special input, inject as-is
else:
token = input
if not isinstance(token, torch.Tensor):
continue
if token.is_floating_point():
ignored = True
# grab range of our logits for later
seq_len = token.shape[0]
start, end = it, it+seq_len
it += seq_len + 1 # +1 to incorporate the separator
# deduce if a name for a task is an input or output
if name != task_outputs.get(task_type, name):
if self.ignore_inputs_for_loss:
ignored = True
else:
output_len = seq_len
if ignored:
# pruned
if self.config.loss_factors:
continue
# fill with ignored out tensor
token = torch.tensor( [ self.ignore_index ] * token.shape[0], device=device, dtype=torch.int16)
# perform loss calculation on the individual piece
if self.config.loss_factors:
loss_factor = self.loss_factor(name)
if loss_factor == 0.0:
continue
"""
if name == "resp":
name = f'{name}[{quant_level}]'
"""
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
if nll is not None:
if f'{name}.nll' not in loss:
loss[f'{name}.nll'] = []
loss[f"{name}.nll"].append( nll * loss_factor )
if metrics is not None:
if f'{name}.acc' not in stats:
stats[f'{name}.acc'] = []
stats[f"{name}.acc"].append( metrics )
# add to list
else:
target.append( token )
# perofrm loss calculation on the entire sequence
if not self.config.loss_factors:
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
if nll is not None:
if 'nll' not in loss:
loss['nll'] = []
loss["nll"].append( nll )
if metrics is not None:
if 'acc' not in stats:
stats['acc'] = []
stats["acc"].append( metrics )
# average
loss = { name: sum( loss[name] ) / len( loss[name] ) for name in loss.keys() }
stats = { name: sum( stats[name] ) / len( stats[name] ) for name in stats.keys() }
return LossStats(loss, stats)
def forward(
self,
inputs: list,
quant_levels: list[int] | None = None,
state: dict | list | None = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
):
# derive quant levels from inputs if not provided
if quant_levels is None:
quant_levels = [ x.item() for x in self.get_input( inputs, "quant_level" ) ]
# inputs don't have quant levels added, pure AR
if len(quant_levels) != len(inputs):
quant_levels = [ 0 for _ in range(len(inputs)) ]
x_list = self.inputs_to_embeddings( inputs, quant_levels )
x, mask = list_to_tensor(x_list)
training = self.training
teaching = self.teaching
device = x.device
batch_size = len(x_list)
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input
shape = list(x.shape)
shape[1] = self.l_padding - shape[1] % self.l_padding
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding], dim=1)
# pad mask
shape[2] = 1
padding = torch.zeros(shape[:2], dtype=x.dtype, device=x.device)
mask = torch.cat([mask, padding], dim=1)
m = mask.unsqueeze(dim=-1)
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )
causal_levels = [ "phn", "len", "phn" ] + [ f"AR:{_}:{_}" for _ in range( self.n_resp_levels) ]
# right now limit to new versions because I need to retrain the model for noncausal masks...
is_causal = [ l in causal_levels for l in classifier_levels ] if self.noncausal_masks else [ True for l in classifier_levels ]
output = self._forward(
inputs=x,
mask=mask,
state=state,
is_causal=is_causal,
position_ids=position_ids,
output_attentions = output_attentions,
)
logits = output.logits
hidden_states = output.hidden_states
# output projection layer
# the very, very original implementation multiplied by the mask, but the mask only attends to padding, and the padding gets removed anyways
if self.classifier is not None:
logits = self.classifier(logits) # * m
# to-do: piece-wise classification, now that there's a head for text
# although again, one single monolithic head would be preferable instead......
elif self.classifiers is not None:
logits = self.classifiers(logits, levels = classifier_levels )
# Remove padding
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
if not training:
loss = None
stats = None
self.loss = None
self.stats = None
# compute loss if the target is given
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
# include any additional losses (for example: MoE router)
if output.loss is not None:
loss["aux_loss"] = output.loss
self.loss = loss
self.stats = stats
# rewrap, because we're modifying the logits here
return Logits(logits, output.state, inputs, loss, output.attentions, hidden_states)
def sample(
self,
logits: list[Tensor], # logit scores
prev_list: list[Tensor] | None = None, # previous tokens
quant_levels: list[int] | None = None, # to-do: derive this from the prev_list
**sampling_kwargs,
):
# yikes
temperature = sampling_kwargs.get("temperature", 1.0)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
top_k = sampling_kwargs.get("top_k", -100)
top_p = sampling_kwargs.get("top_p", 1.0)
min_p = sampling_kwargs.get("min_p", 0.0)
# repetition penalty parameters
repetition_penalty = sampling_kwargs.get("repetition_penalty", 1.0)
repetition_penalty_decay = sampling_kwargs.get("repetition_penalty_decay", 0.0)
# length penalty parameters
length_penalty = sampling_kwargs.get("length_penalty", 0.0)
# beam sampling parameters
beam_width = sampling_kwargs.get("beam_width", 0)
# mirostat sampling parameters
mirostat = sampling_kwargs.get("mirostat", None)
# DRY sampling parameters
dry_multiplier = sampling_kwargs.get("dry_multiplier", 0.0)
dry_base = sampling_kwargs.get("dry_base", 1.75)
dry_allowed_length = sampling_kwargs.get("dry_allowed_length", 2)
#
top_no = sampling_kwargs.get("top_no", 1.0)
#
attentions = sampling_kwargs.get("attentions", None)
batch_size = len( logits )
if min_temperature < 0:
min_temperature = temperature
# pick last RVQ level
if prev_list is not None:
prev_list = [ prevs if prevs.dim() == 1 else prevs[:, -1] for prevs in prev_list ]
scores = None
entropy = None
#logits = [ logit.to(device="cpu", dtype=logit.dtype if logit.dtype != torch.float16 else torch.float32) for logit in logits ]
#logits = [ logit.to(device="cpu") for logit in logits ]
# (AR) entropix sampling
# we do it before everything to retain logits for the entire sequence (even though it's still better to pass only the last token)
if attentions is not None and quant_levels is None:
# move to CPU for speedups
seq_lens = [ logit.shape[0] for logit in logits ]
attentions = torch.stack(attentions, dim=1).to(device="cpu") # ( batch, layer, heads, seq_len, seq_len )
res = [ sample_entropix(
logit[:seq_lens[batch], :], # ( seq_len, vocab )
attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
temperature,
top_k,
top_p,
min_p,
) for batch, logit in enumerate(logits) ]
if res:
return Sampled([ r[0] for r in res ], logits, scores, [ r[1] for r in res ])
"""
elif quant_levels is None:
seq_lens = [ logit.shape[0] for logit in logits ]
entropy = [ calculate_entropix_metrics(
logit[:seq_lens[batch], :], # ( seq_len, vocab )
#attentions[batch, :, :, :seq_lens[batch], :seq_lens[batch]], # (layer, heads, seq_len, seq_len )
) for batch, logit in enumerate(logits) ]
"""
# (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
seq_lens = map(len, prev_list)
logits = [ logit[-l:] for logit, l in zip(logits, seq_lens) ]
# (AR chunkwise) return the last chunkwise piece
elif self.causal:
seq_lens = [ logit.shape[0] - self.causal_size for logit in logits ]
logits = [ logit[-self.causal_size:] for logit in logits ]
# (NAR) disable stop token
if quant_levels is not None:
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
# (AR-len) disable extraneous tokens
"""
if quant_levels is None and "len" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ]
"""
# perform repetition penalizing
if prev_list is not None and repetition_penalty != 1.0:
logits = [ reptition_penalize(logit, previous=prevs, factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal and prev_list is not None and length_penalty != 0.0:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# perform min_p filtering of our logits
if min_p > 0.0:
logits = [ min_p_filtering(logit, min_p=min_p) for logit in logits ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# trigger dynamic temperature sampling if the minimum temperature is not the same as the sampling temperature
# epsilon float comparison because I don't trust Python
if abs(temperature - min_temperature) >= 0.001:
logits = [ dynamic_temperature(logit, temperature=temperature, min_temperature=min_temperature) for logit in logits ]
elif temperature > 0.0:
logits = [ logit / temperature for logit in logits ]
# do top-no logit processing
if top_no > 0.0:
logits = [ top_no_logits_processing(logit) for logit in logits ]
# do DRY sampling
if dry_multiplier > 0.0 and prev_list is not None:
logits = [ dry_sampling(logit, previous=prevs, factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, prevs in zip( logits, prev_list ) ]
# do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
if mirostat is not None:
# mirostat sampling
scores = [ mirostat_sample(logit, state=state) for logit, state in zip(logits, mirostat) ]
res = [ state["token"] for state in scores ]
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
# returns the logit scores as well to be P-concatted with the previous scores
# to-do: not naively implement beam searching
elif beam_width > 1:
candidates = top_k_logits_list( logits, beam_width )
res = [ torch.tensor(token, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
# basic sampling
else:
# argmax instead
if temperature <= 0.0:
res = [ logit.argmax(dim=-1) for logit in logits ]
else:
res = [ Categorical(logits=logit).sample() for logit in logits ]
# calculate token probabilities
scores = [
F.softmax(logit, dim=-1).gather(1, tokens.unsqueeze(-1)).squeeze(-1)
for logit, tokens in zip(logits, res)
]
return Sampled(res, logits, scores, entropy)
# this is a VERY basic implementation to test if a HF-ified model works (it sort of does)
if __name__ == "__main__":
from transformers import LlamaForCausalLM, LlamaTokenizer
from ..models import download_model, DEFAULT_MODEL_PATH
from ..emb.qnt import decode_to_file
from ..utils.io import torch_load
# hack in a non-causal mask
def _update_noncausal_mask(
attention_mask,
inputs_embeds,
cache_positions,
past_key_values_length,
output_attentions,
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
bsz, seq_len, _ = inputs_embeds.size()
# generate default mask based on input
if attention_mask is None:
attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device )
# make square
expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype )
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
device = "cuda"
dtype = torch.bfloat16
is_from_pretrained = True
if is_from_pretrained:
# tokenizer = LlamaTokenizer.from_pretrained("ecker/vall-e", revision="hf")
hf_model = LlamaForCausalLM.from_pretrained("ecker/vall-e", revision="hf")
hf_model.to(device=device, dtype=dtype)
hf_model.eval()
model = hf_model.model
else:
download_model()
model = LlamaModel(LlamaConfig(
vocab_size=1024,
hidden_size=1024,
max_position_embeddings=75 * 60 * 5, # max-length of 60 seconds
intermediate_size=1024*4,
num_hidden_layers=12,
num_attention_heads=16,
attention_dropout=0.0,
num_key_value_heads=16,
sliding_window=75 * 12, # 12 second context window
hidden_act="gelu",
is_encoder_decoder=False,
is_decoder=True,
))
state_dict = torch_load(DEFAULT_MODEL_PATH)['module']
state_dict_model = {}
for k, v in state_dict.items():
if not k.startswith('model.'):
continue
state_dict_model[k.replace("model.", "")] = v
model.load_state_dict( state_dict_model, strict=False )
model.to(device=device, dtype=dtype)
model.eval()
model._original_update_causal_mask = model._update_causal_mask
model._update_noncausal_mask = _update_noncausal_mask
phn = [1,22,111,100,4,37,115,169,11,2]
prom = [
[62,835,835,835,339,395,798,537,537,537,537,222,76,989,548,65,705,375,261,375,297,503,529,571,707,346,266,862,148,496,574,115,115,438,934,339,865,876,63,40,779,461,602,794,10,220,507,869,639,705,869,917,705,893,917,705,869,938,439,175,139,506,375,529,297,705,651,238,962,461,195,441,377,581,473,795,644,626,459,981,767,670,696,73,779,257,738,1017,1019,133,133,1017,835,604,699,626,67,92,707,92,179,179,772,869,441,799,630,238,745,904,904,904,106,133,133,1017,1017,395,883,87,519,594,1002,682,996,540,186,855,430,202,347,889,61,92,542,297,67,669,571,707,346,67,359,571,707,669,604,395,1008,810,35,621,67,600,333,123,284,568,817,243,778,464,638,610,359,538,464,975,321,700,377,484,179,284,284,621,538,464,745,171,171,159,744,744,287,461,69,15,529,67,92,669,464,515,605,24,822,865,293,865,172,638,359,562,138,839,846,775,556,775,1006,917,346,312,148,331,496,646,67,314,15,705,131,855,662,287,172,85,107,519,374,450,391,609,643,778,80,287,794,794,115,785,794,461,699,519,932,522,652,262,508,902,932,932,391,769,18,507,90,442,762,610,610,669,605,35,855,56,989,863,195,464,604,257,904,632,786,951,461,239,195,878,771,146,481,146,481,434,643,917,280,67,464,115,744,744,115,115,115,819,709,63,907,359,519,996,616,682,996,616,519,762,917,841,772,568,954,600,422,893,592,464,626,86,143,615,171,744,744,196,115,821,415,521,799,654,839,644,473,592,953,523,855,738,855,855,876,1017,63,329],
[913,859,740,740,937,601,961,961,877,747,747,559,474,618,20,316,58,316,180,112,290,869,610,869,869,943,127,153,236,794,282,857,984,196,875,648,993,913,860,616,38,833,620,133,123,992,247,367,252,50,298,27,27,631,163,784,271,20,843,514,869,258,180,66,803,281,123,493,831,102,556,992,385,122,31,251,990,827,26,347,460,43,43,460,228,43,841,913,302,544,544,277,859,404,646,775,315,848,726,185,203,314,203,174,252,174,378,954,214,993,924,809,277,765,363,544,363,518,791,185,454,193,193,193,193,193,573,977,924,76,434,56,193,962,610,24,954,459,396,112,903,137,398,474,506,791,839,399,102,25,205,792,459,474,526,817,869,192,792,593,878,506,24,410,539,788,522,667,566,584,588,992,444,24,869,925,635,393,903,742,320,1023,833,136,216,924,220,24,563,630,968,96,708,24,708,127,399,364,67,740,381,981,203,248,607,744,252,996,474,582,248,527,423,25,387,94,229,775,122,474,792,367,650,371,413,448,448,784,506,795,848,298,27,526,96,905,70,693,956,1002,1002,37,747,857,993,124,193,193,193,193,732,732,732,992,447,792,929,291,289,524,451,27,27,524,202,693,374,1002,125,732,585,367,317,679,395,413,189,493,386,650,110,912,505,384,399,851,367,367,27,230,988,810,975,842,956,1002,4,551,729,956,1002,750,648,231,950,193,96,912,410,732,539,103,193,904,491,213,792,792,998,193,399,151,410,96,673,497,1002,241,833,956,630,43,399,775,732,792,792,792,792,917,750,185,812,812,700,859,841,363,833,630],
[786,36,821,937,1000,705,1016,345,345,470,165,581,95,404,95,95,1006,477,95,95,691,254,997,657,176,124,95,673,489,326,218,437,907,590,752,541,1016,821,445,563,181,555,181,345,576,190,987,0,265,997,488,12,598,687,152,108,52,95,95,71,87,945,95,997,754,488,955,694,925,82,18,1020,1006,542,788,441,325,532,246,132,560,532,947,655,653,842,732,36,36,829,36,937,989,989,752,651,87,489,677,260,789,462,95,227,986,955,95,810,624,435,280,868,832,879,863,821,829,937,168,270,489,544,909,562,957,0,593,714,675,690,626,227,794,489,489,563,489,298,269,741,249,516,360,240,516,336,93,808,1022,682,555,737,147,405,476,895,323,694,412,689,963,72,193,298,181,521,741,193,93,153,773,677,689,495,30,564,719,1020,559,940,53,53,53,929,360,971,403,1012,997,919,957,433,919,787,401,401,355,276,370,414,690,697,330,629,552,930,720,259,579,221,62,945,135,1020,626,663,401,153,997,381,830,185,587,853,207,126,66,529,410,113,997,488,431,563,488,488,719,746,790,296,843,752,790,23,984,292,41,27,120,249,124,900,358,801,227,978,95,997,997,997,371,561,86,388,52,667,601,894,545,997,498,900,494,365,852,986,95,841,664,256,18,1020,963,901,447,498,262,388,691,997,646,651,757,468,114,601,437,940,212,655,541,970,870,521,237,957,563,794,563,564,620,489,351,489,489,257,733,629,489,227,622,962,7,598,374,470,114,159,211,298,363,843,818,153,59,452,529,258,419,605,689,526,39,982,829,982,752,678,1005,312],
[673,673,919,866,762,961,52,674,528,528,675,526,12,753,297,967,661,845,482,303,338,1021,506,445,247,214,206,94,434,799,210,885,552,695,853,1022,916,762,764,721,445,434,529,999,771,708,767,498,282,736,227,150,299,12,536,767,321,561,12,530,147,530,262,325,196,990,874,997,944,875,426,12,282,571,571,282,365,534,365,424,89,388,563,222,31,1019,624,74,215,651,1018,74,956,1022,74,18,633,350,72,448,454,769,267,938,12,534,929,723,829,614,505,364,1018,1014,838,673,919,74,223,761,266,78,177,736,20,718,425,1001,366,58,874,58,153,627,312,197,801,530,767,674,196,633,327,425,376,413,1019,209,594,383,744,458,468,711,282,885,640,435,655,571,556,1020,310,116,273,116,504,633,15,736,633,448,662,612,487,345,19,612,665,556,198,778,705,403,706,31,196,197,536,805,427,339,161,241,116,504,58,945,853,734,670,424,807,19,397,175,144,419,19,221,697,68,321,800,210,824,972,712,911,362,427,694,182,651,972,863,684,887,548,806,27,627,639,432,193,103,198,436,837,366,212,125,1001,493,874,808,17,17,127,204,530,300,345,425,246,240,640,906,340,310,633,246,774,114,633,522,777,874,494,577,353,939,571,693,857,722,530,521,354,492,735,214,806,483,736,530,118,234,536,177,132,522,349,259,436,973,528,414,224,762,212,854,744,271,568,127,323,736,304,499,499,78,536,736,805,232,126,468,566,611,52,339,450,258,157,602,594,854,602,599,82,124,472,563,666,174,936,818,66,758,627,52,350,999,734,215,919,1018,874,885],
[528,448,646,190,222,884,939,907,907,673,413,786,527,517,710,449,119,531,565,762,531,501,522,246,162,871,8,594,206,937,462,712,862,151,103,261,882,990,1007,314,683,864,693,812,319,786,107,531,31,342,632,460,269,429,531,531,717,417,321,671,1015,152,467,863,285,875,941,417,475,825,596,957,117,460,162,162,117,630,735,527,272,558,38,39,605,375,39,900,862,646,712,804,622,963,407,93,828,796,306,415,70,667,371,531,1000,411,710,162,812,381,673,498,691,884,928,712,528,48,630,24,593,901,973,579,722,75,139,909,919,328,764,393,777,753,512,577,175,577,512,922,834,863,30,69,94,68,616,691,835,335,486,345,306,374,732,938,580,311,715,495,527,1008,306,369,663,512,369,320,360,80,42,1021,1021,1021,175,568,526,362,320,317,488,613,937,548,966,545,596,177,306,480,522,577,512,512,638,1008,82,100,696,89,714,531,639,460,679,718,492,509,492,624,460,572,531,306,19,473,915,558,285,319,713,1018,381,877,667,425,905,43,437,632,634,324,306,207,324,303,48,69,467,39,902,599,3,617,465,78,918,459,1009,427,751,145,531,349,356,1021,157,507,780,624,165,507,144,270,94,414,899,379,947,994,853,107,586,652,877,92,19,91,188,544,624,470,503,513,13,192,563,145,531,618,743,470,62,701,499,436,679,505,198,959,3,766,839,437,491,395,1021,512,306,512,356,851,1021,1021,78,690,856,735,286,280,4,1008,369,359,309,651,864,561,170,692,952,877,520,959,306,37,1021,31,236,162,773,522,254,446,606,691,804,882,58,974],
[1011,939,881,881,140,937,724,724,937,1011,381,229,965,251,745,69,305,206,566,813,503,116,940,127,353,621,57,779,595,744,755,530,701,862,760,443,293,768,156,281,960,504,327,979,55,790,545,953,830,759,667,485,861,63,485,55,898,581,520,49,99,651,940,945,685,621,728,487,650,530,934,378,522,522,522,996,534,522,739,534,378,543,94,602,390,948,692,692,41,41,768,412,982,692,692,774,176,791,526,497,57,940,542,685,694,916,813,890,357,193,430,863,929,412,412,903,140,763,465,707,569,925,859,985,24,411,835,298,293,791,837,460,182,296,137,474,809,111,376,1021,111,490,111,938,542,578,477,506,57,385,300,873,240,104,667,204,515,834,24,125,113,980,111,997,859,997,376,193,490,824,511,799,719,575,451,575,251,222,630,429,920,788,300,993,641,154,816,940,618,130,940,462,823,955,1001,569,508,632,2,903,399,333,709,489,726,932,725,777,970,843,717,940,211,534,274,161,392,103,31,462,813,985,638,213,352,219,236,381,287,111,87,818,953,112,336,980,1016,72,960,426,238,60,9,487,665,129,24,24,162,312,411,111,157,473,466,222,940,341,55,457,712,179,451,111,831,918,826,814,940,30,468,240,207,389,923,186,95,300,876,679,576,543,582,111,227,312,112,545,747,378,165,158,610,601,425,238,704,630,124,644,949,982,297,868,569,24,57,465,24,859,111,24,752,775,24,647,465,495,57,24,57,227,907,296,581,843,1013,514,555,319,937,347,478,186,684,15,241,534,369,381,846,578,314,711,814,435,41,986,673,991],
[485,748,562,562,485,380,834,997,78,963,755,142,978,135,362,421,217,79,530,1012,972,946,127,587,838,818,456,548,424,479,944,650,694,447,391,616,938,908,206,259,998,292,818,128,353,273,566,796,333,146,110,986,571,451,166,229,421,300,911,689,329,145,287,273,542,808,301,491,0,278,825,442,0,100,818,826,66,904,642,566,135,305,999,993,905,485,755,782,365,977,485,1015,570,1002,755,169,967,36,721,1019,273,931,273,166,216,31,346,946,32,290,362,828,464,748,782,1002,1015,755,1014,100,315,777,549,177,882,110,603,975,531,608,67,1011,950,465,368,416,798,941,635,602,553,300,200,644,498,325,786,734,342,222,403,1,716,175,899,273,40,333,999,74,54,644,408,976,407,631,577,338,435,612,333,273,162,709,882,555,384,995,173,459,442,72,72,200,72,711,219,282,716,442,431,801,976,130,622,72,582,384,516,772,0,440,1001,249,1,953,65,945,438,249,511,561,205,507,821,998,427,746,290,544,426,693,999,190,214,167,219,534,166,325,975,414,326,326,268,679,991,418,868,445,632,160,380,890,346,315,806,258,806,486,326,797,471,18,790,33,66,63,66,224,38,599,599,110,801,761,18,936,230,253,171,393,774,887,887,403,466,495,524,261,666,256,687,759,263,713,185,454,242,988,185,161,911,430,86,550,439,327,527,671,782,383,916,590,315,806,583,465,785,321,315,421,856,66,352,0,634,540,362,948,185,16,224,372,694,259,648,87,733,659,603,67,269,901,66,566,173,705,746,566,911,10,743,860,78,782,1002,755,389,175],
[948,948,975,975,948,322,672,639,902,55,916,439,498,389,407,682,451,401,386,440,499,348,736,891,603,762,783,407,886,76,543,699,137,458,639,253,63,475,55,436,502,888,542,131,524,167,738,131,907,29,378,545,227,382,478,399,218,872,917,202,330,2,371,264,667,355,1016,768,590,408,463,542,214,202,715,891,840,297,509,689,290,439,672,714,528,940,1019,534,975,475,1019,835,975,558,975,981,330,635,96,858,606,627,367,191,191,669,40,873,359,267,701,426,210,1012,899,975,475,1012,610,6,300,749,231,616,877,631,720,574,551,398,503,789,684,664,390,277,150,990,823,190,971,903,175,863,316,965,988,988,800,612,336,506,242,847,389,939,415,202,83,317,2,153,365,363,57,2,891,965,300,754,763,426,555,621,303,415,367,902,829,741,119,380,902,25,884,439,822,49,76,760,566,316,249,555,774,955,834,309,859,173,935,812,682,586,141,606,197,131,644,631,913,586,202,117,810,884,76,592,754,531,586,925,649,583,145,816,821,283,871,1017,316,377,646,339,201,76,780,76,976,217,38,598,977,617,825,833,49,231,749,749,633,205,231,271,50,249,684,555,982,526,895,288,22,57,722,996,260,1018,110,833,644,738,648,468,798,297,769,282,197,402,465,510,194,930,182,909,749,986,187,187,917,38,38,985,985,988,815,878,814,459,237,768,781,649,683,749,934,729,463,181,625,231,917,96,499,839,720,439,842,205,808,338,617,681,326,446,905,346,647,533,49,728,147,432,846,536,586,611,49,879,872,893,859,859,961,989,975,701,495,65],
]
resp = []
"""
resp = [
[922,738,461,341,341,10,416,416,416,416,346,346,346,346,346,484,484,484,484,484,484,333,442,442,359,359,359,459,459,975,975,626,626,626,626,626,610,359,359,359,359,359,359,359,359,359,610,610,442,90,90,90,90,90,90,90,90,90,90,90,90,90,90,90,90,638,638,638,638,975,975,672,875,63,144],
[993,700,384,213,794,10,305,778,58,225,118,260,768,768,260,474,903,732,70,992,447,70,1000,665,848,379,485,934,181,795,438,298,688,324,934,756,395,795,110,328,343,172,768,871,593,355,396,783,24,24,911,20,27,562,697,616,668,27,27,755,20,505,248,79,822,461,197,156,27,492,151,1013,669,669,562],
[626,989,936,488,511,624,997,112,112,648,210,650,563,650,41,41,490,920,977,986,920,927,131,167,167,968,346,168,167,168,120,355,766,599,712,390,558,810,948,332,332,867,994,346,955,392,920,452,576,346,52,254,52,307,897,307,968,920,167,563,167,167,167,968,167,488,968,488,1001,938,563,741,432,566,758],
[916,874,798,212,496,751,620,616,982,745,975,890,890,141,141,321,321,214,899,42,151,722,310,971,774,35,627,995,27,43,248,248,595,774,942,352,810,35,384,340,654,639,89,214,737,197,657,45,622,321,337,19,483,679,938,938,682,938,938,141,938,310,114,724,116,327,372,607,607,310,204,713,762,853,853],
[528,222,992,727,536,191,202,483,306,568,533,577,398,533,202,24,753,753,739,739,643,513,4,324,369,66,447,201,66,802,66,957,665,526,602,749,483,447,193,853,531,201,201,71,888,202,66,66,650,228,533,102,639,513,533,531,533,471,344,566,201,639,471,639,732,594,464,308,116,533,116,174,959,621,539],
[692,632,478,375,910,857,775,503,503,193,717,548,344,717,55,808,162,112,112,112,543,582,847,712,691,679,427,940,369,475,153,526,729,269,323,721,526,211,191,192,685,844,731,813,914,545,582,712,925,916,375,111,340,162,844,940,844,162,844,990,111,491,232,582,491,582,618,121,1020,664,670,254,315,438,723],
[365,908,896,819,206,153,515,471,75,79,664,145,145,801,135,321,79,216,233,223,79,66,724,517,135,474,818,818,105,892,971,337,818,19,932,981,469,135,163,75,135,818,999,555,135,710,256,105,590,31,539,1003,517,130,445,40,549,130,859,385,1003,1003,549,33,286,932,329,774,321,664,686,16,834,703,290],
[899,237,832,748,425,121,460,872,391,586,857,215,306,76,306,554,187,57,482,406,802,555,710,895,448,517,506,316,18,772,779,697,855,1005,792,96,402,96,517,775,506,938,114,986,986,503,749,984,524,527,506,749,463,490,188,374,506,49,537,188,494,900,526,524,524,500,500,345,630,338,982,761,700,598,749],
]
"""
# name, (start, end), classifier, src_name
io_map = {
'text': [(0, 256), 9, "text_emb.weight"],
'rvq_l': [(256, 264), -1, "rvq_l_emb.weight"],
'lang': [(264, 270), -1, "langs_emb.weight"],
'task': [(270, 279), -1, "tasks_emb.weight"],
'len': [(279, 290), 10, "len_emb.weight"],
'tone': [(290, 291), -1, "tones_emb.weight"],
'sep': [(291, 292), -1, "sep"],
'prom|0': [(292, 1316), -1, "proms_emb.embeddings.0.weight"],
'prom|1': [(1316, 2340), -1, "proms_emb.embeddings.1.weight"],
'prom|2': [(2340, 3364), -1, "proms_emb.embeddings.2.weight"],
'prom|3': [(3364, 4388), -1, "proms_emb.embeddings.3.weight"],
'prom|4': [(4388, 5412), -1, "proms_emb.embeddings.4.weight"],
'prom|5': [(5412, 6436), -1, "proms_emb.embeddings.5.weight"],
'prom|6': [(6436, 7460), -1, "proms_emb.embeddings.6.weight"],
'prom|7': [(7460, 8484), -1, "proms_emb.embeddings.7.weight"],
'resp|AR:0:0': [(8484, 9509), 0, "resps_emb.embeddings.0.weight"],
'resp|NAR:0:1': [(9509, 10533), 1, "resps_emb.embeddings.1.weight"],
'resp|NAR:1:2': [(10533, 11557), 2, "resps_emb.embeddings.2.weight"],
'resp|NAR:2:3': [(11557, 12581), 3, "resps_emb.embeddings.3.weight"],
'resp|NAR:3:4': [(12581, 13605), 4, "resps_emb.embeddings.4.weight"],
'resp|NAR:4:5': [(13605, 14629), 5, "resps_emb.embeddings.5.weight"],
'resp|NAR:5:6': [(14629, 15653), 6, "resps_emb.embeddings.6.weight"],
'resp|NAR:6:7': [(15653, 16677), 7, "resps_emb.embeddings.7.weight"],
'resp|NAR:0:0': [(16677, 17702), 8, "resps_emb.embeddings.8.weight"],
}
mode_lvl_map = {
'AR:0:0': 0,
'NAR:0:1': 1,
'NAR:1:2': 2,
'NAR:2:3': 3,
'NAR:3:4': 4,
'NAR:4:5': 5,
'NAR:5:6': 6,
'NAR:6:7': 7,
'NAR:0:0': 0,
'len': 0,
}
embds = {}
heads = {}
n_embd = 1024
with torch.no_grad():
for k, v in io_map.items():
start, end = v[0]
classifier_idx = v[1]
embd_name = v[2]
if is_from_pretrained:
n_vocab = end - start
embds[k] = torch.ml.Embedding( n_vocab, n_embd ).to(model.embed_tokens.weight)
embds[k].weight[:] = model.embed_tokens.weight[start:end, :]
if classifier_idx >= 0:
# NAR:0:0 does not have a masked token output
if k == "resp|NAR:0:0":
end -= 1
n_vocab -= 1
heads[k] = torch.nn.Linear( n_embd, n_vocab, bias=False ).to(hf_model.lm_head.weight)
heads[k].weight[:] = hf_model.lm_head.weight[start:end, :]
else:
embd_weight = state_dict[embd_name].unsqueeze(0) if state_dict[embd_name].dim() == 1 else state_dict[embd_name]
embds[k] = torch.ml.Embedding( embd_weight.shape[0], embd_weight.shape[1] ).to(device=device, dtype=dtype)
embds[k].load_state_dict({ "weight": embd_weight })
if classifier_idx >= 0:
head_weight = state_dict[f'classifiers.proj.{classifier_idx}.weight']
heads[k] = torch.nn.Linear( head_weight.shape[1], head_weight.shape[0], bias=False ).to(device=device, dtype=dtype)
heads[k].load_state_dict({ "weight": head_weight })
def create_inputs( phn, prom, lang=0, seq=None, mode="AR:0:0" ):
rvq_l = mode_lvl_map[mode]
inputs = torch.tensor([])
pos_ids = torch.tensor([])
attn_mask = torch.tensor([])
seqs = []
phn = torch.tensor(phn, device=device,dtype=torch.int32)
prom = torch.tensor(prom, device=device,dtype=torch.int32)
lang = torch.tensor([lang], device=device,dtype=torch.int32)
rvq_l = torch.tensor([rvq_l], device=device,dtype=torch.int32)
zero = torch.tensor([0], device=device,dtype=torch.int32)
if mode == "len":
seq = zero if not seq else torch.concat([zero, torch.tensor(seq, device=device, dtype=torch.int32)])
elif seq:
seq = torch.tensor(seq, device=device,dtype=torch.int32)
seq = seq[:rvq_l, :] if rvq_l > 0 else seq
sep_embd = embds["sep"](zero)
phn_embd = embds["phn"](phn)
rvq_l_embd = embds["rvq_l"](rvq_l)
lang_embd = embds["lang"](lang)
prom_embd = torch.zeros(prom.shape[-1], n_embd, device=device, dtype=dtype)
seq_embd = None
for i, p in enumerate(prom):
if i > rvq_l:
break
prom_embd += embds[f"prom|{i}"](p)
if seq is not None:
if mode == "len":
seq_embd = embds["len"](seq)
elif mode == "AR:0:0":
seq_embd = embds["resp|AR:0:0"](seq)
else:
seq_embd = torch.zeros(seq.shape[-1], n_embd, device=device, dtype=dtype)
for i, r in enumerate(seq):
seq_embd += embds[f"resp|NAR:{i}:{i+1}"](r)
seqs.append(torch.concat([phn_embd, sep_embd]))
seqs.append(torch.concat([lang_embd, sep_embd]))
seqs.append(torch.concat([rvq_l_embd, sep_embd]))
seqs.append(torch.concat([prom_embd, sep_embd]))
if seq_embd is not None:
seqs.append(seq_embd)
inputs = torch.concat(seqs)
pos_ids = torch.tensor([ i for seq in seqs for i, _ in enumerate(seq) ], device=device, dtype=torch.int32)
attn_mask = torch.tensor([ True for seq in seqs for i, _ in enumerate(seq) ], device=device, dtype=torch.bool)
return inputs, pos_ids, attn_mask
def generate( phn, prom, sequence=[], mode="resp|AR:0:0", max_tokens = 75 * 4, temperature = 1.0 ):
lm_head = heads[mode]
model._update_causal_mask = model._original_update_causal_mask
n_outputs = 1
stop_token = 1024
if mode == "len":
temperature = 0.0
max_tokens = 5
stop_token = 10
elif mode != "resp|AR:0:0":
temperature = 0.0
max_tokens = len(sequence)+1
n_outputs = len(sequence[0])
model._update_causal_mask = model._update_noncausal_mask
while len(sequence) < max_tokens:
inputs, pos_ids, attn_mask = create_inputs( phn, prom, seq=sequence, mode=mode.split("|")[-1] )
out = model(inputs_embeds=inputs.unsqueeze(0), position_ids=pos_ids.unsqueeze(0), attention_mask=attn_mask.unsqueeze(0))
logits = lm_head(out[0]).float()
logits = logits[0, -n_outputs:, :]
t = Categorical(logits=logits / temperature).sample() if temperature > 0 else logits.argmax(dim=-1)
if n_outputs > 1:
sequence.append([ _.item() for _ in t ])
break
else:
t = t[0]
if stop_token in t:
break
sequence.append(t.item())
return sequence
# check embds
if False:
inputs, pos_ids, attn_mask = create_inputs( phn, prom, mode="len" )
flattened = [ sum(embd).item() for embd in inputs ]
for i, embd in enumerate( flattened ):
print(f'{i}: ', pos_ids[i].item(), "\t", embd )
# test len inferencing
print( "len:", generate( phn, prom, mode="len" ) )
# test ar ouptut
if resp:
resp = [ resp[0] ]
else:
resp = [ generate( phn, prom ) ]
print( "AR:", resp )
# test nar ouptut
for i in range(1, 8):
resp = generate( phn, prom, sequence=resp, mode=f"resp|NAR:{i-1}:{i}" )
print( f"NAR:{i-1}:{i}: ", resp[-1] )
decode_to_file( torch.tensor(resp, dtype=torch.int16, device=device).t(), "./data/test.wav" )