sped up inferencing by not doing .tolist() for rep pen / length pen (and a bug fix in the web UI from prev commit)

This commit is contained in:
mrq 2024-10-04 22:18:20 -05:00
parent 4a8e3ccf06
commit a507b769a1
5 changed files with 36 additions and 25 deletions

View File

@ -17,11 +17,14 @@ import math
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from tqdm import trange from tqdm import trange
from time import perf_counter
import logging import logging
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding from ..emb.qnt import trim, encode_as_embedding
from ..utils import get_devices, setup_logging, timer
from .lora import enable_lora from .lora import enable_lora
@ -301,7 +304,7 @@ class AR_NAR(Base):
r = super().sample( r = super().sample(
logits=logits, logits=logits,
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
temperature=sampling_temperature, temperature=sampling_temperature,
min_temperature=sampling_min_temperature, min_temperature=sampling_min_temperature,

View File

@ -1447,7 +1447,7 @@ class Base(nn.Module):
def sample( def sample(
self, self,
logits: list[Tensor], # logit scores logits: list[Tensor], # logit scores
prev_list: list[Tensor], # previous tokens prev_list: list[Tensor] | None = None, # previous tokens
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
# base sampling parameters # base sampling parameters
temperature: float = 1.0, temperature: float = 1.0,
@ -1494,11 +1494,12 @@ class Base(nn.Module):
return [ logit.argmax(dim=1) for logit in logits ] return [ logit.argmax(dim=1) for logit in logits ]
# perform repetition penalizing # perform repetition penalizing
if "len" not in self.capabilities: if "len" not in self.capabilities and prev_list is not None:
# to-do: figure out a faster way to handle tolist()
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing # (AR) perform length penalizing
if quant_levels is None and self.causal: if quant_levels is None and self.causal and prev_list is not None:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# perform top_k/top_p filtering of our logits # perform top_k/top_p filtering of our logits

View File

@ -10,5 +10,6 @@ from .utils import (
set_seed, set_seed,
passes_policy, passes_policy,
get_devices, get_devices,
truncate_json truncate_json,
timer
) )

View File

@ -25,8 +25,29 @@ from torch import Tensor, nn
from tqdm.auto import tqdm from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload from typing import Callable, TypeVar, overload
from contextlib import contextmanager from contextlib import contextmanager
from time import perf_counter
from datetime import datetime
T = TypeVar("T") T = TypeVar("T")
class timer:
def __init__(self, msg="Elapsed time:", callback=None):
self.msg = msg
self.callback = callback
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.9f}s'
if self.callback:
self.callback(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
def truncate_json( str ): def truncate_json( str ):
def fun( match ): def fun( match ):

View File

@ -9,17 +9,14 @@ import functools
import torch import torch
import numpy as np import numpy as np
from datetime import datetime
import torchaudio import torchaudio
import gradio as gr import gradio as gr
from time import perf_counter
from pathlib import Path from pathlib import Path
from .inference import TTS, cfg from .inference import TTS, cfg
from .train import train from .train import train
from .utils import get_devices, setup_logging from .utils import get_devices, setup_logging, timer
from .utils.io import json_read, json_stringify from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave from .emb.qnt import decode_to_wave
from .data import get_lang_symmap from .data import get_lang_symmap
@ -52,20 +49,6 @@ def gradio_wrapper(inputs):
return wrapped_function return wrapped_function
return decorated return decorated
class timer:
def __init__(self, msg="Elapsed time:"):
self.msg = msg
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
gr.Info(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
# returns a list of models, assuming the models are placed under ./training/ or ./models/ # returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ): def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
yamls = [] yamls = []
@ -164,7 +147,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"]) #parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
parser.add_argument("--input-prompt-prefix", action='store_true')
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"]) parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
@ -194,7 +178,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
tts = init_tts() tts = init_tts()
gr.Info("Inferencing...") gr.Info("Inferencing...")
with timer("Inferenced in") as t:
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
wav, sr = tts.inference( wav, sr = tts.inference(
text=args.text, text=args.text,
language=args.language, language=args.language,