diff --git a/vall_e/plot.py b/vall_e/plot.py index ee7ca01..c2231f3 100644 --- a/vall_e/plot.py +++ b/vall_e/plot.py @@ -99,10 +99,7 @@ def plot_entropies( entropies ): fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second ) """ - data = {} - - for key in entropies[0][0].keys(): - data[key] = [ e[0][key].item() if hasattr( e[0][key], "item" ) else e[0][key] for e in entropies ] + data = { key: [ e[0][key] for e in entropies ] for key in entropies[0][0].keys() } df = pd.DataFrame(data) df.plot() diff --git a/vall_e/samplers.py b/vall_e/samplers.py index 0a2563c..76f0cf4 100644 --- a/vall_e/samplers.py +++ b/vall_e/samplers.py @@ -252,19 +252,19 @@ def calculate_entropix_metrics( logits, attention_scores=None, dim=-1 ): interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)) return { - "logits_entropy": torch.mean(entropy), - "logits_varentropy": torch.mean(varentropy), - "attn_entropy": torch.mean(attn_entropy), - "attn_varentropy": torch.mean(attn_varentropy), - "agreement": torch.mean(agreement), - "interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)), + "logits_entropy": torch.mean(entropy).item(), + "logits_varentropy": torch.mean(varentropy).item(), + "attn_entropy": torch.mean(attn_entropy).item(), + "attn_varentropy": torch.mean(attn_varentropy).item(), + "agreement": torch.mean(agreement).item(), + "interaction_strength": torch.mean(torch.abs(attention_scores), dim=(1, 2, 3)).item(), "action": -1 } # to-do: play around with these values @dataclass() class EntropixSamplerConfig: - temp: float = 0.85 + temp: float = 0.666 top_p: float = 0.90 top_k: int = 27 min_p: float = 0.01 # was 0.03 # Turn this down to 0.01 to reduce the shoggoth @@ -315,6 +315,8 @@ class EntropixSamplerConfig: min_p_max: int = 0.5 Exponential = torch.distributions.exponential.Exponential(1.0) + +# Doing as close to the original sampling method just to reduce variance def _sample_entropix( logits, temperature=1.0, @@ -365,6 +367,10 @@ def sample_entropix( min_p=0.0, cfg=EntropixSamplerConfig(), ): + temperature = cfg.temp + top_k = cfg.top_k + top_p = cfg.top_p + metrics = calculate_entropix_metrics( logits, attentions ) ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"] @@ -403,10 +409,10 @@ def sample_entropix( logits_uncertainty = ent + vent attn_uncertainty = attn_ent + attn_vent - temperature *= float(1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement) - top_p = float(top_p * (1 + cfg.ada_top_p * attn_vent)) - top_k = int(round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement)))) - min_p = float(cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty)) + temperature *= 1 + cfg.ada_temp_logits * logits_uncertainty + cfg.ada_temp_attn * attn_uncertainty - cfg.ada_temp_agree * agreement + top_p = top_p * (1 + cfg.ada_top_p * attn_vent) + top_k = round(float(top_k * (1 + cfg.ada_top_k_int * interaction_strength - cfg.ada_top_k_agree * agreement))) + min_p = cfg.min_p * (1 - cfg.ada_min_p * logits_uncertainty) samples = [ _sample_entropix( logits.clone(), temperature, top_k, top_p, min_p, cfg=cfg ) for _ in range(cfg.n_adaptive_samples) ]