diff --git a/vall_e/config.py b/vall_e/config.py index 2c6626d..0eca960 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -259,6 +259,19 @@ class Trainer: zero_optimization_level: int = 0 use_compression_training: bool = False + +@dataclass() +class Inference: + use_vocos: bool = True + +@dataclass() +class BitsAndBytes: + enabled: bool = False + injects: bool = False + + linear: bool = False + embedding: bool = False + @dataclass() class Config(_Config): device: str = "cuda" @@ -268,8 +281,9 @@ class Config(_Config): hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) + inference: Inference = field(default_factory=lambda: Inference) + bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) - use_vocos: bool = True @property def sample_rate(self): @@ -438,6 +452,8 @@ cfg.models = Models(**cfg.models) cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters) cfg.evaluation = Evaluation(**cfg.evaluation) cfg.trainer = Trainer(**cfg.trainer) +cfg.inference = Inference(**cfg.inference) +cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) # cached_property stopped working... if cfg.dataset.use_hdf5: diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 4f25aeb..1900d89 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -15,12 +15,10 @@ from einops import rearrange from torch import Tensor from tqdm import tqdm -USE_VOCOS = False try: from vocos import Vocos - USE_VOCOS = True except Exception as e: - USE_VOCOS = False + cfg.inference.use_vocos = False @cache def _load_encodec_model(device="cuda"): @@ -35,9 +33,11 @@ def _load_encodec_model(device="cuda"): elif cfg.models.levels == 8: bandwidth_id = 6.0 - model = EncodecModel.encodec_model_24khz() + model = EncodecModel.encodec_model_24khz().to(device) model.set_target_bandwidth(bandwidth_id) - model.to(device) + model.bandwidth_id = bandwidth_id + model.sample_rate = cfg.sample_rate + model.backend = "encodec" return model @@ -58,11 +58,12 @@ def _load_vocos_model(device="cuda"): model.bandwidth_id = torch.tensor([bandwidth_id], device=device) model.sample_rate = cfg.sample_rate + model.backend = "vocos" return model @cache -def _load_model(device="cuda", vocos=USE_VOCOS): +def _load_model(device="cuda", vocos=cfg.inference.use_vocos): if vocos: model = _load_vocos_model(device) else: @@ -99,7 +100,7 @@ def decode(codes: Tensor, device="cuda"): codes = codes.to(torch.int32) kwargs = {} - if USE_VOCOS: + if model.backend == "vocos": x = model.codes_to_features(codes[0]) kwargs['bandwidth_id'] = model.bandwidth_id else: @@ -107,7 +108,7 @@ def decode(codes: Tensor, device="cuda"): wav = model.decode(x, **kwargs) - if not USE_VOCOS: + if model.backend == "encodec": wav = wav[0] return wav, model.sample_rate diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 5973cbb..7eb3902 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -16,8 +16,6 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult from .retnet import RetNetDecoder, RetNetConfig from .transformer import SinusoidalEmbedding, Block as TransformerBlock -from ..utils import wrapper as ml - def _create_mask(l, device): """1 is valid region and 0 is invalid.""" seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) diff --git a/vall_e/utils/engines.py b/vall_e/utils/engines.py index 64afaeb..6549b9a 100755 --- a/vall_e/utils/engines.py +++ b/vall_e/utils/engines.py @@ -5,7 +5,7 @@ # to-do: replace this # to-do: swap out deepspeed -from .config import Config +from ..config import Config from .distributed import fix_unset_envs from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index 6fe7687..040762d 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -1,21 +1,19 @@ -# to-do: re-introduce bitsandbytes support - from contextlib import contextmanager import torch import torch.nn.functional as F +from ..config import cfg Embedding = torch.nn.Embedding Linear = torch.nn.Linear -""" -if cfg.bitsandbytes: +if cfg.bitsandbytes.enabled: import bitsandbytes as bnb - if cfg.bitsandbytes_linear: + if cfg.bitsandbytes.linear: Linear = bnb.nn.Linear8bitLt - if cfg.bitsandbytes_embedding: + if cfg.bitsandbytes.embedding: Embedding = bnb.nn.StableEmbedding Embedding.forward = lambda self, input: ( self.norm(F.embedding( input, @@ -26,28 +24,15 @@ if cfg.bitsandbytes: self.scale_grad_by_freq, self.sparse, )).to(self.weight.dtype) ) -""" Adam = torch.optim.Adam AdamW = torch.optim.AdamW -""" -if cfg.bitsandbytes: +if cfg.bitsandbytes.enabled: import bitsandbytes as bnb Adam = bnb.optim.Adam AdamW = bnb.optim.AdamW -""" - -# handles temporarily upcasting 'index tensors' so torch will stop bitching -def autocast_forward( func ): - def wrapper( self, input, *args, **kwargs ): - if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8: - input = input.to(torch.int32) - - return func( self, input, *args, **kwargs ) - return wrapper -Embedding.forward = autocast_forward(Embedding.forward) # handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) @contextmanager @@ -57,4 +42,34 @@ def autocast(input, from_dtype, to_dtype): yield input input = input.to(from_dtype) else: - yield input \ No newline at end of file + yield input + +@contextmanager +def autocasts(input, from_dtype, to_dtype): + if input.dtype in from_dtype: + from_dtype = input.dtype + input = input.to(to_dtype) + yield input + input = input.to(from_dtype) + else: + yield input + +# handles temporarily upcasting 'index tensors' so torch will stop bitching +def autocast_forward( func ): + def wrapper( self, input, *args, **kwargs ): + with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k: + return func( self, k, *args, **kwargs ) + """ + if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8: + return func( self, input.to(torch.int32), *args, **kwargs ) + return func( self, input, *args, **kwargs ) + """ + return wrapper +Embedding.forward = autocast_forward(Embedding.forward) + +if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled: + torch.nn.Linear = Linear + torch.nn.Embedding = Embedding + + torch.optim.Adam = Adam + torch.optim.AdamW = AdamW \ No newline at end of file