cleaned up classifier-free guidance logit processing (in order to try and cope with a bad nar-len model)

This commit is contained in:
mrq 2024-11-19 10:30:05 -06:00
parent 5ba80686e1
commit 0e621354e7
2 changed files with 55 additions and 58 deletions

View File

@ -27,6 +27,7 @@ from ..emb.qnt import trim, encode_as_embedding, get_silence
from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs
from .lora import enable_lora
from ..samplers import cfg_logits
text_task = [ "stt" ]
@ -223,8 +224,8 @@ class AR_NAR(Base):
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
"""
# to-do: check if gumbel sampling works / helps
"""
def log(x, eps = 1e-20):
return torch.log(x.clamp(min = eps))
@ -232,6 +233,18 @@ class AR_NAR(Base):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
"""
def log(t, eps=1e-10):
return torch.log(t + eps)
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature=1.0, dim=-1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
# convert (N)AR specific args
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
@ -243,6 +256,7 @@ class AR_NAR(Base):
temperature = sampling_kwargs.pop("temperature", 1.0)
cfg_strength = sampling_kwargs.get("cfg_strength", 3.0) # this really helps keep audio coherent so far
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0)
max_steps = math.floor(max_steps * (end_noise - start_noise))
@ -286,6 +300,7 @@ class AR_NAR(Base):
annealing = 1.0 - timestep
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
#noise_p = annealing
# pick the worst scoring tokens to mask off
masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# mask off inputs
@ -335,8 +350,8 @@ class AR_NAR(Base):
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * sampling_cfg
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
# sample with sampler settings
filtered_sampled = super().sample(
@ -361,60 +376,26 @@ class AR_NAR(Base):
"""
# update previous list of tokens
prev_list = resps_list
# sample with gumbelnoise
# This actually lobotomizes things
#sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits[0] ]
# get sampled tokens
sampled_ids = filtered_sampled.ids
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# get probability scores (conjugate to have worse scoring tokens picked for topk)
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
"""
# maskgct does some funny stuff but it doesn't amount to anything
if annealing < 1.0e-3:
sampled_ids = filtered_sampled.ids
else:
sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits ]
# keep unmasked tokens
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# update scores (conjugated to put the worst scores at the top)
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
# refinement step
if refine_on_stop:
inputs = super().inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
output = super().forward(
inputs=inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
if cfg_strength > 0:
null_inputs = super().inputs(
text_list=null_text,
proms_list=null_prom,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
null_output = super().forward(
inputs=null_inputs,
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
logits = [ logit[-length-1:-1] for logit, length in zip(logits, len_list) ]
# greedy sample from the sequence
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
"""
if cfg.experimental and max_steps > 0:
print( timestep, steps_until_x0, noise_p, resps_list, scores )
"""
scores = [ torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
scores = [ 1.0 - (choice_temperature * annealing * gumbel_noise( score ) + score) for score in scores ]
"""
return resps_list
@ -449,6 +430,7 @@ class AR_NAR(Base):
max_levels = sampling_kwargs.get("max_levels", 0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
if max_levels == 0:
max_levels = self.n_max_levels - 1
@ -541,9 +523,8 @@ class AR_NAR(Base):
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits):
seq_len = resp.shape[0]
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] for resp in resps_list ] )
sampled = super().sample(
logits=logits,
@ -591,6 +572,7 @@ class AR_NAR(Base):
temperature = sampling_kwargs.get("temperature", 1.0)
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
max_duration = sampling_kwargs.get("max_duration", 500)
beam_width = sampling_kwargs.get("beam_width", 0)
@ -736,9 +718,7 @@ class AR_NAR(Base):
quant_levels=quant_levels,
#layer_skip_variables=sampling_layer_skip_variables,
)
for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits):
seq_len = resp.shape[0] + 1
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] )
logits, state = output.logits, output.state

View File

@ -171,7 +171,24 @@ def top_no_logits_processing( logits, n = 1.0 ):
return logits
# perform classifier-free guidance given positive logits and negative/null logits
# some funny nonsense with needing to operate on slices since this is performed before sampling, where the logits are the entire sequence
# (and because the null logits have a shorter input sequence compared to the positive logits)
def cfg_logits( logits, null, strength, lens, rescale=0.0 ):
for i, seq_len in enumerate( lens ):
pos = logits[i][-seq_len:]
neg = null[i][-seq_len:]
summed = neg + (pos - neg) * strength
if rescale <= 0:
logits[i][-seq_len:] = summed
else:
dims = tuple(range(1, summed.ndim - 1))
factor = rescale * (pos.std(dims, keepdim=True) / summed.std(dims, keepdim=True)) + (1 - rescale)
logits[i][-seq_len:] = summed * factor
return logits
# Credit to: https://github.com/basusourya/mirostat/
# performs mirostat-based sampling