too brainlet to diagnose why low temp / greedy sampling is randomly unstable some of the time
This commit is contained in:
parent
8eb9a4056b
commit
910571ad34
|
@ -266,15 +266,17 @@ class AR_NAR(Base):
|
|||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||
|
||||
scores = [ 1.0 ] * sampling_beam_width
|
||||
entropies = []
|
||||
metrics = []
|
||||
|
||||
# ick
|
||||
low_temperature = False # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
|
||||
"""
|
||||
low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
|
||||
low_temperature_range = cfg.dataset.frames_per_second * 5
|
||||
|
||||
original_sampling_temperature = sampling_temperature
|
||||
original_sampling_repetition_penalty = sampling_repetition_penalty
|
||||
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
|
||||
"""
|
||||
|
||||
for i, sequence in enumerate( sequence_list ):
|
||||
# add <bos> to text for STT
|
||||
|
@ -300,11 +302,13 @@ class AR_NAR(Base):
|
|||
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
|
||||
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
|
||||
# to-do: tune these values, maybe have it factor based on confidence scores or something
|
||||
"""
|
||||
if low_temperature:
|
||||
enabled = n < low_temperature_range
|
||||
sampling_repetition_penalty = 1.125 if enabled else original_sampling_repetition_penalty
|
||||
sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
|
||||
sampling_temperature = original_sampling_temperature if enabled else 1.0
|
||||
sampling_repetition_penalty = 1.125 if enabled else 1.25
|
||||
#sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
|
||||
#sampling_temperature = original_sampling_temperature if enabled else 1.0
|
||||
"""
|
||||
|
||||
inputs = self.inputs(
|
||||
text_list=text_list,
|
||||
|
@ -351,7 +355,14 @@ class AR_NAR(Base):
|
|||
r = sampled[0]
|
||||
|
||||
if sampled.entropy:
|
||||
entropies.append( sampled.entropy )
|
||||
metrics.append( sampled.entropy )
|
||||
"""
|
||||
elif sampled.confidence:
|
||||
metrics.append( sampled.confidence )
|
||||
"""
|
||||
elif False:
|
||||
p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ]
|
||||
metrics.append( p )
|
||||
|
||||
if mirostat is not None:
|
||||
mirostat = sampled.scores
|
||||
|
@ -381,9 +392,9 @@ class AR_NAR(Base):
|
|||
if stopped.all().item():
|
||||
break
|
||||
|
||||
if entropies:
|
||||
from ..plot import plot_entropies
|
||||
plot_entropies( entropies )
|
||||
if metrics:
|
||||
from ..plot import plot_sample_metrics
|
||||
plot_sample_metrics( metrics )
|
||||
|
||||
# pick the best scoring candidate
|
||||
# desu this is always going to be candidate 0
|
||||
|
|
|
@ -93,13 +93,13 @@ def plot(paths, args):
|
|||
#bbox_to_anchor=(1.04, 0.5),
|
||||
)
|
||||
|
||||
def plot_entropies( entropies ):
|
||||
def plot_sample_metrics( metrics ):
|
||||
"""
|
||||
fig = plt.figure()
|
||||
fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second )
|
||||
fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
|
||||
"""
|
||||
|
||||
data = { key: [ e[0][key] for e in entropies ] for key in entropies[0][0].keys() }
|
||||
data = { key: [ e[0][key] for e in metrics ] for key in metrics[0][0].keys() }
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.plot()
|
||||
|
|
|
@ -134,8 +134,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
raise Exception("No YAML loaded.")
|
||||
|
||||
if kwargs.pop("dynamic-sampling", False):
|
||||
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
||||
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
|
||||
kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0
|
||||
kwargs['min-nar-temp'] = 0.0 # 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
|
||||
else:
|
||||
kwargs['min-ar-temp'] = -1
|
||||
kwargs['min-nar-temp'] = -1
|
||||
|
|
Loading…
Reference in New Issue
Block a user