reduced dynamic temperature threshold to > 1.0, as it seems to not quite be useful for audio LMs, sped up any sampling that touches logits by copying them to CPU first, as accessing tensors on the GPU is slow as balls)
This commit is contained in:
parent
29873e6ded
commit
26fbb92ec6
|
@ -15,6 +15,16 @@ from .models import get_models
|
||||||
from .train import load_engines
|
from .train import load_engines
|
||||||
from .data import get_phone_symmap, _load_quants
|
from .data import get_phone_symmap, _load_quants
|
||||||
|
|
||||||
|
use_deepspeed_inference = False
|
||||||
|
# to-do: integrate this for windows
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import deepspeed
|
||||||
|
use_deepspeed_inference = True
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
class TTS():
|
class TTS():
|
||||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
|
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
|
||||||
self.loading = True
|
self.loading = True
|
||||||
|
@ -41,7 +51,6 @@ class TTS():
|
||||||
cfg.mode = "inferencing"
|
cfg.mode = "inferencing"
|
||||||
cfg.device = device
|
cfg.device = device
|
||||||
cfg.trainer.load_state_dict = True
|
cfg.trainer.load_state_dict = True
|
||||||
cfg.trainer.backend = "local"
|
|
||||||
cfg.trainer.weight_dtype = dtype
|
cfg.trainer.weight_dtype = dtype
|
||||||
cfg.inference.weight_dtype = dtype
|
cfg.inference.weight_dtype = dtype
|
||||||
|
|
||||||
|
@ -85,12 +94,16 @@ class TTS():
|
||||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
|
|
||||||
|
if use_deepspeed_inference:
|
||||||
|
self.ar = deepspeed.init_inference(model=self.ar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
|
||||||
|
self.nar = deepspeed.init_inference(model=self.nar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
|
||||||
|
else:
|
||||||
|
self.ar.eval()
|
||||||
|
self.nar.eval()
|
||||||
|
|
||||||
if self.symmap is None:
|
if self.symmap is None:
|
||||||
self.symmap = get_phone_symmap()
|
self.symmap = get_phone_symmap()
|
||||||
|
|
||||||
self.ar.eval()
|
|
||||||
self.nar.eval()
|
|
||||||
|
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
|
||||||
def load_models( self ):
|
def load_models( self ):
|
||||||
|
|
|
@ -120,7 +120,7 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
# credit to https://github.com/LostRuins/koboldcpp/pull/464
|
# credit to https://github.com/LostRuins/koboldcpp/pull/464
|
||||||
def dynamic_temperature( logits, temperature=1.0, min_temperature = 1.0/256.0, k = 10, sigmoidCenterPoint = 0.5 ):
|
def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.00390625, k = 10, sigmoidCenterPoint = 0.5 ):
|
||||||
# loop over logits[:], as the NAR will have logits.shape[0] > 1
|
# loop over logits[:], as the NAR will have logits.shape[0] > 1
|
||||||
for i in range(logits.shape[0]):
|
for i in range(logits.shape[0]):
|
||||||
maximum = 0.0
|
maximum = 0.0
|
||||||
|
@ -133,6 +133,11 @@ def dynamic_temperature( logits, temperature=1.0, min_temperature = 1.0/256.0, k
|
||||||
|
|
||||||
prob_max_token_before_temp = 1.0 / sum_exp
|
prob_max_token_before_temp = 1.0 / sum_exp
|
||||||
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
|
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
|
||||||
|
|
||||||
|
#print( "sum_exp:", sum_exp )
|
||||||
|
#print( "prob_max_token_before_temp:", prob_max_token_before_temp )
|
||||||
|
#print( "dynamic temperature:", dynamic_temperature )
|
||||||
|
|
||||||
logits[i] /= dynamic_temperature
|
logits[i] /= dynamic_temperature
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
@ -560,6 +565,9 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
logits = [ logit[-1:] for logit in logits ]
|
logits = [ logit[-1:] for logit in logits ]
|
||||||
|
|
||||||
|
devices = [ logit.device for logit in logits ]
|
||||||
|
logits = [ logit.cpu() for logit in logits ]
|
||||||
|
|
||||||
# perform repetition penalizing
|
# perform repetition penalizing
|
||||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||||
|
|
||||||
|
@ -571,8 +579,8 @@ class Base(nn.Module):
|
||||||
if top_k > 0 or top_p < 1.0:
|
if top_k > 0 or top_p < 1.0:
|
||||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||||
|
|
||||||
# our dynamic temperature threshold is considered to be anything over 1.25.
|
# our dynamic temperature threshold is considered to be anything over 1.0.
|
||||||
if temperature > 1.25:
|
if temperature > 1.0:
|
||||||
logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ]
|
logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ]
|
||||||
else:
|
else:
|
||||||
logits = [ logit / temperature for logit in logits ]
|
logits = [ logit / temperature for logit in logits ]
|
||||||
|
@ -594,7 +602,7 @@ class Base(nn.Module):
|
||||||
return res, scores
|
return res, scores
|
||||||
|
|
||||||
# and sample
|
# and sample
|
||||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
return [ Categorical(logits=logit).sample().to(device) for logit, device in zip(logits, devices) ]
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
|
@ -190,8 +190,8 @@ with ui:
|
||||||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=3.0, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=3.0, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user