too brainlet to diagnose why low temp / greedy sampling is randomly unstable some of the time
This commit is contained in:
parent
8eb9a4056b
commit
910571ad34
|
@ -266,15 +266,17 @@ class AR_NAR(Base):
|
||||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||||
|
|
||||||
scores = [ 1.0 ] * sampling_beam_width
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
entropies = []
|
metrics = []
|
||||||
|
|
||||||
# ick
|
# ick
|
||||||
low_temperature = False # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
|
"""
|
||||||
|
low_temperature = False # sampling_temperature < 0.6 # sampling_repetition_penalty == 1.0 and sampling_temperature == 0.0 #
|
||||||
low_temperature_range = cfg.dataset.frames_per_second * 5
|
low_temperature_range = cfg.dataset.frames_per_second * 5
|
||||||
|
|
||||||
original_sampling_temperature = sampling_temperature
|
original_sampling_temperature = sampling_temperature
|
||||||
original_sampling_repetition_penalty = sampling_repetition_penalty
|
original_sampling_repetition_penalty = sampling_repetition_penalty
|
||||||
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
|
original_sampling_repetition_penalty_decay = sampling_repetition_penalty_decay
|
||||||
|
"""
|
||||||
|
|
||||||
for i, sequence in enumerate( sequence_list ):
|
for i, sequence in enumerate( sequence_list ):
|
||||||
# add <bos> to text for STT
|
# add <bos> to text for STT
|
||||||
|
@ -300,11 +302,13 @@ class AR_NAR(Base):
|
||||||
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
|
# naturally, rep pen wrangles this initial burst of noise, but naively relying on rep_pen is no good, as it fails after ~6 seconds of audio
|
||||||
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
|
# however, switching to a default sampling temperature with "clean greedy sampled codes" will make the rest of sequence sound as if it were greedy sampled
|
||||||
# to-do: tune these values, maybe have it factor based on confidence scores or something
|
# to-do: tune these values, maybe have it factor based on confidence scores or something
|
||||||
|
"""
|
||||||
if low_temperature:
|
if low_temperature:
|
||||||
enabled = n < low_temperature_range
|
enabled = n < low_temperature_range
|
||||||
sampling_repetition_penalty = 1.125 if enabled else original_sampling_repetition_penalty
|
sampling_repetition_penalty = 1.125 if enabled else 1.25
|
||||||
sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
|
#sampling_repetition_penalty_decay = 0.0 if enabled else original_sampling_repetition_penalty_decay
|
||||||
sampling_temperature = original_sampling_temperature if enabled else 1.0
|
#sampling_temperature = original_sampling_temperature if enabled else 1.0
|
||||||
|
"""
|
||||||
|
|
||||||
inputs = self.inputs(
|
inputs = self.inputs(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
|
@ -351,7 +355,14 @@ class AR_NAR(Base):
|
||||||
r = sampled[0]
|
r = sampled[0]
|
||||||
|
|
||||||
if sampled.entropy:
|
if sampled.entropy:
|
||||||
entropies.append( sampled.entropy )
|
metrics.append( sampled.entropy )
|
||||||
|
"""
|
||||||
|
elif sampled.confidence:
|
||||||
|
metrics.append( sampled.confidence )
|
||||||
|
"""
|
||||||
|
elif False:
|
||||||
|
p = [ { "p": torch.nn.functional.softmax(logit[-1, :].cpu(), dim=0)[token.item()].item() } for logit, token in zip(logits, r) ]
|
||||||
|
metrics.append( p )
|
||||||
|
|
||||||
if mirostat is not None:
|
if mirostat is not None:
|
||||||
mirostat = sampled.scores
|
mirostat = sampled.scores
|
||||||
|
@ -381,9 +392,9 @@ class AR_NAR(Base):
|
||||||
if stopped.all().item():
|
if stopped.all().item():
|
||||||
break
|
break
|
||||||
|
|
||||||
if entropies:
|
if metrics:
|
||||||
from ..plot import plot_entropies
|
from ..plot import plot_sample_metrics
|
||||||
plot_entropies( entropies )
|
plot_sample_metrics( metrics )
|
||||||
|
|
||||||
# pick the best scoring candidate
|
# pick the best scoring candidate
|
||||||
# desu this is always going to be candidate 0
|
# desu this is always going to be candidate 0
|
||||||
|
|
|
@ -93,13 +93,13 @@ def plot(paths, args):
|
||||||
#bbox_to_anchor=(1.04, 0.5),
|
#bbox_to_anchor=(1.04, 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
def plot_entropies( entropies ):
|
def plot_sample_metrics( metrics ):
|
||||||
"""
|
"""
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
fig.set_figwidth( 16 * len(entropies) // cfg.dataset.frames_per_second )
|
fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data = { key: [ e[0][key] for e in entropies ] for key in entropies[0][0].keys() }
|
data = { key: [ e[0][key] for e in metrics ] for key in metrics[0][0].keys() }
|
||||||
|
|
||||||
df = pd.DataFrame(data)
|
df = pd.DataFrame(data)
|
||||||
df.plot()
|
df.plot()
|
||||||
|
|
|
@ -134,8 +134,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
||||||
raise Exception("No YAML loaded.")
|
raise Exception("No YAML loaded.")
|
||||||
|
|
||||||
if kwargs.pop("dynamic-sampling", False):
|
if kwargs.pop("dynamic-sampling", False):
|
||||||
kwargs['min-ar-temp'] = 0.85 if kwargs['ar-temp'] > 0.85 else 0.0
|
kwargs['min-ar-temp'] = 0.01 if kwargs['ar-temp'] > 0.01 else 0.0
|
||||||
kwargs['min-nar-temp'] = 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
|
kwargs['min-nar-temp'] = 0.0 # 0.85 if kwargs['nar-temp'] > 0.85 else 0.0 # should probably disable it for the NAR
|
||||||
else:
|
else:
|
||||||
kwargs['min-ar-temp'] = -1
|
kwargs['min-ar-temp'] = -1
|
||||||
kwargs['min-nar-temp'] = -1
|
kwargs['min-nar-temp'] = -1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user