This commit is contained in:
mrq 2024-11-08 13:34:39 -06:00
parent c127c4e488
commit 13b54953bd
3 changed files with 56 additions and 71 deletions

View File

@ -40,7 +40,7 @@ from ..data import get_task_symmap
# these seem more elegant than a dict
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
Sampled = namedtuple('Sampled', ['out', 'logits', 'scores', 'entropy'])
LossStats = namedtuple('LossStats', ['loss', 'stats'])
"""
@ -1681,7 +1681,7 @@ class Base(nn.Module):
) for batch, logit in enumerate(logits) ]
if res:
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
return Sampled([ r[0] for r in res ], logits, scores, [ r[1] for r in res ])
"""
elif quant_levels is None:
seq_lens = [ logit.shape[0] for logit in logits ]
@ -1772,4 +1772,4 @@ class Base(nn.Module):
for logit, tokens in zip(logits, res)
]
return Sampled(res, scores, entropy)
return Sampled(res, logits, scores, entropy)

View File

@ -226,7 +226,7 @@ class NAR(Base):
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
# initial condition
len_list = [ min(l, 500) for l in len_list ]
len_list = [ min(l, 75*3) for l in len_list ]
metrics = []
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
@ -240,9 +240,11 @@ class NAR(Base):
_super = super()
def forward_lambda( ids, step, temperature ):
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ ids[0] ]
prev_list = [ ids ]
seq_len = ids.shape[-1]
sampling_top_k = math.floor( seq_len * 0.9 )
inputs = _super.inputs(
text_list=text_list,
proms_list=proms_list,
@ -260,6 +262,7 @@ class NAR(Base):
)
logits = output.logits
# sample with sampler settings
sampled = _super.sample(
logits=logits,
prev_list=prev_list,
@ -277,14 +280,30 @@ class NAR(Base):
#mirostat=mirostat,
)
ids = sampled[0]
# greedy sample
greedy_sampled = _super.sample(
logits=logits,
prev_list=prev_list,
quant_levels=quant_levels,
return logits[0][-seq_len:].unsqueeze(0), ids[0].unsqueeze(0)
temperature=0.0,
#min_temperature=sampling_min_temperature,
#top_p=sampling_top_p,
#top_k=sampling_top_k,
#min_p=sampling_min_p,
#repetition_penalty=sampling_repetition_penalty,
#repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat=mirostat,
)
return sampled, greedy_sampled
scheduler = SampleScheduler(
device=device,
mask_token=self.stop_token,
max_steps=30,
max_steps=5,
forward_lambda=forward_lambda,
sampling_temperature=sampling_temperature,
)

View File

@ -537,13 +537,6 @@ def gumbel_noise(t):
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
def top_k(logits, thres = 0.9):
k = math.ceil((1 - thres) * logits.shape[-1])
val, ind = logits.topk(k, dim = -1)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(2, ind, val)
return probs
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
class SampleScheduler:
def __init__(
@ -558,66 +551,39 @@ class SampleScheduler:
self.max_steps = max_steps
self.mask_token = mask_token
self.device = device
"""
self.ratios = (np.cos(np.linspace(0, math.pi / 2, self.max_steps + 1)))[1:-1]
self.annealed_temperatures = (1 - np.linspace(0, 1, self.max_steps + 1))[:-2]
self.sampling_temperatures = [sampling_temperature for _ in range(self.max_steps)]
"""
# lifted from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L493
def sample( self, seq_len ):
ids = torch.full((1, seq_len), self.mask_token, dtype = torch.long, device = self.device)
scores = torch.zeros((1, seq_len), dtype = torch.float32, device = self.device)
starting_temperature = 0.2
for step in range( self.max_steps ):
t = step / self.max_steps
mask_ratio = math.cos(t * math.pi * 0.5)
sampling_temperature = 1.0
annealed_temperature = sampling_temperature * (1.0 - t)
input_ids = torch.ones((seq_len,), dtype=torch.long, device=self.device) * self.mask_token
scores = torch.zeros((seq_len,), dtype=torch.float32, device=self.device)
num_token_masked = max(int(mask_ratio * seq_len), 1)
masked_indices = scores.topk(num_token_masked, dim = -1).indices
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, self.max_steps), reversed(range(self.max_steps))):
# anneal temperature
temperature = starting_temperature * (steps_until_x0 / self.max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 )
# number of tokens to mask off to "noise" the input sequence
masked_tokens_n = max(int( noise_p * seq_len ), 1)
# pick the worst scoring tokens to mask off
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
# mask off inputs
input_ids = input_ids.scatter(0, masked_indices, self.mask_token)
# boolean mask
is_masked = input_ids == self.mask_token
# sample
sampled, greedy_sampled = self.forward_lambda( input_ids, step=timestep, temperature=temperature )
# extract logits
logits = greedy_sampled.logits[0]
filtered_logits = sampled.logits[0]
ids = ids.scatter(1, masked_indices, self.mask_token)
# sample with gumbelnoise
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
# keep unmasked tokens
input_ids = torch.where( is_masked, sampled_ids, input_ids )
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.concat([ F.softmax(logits[i, :], dim=0)[token, None] for i, token in enumerate(input_ids) ])
logits, _ = self.forward_lambda( ids, step=step, temperature=annealed_temperature )
filtered_logits = top_k( logits )
sampled_ids = gumbel_sample( filtered_logits, temperature=annealed_temperature, dim=-1 )
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, temperature, input_ids, scores )
is_masked = ids == self.mask_token
ids = torch.where( is_masked, sampled_ids, ids )
probs_without_temperature = logits.softmax(dim = -1)
scores = 1 - probs_without_temperature.gather(2, sampled_ids[..., None])
scores = rearrange(scores, '... 1 -> ...')
#scores = scores.to(dtype=torch.float64).masked_fill(~is_masked, -1e5)
"""
if step + 1 == self.max_steps:
break
# lifted from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
# create next input sequence
mask = (ids == self.mask_token)
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(self.device)
mask_len = torch.maximum(
torch.Tensor([1]).to(self.device),
torch.minimum( torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len )
)[0].squeeze()
logits = torch.log_softmax(logits, dim=-1)
sampled_logits = torch.squeeze(torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
sampled_ids = torch.where(mask, sampled_ids, ids)
sampled_logits = torch.where(mask, sampled_logits, +np.inf).float()
confidence = add_gumbel_noise(sampled_logits, annealed_temperature, self.device)
sorted_confidence, _ = torch.sort(confidence, axis=-1)
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
masking = (confidence <= cut_off)
ids = torch.where(masking, self.mask_token, sampled_ids)
"""
return sampled_ids[0]
return input_ids