reduced dynamic temperature threshold to > 1.0, as it seems to not quite be useful for audio LMs, sped up any sampling that touches logits by copying them to CPU first, as accessing tensors on the GPU is slow as balls)
This commit is contained in:
parent
29873e6ded
commit
26fbb92ec6
|
@ -15,6 +15,16 @@ from .models import get_models
|
|||
from .train import load_engines
|
||||
from .data import get_phone_symmap, _load_quants
|
||||
|
||||
use_deepspeed_inference = False
|
||||
# to-do: integrate this for windows
|
||||
"""
|
||||
try:
|
||||
import deepspeed
|
||||
use_deepspeed_inference = True
|
||||
except Exception as e:
|
||||
pass
|
||||
"""
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
|
||||
self.loading = True
|
||||
|
@ -41,7 +51,6 @@ class TTS():
|
|||
cfg.mode = "inferencing"
|
||||
cfg.device = device
|
||||
cfg.trainer.load_state_dict = True
|
||||
cfg.trainer.backend = "local"
|
||||
cfg.trainer.weight_dtype = dtype
|
||||
cfg.inference.weight_dtype = dtype
|
||||
|
||||
|
@ -85,12 +94,16 @@ class TTS():
|
|||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
if use_deepspeed_inference:
|
||||
self.ar = deepspeed.init_inference(model=self.ar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
|
||||
self.nar = deepspeed.init_inference(model=self.nar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
|
||||
else:
|
||||
self.ar.eval()
|
||||
self.nar.eval()
|
||||
|
||||
if self.symmap is None:
|
||||
self.symmap = get_phone_symmap()
|
||||
|
||||
self.ar.eval()
|
||||
self.nar.eval()
|
||||
|
||||
self.loading = False
|
||||
|
||||
def load_models( self ):
|
||||
|
|
|
@ -120,7 +120,7 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
|
|||
return logits
|
||||
|
||||
# credit to https://github.com/LostRuins/koboldcpp/pull/464
|
||||
def dynamic_temperature( logits, temperature=1.0, min_temperature = 1.0/256.0, k = 10, sigmoidCenterPoint = 0.5 ):
|
||||
def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.00390625, k = 10, sigmoidCenterPoint = 0.5 ):
|
||||
# loop over logits[:], as the NAR will have logits.shape[0] > 1
|
||||
for i in range(logits.shape[0]):
|
||||
maximum = 0.0
|
||||
|
@ -133,6 +133,11 @@ def dynamic_temperature( logits, temperature=1.0, min_temperature = 1.0/256.0, k
|
|||
|
||||
prob_max_token_before_temp = 1.0 / sum_exp
|
||||
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
|
||||
|
||||
#print( "sum_exp:", sum_exp )
|
||||
#print( "prob_max_token_before_temp:", prob_max_token_before_temp )
|
||||
#print( "dynamic temperature:", dynamic_temperature )
|
||||
|
||||
logits[i] /= dynamic_temperature
|
||||
|
||||
return logits
|
||||
|
@ -560,6 +565,9 @@ class Base(nn.Module):
|
|||
else:
|
||||
logits = [ logit[-1:] for logit in logits ]
|
||||
|
||||
devices = [ logit.device for logit in logits ]
|
||||
logits = [ logit.cpu() for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
|
||||
|
||||
|
@ -571,8 +579,8 @@ class Base(nn.Module):
|
|||
if top_k > 0 or top_p < 1.0:
|
||||
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
|
||||
|
||||
# our dynamic temperature threshold is considered to be anything over 1.25.
|
||||
if temperature > 1.25:
|
||||
# our dynamic temperature threshold is considered to be anything over 1.0.
|
||||
if temperature > 1.0:
|
||||
logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ]
|
||||
else:
|
||||
logits = [ logit / temperature for logit in logits ]
|
||||
|
@ -594,7 +602,7 @@ class Base(nn.Module):
|
|||
return res, scores
|
||||
|
||||
# and sample
|
||||
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||
return [ Categorical(logits=logit).sample().to(device) for logit, device in zip(logits, devices) ]
|
||||
|
||||
def example_usage():
|
||||
from ..config import cfg
|
||||
|
|
|
@ -190,8 +190,8 @@ with ui:
|
|||
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
|
||||
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=3.0, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=3.0, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
|
||||
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
|
||||
|
||||
with gr.Row():
|
||||
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")
|
||||
|
|
Loading…
Reference in New Issue
Block a user