disabled preparing of SpeechX tasks, added dynamic temperature testing (to-do: test it, credited in the function)
This commit is contained in:
parent
2deb995cc9
commit
27483e56f0
|
@ -291,8 +291,8 @@ class Dataset(_Dataset):
|
|||
|
||||
# shuffle it up a bit
|
||||
prom_length = 0
|
||||
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||
#trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
|
@ -336,15 +336,19 @@ class Dataset(_Dataset):
|
|||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||
resps = _load_quants(path)
|
||||
|
||||
task = "tts"
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75)
|
||||
proms = self.sample_prompts(spkr_name, ignore=path) if random.random() < cfg.dataset.random_utterance else resps
|
||||
|
||||
# Disabled until I swap over to a better method
|
||||
"""
|
||||
task = random.choice(self.tasks)
|
||||
|
||||
# ensure a speaker has at least four utterances
|
||||
# default to tts if not
|
||||
if len(set(self.paths_by_spkr_name[spkr_name]) - {path}) < 4:
|
||||
task = "tts"
|
||||
|
||||
noise_scale = 0.25
|
||||
# text-to-speech
|
||||
if task == "tts" or task == "tts-c":
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75)
|
||||
# demote if the target is too short
|
||||
|
@ -480,6 +484,7 @@ class Dataset(_Dataset):
|
|||
)
|
||||
else:
|
||||
raise Exception(f'Undefined task: {task}')
|
||||
"""
|
||||
|
||||
"""
|
||||
# emulate SVC
|
||||
|
|
|
@ -336,7 +336,7 @@ def example_usage():
|
|||
|
||||
tqdm.write(f"{stats}")
|
||||
|
||||
sample("init", 75)
|
||||
sample("init", 5)
|
||||
train()
|
||||
sample("final")
|
||||
|
||||
|
|
|
@ -119,6 +119,26 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
|
|||
|
||||
return logits
|
||||
|
||||
# credit to https://github.com/LostRuins/koboldcpp/pull/464
|
||||
def dynamic_temperature( logits, temperature=1.0, min_temperature = 1.0/256.0, k = 10, sigmoidCenterPoint = 0.5 ):
|
||||
# loop over logits[:], as the NAR will have logits.shape[0] > 1
|
||||
for i in range(logits.shape[0]):
|
||||
maximum = 0.0
|
||||
for logit in logits[i]:
|
||||
maximum = max( maximum, logit )
|
||||
|
||||
sum_exp = 0.0
|
||||
for logit in logits[i]:
|
||||
sum_exp += math.exp( logit - maximum )
|
||||
|
||||
prob_max_token_before_temp = 1.0 / sum_exp
|
||||
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
|
||||
logits[i] /= dynamic_temperature
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
|
||||
# picks the top K tokens amongst a batch of logits
|
||||
# logits: [Tensor] list of logits
|
||||
# candidates: [(batch, token)] list, where batch indicates the index of the logits the given token is from
|
||||
|
@ -547,13 +567,16 @@ class Base(nn.Module):
|
|||
if quant_levels is None and self.causal:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||
|
||||
# scale our logits by the temp
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
||||
# our dynamic temperature threshold is considered to be anything over 1.25.
|
||||
if temperature > 1.25:
|
||||
logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ]
|
||||
else:
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
||||
# do mirostat sampling
|
||||
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work
|
||||
if mirostat is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user