diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 7342766..b29056e 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -266,15 +266,17 @@ class AR_NAR(Base): ] * batch_size if sampling_mirostat_tau > 0.0 else None scores = [ 1.0 ] * sampling_beam_width - entropies = [] + metrics = [] # ick - low_temperature = False # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 # + """ + low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 # low_temperature_range = cfg.dataset.frames_per_second * 5 original_sampling_temperature = sampling_temperature original_sampling_repetition_penalty = sampling_repetition_penalty original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay + """ for i, sequence in enumerate( sequence_list ): # add to text for STT @@ -300,11 +302,13 @@ class AR_NAR(Base): # naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio # however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled # to-do: tune these values, maybe have it factor based on confidence scores or something + """ if low_temperature: enabled = n < low_temperature_range - sampling_repetition_penalty = 1.125 if enabled else original_sampling_repetition_penalty - sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay - sampling_temperature = original_sampling_temperature if enabled else 1.0 + sampling_repetition_penalty = 1.125 if enabled else 1.25 + #sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay + #sampling_temperature = original_sampling_temperature if enabled else 1.0 + """ inputs = self.inputs( text_list=text_list, @@ -351,7 +355,14 @@ class AR_NAR(Base): r = sampled[0] if sampled.entropy: - entropies.append( sampled.entropy ) + metrics.append( sampled.entropy ) + """ + elif sampled.confidence: + metrics.append( sampled.confidence ) + """ + elif False: + p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ] + metrics.append( p ) if mirostat is not None: mirostat = sampled.scores @@ -381,9 +392,9 @@ class AR_NAR(Base): if stopped.all().item(): break - if entropies: - from ..plot import plot_entropies - plot_entropies( entropies ) + if metrics: + from ..plot import plot_sample_metrics + plot_sample_metrics( metrics ) # pick the best scoring candidate # desu this is always going to be candidate 0 diff --git a/vall_e/plot.py b/vall_e/plot.py index c2231f3..d5111ae 100644 --- a/vall_e/plot.py +++ b/vall_e/plot.py @@ -93,13 +93,13 @@ def plot(paths, args): #bbox_to_anchor=(1.04, 0.5), ) -def plot_entropies( entropies ): +def plot_sample_metrics( metrics ): """ fig = plt.figure() - fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second ) + fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second ) """ - data = { key: [ e[0][key] for e in entropies ] for key in entropies[0][0].keys() } + data = { key: [ e[0][key] for e in metrics ] for key in metrics[0][0].keys() } df = pd.DataFrame(data) df.plot() diff --git a/vall_e/webui.py b/vall_e/webui.py index 0bb36e0..6797911 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -134,8 +134,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): raise Exception("No YAML loaded.") if kwargs.pop("dynamic-sampling", False): - kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0 - kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR + kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0 + kwargs['min-nar-temp'] = 0.0 # 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR else: kwargs['min-ar-temp'] = -1 kwargs['min-nar-temp'] = -1