all I can do now until I wait for the model to (re)train for pure NAR

This commit is contained in:
mrq 2024-11-09 22:57:34 -06:00
parent ad7e290a5e
commit a9d2faf2d7
13 changed files with 103 additions and 116 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -52,6 +52,10 @@ def main():
parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
parser.add_argument("--refine-on-stop", action="store_true")
# experimental settings
parser.add_argument("--load-from-artifact", type=Path, default=None)
parser.add_argument("--denoise-start", type=float, default=0.0)
parser.add_argument("--seed", type=int, default=None)
@ -92,6 +96,10 @@ def main():
layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
refine_on_stop=args.refine_on_stop,
load_from_artifact=args.load_from_artifact,
denoise_start=args.denoise_start,
seed=args.seed,
)

View File

@ -3,6 +3,7 @@ import torchaudio
import soundfile
import time
import logging
import numpy as np
_logger = logging.getLogger(__name__)
@ -12,7 +13,7 @@ from pathlib import Path
from .emb import g2p, qnt
from .emb.qnt import trim, trim_random, unload_model, repeat_extend_audio
from .utils import to_device, set_seed, wrapper as ml
from .utils import to_device, set_seed, clamp, wrapper as ml
from .config import cfg, Config
from .models import get_models
@ -229,6 +230,9 @@ class TTS():
refine_on_stop=False,
#
seed = None,
#
load_from_artifact = None,
denoise_start = 0.0,
out_path=None,
@ -355,17 +359,34 @@ class TTS():
use_lora=use_lora,
)
elif model_len is not None:
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=10, disable_tqdm=not tqdm ) # don't need more than that
len_list = [ min(l, max_ar_steps) for l in len_list ]
len_list = model_len( text_list=[phns], proms_list=[prom], max_steps=5, disable_tqdm=not tqdm ) # don't need more than that
len_list = [ clamp(1, max_ar_steps, l) for l in len_list ]
kwargs = {}
# nasty hardcode to load a reference file and have that as the input target
if load_from_artifact and load_from_artifact.exists():
artifact = np.load(load_from_artifact, allow_pickle=True)[()]
phns = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=self.device)
resp = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=self.device)
prom = resp[:75*3, :]
len_list = [ resp.shape[0] ]
kwargs["resps_list"] = [ resp[:, :1] ]
resps_list = model_nar( text_list=[phns], proms_list=[prom], len_list=len_list,
max_steps=max_ar_steps,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,
sampling_min_temperature=min_nar_temp,
sampling_top_p=top_p, sampling_top_k=top_k, sampling_min_p=min_p,
sampling_repetition_penalty=repetition_penalty, sampling_repetition_penalty_decay=repetition_penalty_decay,
denoise_start=denoise_start,
disable_tqdm=not tqdm,
use_lora=use_lora,
**kwargs,
)
else:
raise Exception("!")

View File

