too brainlet to diagnose why low temp / greedy sampling is randomly unstable some of the time

This commit is contained in:
mrq 2024-10-22 20:13:54 -05:00
parent 8eb9a4056b
commit 910571ad34
3 changed files with 25 additions and 14 deletions

View File

@ -266,15 +266,17 @@ class AR_NAR(Base):
] * batch_size if sampling_mirostat_tau > 0.0 else None
scores = [ 1.0 ] * sampling_beam_width
entropies = []
metrics = []
# ick
low_temperature = False # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
"""
low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
low_temperature_range = cfg.dataset.frames_per_second * 5
original_sampling_temperature = sampling_temperature
original_sampling_repetition_penalty = sampling_repetition_penalty
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
"""
for i, sequence in enumerate( sequence_list ):
# add <bos> to text for STT
@ -300,11 +302,13 @@ class AR_NAR(Base):
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
# to-do: tune these values, maybe have it factor based on confidence scores or something
"""
if low_temperature:
enabled = n < low_temperature_range
sampling_repetition_penalty = 1.125 if enabled else original_sampling_repetition_penalty
sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
sampling_temperature = original_sampling_temperature if enabled else 1.0
sampling_repetition_penalty = 1.125 if enabled else 1.25
#sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
#sampling_temperature = original_sampling_temperature if enabled else 1.0
"""
inputs = self.inputs(
text_list=text_list,
@ -351,7 +355,14 @@ class AR_NAR(Base):
r = sampled[0]
if sampled.entropy:
entropies.append( sampled.entropy )
metrics.append( sampled.entropy )
"""
elif sampled.confidence:
metrics.append( sampled.confidence )
"""
elif False:
p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ]
metrics.append( p )
if mirostat is not None:
mirostat = sampled.scores
@ -381,9 +392,9 @@ class AR_NAR(Base):
if stopped.all().item():
break
if entropies:
from ..plot import plot_entropies
plot_entropies( entropies )
if metrics:
from ..plot import plot_sample_metrics
plot_sample_metrics( metrics )
# pick the best scoring candidate
# desu this is always going to be candidate 0

View File

@ -93,13 +93,13 @@ def plot(paths, args):
#bbox_to_anchor=(1.04, 0.5),
)
def plot_entropies( entropies ):
def plot_sample_metrics( metrics ):
"""
fig = plt.figure()
fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second )
fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
"""
data = { key: [ e[0][key] for e in entropies ] for key in entropies[0][0].keys() }
data = { key: [ e[0][key] for e in metrics ] for key in metrics[0][0].keys() }
df = pd.DataFrame(data)
df.plot()

View File

@ -134,8 +134,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
raise Exception("No YAML loaded.")
if kwargs.pop("dynamic-sampling", False):
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0
kwargs['min-nar-temp'] = 0.0 # 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
else:
kwargs['min-ar-temp'] = -1
kwargs['min-nar-temp'] = -1