sped up inferencing by not doing .tolist() for rep pen / length pen (and a bug fix in the web UI from prev commit)
This commit is contained in:
parent
4a8e3ccf06
commit
a507b769a1
|
@ -17,11 +17,14 @@ import math
|
|||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
from time import perf_counter
|
||||
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from ..emb.qnt import trim, encode_as_embedding
|
||||
from ..utils import get_devices, setup_logging, timer
|
||||
|
||||
from .lora import enable_lora
|
||||
|
||||
|
@ -301,7 +304,7 @@ class AR_NAR(Base):
|
|||
|
||||
r = super().sample(
|
||||
logits=logits,
|
||||
prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||
prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ],
|
||||
|
||||
temperature=sampling_temperature,
|
||||
min_temperature=sampling_min_temperature,
|
||||
|
|
|
@ -1447,7 +1447,7 @@ class Base(nn.Module):
|
|||
def sample(
|
||||
self,
|
||||
logits: list[Tensor], # logit scores
|
||||
prev_list: list[Tensor], # previous tokens
|
||||
prev_list: list[Tensor] | None = None, # previous tokens
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
# base sampling parameters
|
||||
temperature: float = 1.0,
|
||||
|
@ -1494,11 +1494,12 @@ class Base(nn.Module):
|
|||
return [ logit.argmax(dim=1) for logit in logits ]
|
||||
|
||||
# perform repetition penalizing
|
||||
if "len" not in self.capabilities:
|
||||
if "len" not in self.capabilities and prev_list is not None:
|
||||
# to-do: figure out a faster way to handle tolist()
|
||||
logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ]
|
||||
|
||||
# (AR) perform length penalizing
|
||||
if quant_levels is None and self.causal:
|
||||
if quant_levels is None and self.causal and prev_list is not None:
|
||||
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
|
||||
|
||||
# perform top_k/top_p filtering of our logits
|
||||
|
|
|
@ -10,5 +10,6 @@ from .utils import (
|
|||
set_seed,
|
||||
passes_policy,
|
||||
get_devices,
|
||||
truncate_json
|
||||
truncate_json,
|
||||
timer
|
||||
)
|
|
@ -25,8 +25,29 @@ from torch import Tensor, nn
|
|||
from tqdm.auto import tqdm
|
||||
from typing import Callable, TypeVar, overload
|
||||
from contextlib import contextmanager
|
||||
|
||||
from time import perf_counter
|
||||
from datetime import datetime
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class timer:
|
||||
def __init__(self, msg="Elapsed time:", callback=None):
|
||||
self.msg = msg
|
||||
self.callback = callback
|
||||
|
||||
def __enter__(self):
|
||||
self.start = perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
msg = f'{self.msg} {(perf_counter() - self.start):.9f}s'
|
||||
|
||||
if self.callback:
|
||||
self.callback(msg)
|
||||
|
||||
print(f'[{datetime.now().isoformat()}] {msg}')
|
||||
|
||||
def truncate_json( str ):
|
||||
|
||||
def fun( match ):
|
||||
|
|
|
@ -9,17 +9,14 @@ import functools
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import torchaudio
|
||||
import gradio as gr
|
||||
|
||||
from time import perf_counter
|
||||
from pathlib import Path
|
||||
|
||||
from .inference import TTS, cfg
|
||||
from .train import train
|
||||
from .utils import get_devices, setup_logging
|
||||
from .utils import get_devices, setup_logging, timer
|
||||
from .utils.io import json_read, json_stringify
|
||||
from .emb.qnt import decode_to_wave
|
||||
from .data import get_lang_symmap
|
||||
|
@ -52,20 +49,6 @@ def gradio_wrapper(inputs):
|
|||
return wrapped_function
|
||||
return decorated
|
||||
|
||||
class timer:
|
||||
def __init__(self, msg="Elapsed time:"):
|
||||
self.msg = msg
|
||||
|
||||
def __enter__(self):
|
||||
self.start = perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
msg = f'{self.msg} {(perf_counter() - self.start):.3f}s'
|
||||
|
||||
gr.Info(msg)
|
||||
print(f'[{datetime.now().isoformat()}] {msg}')
|
||||
|
||||
# returns a list of models, assuming the models are placed under ./training/ or ./models/
|
||||
def get_model_paths( paths=[Path("./training/"), Path("./models/")] ):
|
||||
yamls = []
|
||||
|
@ -164,7 +147,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
parser.add_argument("--references", type=str, default=kwargs["reference"])
|
||||
parser.add_argument("--language", type=str, default=kwargs["language"])
|
||||
parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"])
|
||||
parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
|
||||
#parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"])
|
||||
parser.add_argument("--input-prompt-prefix", action='store_true')
|
||||
parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second))
|
||||
parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"])
|
||||
parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"])
|
||||
|
@ -194,7 +178,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ):
|
|||
tts = init_tts()
|
||||
|
||||
gr.Info("Inferencing...")
|
||||
with timer("Inferenced in") as t:
|
||||
|
||||
with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t:
|
||||
wav, sr = tts.inference(
|
||||
text=args.text,
|
||||
language=args.language,
|
||||
|
|
Loading…
Reference in New Issue
Block a user