diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 74e2c8d..50035cf 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -17,11 +17,14 @@ import math from einops import rearrange from torch import Tensor from tqdm import trange +from time import perf_counter + import logging _logger = logging.getLogger(__name__) from ..emb.qnt import trim, encode_as_embedding +from ..utils import get_devices, setup_logging, timer from .lora import enable_lora @@ -301,7 +304,7 @@ class AR_NAR(Base): r = super().sample( logits=logits, - prev_list=[ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], + prev_list=None if sampling_repetition_penalty == 1.0 and sampling_length_penalty == 0.0 else [ resps_list[i] if task not in text_task else text_list[i] for i, task in enumerate( task_list ) ], temperature=sampling_temperature, min_temperature=sampling_min_temperature, diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 620b349..6149800 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -1447,7 +1447,7 @@ class Base(nn.Module): def sample( self, logits: list[Tensor], # logit scores - prev_list: list[Tensor], # previous tokens + prev_list: list[Tensor] | None = None, # previous tokens quant_levels: int | list[int] | Tensor | None = None, # base sampling parameters temperature: float = 1.0, @@ -1494,11 +1494,12 @@ class Base(nn.Module): return [ logit.argmax(dim=1) for logit in logits ] # perform repetition penalizing - if "len" not in self.capabilities: + if "len" not in self.capabilities and prev_list is not None: + # to-do: figure out a faster way to handle tolist() logits = [ reptition_penalize(logit, previous=prevs[:, -1].tolist() if prevs.dim() > 1 else prevs.tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, prevs in zip( logits, prev_list ) ] # (AR) perform length penalizing - if quant_levels is None and self.causal: + if quant_levels is None and self.causal and prev_list is not None: logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ] # perform top_k/top_p filtering of our logits diff --git a/vall_e/utils/__init__.py b/vall_e/utils/__init__.py index 70bc823..c941932 100755 --- a/vall_e/utils/__init__.py +++ b/vall_e/utils/__init__.py @@ -10,5 +10,6 @@ from .utils import ( set_seed, passes_policy, get_devices, - truncate_json + truncate_json, + timer ) \ No newline at end of file diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 8bb98d2..b487c46 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -25,8 +25,29 @@ from torch import Tensor, nn from tqdm.auto import tqdm from typing import Callable, TypeVar, overload from contextlib import contextmanager + +from time import perf_counter +from datetime import datetime + T = TypeVar("T") +class timer: + def __init__(self, msg="Elapsed time:", callback=None): + self.msg = msg + self.callback = callback + + def __enter__(self): + self.start = perf_counter() + return self + + def __exit__(self, type, value, traceback): + msg = f'{self.msg} {(perf_counter() - self.start):.9f}s' + + if self.callback: + self.callback(msg) + + print(f'[{datetime.now().isoformat()}] {msg}') + def truncate_json( str ): def fun( match ): diff --git a/vall_e/webui.py b/vall_e/webui.py index 5d2d51b..7d5a5e7 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -9,17 +9,14 @@ import functools import torch import numpy as np -from datetime import datetime - import torchaudio import gradio as gr -from time import perf_counter from pathlib import Path from .inference import TTS, cfg from .train import train -from .utils import get_devices, setup_logging +from .utils import get_devices, setup_logging, timer from .utils.io import json_read, json_stringify from .emb.qnt import decode_to_wave from .data import get_lang_symmap @@ -52,20 +49,6 @@ def gradio_wrapper(inputs): return wrapped_function return decorated -class timer: - def __init__(self, msg="Elapsed time:"): - self.msg = msg - - def __enter__(self): - self.start = perf_counter() - return self - - def __exit__(self, type, value, traceback): - msg = f'{self.msg} {(perf_counter() - self.start):.3f}s' - - gr.Info(msg) - print(f'[{datetime.now().isoformat()}] {msg}') - # returns a list of models, assuming the models are placed under ./training/ or ./models/ def get_model_paths( paths=[Path("./training/"), Path("./models/")] ): yamls = [] @@ -164,7 +147,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): parser.add_argument("--references", type=str, default=kwargs["reference"]) parser.add_argument("--language", type=str, default=kwargs["language"]) parser.add_argument("--input-prompt-length", type=float, default=kwargs["input-prompt-length"]) - parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"]) + #parser.add_argument("--input-prompt-prefix", action='store_true', default=kwargs["input-prompt-prefix"]) + parser.add_argument("--input-prompt-prefix", action='store_true') parser.add_argument("--max-ar-steps", type=int, default=int(kwargs["max-seconds"]*cfg.dataset.frames_per_second)) parser.add_argument("--max-nar-levels", type=int, default=0), # kwargs["max-nar-levels"]) parser.add_argument("--ar-temp", type=float, default=kwargs["ar-temp"]) @@ -194,7 +178,8 @@ def do_inference_tts( progress=gr.Progress(track_tqdm=True), *args, **kwargs ): tts = init_tts() gr.Info("Inferencing...") - with timer("Inferenced in") as t: + + with timer("Inferenced in", callback=lambda msg: gr.Info( msg )) as t: wav, sr = tts.inference( text=args.text, language=args.language,