agony
This commit is contained in:
parent
c127c4e488
commit
13b54953bd
|
@ -40,7 +40,7 @@ from ..data import get_task_symmap
|
||||||
|
|
||||||
# these seem more elegant than a dict
|
# these seem more elegant than a dict
|
||||||
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
|
Logits = namedtuple('Logits', ['logits', 'state', 'aux_loss', 'attentions', 'hidden_states', 'exited_layer'])
|
||||||
Sampled = namedtuple('Sampled', ['out', 'scores', 'entropy'])
|
Sampled = namedtuple('Sampled', ['out', 'logits', 'scores', 'entropy'])
|
||||||
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
LossStats = namedtuple('LossStats', ['loss', 'stats'])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -1681,7 +1681,7 @@ class Base(nn.Module):
|
||||||
) for batch, logit in enumerate(logits) ]
|
) for batch, logit in enumerate(logits) ]
|
||||||
|
|
||||||
if res:
|
if res:
|
||||||
return Sampled([ r[0] for r in res ], scores, [ r[1] for r in res ])
|
return Sampled([ r[0] for r in res ], logits, scores, [ r[1] for r in res ])
|
||||||
"""
|
"""
|
||||||
elif quant_levels is None:
|
elif quant_levels is None:
|
||||||
seq_lens = [ logit.shape[0] for logit in logits ]
|
seq_lens = [ logit.shape[0] for logit in logits ]
|
||||||
|
@ -1772,4 +1772,4 @@ class Base(nn.Module):
|
||||||
for logit, tokens in zip(logits, res)
|
for logit, tokens in zip(logits, res)
|
||||||
]
|
]
|
||||||
|
|
||||||
return Sampled(res, scores, entropy)
|
return Sampled(res, logits, scores, entropy)
|
|
@ -226,7 +226,7 @@ class NAR(Base):
|
||||||
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
sampling_layer_skip_variables["max_layer"] = sampling_layer_skip_exit_layer
|
||||||
|
|
||||||
# initial condition
|
# initial condition
|
||||||
len_list = [ min(l, 500) for l in len_list ]
|
len_list = [ min(l, 75*3) for l in len_list ]
|
||||||
metrics = []
|
metrics = []
|
||||||
|
|
||||||
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
mask_token = torch.tensor([self.stop_token], dtype=torch.int16, device=device)
|
||||||
|
@ -240,9 +240,11 @@ class NAR(Base):
|
||||||
_super = super()
|
_super = super()
|
||||||
def forward_lambda( ids, step, temperature ):
|
def forward_lambda( ids, step, temperature ):
|
||||||
quant_levels = [ level for _ in range(batch_size) ]
|
quant_levels = [ level for _ in range(batch_size) ]
|
||||||
prev_list = [ ids[0] ]
|
prev_list = [ ids ]
|
||||||
seq_len = ids.shape[-1]
|
seq_len = ids.shape[-1]
|
||||||
|
|
||||||
|
sampling_top_k = math.floor( seq_len * 0.9 )
|
||||||
|
|
||||||
inputs = _super.inputs(
|
inputs = _super.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
|
@ -260,6 +262,7 @@ class NAR(Base):
|
||||||
)
|
)
|
||||||
logits = output.logits
|
logits = output.logits
|
||||||
|
|
||||||
|
# sample with sampler settings
|
||||||
sampled = _super.sample(
|
sampled = _super.sample(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
prev_list=prev_list,
|
prev_list=prev_list,
|
||||||
|
@ -277,14 +280,30 @@ class NAR(Base):
|
||||||
#mirostat=mirostat,
|
#mirostat=mirostat,
|
||||||
)
|
)
|
||||||
|
|
||||||
ids = sampled[0]
|
# greedy sample
|
||||||
|
greedy_sampled = _super.sample(
|
||||||
|
logits=logits,
|
||||||
|
prev_list=prev_list,
|
||||||
|
quant_levels=quant_levels,
|
||||||
|
|
||||||
return logits[0][-seq_len:].unsqueeze(0), ids[0].unsqueeze(0)
|
temperature=0.0,
|
||||||
|
#min_temperature=sampling_min_temperature,
|
||||||
|
#top_p=sampling_top_p,
|
||||||
|
#top_k=sampling_top_k,
|
||||||
|
#min_p=sampling_min_p,
|
||||||
|
#repetition_penalty=sampling_repetition_penalty,
|
||||||
|
#repetition_penalty_decay=sampling_repetition_penalty_decay,
|
||||||
|
#length_penalty=sampling_length_penalty,
|
||||||
|
#beam_width=sampling_beam_width,
|
||||||
|
#mirostat=mirostat,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sampled, greedy_sampled
|
||||||
|
|
||||||
scheduler = SampleScheduler(
|
scheduler = SampleScheduler(
|
||||||
device=device,
|
device=device,
|
||||||
mask_token=self.stop_token,
|
mask_token=self.stop_token,
|
||||||
max_steps=30,
|
max_steps=5,
|
||||||
forward_lambda=forward_lambda,
|
forward_lambda=forward_lambda,
|
||||||
sampling_temperature=sampling_temperature,
|
sampling_temperature=sampling_temperature,
|
||||||
)
|
)
|
||||||
|
|
|
@ -537,13 +537,6 @@ def gumbel_noise(t):
|
||||||
def gumbel_sample(t, temperature = 1., dim = -1):
|
def gumbel_sample(t, temperature = 1., dim = -1):
|
||||||
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
|
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
|
||||||
|
|
||||||
def top_k(logits, thres = 0.9):
|
|
||||||
k = math.ceil((1 - thres) * logits.shape[-1])
|
|
||||||
val, ind = logits.topk(k, dim = -1)
|
|
||||||
probs = torch.full_like(logits, float('-inf'))
|
|
||||||
probs.scatter_(2, ind, val)
|
|
||||||
return probs
|
|
||||||
|
|
||||||
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
|
# this provides mostly poor output, but it might just be a matter of how I'm naively training the model for """diffusion"""
|
||||||
class SampleScheduler:
|
class SampleScheduler:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -558,66 +551,39 @@ class SampleScheduler:
|
||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.mask_token = mask_token
|
self.mask_token = mask_token
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
"""
|
|
||||||
self.ratios = (np.cos(np.linspace(0, math.pi / 2, self.max_steps + 1)))[1:-1]
|
|
||||||
self.annealed_temperatures = (1 - np.linspace(0, 1, self.max_steps + 1))[:-2]
|
|
||||||
self.sampling_temperatures = [sampling_temperature for _ in range(self.max_steps)]
|
|
||||||
"""
|
|
||||||
|
|
||||||
# lifted from https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/muse_maskgit_pytorch.py#L493
|
|
||||||
def sample( self, seq_len ):
|
def sample( self, seq_len ):
|
||||||
ids = torch.full((1, seq_len), self.mask_token, dtype = torch.long, device = self.device)
|
starting_temperature = 0.2
|
||||||
scores = torch.zeros((1, seq_len), dtype = torch.float32, device = self.device)
|
|
||||||
|
|
||||||
for step in range( self.max_steps ):
|
input_ids = torch.ones((seq_len,), dtype=torch.long, device=self.device) * self.mask_token
|
||||||
t = step / self.max_steps
|
scores = torch.zeros((seq_len,), dtype=torch.float32, device=self.device)
|
||||||
mask_ratio = math.cos(t * math.pi * 0.5)
|
|
||||||
sampling_temperature = 1.0
|
|
||||||
annealed_temperature = sampling_temperature * (1.0 - t)
|
|
||||||
|
|
||||||
num_token_masked = max(int(mask_ratio * seq_len), 1)
|
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, self.max_steps), reversed(range(self.max_steps))):
|
||||||
masked_indices = scores.topk(num_token_masked, dim = -1).indices
|
# anneal temperature
|
||||||
|
temperature = starting_temperature * (steps_until_x0 / self.max_steps)
|
||||||
|
# get noise level, per cosine scheduling
|
||||||
|
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||||
|
# number of tokens to mask off to "noise" the input sequence
|
||||||
|
masked_tokens_n = max(int( noise_p * seq_len ), 1)
|
||||||
|
# pick the worst scoring tokens to mask off
|
||||||
|
masked_indices = scores.topk( masked_tokens_n, dim=-1 ).indices
|
||||||
|
# mask off inputs
|
||||||
|
input_ids = input_ids.scatter(0, masked_indices, self.mask_token)
|
||||||
|
# boolean mask
|
||||||
|
is_masked = input_ids == self.mask_token
|
||||||
|
# sample
|
||||||
|
sampled, greedy_sampled = self.forward_lambda( input_ids, step=timestep, temperature=temperature )
|
||||||
|
# extract logits
|
||||||
|
logits = greedy_sampled.logits[0]
|
||||||
|
filtered_logits = sampled.logits[0]
|
||||||
|
|
||||||
ids = ids.scatter(1, masked_indices, self.mask_token)
|
# sample with gumbelnoise
|
||||||
|
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||||
|
# keep unmasked tokens
|
||||||
|
input_ids = torch.where( is_masked, sampled_ids, input_ids )
|
||||||
|
# update scores (conjugated to put the worst scores at the top)
|
||||||
|
scores = 1.0 - torch.concat([ F.softmax(logits[i, :], dim=0)[token, None] for i, token in enumerate(input_ids) ])
|
||||||
|
|
||||||
logits, _ = self.forward_lambda( ids, step=step, temperature=annealed_temperature )
|
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, temperature, input_ids, scores )
|
||||||
filtered_logits = top_k( logits )
|
|
||||||
sampled_ids = gumbel_sample( filtered_logits, temperature=annealed_temperature, dim=-1 )
|
|
||||||
|
|
||||||
is_masked = ids == self.mask_token
|
return input_ids
|
||||||
ids = torch.where( is_masked, sampled_ids, ids )
|
|
||||||
|
|
||||||
probs_without_temperature = logits.softmax(dim = -1)
|
|
||||||
|
|
||||||
scores = 1 - probs_without_temperature.gather(2, sampled_ids[..., None])
|
|
||||||
scores = rearrange(scores, '... 1 -> ...')
|
|
||||||
#scores = scores.to(dtype=torch.float64).masked_fill(~is_masked, -1e5)
|
|
||||||
|
|
||||||
"""
|
|
||||||
if step + 1 == self.max_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
# lifted from https://github.com/LeapLabTHU/ImprovedNAT/blob/main/libs/nat_misc.py#L39
|
|
||||||
# create next input sequence
|
|
||||||
mask = (ids == self.mask_token)
|
|
||||||
mask_len = torch.Tensor([np.floor(seq_len * mask_ratio)]).to(self.device)
|
|
||||||
mask_len = torch.maximum(
|
|
||||||
torch.Tensor([1]).to(self.device),
|
|
||||||
torch.minimum( torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len )
|
|
||||||
)[0].squeeze()
|
|
||||||
|
|
||||||
logits = torch.log_softmax(logits, dim=-1)
|
|
||||||
sampled_logits = torch.squeeze(torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
|
|
||||||
sampled_ids = torch.where(mask, sampled_ids, ids)
|
|
||||||
sampled_logits = torch.where(mask, sampled_logits, +np.inf).float()
|
|
||||||
|
|
||||||
confidence = add_gumbel_noise(sampled_logits, annealed_temperature, self.device)
|
|
||||||
sorted_confidence, _ = torch.sort(confidence, axis=-1)
|
|
||||||
cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()]
|
|
||||||
masking = (confidence <= cut_off)
|
|
||||||
|
|
||||||
ids = torch.where(masking, self.mask_token, sampled_ids)
|
|
||||||
"""
|
|
||||||
|
|
||||||
return sampled_ids[0]
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user