agony
This commit is contained in:
parent
c127c4e488
commit
13b54953bd
|
@ -40,7 +40,7 @@ from ..data import get_task_symmap
|
|||
|
||||
# these seem more elegant than a dict
|
||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
|
||||
Sampled = namedtuple('Sampled', ['out', 'logits', 'scores', 'entropy'])
|
||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||
|
||||
"""
|
||||
|
@ -1681,7 +1681,7 @@ class Base(nn.Module):
|
|||
) for batch, logit in enumerate(logits) ]
|
||||
|
||||
if res:
|
||||
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
|
||||
return Sampled([ r[0] for r in res ], logits, scores, [ r[1] for r in res ])
|
||||
"""
|
||||
elif quant_levels is None:
|
||||
seq_lens = [ logit.shape[0] for logit in logits ]
|
||||
|
@ -1772,4 +1772,4 @@ class Base(nn.Module):
|
|||
for logit, tokens in zip(logits, res)
|
||||
]
|
||||
|
||||
return Sampled(res, scores, entropy)
|
||||
return Sampled(res, logits, scores, entropy)
|
|
@ -226,7 +226,7 @@ class NAR(Base):
|
|||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||
|
||||
# initial condition
|
||||
len_list = [ min(l, 500) for l in len_list ]
|
||||
len_list = [ min(l, 75*3) for l in len_list ]
|
||||
metrics = []
|
||||
|
||||
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
||||
|
@ -240,9 +240,11 @@ class NAR(Base):
|
|||
_super = super()
|
||||
def forward_lambda( ids, step, temperature ):
|
||||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
prev_list = [ ids[0] ]
|
||||
prev_list = [ ids ]
|
||||
seq_len = ids.shape[-1]
|
||||
|
||||
sampling_top_k = math.floor( seq_len * 0.9 )
|
||||
|
||||
inputs = _super.inputs(
|
||||
text_list=text_list,
|
||||
proms_list=proms_list,
|
||||
|
@ -260,6 +262,7 @@ class NAR(Base):
|
|||
)
|
||||
logits = output.logits
|
||||
|
||||
# sample with sampler settings
|
||||
sampled = _super.sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
|
@ -277,14 +280,30 @@ class NAR(Base):
|
|||
#mirostat=mirostat,
|
||||
)
|
||||
|
||||
ids = sampled[0]
|
||||
# greedy sample
|
||||
greedy_sampled = _super.sample(
|
||||
logits=logits,
|
||||
prev_list=prev_list,
|
||||
quant_levels=quant_levels,
|
||||
|
||||
return logits[0][-seq_len:].unsqueeze(0), ids[0].unsqueeze(0)
|
||||
temperature=0.0,
|
||||
#min_temperature=sampling_min_temperature,
|
||||
#top_p=sampling_top_p,
|
||||
#top_k=sampling_top_k,
|
||||
#min_p=sampling_min_p,
|
||||
#repetition_penalty=sampling_repetition_penalty,
|
||||
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||
#length_penalty=sampling_length_penalty,
|
||||
#beam_width=sampling_beam_width,
|
||||
#mirostat=mirostat,
|
||||
)
|
||||
|
||||
return sampled, greedy_sampled
|
||||
|
||||
scheduler = SampleScheduler(
|
||||
device=device,
|
||||
mask_token=self.stop_token,
|
||||
max_steps=30,
|
||||
max_steps=5,
|
||||
forward_lambda=forward_lambda,
|
||||
sampling_temperature=sampling_temperature,
|
||||
)
|
||||
|
|
|
@ -537,13 +537,6 @@ def gumbel_noise(t):
|
|||
def gumbel_sample(t, temperature = 1., dim = -1):
|
||||
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
|
||||
|
||||
def top_k(logits, thres = 0.9):
|
||||
k = math.ceil((1 - thres) * logits.shape[-1])
|
||||
val, ind = logits.topk(k, dim = -1)
|
||||
probs = torch.full_like(logits, float('-inf'))
|
||||
probs.scatter_(2, ind, val)
|
||||
return probs
|
||||
|
||||
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
|
||||
class SampleScheduler:
|
||||
def __init__(
|
||||
|
@ -558,66 +551,39 @@ class SampleScheduler:
|
|||
self.max_steps = max_steps
|
||||
self.mask_token = mask_token
|
||||
self.device = device
|
||||
|
||||
"""
|
||||
self.ratios = (np.cos(np.linspace(0, math.pi / 2, self.max_steps + 1)))[1:-1]
|
||||
self.annealed_temperatures = (1 - np.linspace(0, 1, self.max_steps + 1))[:-2]
|
||||
self.sampling_temperatures = [sampling_temperature for _ in range(self.max_steps)]
|
||||
"""
|
||||
|
||||
# lifted from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L493
|
||||
def sample( self, seq_len ):
|
||||
ids = torch.full((1, seq_len), self.mask_token, dtype = torch.long, device = self.device)
|
||||
scores = torch.zeros((1, seq_len), dtype = torch.float32, device = self.device)
|
||||
starting_temperature = 0.2
|
||||
|
||||
for step in range( self.max_steps ):
|
||||
t = step / self.max_steps
|
||||
mask_ratio = math.cos(t * math.pi * 0.5)
|
||||
sampling_temperature = 1.0
|
||||
annealed_temperature = sampling_temperature * (1.0 - t)
|
||||
input_ids = torch.ones((seq_len,), dtype=torch.long, device=self.device) * self.mask_token
|
||||
scores = torch.zeros((seq_len,), dtype=torch.float32, device=self.device)
|
||||
|
||||
num_token_masked = max(int(mask_ratio * seq_len), 1)
|
||||
masked_indices = scores.topk(num_token_masked, dim = -1).indices
|
||||
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, self.max_steps), reversed(range(self.max_steps))):
|
||||
# anneal temperature
|
||||
temperature = starting_temperature * (steps_until_x0 / self.max_steps)
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# number of tokens to mask off to "noise" the input sequence
|
||||
masked_tokens_n = max(int( noise_p * seq_len ), 1)
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
|
||||
# mask off inputs
|
||||
input_ids = input_ids.scatter(0, masked_indices, self.mask_token)
|
||||
# boolean mask
|
||||
is_masked = input_ids == self.mask_token
|
||||
# sample
|
||||
sampled, greedy_sampled = self.forward_lambda( input_ids, step=timestep, temperature=temperature )
|
||||
# extract logits
|
||||
logits = greedy_sampled.logits[0]
|
||||
filtered_logits = sampled.logits[0]
|
||||
|
||||
ids = ids.scatter(1, masked_indices, self.mask_token)
|
||||
# sample with gumbelnoise
|
||||
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||
# keep unmasked tokens
|
||||
input_ids = torch.where( is_masked, sampled_ids, input_ids )
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = 1.0 - torch.concat([ F.softmax(logits[i, :], dim=0)[token, None] for i, token in enumerate(input_ids) ])
|
||||
|
||||
logits, _ = self.forward_lambda( ids, step=step, temperature=annealed_temperature )
|
||||
filtered_logits = top_k( logits )
|
||||
sampled_ids = gumbel_sample( filtered_logits, temperature=annealed_temperature, dim=-1 )
|
||||
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, temperature, input_ids, scores )
|
||||
|
||||
is_masked = ids == self.mask_token
|
||||
ids = torch.where( is_masked, sampled_ids, ids )
|
||||
|
||||
probs_without_temperature = logits.softmax(dim = -1)
|
||||
|
||||
scores = 1 - probs_without_temperature.gather(2, sampled_ids[..., None])
|
||||
scores = rearrange(scores, '... 1 -> ...')
|
||||
#scores = scores.to(dtype=torch.float64).masked_fill(~is_masked, -1e5)
|
||||
|
||||
"""
|
||||
if step + 1 == self.max_steps:
|
||||
break
|
||||
|
||||
# lifted from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
|
||||
# create next input sequence
|
||||
mask = (ids == self.mask_token)
|
||||
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(self.device)
|
||||
mask_len = torch.maximum(
|
||||
torch.Tensor([1]).to(self.device),
|
||||
torch.minimum( torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len )
|
||||
)[0].squeeze()
|
||||
|
||||
logits = torch.log_softmax(logits, dim=-1)
|
||||
sampled_logits = torch.squeeze(torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
||||
sampled_ids = torch.where(mask, sampled_ids, ids)
|
||||
sampled_logits = torch.where(mask, sampled_logits, +np.inf).float()
|
||||
|
||||
confidence = add_gumbel_noise(sampled_logits, annealed_temperature, self.device)
|
||||
sorted_confidence, _ = torch.sort(confidence, axis=-1)
|
||||
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
|
||||
masking = (confidence <= cut_off)
|
||||
|
||||
ids = torch.where(masking, self.mask_token, sampled_ids)
|
||||
"""
|
||||
|
||||
return sampled_ids[0]
|
||||
return input_ids
|
||||
|
|
Loading…
Reference in New Issue
Block a user