reduced dynamic temperature threshold to > 1.0, as it seems to not quite be useful for audio LMs, sped up any sampling that touches logits by copying them to CPU first, as accessing tensors on the GPU is slow as balls)

This commit is contained in:
mrq 2023-10-09 14:46:17 -05:00
parent 29873e6ded
commit 26fbb92ec6
3 changed files with 31 additions and 10 deletions

View File

@ -15,6 +15,16 @@ from .models import get_models
from .train import load_engines from .train import load_engines
from .data import get_phone_symmap, _load_quants from .data import get_phone_symmap, _load_quants
use_deepspeed_inference = False
# to-do: integrate this for windows
"""
try:
import deepspeed
use_deepspeed_inference = True
except Exception as e:
pass
"""
class TTS(): class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ): def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
self.loading = True self.loading = True
@ -41,7 +51,6 @@ class TTS():
cfg.mode = "inferencing" cfg.mode = "inferencing"
cfg.device = device cfg.device = device
cfg.trainer.load_state_dict = True cfg.trainer.load_state_dict = True
cfg.trainer.backend = "local"
cfg.trainer.weight_dtype = dtype cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype cfg.inference.weight_dtype = dtype
@ -85,12 +94,16 @@ class TTS():
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
if use_deepspeed_inference:
self.ar = deepspeed.init_inference(model=self.ar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
self.nar = deepspeed.init_inference(model=self.nar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
else:
self.ar.eval()
self.nar.eval()
if self.symmap is None: if self.symmap is None:
self.symmap = get_phone_symmap() self.symmap = get_phone_symmap()
self.ar.eval()
self.nar.eval()
self.loading = False self.loading = False
def load_models( self ): def load_models( self ):

View File

@ -120,7 +120,7 @@ def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"
return logits return logits
# credit to https://github.com/LostRuins/koboldcpp/pull/464 # credit to https://github.com/LostRuins/koboldcpp/pull/464
def dynamic_temperature( logits, temperature=1.0, min_temperature = 1.0/256.0, k = 10, sigmoidCenterPoint = 0.5 ): def dynamic_temperature( logits, temperature=1.0, min_temperature = 0.00390625, k = 10, sigmoidCenterPoint = 0.5 ):
# loop over logits[:], as the NAR will have logits.shape[0] > 1 # loop over logits[:], as the NAR will have logits.shape[0] > 1
for i in range(logits.shape[0]): for i in range(logits.shape[0]):
maximum = 0.0 maximum = 0.0
@ -133,6 +133,11 @@ def dynamic_temperature( logits, temperature=1.0, min_temperature = 1.0/256.0, k
prob_max_token_before_temp = 1.0 / sum_exp prob_max_token_before_temp = 1.0 / sum_exp
dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint))) dynamic_temperature = temperature - (temperature - min_temperature) / (1 + math.exp(-k * (prob_max_token_before_temp - sigmoidCenterPoint)))
#print( "sum_exp:", sum_exp )
#print( "prob_max_token_before_temp:", prob_max_token_before_temp )
#print( "dynamic temperature:", dynamic_temperature )
logits[i] /= dynamic_temperature logits[i] /= dynamic_temperature
return logits return logits
@ -560,6 +565,9 @@ class Base(nn.Module):
else: else:
logits = [ logit[-1:] for logit in logits ] logits = [ logit[-1:] for logit in logits ]
devices = [ logit.device for logit in logits ]
logits = [ logit.cpu() for logit in logits ]
# perform repetition penalizing # perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] logits = [ reptition_penalize(logit, previous=resps[:, -1], factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ]
@ -571,8 +579,8 @@ class Base(nn.Module):
if top_k > 0 or top_p < 1.0: if top_k > 0 or top_p < 1.0:
logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ] logits = [ top_k_top_p_filtering(logit, top_k=top_k, top_p=top_p) for logit in logits ]
# our dynamic temperature threshold is considered to be anything over 1.25. # our dynamic temperature threshold is considered to be anything over 1.0.
if temperature > 1.25: if temperature > 1.0:
logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ] logits = [ dynamic_temperature(logit, temperature=temperature) for logit in logits ]
else: else:
logits = [ logit / temperature for logit in logits ] logits = [ logit / temperature for logit in logits ]
@ -594,7 +602,7 @@ class Base(nn.Module):
return res, scores return res, scores
# and sample # and sample
return [ Categorical(logits=logit).sample() for logit in logits ] return [ Categorical(logits=logit).sample().to(device) for logit, device in zip(logits, devices) ]
def example_usage(): def example_usage():
from ..config import cfg from ..config import cfg

View File

@ -190,8 +190,8 @@ with ui:
layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.") layout["inference"]["inputs"]["max-nar-levels"] = gr.Slider(value=7, minimum=0, maximum=7, step=1, label="Max NAR Levels", info="Limits how many steps to perform in the NAR pass.")
layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.") layout["inference"]["inputs"]["input-prompt-length"] = gr.Slider(value=3.0, minimum=0.0, maximum=12.0, step=0.05, label="Input Prompt Trim Length", info="Trims the input prompt down to X seconds. Set 0 to disable.")
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=3.0, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.") layout["inference"]["inputs"]["ar-temp"] = gr.Slider(value=0.95, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (AR)", info="Modifies the randomness from the samples in the AR.")
layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=3.0, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.") layout["inference"]["inputs"]["nar-temp"] = gr.Slider(value=0.25, minimum=0.0, maximum=1.5, step=0.05, label="Temperature (NAR)", info="Modifies the randomness from the samples in the NAR.")
with gr.Row(): with gr.Row():
layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.") layout["inference"]["inputs"]["top-p"] = gr.Slider(value=1.0, minimum=0.0, maximum=1.0, step=0.05, label="Top P", info=r"Limits the samples that are outside the top P% of probabilities.")