I suppose I just have a shit training method since the sampler is as solid as I can get it...............
This commit is contained in:
parent
13b54953bd
commit
811b15d280
|
@ -1711,7 +1711,7 @@ class Base(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
if "len" not in self.capabilities and prev_list is not None and repetition_penalty != 1.0:
|
if prev_list is not None and repetition_penalty != 1.0:
|
||||||
# to-do: figure out a faster way to handle tolist()
|
# to-do: figure out a faster way to handle tolist()
|
||||||
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,6 @@ from tqdm import trange
|
||||||
from .base import Base, list_to_tensor, Categorical, _dropout_mask
|
from .base import Base, list_to_tensor, Categorical, _dropout_mask
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..emb.qnt import trim, repeat_extend_audio
|
from ..emb.qnt import trim, repeat_extend_audio
|
||||||
from ..samplers import SampleScheduler
|
|
||||||
|
|
||||||
def clamp(n, lo, hi):
|
def clamp(n, lo, hi):
|
||||||
return max(lo, min(n, hi))
|
return max(lo, min(n, hi))
|
||||||
|
@ -237,34 +236,79 @@ class NAR(Base):
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
enable_lora( self, cfg.lora.active_level( level ) if use_lora is None else use_lora )
|
||||||
|
|
||||||
|
def log(x, eps = 1e-20):
|
||||||
|
return torch.log(x.clamp(min = eps))
|
||||||
|
|
||||||
|
def gumbel_sample(x, temperature = 1., dim = -1):
|
||||||
|
return ((x / max(temperature, 1e-10)) + -log(-log(torch.zeros_like(x).uniform_(0, 1)))).argmax(dim = dim)
|
||||||
|
|
||||||
|
test_artifact = None
|
||||||
|
|
||||||
|
"""
|
||||||
|
if False:
|
||||||
|
path = "./data/237_134500_000036_000004.enc"
|
||||||
|
test_artifact = np.load(path, allow_pickle=True)[()]
|
||||||
|
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
|
||||||
|
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
|
||||||
|
proms_list = [ resps for resps in resps_list ]
|
||||||
|
len_list = [ resps.shape[0] for resps in resps_list ]
|
||||||
|
"""
|
||||||
|
|
||||||
_super = super()
|
_super = super()
|
||||||
def forward_lambda( ids, step, temperature ):
|
def demask_sampling( seq_len, max_steps=10, temperature=0.3 ):
|
||||||
|
starting_temperature = temperature
|
||||||
|
|
||||||
|
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
|
||||||
|
scores = torch.zeros((seq_len,), dtype=torch.float32, device=device)
|
||||||
|
|
||||||
quant_levels = [ level for _ in range(batch_size) ]
|
quant_levels = [ level for _ in range(batch_size) ]
|
||||||
prev_list = [ ids ]
|
prev_list = [ input_ids ]
|
||||||
seq_len = ids.shape[-1]
|
|
||||||
|
|
||||||
|
noise_scale = 1.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
if test_artifact is not None:
|
||||||
|
nonlocal resps_list
|
||||||
|
input = resps_list[0][:, 0]
|
||||||
|
noise_scale = 1.0
|
||||||
|
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_scale else token for _, token in enumerate( input ) ], dtype=torch.int16, device=device )
|
||||||
|
print( input )
|
||||||
|
print( input_ids )
|
||||||
|
"""
|
||||||
|
|
||||||
|
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, max_steps), reversed(range(max_steps))):
|
||||||
|
# anneal temperature
|
||||||
|
temperature = starting_temperature * (steps_until_x0 / max_steps)
|
||||||
|
# get noise level, per cosine scheduling
|
||||||
|
noise_p = math.cos( timestep * math.pi * 0.5 ) * noise_scale
|
||||||
|
# number of tokens to mask off to "noise" the input sequence
|
||||||
|
masked_tokens_n = max(int( noise_p * seq_len ), 1)
|
||||||
|
# pick the worst scoring tokens to mask off
|
||||||
|
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
|
||||||
|
# mask off inputs
|
||||||
|
input_ids = input_ids.scatter(0, masked_indices, self.stop_token)
|
||||||
|
# boolean mask
|
||||||
|
is_masked = input_ids == self.stop_token
|
||||||
|
# sample
|
||||||
sampling_top_k = math.floor( seq_len * 0.9 )
|
sampling_top_k = math.floor( seq_len * 0.9 )
|
||||||
|
resps_list = [ input_ids ]
|
||||||
inputs = _super.inputs(
|
inputs = _super.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
resps_list=prev_list,
|
resps_list=resps_list,
|
||||||
lang_list=lang_list,
|
lang_list=lang_list,
|
||||||
tone_list=tone_list,
|
tone_list=tone_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = _super.forward(
|
output = _super.forward(
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
layer_skip_variables=sampling_layer_skip_variables,
|
layer_skip_variables=sampling_layer_skip_variables,
|
||||||
)
|
)
|
||||||
logits = output.logits
|
|
||||||
|
|
||||||
# sample with sampler settings
|
# sample with sampler settings
|
||||||
sampled = _super.sample(
|
filtered_sampled = _super.sample(
|
||||||
logits=logits,
|
logits=output.logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
|
@ -276,38 +320,39 @@ class NAR(Base):
|
||||||
repetition_penalty=sampling_repetition_penalty,
|
repetition_penalty=sampling_repetition_penalty,
|
||||||
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||||
length_penalty=sampling_length_penalty,
|
length_penalty=sampling_length_penalty,
|
||||||
#beam_width=sampling_beam_width,
|
|
||||||
#mirostat=mirostat,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# greedy sample
|
# retrieves unfiltered logits
|
||||||
greedy_sampled = _super.sample(
|
unfiltered_sampled = _super.sample(
|
||||||
logits=logits,
|
logits=output.logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
#min_temperature=sampling_min_temperature,
|
|
||||||
#top_p=sampling_top_p,
|
|
||||||
#top_k=sampling_top_k,
|
|
||||||
#min_p=sampling_min_p,
|
|
||||||
#repetition_penalty=sampling_repetition_penalty,
|
|
||||||
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
|
||||||
#length_penalty=sampling_length_penalty,
|
|
||||||
#beam_width=sampling_beam_width,
|
|
||||||
#mirostat=mirostat,
|
|
||||||
)
|
)
|
||||||
|
# update previous list of tokens
|
||||||
|
prev_list = [ input_ids ]
|
||||||
|
|
||||||
return sampled, greedy_sampled
|
# extract logits
|
||||||
|
filtered_logits = filtered_sampled.logits[0]
|
||||||
|
unfiltered_logits = unfiltered_sampled.logits[0]
|
||||||
|
|
||||||
scheduler = SampleScheduler(
|
# extract scores
|
||||||
device=device,
|
filtered_scores = filtered_sampled.scores[0]
|
||||||
mask_token=self.stop_token,
|
unfiltered_scores = unfiltered_sampled.scores[0]
|
||||||
max_steps=5,
|
|
||||||
forward_lambda=forward_lambda,
|
# sample with gumbelnoise
|
||||||
sampling_temperature=sampling_temperature,
|
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||||
)
|
# keep unmasked tokens
|
||||||
prev_list = [ scheduler.sample( seq_len=len_list[0] ) ]
|
input_ids = torch.where( is_masked, sampled_ids, input_ids )
|
||||||
|
# update scores (conjugated to put the worst scores at the top)
|
||||||
|
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
|
||||||
|
|
||||||
|
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||||
|
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
# perform demasked sampling (mock diffusion)
|
||||||
|
prev_list = [ demask_sampling( seq_len=l ) for l in len_list ]
|
||||||
|
|
||||||
# expand if given a raw 1D tensor
|
# expand if given a raw 1D tensor
|
||||||
for i, resp in enumerate(prev_list):
|
for i, resp in enumerate(prev_list):
|
||||||
|
|
|
@ -521,69 +521,3 @@ def sample_entropix(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return res, metrics
|
return res, metrics
|
||||||
|
|
||||||
"""
|
|
||||||
def add_gumbel_noise(t, temperature, device):
|
|
||||||
return (t + torch.Tensor(temperature * np.random.gumbel(size=t.shape)).to(device))
|
|
||||||
"""
|
|
||||||
|
|
||||||
def log(t, eps = 1e-20):
|
|
||||||
return torch.log(t.clamp(min = eps))
|
|
||||||
|
|
||||||
def gumbel_noise(t):
|
|
||||||
noise = torch.zeros_like(t).uniform_(0, 1)
|
|
||||||
return -log(-log(noise))
|
|
||||||
|
|
||||||
def gumbel_sample(t, temperature = 1., dim = -1):
|
|
||||||
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
|
|
||||||
|
|
||||||
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
|
|
||||||
class SampleScheduler:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
forward_lambda = None,
|
|
||||||
mask_token = -1,
|
|
||||||
max_steps = 25,
|
|
||||||
device = "cuda",
|
|
||||||
sampling_temperature=1.0,
|
|
||||||
):
|
|
||||||
self.forward_lambda = forward_lambda
|
|
||||||
self.max_steps = max_steps
|
|
||||||
self.mask_token = mask_token
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def sample( self, seq_len ):
|
|
||||||
starting_temperature = 0.2
|
|
||||||
|
|
||||||
input_ids = torch.ones((seq_len,), dtype=torch.long, device=self.device) * self.mask_token
|
|
||||||
scores = torch.zeros((seq_len,), dtype=torch.float32, device=self.device)
|
|
||||||
|
|
||||||
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, self.max_steps), reversed(range(self.max_steps))):
|
|
||||||
# anneal temperature
|
|
||||||
temperature = starting_temperature * (steps_until_x0 / self.max_steps)
|
|
||||||
# get noise level, per cosine scheduling
|
|
||||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
|
||||||
# number of tokens to mask off to "noise" the input sequence
|
|
||||||
masked_tokens_n = max(int( noise_p * seq_len ), 1)
|
|
||||||
# pick the worst scoring tokens to mask off
|
|
||||||
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
|
|
||||||
# mask off inputs
|
|
||||||
input_ids = input_ids.scatter(0, masked_indices, self.mask_token)
|
|
||||||
# boolean mask
|
|
||||||
is_masked = input_ids == self.mask_token
|
|
||||||
# sample
|
|
||||||
sampled, greedy_sampled = self.forward_lambda( input_ids, step=timestep, temperature=temperature )
|
|
||||||
# extract logits
|
|
||||||
logits = greedy_sampled.logits[0]
|
|
||||||
filtered_logits = sampled.logits[0]
|
|
||||||
|
|
||||||
# sample with gumbelnoise
|
|
||||||
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
|
||||||
# keep unmasked tokens
|
|
||||||
input_ids = torch.where( is_masked, sampled_ids, input_ids )
|
|
||||||
# update scores (conjugated to put the worst scores at the top)
|
|
||||||
scores = 1.0 - torch.concat([ F.softmax(logits[i, :], dim=0)[token, None] for i, token in enumerate(input_ids) ])
|
|
||||||
|
|
||||||
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, temperature, input_ids, scores )
|
|
||||||
|
|
||||||
return input_ids
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user