@ -19,13 +19,10 @@ import logging
_logger = logging.getLogger(__name__)
from ..utils import clamp
from ..emb.qnt import trim, encode_as_embedding
from .lora import enable_lora
def clamp(n, lo, hi):
return max(lo, min(n, hi))
class AR(Base):
def forward(
self,

View File

@ -24,13 +24,10 @@ import logging
_logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding, get_silence
from ..utils import get_devices, setup_logging, timer
from ..utils import get_devices, setup_logging, timer, clamp
from .lora import enable_lora
def clamp(n, lo, hi):
return max(lo, min(n, hi))
class AR_NAR(Base):
def forward(
self,
@ -490,32 +487,19 @@ def example_usage():
"""
# cfg.model.loss_factors = {}
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
def load_artifact( path ):
artifact = np.load(path, allow_pickle=True)[()]
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device)
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
noise = _load_quants(f"./data/noise.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
return text, audio
text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
]
proms_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resps_list = [
qnt[:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = text_list[:1]
proms_list = proms_list[:1]
resps_list = resps_list[:1]
text_list = [ text ]
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ]
resps_list = [ audio ]
batch_size = len(text_list)
@ -721,7 +705,7 @@ def example_usage():
resps = engine( texts, proms, resps, task_list=tasks, sampling_temperature=0.2 )
for i, o in enumerate(resps):
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{name}.wav", device=device)
unload_model()

View File

@ -30,9 +30,8 @@ from torch.utils.checkpoint import checkpoint
from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, MulticlassPrecision
from .arch import *
from ..utils import wrapper as ml
from ..utils import wrapper as ml, clamp
from ..samplers import *
from ..emb.qnt import encode_as_embedding
# yuck, kind of needed
@ -57,9 +56,6 @@ def _dropout_mask( input, p=None ):
mask = torch.tensor( seq, dtype=torch.bool, device=input.device )
return mask
def clamp(n, lo, hi):
return max(lo, min(n, hi))
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
@ -1853,19 +1849,19 @@ class Base(nn.Module):
else:
# argmax instead
if temperature <= 0.0:
res = [ logit.argmax(dim=1) for logit in logits ]
res = [ logit.argmax(dim=-1) for logit in logits ]
else:
res = [ Categorical(logits=logit).sample() for logit in logits ]
# calculate token probabilities
if "len" in self.capabilities:
scores = [
[ F.softmax(logit[i, :], dim=0)[token].item() for i, token in enumerate(tokens) ]
[ F.softmax(logit[i, :], dim=-1)[token].item() for i, token in enumerate(tokens) ]
for logit, tokens in zip(logits, res)
]
else:
scores = [
[ F.softmax(logit[-1, :], dim=0)[token].item() for token in tokens ]
[ F.softmax(logit[-1, :], dim=-1)[token].item() for token in tokens ]
for logit, tokens in zip(logits, res)
]

View File

@ -21,9 +21,7 @@ from tqdm import trange
from .base import Base, list_to_tensor, Categorical, _dropout_mask
from ..config import cfg
from ..emb.qnt import trim, repeat_extend_audio
def clamp(n, lo, hi):
return max(lo, min(n, hi))
from ..utils import clamp
_logger = logging.getLogger(__name__)
@ -46,6 +44,7 @@ class NAR(Base):
input_prompt_prefix: bool = False,
prefix_silence: float = 1.0,
denoise_start: float = 0.0,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
@ -245,7 +244,11 @@ class NAR(Base):
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
# initial condition
len_list = [ clamp(1, 75*3, l) for l in len_list ]
"""
print( len_list )
len_list = [ clamp(1, max_steps, l) for l in len_list ]
print( len_list )
"""
metrics = []
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
@ -262,50 +265,39 @@ class NAR(Base):
def gumbel_sample(x, temperature = 1., dim = -1):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
test_artifact = None
# nasty hardcode to load a reference file and have that as the input target
# to-do: expose a way to provide the initial sequence instead through CLI
"""
if False:
path = "./data/00_part2_success-1.enc"
test_artifact = np.load(path, allow_pickle=True)[()]
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
proms_list = [ resps[:75*3, :] for resps in resps_list ]
#proms_list = [ resps for resps in resps_list ]
len_list = [ resps.shape[0] for resps in resps_list ]
"""
_super = super()
def demask_sampling( seq_len, max_steps=5, temperature=1.0 ):
starting_temperature = temperature
def demask_sampling( batch_index, seq_len ):
# overrides
max_steps = 10
temperature = 0.3
cfg_strength = 1.0
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
sampling_top_p = 0.9 # a lot of demasking samplers use a top-k of seq_len * 0.9
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
# if we're denoising from an existing sequence
if denoise_start > 0.0 and resps_list is not None:
start_noise = denoise_start
noise_p = math.cos( start_noise * math.pi * 0.5 )
mask = torch.tensor( [ random.random() < noise_p for _ in range( seq_len ) ], dtype=torch.bool, device=device )
input_ids = torch.where( mask, self.stop_token, resps_list[batch_index][:, 0] )
else:
input_ids = torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token
scores = torch.zeros((seq_len,), dtype=torch.float32, device=device)
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ input_ids ]
start_temperature = temperature
start_noise = 0.0
end_noise = 1.0
sampling_repetition_penalty = 1.0 # force rep pen off, because this caused false positives due to how rep pen was being naively applied......
# use hardcoded reference file to test inference capabilities
if test_artifact is not None:
# because we "set" it later on, it's not implicitly captured
nonlocal resps_list
start_noise = 0.5
noise_p = math.cos( start_noise * math.pi * 0.5 )
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
null_text = torch.tensor([1, 2], device=device, dtype=torch.int16)
null_prom = None
cfg_strength = 1.0
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
# anneal temperature
temperature = starting_temperature * (steps_until_x0 / max_steps)
temperature = start_temperature * (steps_until_x0 / max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# number of tokens to mask off to "noise" the input sequence
@ -317,12 +309,11 @@ class NAR(Base):
# boolean mask
is_masked = input_ids == self.stop_token
# setup inputs
resps_list = [ input_ids ]
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
resps_list=[ input_ids ],
lang_list=lang_list,
tone_list=tone_list,
time_list=[ timestep ],
@ -333,11 +324,14 @@ class NAR(Base):
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = _super.inputs(
text_list=[ null_text ],
proms_list=[ null_prom ],
resps_list=resps_list,
resps_list=[ input_ids ],
lang_list=lang_list,
tone_list=tone_list,
time_list=[ timestep ],
@ -348,12 +342,10 @@ class NAR(Base):
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
logits = [ logits + ( logits - null_logits ) * cfg_strength for logits, null_logits in zip(output.logits, null_output.logits) ]
else:
logits = output.logits
for logit, null_logits in zip(output.logits, null_output.logits):
logit[-seq_len:] = logit[-seq_len:] + ( logit[-seq_len:] - null_logits[-seq_len:] ) * cfg_strength
# sample with sampler settings
sampling_top_p = 0.9
filtered_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
@ -401,12 +393,13 @@ class NAR(Base):
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
if cfg.experimental:
print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids
# perform demasked sampling (mock diffusion)
prev_list = [ demask_sampling( seq_len=l ) for l in len_list ]
prev_list = [ demask_sampling( batch_index=i, seq_len=l ) for i, l in enumerate( len_list ) ]
# expand if given a raw 1D tensor
for i, resp in enumerate(prev_list):
@ -530,39 +523,20 @@ def example_usage():
import re
device = "cuda"
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""
if "mamba" in cfg.model.arch_type:
cfg.model.resp_levels = 1
"""
# cfg.model.loss_factors = {}
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
def load_artifact( path ):
artifact = np.load(path, allow_pickle=True)[()]
def _load_quants(path) -> Tensor:
qnt = np.load(path, allow_pickle=True)[()]
return torch.from_numpy(qnt["codes"].astype(np.int16))[0, :cfg.model.resp_levels, :].t().to(torch.int16)
text = torch.tensor( cfg.tokenizer.encode( artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device)
audio = torch.from_numpy(artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device)
qnt = _load_quants(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
return text, audio
text_list = [
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
]
proms_list = [
qnt[:cfg.dataset.frames_per_second, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
resps_list = [
qnt[:, :].to(device),
#qnt[:cfg.dataset.frames_per_second, :].to(device),
]
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
text_list = text_list[:1]
proms_list = proms_list[:1]
resps_list = resps_list[:1]
text_list = [ text ]
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ]
resps_list = [ audio ]
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
kwargs = {

View File

@ -8,8 +8,7 @@ from torch import Tensor, einsum, nn
from einops import rearrange
from dataclasses import asdict, dataclass, field
def clamp(n, lo, hi):
return max(lo, min(n, hi))
from .utils import clamp
# Simple filter to modify a token's probability if it shows up in the past
# `one_time` will only apply the penalty once

View File

@ -147,9 +147,13 @@ def run_eval(engines, eval_name, dl, args=None):
elif "len" in engine.hyper_config.capabilities:
kwargs = base_kwargs | cfg.evaluation.ar_kwargs
max_steps = kwargs.pop("max_steps", 500)
kwargs["max_steps"] = 10
len_list = engine( **kwargs ) # don't need more than that
len_list = engine( max_steps=5, **kwargs )
len_list = [ min( l, max_steps ) for l in len_list ]
if True:
len_list = [ resp.shape[0] for resp in batch["resps"] ]
kwargs["resps_list"] = [ resp[:, :1] for resp in batch["resps"] ]
kwargs["denoise_start"] = 0.5
kwargs = base_kwargs | cfg.evaluation.nar_kwargs
resps_list = engine( **kwargs, len_list=len_list )

View File

@ -12,5 +12,6 @@ from .utils import (
get_devices,
truncate_json,
timer,
prune_missing
prune_missing,
clamp
)

View File

@ -49,6 +49,9 @@ def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, retu
missing += m
return (keep, missing) if return_missing else keep
def clamp(n, lo, hi):
return max(lo, min(n, hi))
class timer:
def __init__(self, msg="Elapsed time:", callback=None):
self.msg = msg