cleaned up classifier-free guidance logit processing (in order to try and cope with a bad nar-len model)
This commit is contained in:
parent
5ba80686e1
commit
0e621354e7
|
@ -27,6 +27,7 @@ from ..emb.qnt import trim, encode_as_embedding, get_silence
|
|||
from ..utils import get_devices, setup_logging, timer, clamp, convert_kwargs
|
||||
|
||||
from .lora import enable_lora
|
||||
from ..samplers import cfg_logits
|
||||
|
||||
text_task = [ "stt" ]
|
||||
|
||||
|
@ -223,8 +224,8 @@ class AR_NAR(Base):
|
|||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
"""
|
||||
# to-do: check if gumbel sampling works / helps
|
||||
"""
|
||||
def log(x, eps = 1e-20):
|
||||
return torch.log(x.clamp(min = eps))
|
||||
|
||||
|
@ -232,6 +233,18 @@ class AR_NAR(Base):
|
|||
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
|
||||
"""
|
||||
|
||||
def log(t, eps=1e-10):
|
||||
return torch.log(t + eps)
|
||||
|
||||
|
||||
def gumbel_noise(t):
|
||||
noise = torch.zeros_like(t).uniform_(0, 1)
|
||||
return -log(-log(noise))
|
||||
|
||||
|
||||
def gumbel_sample(t, temperature=1.0, dim=-1):
|
||||
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
||||
|
||||
# convert (N)AR specific args
|
||||
sampling_kwargs = convert_kwargs( sampling_kwargs, "ar_" )
|
||||
|
||||
|
@ -243,6 +256,7 @@ class AR_NAR(Base):
|
|||
|
||||
temperature = sampling_kwargs.pop("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 3.0) # this really helps keep audio coherent so far
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
@ -286,6 +300,7 @@ class AR_NAR(Base):
|
|||
annealing = 1.0 - timestep
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
#noise_p = annealing
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( max(int( noise_p * seq_len ), 1), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
# mask off inputs
|
||||
|
@ -335,8 +350,8 @@ class AR_NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * sampling_cfg
|
||||
|
||||
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ l for l in len_list ] )
|
||||
|
||||
# sample with sampler settings
|
||||
filtered_sampled = super().sample(
|
||||
|
@ -361,60 +376,26 @@ class AR_NAR(Base):
|
|||
"""
|
||||
# update previous list of tokens
|
||||
prev_list = resps_list
|
||||
|
||||
# sample with gumbelnoise
|
||||
# This actually lobotomizes things
|
||||
#sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits[0] ]
|
||||
# get sampled tokens
|
||||
sampled_ids = filtered_sampled.ids
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# get probability scores (conjugate to have worse scoring tokens picked for topk)
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||
|
||||
"""
|
||||
# maskgct does some funny stuff but it doesn't amount to anything
|
||||
if annealing < 1.0e-3:
|
||||
sampled_ids = filtered_sampled.ids
|
||||
else:
|
||||
sampled_ids = [ gumbel_sample( logits, temperature=temperature * annealing, dim=-1 ) for logits in filtered_sampled.logits ]
|
||||
|
||||
# keep unmasked tokens
|
||||
resps_list = [ torch.where( masked, input_ids, resps ) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = [ 1.0 - torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||
|
||||
# refinement step
|
||||
if refine_on_stop:
|
||||
inputs = super().inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
output = super().forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
|
||||
logits = output.logits
|
||||
|
||||
if cfg_strength > 0:
|
||||
null_inputs = super().inputs(
|
||||
text_list=null_text,
|
||||
proms_list=null_prom,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
null_output = super().forward(
|
||||
inputs=null_inputs,
|
||||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for seq_len, logit, null_logit in zip(len_list, output.logits, null_output.logits):
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
logits = [ logit[-length-1:-1] for logit, length in zip(logits, len_list) ]
|
||||
# greedy sample from the sequence
|
||||
refined_list = [ logit.argmax(dim=-1) for logit in logits ]
|
||||
|
||||
"""
|
||||
if cfg.experimental and max_steps > 0:
|
||||
print( timestep, steps_until_x0, noise_p, resps_list, scores )
|
||||
"""
|
||||
scores = [ torch.tensor([score for score in scores], device=device) for scores in filtered_sampled.scores ]
|
||||
scores = [ 1.0 - (choice_temperature * annealing * gumbel_noise( score ) + score) for score in scores ]
|
||||
"""
|
||||
|
||||
return resps_list
|
||||
|
||||
|
@ -449,6 +430,7 @@ class AR_NAR(Base):
|
|||
|
||||
max_levels = sampling_kwargs.get("max_levels", 0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
|
||||
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_max_levels - 1
|
||||
|
@ -541,9 +523,8 @@ class AR_NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits):
|
||||
seq_len = resp.shape[0]
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
|
||||
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] for resp in resps_list ] )
|
||||
|
||||
sampled = super().sample(
|
||||
logits=logits,
|
||||
|
@ -591,6 +572,7 @@ class AR_NAR(Base):
|
|||
|
||||
temperature = sampling_kwargs.get("temperature", 1.0)
|
||||
cfg_strength = sampling_kwargs.get("cfg_strength", 0.0)
|
||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.7)
|
||||
min_temperature = sampling_kwargs.get("min_temperature", -1.0)
|
||||
max_duration = sampling_kwargs.get("max_duration", 500)
|
||||
beam_width = sampling_kwargs.get("beam_width", 0)
|
||||
|
@ -736,9 +718,7 @@ class AR_NAR(Base):
|
|||
quant_levels=quant_levels,
|
||||
#layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
for resp, logit, null_logit in zip(resps_list, output.logits, null_output.logits):
|
||||
seq_len = resp.shape[0] + 1
|
||||
logit[-seq_len:] = null_logit[-seq_len:] + ( logit[-seq_len:] - null_logit[-seq_len:] ) * cfg_strength
|
||||
logits = cfg_logits( logits=output.logits, null=null_output.logits, strength=cfg_strength, rescale=cfg_rescale, lens=[ resp.shape[0] + 1 for resp in resps_list ] )
|
||||
|
||||
logits, state = output.logits, output.state
|
||||
|
||||
|
|
|
@ -171,7 +171,24 @@ def top_no_logits_processing( logits, n = 1.0 ):
|
|||
|
||||
return logits
|
||||
|
||||
# perform classifier-free guidance given positive logits and negative/null logits
|
||||
# some funny nonsense with needing to operate on slices since this is performed before sampling, where the logits are the entire sequence
|
||||
# (and because the null logits have a shorter input sequence compared to the positive logits)
|
||||
def cfg_logits( logits, null, strength, lens, rescale=0.0 ):
|
||||
for i, seq_len in enumerate( lens ):
|
||||
pos = logits[i][-seq_len:]
|
||||
neg = null[i][-seq_len:]
|
||||
|
||||
summed = neg + (pos - neg) * strength
|
||||
|
||||
if rescale <= 0:
|
||||
logits[i][-seq_len:] = summed
|
||||
else:
|
||||
dims = tuple(range(1, summed.ndim - 1))
|
||||
factor = rescale * (pos.std(dims, keepdim=True) / summed.std(dims, keepdim=True)) + (1 - rescale)
|
||||
logits[i][-seq_len:] = summed * factor
|
||||
|
||||
return logits
|
||||
|
||||
# Credit to: https://github.com/basusourya/mirostat/
|
||||
# performs mirostat-based sampling
|
||||
|
|
Loading…
Reference in New Issue
Block a user