all I can do now until I wait for the model to (re)train for pure NAR
This commit is contained in:
parent
ad7e290a5e
commit
a9d2faf2d7
BIN
data/noise.dac
BIN
data/noise.dac
Binary file not shown.
BIN
data/qnt.dac
BIN
data/qnt.dac
Binary file not shown.
BIN
data/qnt.enc
BIN
data/qnt.enc
Binary file not shown.
|
@ -52,6 +52,10 @@ def main():
|
|||
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
|
||||
parser.add_argument("--refine-on-stop", action="store_true")
|
||||
|
||||
# experimental settings
|
||||
parser.add_argument("--load-from-artifact", type=Path, default=None)
|
||||
parser.add_argument("--denoise-start", type=float, default=0.0)
|
||||
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
|
||||
|
@ -92,6 +96,10 @@ def main():
|
|||
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
|
||||
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
|
||||
refine_on_stop=args.refine_on_stop,
|
||||
|
||||
load_from_artifact=args.load_from_artifact,
|
||||
denoise_start=args.denoise_start,
|
||||
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import torchaudio
|
|||
import soundfile
|
||||
import time
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -12,7 +13,7 @@ from pathlib import Path
|
|||
|
||||
from .emb import g2p, qnt
|
||||
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
|
||||
from .utils import to_device, set_seed, wrapper as ml
|
||||
from .utils import to_device, set_seed, clamp, wrapper as ml
|
||||
|
||||
from .config import cfg, Config
|
||||
from .models import get_models
|
||||
|
@ -229,6 +230,9 @@ class TTS():
|
|||
refine_on_stop=False,
|
||||
#
|
||||
seed = None,
|
||||
#
|
||||
load_from_artifact = None,
|
||||
denoise_start = 0.0,
|
||||
|
||||
out_path=None,
|
||||
|
||||
|
@ -355,17 +359,34 @@ class TTS():
|
|||
use_lora=use_lora,
|
||||
)
|
||||
elif model_len is not None:
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
|
||||
len_list = [ min(l, max_ar_steps) for l in len_list ]
|
||||
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=5, disable_tqdm=not tqdm ) # don't need more than that
|
||||
len_list = [ clamp(1, max_ar_steps, l) for l in len_list ]
|
||||
|
||||
kwargs = {}
|
||||
|
||||
# nasty hardcode to load a reference file and have that as the input target
|
||||
if load_from_artifact and load_from_artifact.exists():
|
||||
artifact = np.load(load_from_artifact, allow_pickle=True)[()]
|
||||
|
||||
phns = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=self.device)
|
||||
resp = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=self.device)
|
||||
prom = resp[:75*3, :]
|
||||
len_list = [ resp.shape[0] ]
|
||||
|
||||
kwargs["resps_list"] = [ resp[:, :1] ]
|
||||
|
||||
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
|
||||
max_steps=max_ar_steps,
|
||||
max_levels=max_nar_levels,
|
||||
sampling_temperature=nar_temp,
|
||||
sampling_min_temperature=min_nar_temp,
|
||||
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
|
||||
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
|
||||
denoise_start=denoise_start,
|
||||
|
||||
disable_tqdm=not tqdm,
|
||||
use_lora=use_lora,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
|
|
@ -19,13 +19,10 @@ import logging
|
|||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from ..utils import clamp
|
||||
from ..emb.qnt import trim, encode_as_embedding
|
||||
|
||||
from .lora import enable_lora
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
class AR(Base):
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -24,13 +24,10 @@ import logging
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from ..emb.qnt import trim, encode_as_embedding, get_silence
|
||||
from ..utils import get_devices, setup_logging, timer
|
||||
from ..utils import get_devices, setup_logging, timer, clamp
|
||||
|
||||
from .lora import enable_lora
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
class AR_NAR(Base):
|
||||
def forward(
|
||||
self,
|
||||
|
@ -490,32 +487,19 @@ def example_usage():
|
|||
"""
|
||||
# cfg.model.loss_factors = {}
|
||||
|
||||
def tokenize(content):
|
||||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
def load_artifact( path ):
|
||||
artifact = np.load(path, allow_pickle=True)[()]
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
qnt = np.load(path, allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device)
|
||||
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
return text, audio
|
||||
|
||||
text_list = [
|
||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt[:, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
text_list = text_list[:1]
|
||||
proms_list = proms_list[:1]
|
||||
resps_list = resps_list[:1]
|
||||
text_list = [ text ]
|
||||
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ]
|
||||
resps_list = [ audio ]
|
||||
|
||||
batch_size = len(text_list)
|
||||
|
||||
|
@ -721,7 +705,7 @@ def example_usage():
|
|||
resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps):
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
|
||||
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
|
||||
|
||||
unload_model()
|
||||
|
||||
|
|
|
@ -30,9 +30,8 @@ from torch.utils.checkpoint import checkpoint
|
|||
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
|
||||
|
||||
from .arch import *
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils import wrapper as ml, clamp
|
||||
from ..samplers import *
|
||||
|
||||
from ..emb.qnt import encode_as_embedding
|
||||
|
||||
# yuck, kind of needed
|
||||
|
@ -57,9 +56,6 @@ def _dropout_mask( input, p=None ):
|
|||
mask = torch.tensor( seq, dtype=torch.bool, device=input.device )
|
||||
return mask
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
def _create_mask(l, device):
|
||||
"""1 is valid region and 0 is invalid."""
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
@ -1853,19 +1849,19 @@ class Base(nn.Module):
|
|||
else:
|
||||
# argmax instead
|
||||
if temperature <= 0.0:
|
||||
res = [ logit.argmax(dim=1) for logit in logits ]
|
||||
res = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
else:
|
||||
res = [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
|
||||
# calculate token probabilities
|
||||
if "len" in self.capabilities:
|
||||
scores = [
|
||||
[ F.softmax(logit[i, :], dim=0)[token].item() for i, token in enumerate(tokens) ]
|
||||
[ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ]
|
||||
for logit, tokens in zip(logits, res)
|
||||
]
|
||||
else:
|
||||
scores = [
|
||||
[ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ]
|
||||
[ F.softmax(logit[-1, :], dim=-1)[token].item() for token in tokens ]
|
||||
for logit, tokens in zip(logits, res)
|
||||
]
|
||||
|
||||
|
|
|
@ -21,9 +21,7 @@ from tqdm import trange
|
|||
from .base import Base, list_to_tensor, Categorical, _dropout_mask
|
||||
from ..config import cfg
|
||||
from ..emb.qnt import trim, repeat_extend_audio
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
from ..utils import clamp
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -46,6 +44,7 @@ class NAR(Base):
|
|||
|
||||
input_prompt_prefix: bool = False,
|
||||
prefix_silence: float = 1.0,
|
||||
denoise_start: float = 0.0,
|
||||
|
||||
sampling_temperature: float = 1.0,
|
||||
sampling_min_temperature: float = -1.0,
|
||||
|
@ -245,7 +244,11 @@ class NAR(Base):
|
|||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||
|
||||
# initial condition
|
||||
len_list = [ clamp(1, 75*3, l) for l in len_list ]
|
||||
"""
|
||||
print( len_list )
|
||||
len_list = [ clamp(1, max_steps, l) for l in len_list ]
|
||||
print( len_list )
|
||||
"""
|
||||
metrics = []
|
||||
|
||||
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
||||
|
@ -262,50 +265,39 @@ class NAR(Base):
|
|||
def gumbel_sample(x, temperature = 1., dim = -1):
|
||||
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
|
||||
|
||||
test_artifact = None
|
||||
|
||||
# nasty hardcode to load a reference file and have that as the input target
|
||||
# to-do: expose a way to provide the initial sequence instead through CLI
|
||||
"""
|
||||
if False:
|
||||
path = "./data/00_part2_success-1.enc"
|
||||
test_artifact = np.load(path, allow_pickle=True)[()]
|
||||
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
|
||||
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
|
||||
proms_list = [ resps[:75*3, :] for resps in resps_list ]
|
||||
#proms_list = [ resps for resps in resps_list ]
|
||||
len_list = [ resps.shape[0] for resps in resps_list ]
|
||||
"""
|
||||
|
||||
_super = super()
|
||||
def demask_sampling( seq_len, max_steps=5, temperature=1.0 ):
|
||||
starting_temperature = temperature
|
||||
def demask_sampling( batch_index, seq_len ):
|
||||
# overrides
|
||||
max_steps = 10
|
||||
temperature = 0.3
|
||||
cfg_strength = 1.0
|
||||
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
|
||||
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
|
||||
|
||||
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
|
||||
# if we're denoising from an existing sequence
|
||||
if denoise_start > 0.0 and resps_list is not None:
|
||||
start_noise = denoise_start
|
||||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
mask = torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device )
|
||||
input_ids = torch.where( mask, self.stop_token, resps_list[batch_index][:, 0] )
|
||||
else:
|
||||
input_ids = torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token
|
||||
|
||||
scores = torch.zeros((seq_len,), dtype=torch.float32, device=device)
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
prev_list = [ input_ids ]
|
||||
|
||||
start_temperature = temperature
|
||||
start_noise = 0.0
|
||||
end_noise = 1.0
|
||||
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
|
||||
|
||||
# use hardcoded reference file to test inference capabilities
|
||||
if test_artifact is not None:
|
||||
# because we "set" it later on, it's not implicitly captured
|
||||
nonlocal resps_list
|
||||
start_noise = 0.5
|
||||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
|
||||
|
||||
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
|
||||
null_prom = None
|
||||
cfg_strength = 1.0
|
||||
|
||||
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
|
||||
# anneal temperature
|
||||
temperature = starting_temperature * (steps_until_x0 / max_steps)
|
||||
temperature = start_temperature * (steps_until_x0 / max_steps)
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# number of tokens to mask off to "noise" the input sequence
|
||||
|
@ -317,12 +309,11 @@ class NAR(Base):
|
|||
# boolean mask
|
||||
is_masked = input_ids == self.stop_token
|
||||
# setup inputs
|
||||
resps_list = [ input_ids ]
|
||||
|
||||
inputs = _super.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
resps_list=[ input_ids ],
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=[ timestep ],
|
||||
|
@ -333,11 +324,14 @@ class NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = _super.inputs(
|
||||
text_list=[ null_text ],
|
||||
proms_list=[ null_prom ],
|
||||
resps_list=resps_list,
|
||||
resps_list=[ input_ids ],
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
time_list=[ timestep ],
|
||||
|
@ -348,12 +342,10 @@ class NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits = [ logits + ( logits - null_logits ) * cfg_strength for logits, null_logits in zip(output.logits, null_output.logits) ]
|
||||
else:
|
||||
logits = output.logits
|
||||
for logit, null_logits in zip(output.logits, null_output.logits):
|
||||
logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength
|
||||
|
||||
# sample with sampler settings
|
||||
sampling_top_p = 0.9
|
||||
filtered_sampled = _super.sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
|
@ -401,12 +393,13 @@ class NAR(Base):
|
|||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
|
||||
|
||||
if cfg.experimental:
|
||||
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||
|
||||
return input_ids
|
||||
|
||||
# perform demasked sampling (mock diffusion)
|
||||
prev_list = [ demask_sampling( seq_len=l ) for l in len_list ]
|
||||
prev_list = [ demask_sampling( batch_index=i, seq_len=l ) for i, l in enumerate( len_list ) ]
|
||||
|
||||
# expand if given a raw 1D tensor
|
||||
for i, resp in enumerate(prev_list):
|
||||
|
@ -530,39 +523,20 @@ def example_usage():
|
|||
import re
|
||||
|
||||
device = "cuda"
|
||||
|
||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||
"""
|
||||
if "mamba" in cfg.model.arch_type:
|
||||
cfg.model.resp_levels = 1
|
||||
"""
|
||||
# cfg.model.loss_factors = {}
|
||||
|
||||
def tokenize(content):
|
||||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
def load_artifact( path ):
|
||||
artifact = np.load(path, allow_pickle=True)[()]
|
||||
|
||||
def _load_quants(path) -> Tensor:
|
||||
qnt = np.load(path, allow_pickle=True)[()]
|
||||
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
|
||||
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device)
|
||||
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device)
|
||||
|
||||
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
return text, audio
|
||||
|
||||
text_list = [
|
||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt[:, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||
|
||||
text_list = text_list[:1]
|
||||
proms_list = proms_list[:1]
|
||||
resps_list = resps_list[:1]
|
||||
text_list = [ text ]
|
||||
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ]
|
||||
resps_list = [ audio ]
|
||||
|
||||
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
||||
kwargs = {
|
||||
|
|
|
@ -8,8 +8,7 @@ from torch import Tensor, einsum, nn
|
|||
from einops import rearrange
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
from .utils import clamp
|
||||
|
||||
# Simple filter to modify a token's probability if it shows up in the past
|
||||
# `one_time` will only apply the penalty once
|
||||
|
|
|
@ -147,9 +147,13 @@ def run_eval(engines, eval_name, dl, args=None):
|
|||
elif "len" in engine.hyper_config.capabilities:
|
||||
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
|
||||
max_steps = kwargs.pop("max_steps", 500)
|
||||
kwargs["max_steps"] = 10
|
||||
len_list = engine( **kwargs ) # don't need more than that
|
||||
len_list = engine( max_steps=5, **kwargs )
|
||||
len_list = [ min( l, max_steps ) for l in len_list ]
|
||||
|
||||
if True:
|
||||
len_list = [ resp.shape[0] for resp in batch["resps"] ]
|
||||
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
|
||||
kwargs["denoise_start"] = 0.5
|
||||
|
||||
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
|
||||
resps_list = engine( **kwargs, len_list=len_list )
|
||||
|
|
|
@ -12,5 +12,6 @@ from .utils import (
|
|||
get_devices,
|
||||
truncate_json,
|
||||
timer,
|
||||
prune_missing
|
||||
prune_missing,
|
||||
clamp
|
||||
)
|
|
@ -49,6 +49,9 @@ def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, retu
|
|||
missing += m
|
||||
return (keep, missing) if return_missing else keep
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
||||
class timer:
|
||||
def __init__(self, msg="Elapsed time:", callback=None):
|
||||
self.msg = msg
|
||||
|
|
Loading…
Reference in New Issue
Block a user