I suppose I just have a shit training method since the sampler is as solid as I can get it...............

This commit is contained in:
mrq 2024-11-08 22:05:41 -06:00
parent 13b54953bd
commit 811b15d280
3 changed files with 107 additions and 128 deletions

View File

@ -1711,7 +1711,7 @@ class Base(nn.Module):
"""
# perform repetition penalizing
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
if prev_list is not None and repetition_penalty != 1.0:
# to-do: figure out a faster way to handle tolist()
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]

View File

@ -21,7 +21,6 @@ from tqdm import trange
from .base import Base, list_to_tensor, Categorical, _dropout_mask
from ..config import cfg
from ..emb.qnt import trim, repeat_extend_audio
from ..samplers import SampleScheduler
def clamp(n, lo, hi):
return max(lo, min(n, hi))
@ -237,77 +236,123 @@ class NAR(Base):
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
def log(x, eps = 1e-20):
return torch.log(x.clamp(min = eps))
def gumbel_sample(x, temperature = 1., dim = -1):
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
test_artifact = None
"""
if False:
path = "./data/237_134500_000036_000004.enc"
test_artifact = np.load(path, allow_pickle=True)[()]
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
proms_list = [ resps for resps in resps_list ]
len_list = [ resps.shape[0] for resps in resps_list ]
"""
_super = super()
def forward_lambda( ids, step, temperature ):
def demask_sampling( seq_len, max_steps=10, temperature=0.3 ):
starting_temperature = temperature
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
scores = torch.zeros((seq_len,), dtype=torch.float32, device=device)
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ ids ]
seq_len = ids.shape[-1]
prev_list = [ input_ids ]
sampling_top_k = math.floor( seq_len * 0.9 )
noise_scale = 1.0
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
"""
if test_artifact is not None:
nonlocal resps_list
input = resps_list[0][:, 0]
noise_scale = 1.0
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_scale else token for _, token in enumerate( input ) ], dtype=torch.int16, device=device )
print( input )
print( input_ids )
"""
output = _super.forward(
inputs=inputs,
quant_levels=quant_levels,
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, max_steps), reversed(range(max_steps))):
# anneal temperature
temperature = starting_temperature * (steps_until_x0 / max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 ) * noise_scale
# number of tokens to mask off to "noise" the input sequence
masked_tokens_n = max(int( noise_p * seq_len ), 1)
# pick the worst scoring tokens to mask off
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
# mask off inputs
input_ids = input_ids.scatter(0, masked_indices, self.stop_token)
# boolean mask
is_masked = input_ids == self.stop_token
# sample
sampling_top_k = math.floor( seq_len * 0.9 )
resps_list = [ input_ids ]
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
output = _super.forward(
inputs=inputs,
quant_levels=quant_levels,
layer_skip_variables=sampling_layer_skip_variables,
)
layer_skip_variables=sampling_layer_skip_variables,
)
logits = output.logits
# sample with sampler settings
filtered_sampled = _super.sample(
logits=output.logits,
prev_list=prev_list,
quant_levels=quant_levels,
# sample with sampler settings
sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
)
temperature=temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
min_p=sampling_min_p,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
)
# retrieves unfiltered logits
unfiltered_sampled = _super.sample(
logits=output.logits,
prev_list=prev_list,
quant_levels=quant_levels,
temperature=0.0,
)
# update previous list of tokens
prev_list = [ input_ids ]
# greedy sample
greedy_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
# extract logits
filtered_logits = filtered_sampled.logits[0]
unfiltered_logits = unfiltered_sampled.logits[0]
temperature=0.0,
#min_temperature=sampling_min_temperature,
#top_p=sampling_top_p,
#top_k=sampling_top_k,
#min_p=sampling_min_p,
#repetition_penalty=sampling_repetition_penalty,
#repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
)
# extract scores
filtered_scores = filtered_sampled.scores[0]
unfiltered_scores = unfiltered_sampled.scores[0]
return sampled, greedy_sampled
# sample with gumbelnoise
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
# keep unmasked tokens
input_ids = torch.where( is_masked, sampled_ids, input_ids )
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
scheduler = SampleScheduler(
device=device,
mask_token=self.stop_token,
max_steps=5,
forward_lambda=forward_lambda,
sampling_temperature=sampling_temperature,
)
prev_list = [ scheduler.sample( seq_len=len_list[0] ) ]
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids
# perform demasked sampling (mock diffusion)
prev_list = [ demask_sampling( seq_len=l ) for l in len_list ]
# expand if given a raw 1D tensor
for i, resp in enumerate(prev_list):

View File

@ -520,70 +520,4 @@ def sample_entropix(
metrics["min_p"] = min_p
"""
return res, metrics
"""
def add_gumbel_noise(t, temperature, device):
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
"""
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
class SampleScheduler:
def __init__(
self,
forward_lambda = None,
mask_token = -1,
max_steps = 25,
device = "cuda",
sampling_temperature=1.0,
):
self.forward_lambda = forward_lambda
self.max_steps = max_steps
self.mask_token = mask_token
self.device = device
def sample( self, seq_len ):
starting_temperature = 0.2
input_ids = torch.ones((seq_len,), dtype=torch.long, device=self.device) * self.mask_token
scores = torch.zeros((seq_len,), dtype=torch.float32, device=self.device)
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, self.max_steps), reversed(range(self.max_steps))):
# anneal temperature
temperature = starting_temperature * (steps_until_x0 / self.max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# number of tokens to mask off to "noise" the input sequence
masked_tokens_n = max(int( noise_p * seq_len ), 1)
# pick the worst scoring tokens to mask off
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
# mask off inputs
input_ids = input_ids.scatter(0, masked_indices, self.mask_token)
# boolean mask
is_masked = input_ids == self.mask_token
# sample
sampled, greedy_sampled = self.forward_lambda( input_ids, step=timestep, temperature=temperature )
# extract logits
logits = greedy_sampled.logits[0]
filtered_logits = sampled.logits[0]
# sample with gumbelnoise
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
# keep unmasked tokens
input_ids = torch.where( is_masked, sampled_ids, input_ids )
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.concat([ F.softmax(logits[i, :], dim=0)[token, None] for i, token in enumerate(input_ids) ])
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, temperature, input_ids, scores )
return input_ids
return res, metrics