vall-e/vall_e/models/base.py

592 lines
21 KiB
Python
Raw Normal View History

2023-08-02 21:53:35 +00:00
import math
import torch
import torch.nn.functional as F
import traceback
import numpy as np
2023-08-02 21:53:35 +00:00
from typing import Literal, overload
from functools import partial
from einops import rearrange
from torch import Tensor, einsum, nn
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .retnet import RetNetDecoder, RetNetConfig
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
return (seq < stop).float() # (b t)
def _join(x: tuple[Tensor], sep: Tensor):
"""
Args:
x: (k t d)
sep: (d)
"""
ret = x[0]
for i in range(1, len(x)):
ret = torch.cat((ret, sep[None], x[i]), dim=0)
return ret
def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
"""
Args:
x_list: [(t d)]
Returns:
x: (? ? ?)
m: (? ? ?), same as x
"""
l = list(map(len, x_list))
x = rearrange(pad_sequence(x_list), pattern)
m = _create_mask(l, x_list[0].device)
m = m.t().unsqueeze(-1) # (t b 1)
m = rearrange(m, pattern)
m = m.to(x)
return x, m
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once
# `decay` is a factor that will exponentially apply to how far away it is
def reptition_penalize( logits, previous, factor=1.0, decay=0.0, one_time=True ):
if factor == 1.0 or previous is None:
return logits
unique = set()
priors = reversed(previous.tolist())
for distance, token in enumerate(priors):
# skip if we're only applying the decay once
if one_time and token in unique:
continue
distance += 1
logits[:, token] /= factor * (distance ** decay)
# add to set if we care about it
if one_time:
unique.add(token)
return logits
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
# `length` is the length of the sequence currently
# `factor` is the power the length is raised to, so values > 0 will yield longer sequences, values < 0 will yield shorter sequences
# `token` is the stop token.
def length_penalize( logits, length, factor=0.0, token=-1 ):
if factor == 0.0:
return logits
logits[:, token] /= (length ** factor)
return logits
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens per batch example in the output
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens > 1:
# Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
# picks the top K tokens amongst a batch of logits
# logits: [Tensor] list of logits
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
def top_k_logits_list( logits_list, k ):
# ( batch, tokens ) => ( batch x tokens )
logits = torch.cat( logits_list )
candidates = list(torch.topk(logits.flatten(), k).indices.tolist()) # perform top-k across all logits
for i, index in enumerate(candidates):
t = []
N = np.prod(logits.size())
for n in logits.size():
N //= n
t.append(index // N)
index %= N
candidates[i] = tuple(t)
return candidates
# automagically parses a batch-list and returns it as a list
2023-08-02 21:53:35 +00:00
class Embedding(nn.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
return []
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
class MultiEmbedding(nn.Module):
2023-08-02 21:53:35 +00:00
"""
This embedding sums embeddings on different levels.
"""
def __init__(self, max_n_levels, n_tokens, token_dim, monolithic=False):
2023-08-02 21:53:35 +00:00
super().__init__(max_n_levels, token_dim)
2023-09-08 20:36:26 +00:00
self.monolithic = monolithic
2023-08-02 21:53:35 +00:00
self.max_n_levels = max_n_levels
self.n_tokens = n_tokens
2023-09-08 20:36:26 +00:00
self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
2023-08-02 21:53:35 +00:00
2023-09-07 22:08:38 +00:00
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
2023-08-02 21:53:35 +00:00
if len(x_list) == 0:
return []
2023-09-08 20:36:26 +00:00
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
if self.monolithic:
2023-09-08 20:36:26 +00:00
w = self.weight[:1] if quant_levels is None else self.weight[1:]
else:
w = self.weight
2023-08-02 21:53:35 +00:00
padded_x_list = []
2023-08-02 21:53:35 +00:00
2023-09-08 20:36:26 +00:00
for i, xi in enumerate(x_list):
2023-08-02 21:53:35 +00:00
xi = F.one_hot(xi.to(torch.int64), num_classes=self.n_tokens) # t l' k
wi = w.shape[0] - xi.shape[1]
xi = F.pad(xi, (0, 0, 0, wi)) # t l k
2023-08-02 21:53:35 +00:00
padded_x_list.append(xi.to(w))
x = torch.cat(padded_x_list) # n l k
x = einsum("l k d, n l k -> n d", w, x)
2023-08-02 21:53:35 +00:00
x_list = x.split([*map(len, x_list)])
return x_list
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module):
2023-09-07 14:14:03 +00:00
def __init__(self, n_levels, n_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
2023-09-07 14:14:03 +00:00
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
res_list = []
2023-08-02 21:53:35 +00:00
for i, xi in enumerate(x_list):
# prom
if quant_levels is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels[i] == 0:
x = self.embeddings[0]( xi[:, 0] )
# NAR resp
else:
x = sum( [ self.embeddings[k+1]( xi[:, k] ) for k in range(xi.shape[-1]) ] )
res_list.append(x)
return res_list
2023-08-02 21:53:35 +00:00
class Base(nn.Module):
@property
def causal(self) -> bool:
raise NotImplementedError
@property
def arch_type(self) -> str:
raise NotImplementedError
@property
def norm_type(self):
raise NotImplementedError
@property
def n_prom_levels(self) -> int:
raise NotImplementedError
@property
def n_resp_levels(self) -> int:
raise NotImplementedError
@property
def n_max_levels(self) -> int:
raise NotImplementedError
2023-08-19 01:58:07 +00:00
@property
def n_tasks(self) -> int:
raise NotImplementedError
2023-08-02 21:53:35 +00:00
@property
def recurrent_chunk_size(self) -> int:
raise NotImplementedError
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return False
2023-09-07 14:14:03 +00:00
@property
def version(self) -> int:
return 1
2023-09-07 14:14:03 +00:00
@property
def stop_token(self):
if not self.causal:
raise ValueError("Not using stop token!")
return self.n_tokens
@property
def ignore_index(self):
return -100
@staticmethod
def _samplewise_merge_tensors(*l, sep: Tensor | None):
if sep is None:
cat = torch.cat
else:
cat = partial(_join, sep=sep)
return [*map(cat, zip(*l))]
2023-08-02 21:53:35 +00:00
def __init__(
self,
2023-08-19 01:58:07 +00:00
n_tokens: int = 1024,
2023-08-02 21:53:35 +00:00
d_model: int = 512,
n_heads: int = 8,
n_layers: int = 12,
p_dropout: float = 0.1,
config = None,
2023-08-02 21:53:35 +00:00
):
super().__init__()
self.config = config
self.activation_checkpointing = self.config.activation_checkpointing if self.config is not None else True
2023-08-02 21:53:35 +00:00
self.n_tokens = n_tokens
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
# +1 to include the stop token
2023-08-19 01:58:07 +00:00
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
2023-08-02 21:53:35 +00:00
self.text_emb = Embedding(n_tokens, d_model)
if self.version == 1: # legacy
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
2023-08-02 21:53:35 +00:00
self.sep = nn.Parameter(torch.randn(d_model))
2023-08-02 21:53:35 +00:00
if self.arch_type == "transformer":
self.sin_emb = SinusoidalEmbedding(d_model)
self.blocks = nn.ModuleList([TransformerBlock(
d_model=d_model,
n_heads=n_heads,
p_dropout=p_dropout,
2023-08-19 01:58:07 +00:00
causal=self.causal,
2023-08-02 21:53:35 +00:00
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type == "retnet":
2023-08-19 01:58:07 +00:00
self.retnet = RetNetDecoder(RetNetConfig(
2023-08-02 21:53:35 +00:00
vocab_size=n_tokens,
decoder_embed_dim=d_model,
decoder_retention_heads=n_heads,
decoder_ffn_embed_dim=d_model * 4,
decoder_layers=n_layers,
dropout=p_dropout,
checkpoint_activations=self.activation_checkpointing,
2023-08-02 21:53:35 +00:00
chunkwise_recurrent=self.causal and self.recurrent_chunk_size > 0,
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
2023-08-02 21:53:35 +00:00
no_output_layer=True,
decoder_normalize_before=True,
2023-08-19 01:58:07 +00:00
))
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
self.classifier = nn.Linear(d_model, n_resp_tokens)
self.accuracy_metric = MulticlassAccuracy(
n_resp_tokens,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=self.ignore_index,
)
self.precision_metric = MulticlassPrecision(
n_resp_tokens,
top_k=10,
average="micro",
multidim_average="global",
ignore_index=self.ignore_index,
)
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
2023-08-02 21:53:35 +00:00
quant_levels: Tensor | None = None,
state: dict | None = None,
2023-08-02 21:53:35 +00:00
):
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.proms_emb(proms_list),
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
2023-08-02 21:53:35 +00:00
x, m = list_to_tensor(x_list)
batch_size = len(text_list)
2023-08-19 01:58:07 +00:00
device = x.device
if state is not None:
# prefill
if len(state) == 0:
prefill_size = x.shape[1]
# run the initial prompt to fill the KV cache
for n in range(prefill_size):
xi = x[:, n, :].unsqueeze(1)
self.retnet(xi, incremental_state=state, token_embeddings=xi, features_only=True)
# grab last token(s)
x = x[:, -1, :].unsqueeze(1)
2023-08-02 21:53:35 +00:00
if self.arch_type == "transformer":
2023-09-12 21:04:45 +00:00
# ensures we specify a quant_level for the transformer implementation's AdaLN
2023-09-07 14:14:03 +00:00
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
l = l.to(device)
2023-09-12 21:04:45 +00:00
# inject position information
x = self.sin_emb.add_pe(x)
# pass our inputs through the transformer
2023-08-02 21:53:35 +00:00
for block in self.blocks:
2023-09-07 14:14:03 +00:00
x = block(x, m, l)
2023-08-02 21:53:35 +00:00
elif self.arch_type == "retnet":
2023-09-12 21:04:45 +00:00
# pass our inputs through the RetNet
2023-08-02 21:53:35 +00:00
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
2023-09-12 21:04:45 +00:00
# output projection layer with masking
2023-08-02 21:53:35 +00:00
x = self.classifier(x) * m
# Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
2023-08-02 21:53:35 +00:00
# compute loss if the target is given
if targ_list is not None:
2023-08-19 01:58:07 +00:00
ignore_sep = torch.tensor(self.ignore_index, device=device)
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
# remake input sequence
text_prom_list = self._samplewise_merge_tensors( text_list, prom_list, sep=ignore_sep )
2023-08-02 21:53:35 +00:00
# process each batch
2023-08-02 21:53:35 +00:00
for i in range(len(text_prom_list)):
2023-09-12 21:04:45 +00:00
# for the AR, shift the text/input prompt and target prompt into the future by 1, and ignore the rolled back text token
if quant_levels is None or quant_levels[i] == 0:
2023-08-02 21:53:35 +00:00
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
2023-09-12 21:04:45 +00:00
targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps
2023-08-02 21:53:35 +00:00
text_prom_list[i][-1] = self.ignore_index
targ_list[i][-1] = self.stop_token
# for the NAR, ignore completely computing the loss against the text prompt
else:
text_prom_list[i][:] = self.ignore_index
2023-08-02 21:53:35 +00:00
2023-08-04 01:26:36 +00:00
# create the new target sequence to compute the loss against
target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) )
inputs = torch.cat( logits )
2023-08-02 21:53:35 +00:00
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
2023-08-02 21:53:35 +00:00
)
self.stats = dict(
acc = self.accuracy_metric( inputs, target ),
precision = self.precision_metric( inputs, target ),
)
return logits
def sample(
self,
logits: list[Tensor],
resps_list: list[Tensor],
quant_levels: Tensor | None = None,
temperature: float = 1.0,
top_k: int = -100,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
repetition_penalty_decay: float = 0.0,
length_penalty: float = 0.0,
beam_width: int = 0,
):
# (NAR) return the entire generated response
if quant_levels is not None:
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
# (AR chunkwise) return the last chunkwise piece
elif self.causal and self.recurrent_chunk_size > 0:
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
# (AR) return just the last code
2023-08-02 21:53:35 +00:00
else:
logits = [ logit[-1:] for logit in logits ]
# perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
# scale our logits by the temp
logits = [ logit / temperature for logit in logits ]
# perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# do beam search (naive implementation)
# picks the top-k across all batches, and re-batches those resultant tokens
# returns the logit scores as well to be P-concatted with the previous scores
# to-do: not naively implement beam searching
if beam_width > 1:
candidates = top_k_logits_list( logits, beam_width )
res = [ torch.tensor(token, device=logits[batch].device, dtype=torch.int16).unsqueeze(dim=-1) for batch, token in candidates ]
scores = [ logits[batch].flatten()[token] for batch, token in candidates ]
return res, scores
# and sample
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
return [ Categorical(logits=logit).sample() for logit in logits ]
2023-08-02 21:53:35 +00:00
def example_usage():
2023-08-04 01:26:36 +00:00
from ..config import cfg
cfg.trainer.backend = "local"
2023-08-04 01:36:19 +00:00
cfg.trainer.check_for_oom = False
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
from functools import partial
from einops import repeat
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
from ..emb.qnt import decode_to_file
2023-08-04 01:26:36 +00:00
from ..engines import Engine, Engines
from tqdm import tqdm, trange
2023-09-07 22:08:38 +00:00
from ..utils import wrapper as ml
2023-08-04 01:26:36 +00:00
2023-08-02 21:53:35 +00:00
from .ar import AR
from .nar import NAR
device = "cuda"
2023-09-02 02:33:51 +00:00
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
2023-08-02 21:53:35 +00:00
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
2023-09-02 02:33:51 +00:00
'n_layers': 12,
2023-08-02 21:53:35 +00:00
}
2023-08-04 01:26:36 +00:00
models = { "ar": AR(**kwargs).to(device), "nar": NAR(**kwargs).to(device) }
2023-09-02 02:33:51 +00:00
for name, model in models.items():
print(f"{name} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
2023-09-07 22:08:38 +00:00
engines = Engines({ name: Engine(model=model, optimizer=ml.AdamW(model.parameters(), lr=1e-4)) for name, model in models.items() })
2023-08-02 21:53:35 +00:00
train = True
2023-09-02 02:33:51 +00:00
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
2023-08-04 01:26:36 +00:00
text_list = [
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),
]
2023-08-02 21:53:35 +00:00
2023-08-04 01:26:36 +00:00
proms_list = [
qnt.to(device),
]
resps_list = [
qnt.to(device),
]
2023-09-07 22:08:38 +00:00
def sample( name, steps=600 ):
2023-08-04 01:26:36 +00:00
AR = None
NAR = None
engines.eval()
for name, engine in engines.items():
if name[:2] == "ar":
AR = engine
elif name[:3] == "nar":
NAR = engine
resps_list = AR(text_list, proms_list, max_steps=steps, sampling_temperature=1.0)
resps_list = [r.unsqueeze(-1) for r in resps_list]
codes = NAR( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
decode_to_file(resps_list[0], f"./data/ar.{name}.wav", device=device)
decode_to_file(codes[0], f"./data/ar+nar.{name}.wav", device=device)
if train:
sample("init", 15)
engines.train()
2023-09-07 22:08:38 +00:00
t = trange(500)
2023-08-04 01:26:36 +00:00
for i in t:
stats = {"step": i}
2023-09-02 02:33:51 +00:00
"""
2023-08-04 01:26:36 +00:00
for name, engine in engines.items():
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
"""
stats = engines.step({"text_list": text_list, "proms_list": proms_list, "resps_list": resps_list})
2023-09-02 02:33:51 +00:00
tqdm.write(f"{stats}")
2023-08-04 01:26:36 +00:00
else:
for name, engine in engines.items():
engine.module.load_state_dict(torch.load(f"./data/{name}.pth"))
2023-08-02 21:53:35 +00:00
2023-08-04 01:26:36 +00:00
sample("final")
2023-08-02 21:53:35 +00:00
if __name__ == "__main__":
example_usage()