This commit is contained in:
mrq 2024-10-12 09:57:34 -05:00
parent d6f7c86a5c
commit 40b089daf3
2 changed files with 18 additions and 15 deletions

View File

@ -99,10 +99,7 @@ def plot_entropies( entropies ):
fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second ) fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second )
""" """
data = {} data = { key: [ e[0][key] for e in entropies ] for key in entropies[0][0].keys() }
for key in entropies[0][0].keys():
data[key] = [ e[0][key].item() if hasattr( e[0][key], "item" ) else e[0][key] for e in entropies ]
df = pd.DataFrame(data) df = pd.DataFrame(data)
df.plot() df.plot()

View File

@ -252,19 +252,19 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ):
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)) interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
return { return {
"logits_entropy": torch.mean(entropy), "logits_entropy": torch.mean(entropy).item(),
"logits_varentropy": torch.mean(varentropy), "logits_varentropy": torch.mean(varentropy).item(),
"attn_entropy": torch.mean(attn_entropy), "attn_entropy": torch.mean(attn_entropy).item(),
"attn_varentropy": torch.mean(attn_varentropy), "attn_varentropy": torch.mean(attn_varentropy).item(),
"agreement": torch.mean(agreement), "agreement": torch.mean(agreement).item(),
"interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)), "interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)).item(),
"action": -1 "action": -1
} }
# to-do: play around with these values # to-do: play around with these values
@dataclass() @dataclass()
class EntropixSamplerConfig: class EntropixSamplerConfig:
temp: float = 0.85 temp: float = 0.666
top_p: float = 0.90 top_p: float = 0.90
top_k: int = 27 top_k: int = 27
min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth
@ -315,6 +315,8 @@ class EntropixSamplerConfig:
min_p_max: int = 0.5 min_p_max: int = 0.5
Exponential = torch.distributions.exponential.Exponential(1.0) Exponential = torch.distributions.exponential.Exponential(1.0)
# Doing as close to the original sampling method just to reduce variance
def _sample_entropix( def _sample_entropix(
logits, logits,
temperature=1.0, temperature=1.0,
@ -365,6 +367,10 @@ def sample_entropix(
min_p=0.0, min_p=0.0,
cfg=EntropixSamplerConfig(), cfg=EntropixSamplerConfig(),
): ):
temperature = cfg.temp
top_k = cfg.top_k
top_p = cfg.top_p
metrics = calculate_entropix_metrics( logits, attentions ) metrics = calculate_entropix_metrics( logits, attentions )
ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]
@ -403,10 +409,10 @@ def sample_entropix(
logits_uncertainty = ent + vent logits_uncertainty = ent + vent
attn_uncertainty = attn_ent + attn_vent attn_uncertainty = attn_ent + attn_vent
temperature *= float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement) temperature *= 1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement
top_p = float(top_p * (1 + cfg.ada_top_p * attn_vent)) top_p = top_p * (1 + cfg.ada_top_p * attn_vent)
top_k = int(round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)))) top_k = round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)))
min_p = float(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty)) min_p = cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty)
samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ] samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ]