sped up inferencing by not doing .tolist() for rep pen / length pen (and a bug fix in the web UI from prev commit)

This commit is contained in:
mrq 2024-10-04 22:18:20 -05:00
parent 4a8e3ccf06
commit a507b769a1
5 changed files with 36 additions and 25 deletions

View File

@ -17,11 +17,14 @@ import math
from einops import rearrange
from torch import Tensor
from tqdm import trange
from time import perf_counter
import logging
_logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding
from ..utils import get_devices, setup_logging, timer
from .lora import enable_lora
@ -301,7 +304,7 @@ class AR_NAR(Base):
r = super().sample(
logits=logits,
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,

View File

@ -1447,7 +1447,7 @@ class Base(nn.Module):
def sample(
self,
logits: list[Tensor], # logit scores
prev_list: list[Tensor], # previous tokens
prev_list: list[Tensor] | None = None, # previous tokens
quant_levels: int | list[int] | Tensor | None = None,
# base sampling parameters
temperature: float = 1.0,
@ -1494,11 +1494,12 @@ class Base(nn.Module):
return [ logit.argmax(dim=1) for logit in logits ]
# perform repetition penalizing
if "len" not in self.capabilities:
if "len" not in self.capabilities and prev_list is not None:
# to-do: figure out a faster way to handle tolist()
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
# (AR) perform length penalizing
if quant_levels is None and self.causal:
if quant_levels is None and self.causal and prev_list is not None:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# perform top_k/top_p filtering of our logits

View File

@ -10,5 +10,6 @@ from .utils import (
set_seed,
passes_policy,
get_devices,
truncate_json
truncate_json,
timer
)

View File

@ -25,8 +25,29 @@ from torch import Tensor, nn
from tqdm.auto import tqdm
from typing import Callable, TypeVar, overload
from contextlib import contextmanager
from time import perf_counter
from datetime import datetime
T = TypeVar("T")
class timer:
def __init__(self, msg="Elapsed time:", callback=None):
self.msg = msg
self.callback = callback
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.9f}s'
if self.callback:
self.callback(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
def truncate_json( str ):
def fun( match ):

View File

@ -9,17 +9,14 @@ import functools
import torch
import numpy as np
from datetime import datetime
import torchaudio
import gradio as gr
from time import perf_counter
from pathlib import Path
from .inference import TTS, cfg
from .train import train
from .utils import get_devices, setup_logging
from .utils import get_devices, setup_logging, timer
from .utils.io import json_read, json_stringify
from .emb.qnt import decode_to_wave
from .data import get_lang_symmap
@ -52,20 +49,6 @@ def gradio_wrapper(inputs):
return wrapped_function
return decorated
class timer:
def __init__(self, msg="Elapsed time:"):
self.msg = msg
def __enter__(self):
self.start = perf_counter()
return self
def __exit__(self, type, value, traceback):
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
gr.Info(msg)
print(f'[{datetime.now().isoformat()}] {msg}')
# returns a list of models, assuming the models are placed under ./training/ or ./models/
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
yamls = []
@ -164,7 +147,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
parser.add_argument("--references", type=str, default=kwargs["reference"])
parser.add_argument("--language", type=str, default=kwargs["language"])
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
#parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
parser.add_argument("--input-prompt-prefix", action='store_true')
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
@ -194,7 +178,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
tts = init_tts()
gr.Info("Inferencing...")
with timer("Inferenced in") as t:
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
wav, sr = tts.inference(
text=args.text,
language=args.language,