I suppose I just have a shit training method since the sampler is as solid as I can get it...............
This commit is contained in:
parent
13b54953bd
commit
811b15d280
|
@ -1711,7 +1711,7 @@ class Base(nn.Module):
|
|||
"""
|
||||
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
||||
if prev_list is not None and repetition_penalty != 1.0:
|
||||
# to-do: figure out a faster way to handle tolist()
|
||||
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ from tqdm import trange
|
|||
from .base import Base, list_to_tensor, Categorical, _dropout_mask
|
||||
from ..config import cfg
|
||||
from ..emb.qnt import trim, repeat_extend_audio
|
||||
from ..samplers import SampleScheduler
|
||||
|
||||
def clamp(n, lo, hi):
|
||||
return max(lo, min(n, hi))
|
||||
|
@ -237,77 +236,123 @@ class NAR(Base):
|
|||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||
|
||||
def log(x, eps = 1e-20):
|
||||
return torch.log(x.clamp(min = eps))
|
||||
|
||||
def gumbel_sample(x, temperature = 1., dim = -1):
|
||||
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
|
||||
|
||||
test_artifact = None
|
||||
|
||||
"""
|
||||
if False:
|
||||
path = "./data/237_134500_000036_000004.enc"
|
||||
test_artifact = np.load(path, allow_pickle=True)[()]
|
||||
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
|
||||
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
|
||||
proms_list = [ resps for resps in resps_list ]
|
||||
len_list = [ resps.shape[0] for resps in resps_list ]
|
||||
"""
|
||||
|
||||
_super = super()
|
||||
def forward_lambda( ids, step, temperature ):
|
||||
def demask_sampling( seq_len, max_steps=10, temperature=0.3 ):
|
||||
starting_temperature = temperature
|
||||
|
||||
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
|
||||
scores = torch.zeros((seq_len,), dtype=torch.float32, device=device)
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
prev_list = [ ids ]
|
||||
seq_len = ids.shape[-1]
|
||||
prev_list = [ input_ids ]
|
||||
|
||||
sampling_top_k = math.floor( seq_len * 0.9 )
|
||||
noise_scale = 1.0
|
||||
|
||||
inputs = _super.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
"""
|
||||
if test_artifact is not None:
|
||||
nonlocal resps_list
|
||||
input = resps_list[0][:, 0]
|
||||
noise_scale = 1.0
|
||||
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_scale else token for _, token in enumerate( input ) ], dtype=torch.int16, device=device )
|
||||
print( input )
|
||||
print( input_ids )
|
||||
"""
|
||||
|
||||
output = _super.forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, max_steps), reversed(range(max_steps))):
|
||||
# anneal temperature
|
||||
temperature = starting_temperature * (steps_until_x0 / max_steps)
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 ) * noise_scale
|
||||
# number of tokens to mask off to "noise" the input sequence
|
||||
masked_tokens_n = max(int( noise_p * seq_len ), 1)
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
|
||||
# mask off inputs
|
||||
input_ids = input_ids.scatter(0, masked_indices, self.stop_token)
|
||||
# boolean mask
|
||||
is_masked = input_ids == self.stop_token
|
||||
# sample
|
||||
sampling_top_k = math.floor( seq_len * 0.9 )
|
||||
resps_list = [ input_ids ]
|
||||
inputs = _super.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
resps_list=resps_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
output = _super.forward(
|
||||
inputs=inputs,
|
||||
quant_levels=quant_levels,
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
|
||||
layer_skip_variables=sampling_layer_skip_variables,
|
||||
)
|
||||
logits = output.logits
|
||||
# sample with sampler settings
|
||||
filtered_sampled = _super.sample(
|
||||
logits=output.logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
# sample with sampler settings
|
||||
sampled = _super.sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
temperature=temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
)
|
||||
|
||||
temperature=temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
top_p=sampling_top_p,
|
||||
top_k=sampling_top_k,
|
||||
min_p=sampling_min_p,
|
||||
repetition_penalty=sampling_repetition_penalty,
|
||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
)
|
||||
# retrieves unfiltered logits
|
||||
unfiltered_sampled = _super.sample(
|
||||
logits=output.logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
temperature=0.0,
|
||||
)
|
||||
# update previous list of tokens
|
||||
prev_list = [ input_ids ]
|
||||
|
||||
# greedy sample
|
||||
greedy_sampled = _super.sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
# extract logits
|
||||
filtered_logits = filtered_sampled.logits[0]
|
||||
unfiltered_logits = unfiltered_sampled.logits[0]
|
||||
|
||||
temperature=0.0,
|
||||
#min_temperature=sampling_min_temperature,
|
||||
#top_p=sampling_top_p,
|
||||
#top_k=sampling_top_k,
|
||||
#min_p=sampling_min_p,
|
||||
#repetition_penalty=sampling_repetition_penalty,
|
||||
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
)
|
||||
# extract scores
|
||||
filtered_scores = filtered_sampled.scores[0]
|
||||
unfiltered_scores = unfiltered_sampled.scores[0]
|
||||
|
||||
return sampled, greedy_sampled
|
||||
# sample with gumbelnoise
|
||||
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||
# keep unmasked tokens
|
||||
input_ids = torch.where( is_masked, sampled_ids, input_ids )
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
|
||||
|
||||
scheduler = SampleScheduler(
|
||||
device=device,
|
||||
mask_token=self.stop_token,
|
||||
max_steps=5,
|
||||
forward_lambda=forward_lambda,
|
||||
sampling_temperature=sampling_temperature,
|
||||
)
|
||||
prev_list = [ scheduler.sample( seq_len=len_list[0] ) ]
|
||||
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||
|
||||
return input_ids
|
||||
|
||||
# perform demasked sampling (mock diffusion)
|
||||
prev_list = [ demask_sampling( seq_len=l ) for l in len_list ]
|
||||
|
||||
# expand if given a raw 1D tensor
|
||||
for i, resp in enumerate(prev_list):
|
||||
|
|
|
@ -520,70 +520,4 @@ def sample_entropix(
|
|||
metrics["min_p"] = min_p
|
||||
"""
|
||||
|
||||
return res, metrics
|
||||
|
||||
"""
|
||||
def add_gumbel_noise(t, temperature, device):
|
||||
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
|
||||
"""
|
||||
|
||||
def log(t, eps = 1e-20):
|
||||
return torch.log(t.clamp(min = eps))
|
||||
|
||||
def gumbel_noise(t):
|
||||
noise = torch.zeros_like(t).uniform_(0, 1)
|
||||
return -log(-log(noise))
|
||||
|
||||
def gumbel_sample(t, temperature = 1., dim = -1):
|
||||
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
|
||||
|
||||
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
|
||||
class SampleScheduler:
|
||||
def __init__(
|
||||
self,
|
||||
forward_lambda = None,
|
||||
mask_token = -1,
|
||||
max_steps = 25,
|
||||
device = "cuda",
|
||||
sampling_temperature=1.0,
|
||||
):
|
||||
self.forward_lambda = forward_lambda
|
||||
self.max_steps = max_steps
|
||||
self.mask_token = mask_token
|
||||
self.device = device
|
||||
|
||||
def sample( self, seq_len ):
|
||||
starting_temperature = 0.2
|
||||
|
||||
input_ids = torch.ones((seq_len,), dtype=torch.long, device=self.device) * self.mask_token
|
||||
scores = torch.zeros((seq_len,), dtype=torch.float32, device=self.device)
|
||||
|
||||
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, self.max_steps), reversed(range(self.max_steps))):
|
||||
# anneal temperature
|
||||
temperature = starting_temperature * (steps_until_x0 / self.max_steps)
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# number of tokens to mask off to "noise" the input sequence
|
||||
masked_tokens_n = max(int( noise_p * seq_len ), 1)
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
|
||||
# mask off inputs
|
||||
input_ids = input_ids.scatter(0, masked_indices, self.mask_token)
|
||||
# boolean mask
|
||||
is_masked = input_ids == self.mask_token
|
||||
# sample
|
||||
sampled, greedy_sampled = self.forward_lambda( input_ids, step=timestep, temperature=temperature )
|
||||
# extract logits
|
||||
logits = greedy_sampled.logits[0]
|
||||
filtered_logits = sampled.logits[0]
|
||||
|
||||
# sample with gumbelnoise
|
||||
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||
# keep unmasked tokens
|
||||
input_ids = torch.where( is_masked, sampled_ids, input_ids )
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = 1.0 - torch.concat([ F.softmax(logits[i, :], dim=0)[token, None] for i, token in enumerate(input_ids) ])
|
||||
|
||||
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, temperature, input_ids, scores )
|
||||
|
||||
return input_ids
|
||||
return res, metrics
|
Loading…
Reference in New Issue
Block a user