lol
This commit is contained in:
parent
d6f7c86a5c
commit
40b089daf3
|
@ -99,10 +99,7 @@ def plot_entropies( entropies ):
|
||||||
fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second )
|
fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data = {}
|
data = { key: [ e[0][key] for e in entropies ] for key in entropies[0][0].keys() }
|
||||||
|
|
||||||
for key in entropies[0][0].keys():
|
|
||||||
data[key] = [ e[0][key].item() if hasattr( e[0][key], "item" ) else e[0][key] for e in entropies ]
|
|
||||||
|
|
||||||
df = pd.DataFrame(data)
|
df = pd.DataFrame(data)
|
||||||
df.plot()
|
df.plot()
|
||||||
|
|
|
@ -252,19 +252,19 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
|
||||||
|
|
||||||
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
|
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
|
||||||
return {
|
return {
|
||||||
"logits_entropy": torch.mean(entropy),
|
"logits_entropy": torch.mean(entropy).item(),
|
||||||
"logits_varentropy": torch.mean(varentropy),
|
"logits_varentropy": torch.mean(varentropy).item(),
|
||||||
"attn_entropy": torch.mean(attn_entropy),
|
"attn_entropy": torch.mean(attn_entropy).item(),
|
||||||
"attn_varentropy": torch.mean(attn_varentropy),
|
"attn_varentropy": torch.mean(attn_varentropy).item(),
|
||||||
"agreement": torch.mean(agreement),
|
"agreement": torch.mean(agreement).item(),
|
||||||
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)),
|
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)).item(),
|
||||||
"action": -1
|
"action": -1
|
||||||
}
|
}
|
||||||
|
|
||||||
# to-do: play around with these values
|
# to-do: play around with these values
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class EntropixSamplerConfig:
|
class EntropixSamplerConfig:
|
||||||
temp: float = 0.85
|
temp: float = 0.666
|
||||||
top_p: float = 0.90
|
top_p: float = 0.90
|
||||||
top_k: int = 27
|
top_k: int = 27
|
||||||
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
|
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
|
||||||
|
@ -315,6 +315,8 @@ class EntropixSamplerConfig:
|
||||||
min_p_max: int = 0.5
|
min_p_max: int = 0.5
|
||||||
|
|
||||||
Exponential = torch.distributions.exponential.Exponential(1.0)
|
Exponential = torch.distributions.exponential.Exponential(1.0)
|
||||||
|
|
||||||
|
# Doing as close to the original sampling method just to reduce variance
|
||||||
def _sample_entropix(
|
def _sample_entropix(
|
||||||
logits,
|
logits,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
|
@ -365,6 +367,10 @@ def sample_entropix(
|
||||||
min_p=0.0,
|
min_p=0.0,
|
||||||
cfg=EntropixSamplerConfig(),
|
cfg=EntropixSamplerConfig(),
|
||||||
):
|
):
|
||||||
|
temperature = cfg.temp
|
||||||
|
top_k = cfg.top_k
|
||||||
|
top_p = cfg.top_p
|
||||||
|
|
||||||
metrics = calculate_entropix_metrics( logits, attentions )
|
metrics = calculate_entropix_metrics( logits, attentions )
|
||||||
|
|
||||||
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
|
||||||
|
@ -403,10 +409,10 @@ def sample_entropix(
|
||||||
logits_uncertainty = ent + vent
|
logits_uncertainty = ent + vent
|
||||||
attn_uncertainty = attn_ent + attn_vent
|
attn_uncertainty = attn_ent + attn_vent
|
||||||
|
|
||||||
temperature *= float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement)
|
temperature *= 1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement
|
||||||
top_p = float(top_p * (1 + cfg.ada_top_p * attn_vent))
|
top_p = top_p * (1 + cfg.ada_top_p * attn_vent)
|
||||||
top_k = int(round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement))))
|
top_k = round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)))
|
||||||
min_p = float(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty))
|
min_p = cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty)
|
||||||
|
|
||||||
samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ]
|
samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user