adjustments

This commit is contained in:
mrq 2023-08-02 18:36:26 -05:00
parent 0f9b81de75
commit f6597e2dfe
5 changed files with 63 additions and 33 deletions

View File

@ -259,6 +259,19 @@ class Trainer:
zero_optimization_level: int = 0 zero_optimization_level: int = 0
use_compression_training: bool = False use_compression_training: bool = False
@dataclass()
class Inference:
use_vocos: bool = True
@dataclass()
class BitsAndBytes:
enabled: bool = False
injects: bool = False
linear: bool = False
embedding: bool = False
@dataclass() @dataclass()
class Config(_Config): class Config(_Config):
device: str = "cuda" device: str = "cuda"
@ -268,8 +281,9 @@ class Config(_Config):
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation) evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer) trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
use_vocos: bool = True
@property @property
def sample_rate(self): def sample_rate(self):
@ -438,6 +452,8 @@ cfg.models = Models(**cfg.models)
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters) cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
cfg.evaluation = Evaluation(**cfg.evaluation) cfg.evaluation = Evaluation(**cfg.evaluation)
cfg.trainer = Trainer(**cfg.trainer) cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
# cached_property stopped working... # cached_property stopped working...
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:

View File

@ -15,12 +15,10 @@ from einops import rearrange
from torch import Tensor from torch import Tensor
from tqdm import tqdm from tqdm import tqdm
USE_VOCOS = False
try: try:
from vocos import Vocos from vocos import Vocos
USE_VOCOS = True
except Exception as e: except Exception as e:
USE_VOCOS = False cfg.inference.use_vocos = False
@cache @cache
def _load_encodec_model(device="cuda"): def _load_encodec_model(device="cuda"):
@ -35,9 +33,11 @@ def _load_encodec_model(device="cuda"):
elif cfg.models.levels == 8: elif cfg.models.levels == 8:
bandwidth_id = 6.0 bandwidth_id = 6.0
model = EncodecModel.encodec_model_24khz() model = EncodecModel.encodec_model_24khz().to(device)
model.set_target_bandwidth(bandwidth_id) model.set_target_bandwidth(bandwidth_id)
model.to(device) model.bandwidth_id = bandwidth_id
model.sample_rate = cfg.sample_rate
model.backend = "encodec"
return model return model
@ -58,11 +58,12 @@ def _load_vocos_model(device="cuda"):
model.bandwidth_id = torch.tensor([bandwidth_id], device=device) model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
model.sample_rate = cfg.sample_rate model.sample_rate = cfg.sample_rate
model.backend = "vocos"
return model return model
@cache @cache
def _load_model(device="cuda", vocos=USE_VOCOS): def _load_model(device="cuda", vocos=cfg.inference.use_vocos):
if vocos: if vocos:
model = _load_vocos_model(device) model = _load_vocos_model(device)
else: else:
@ -99,7 +100,7 @@ def decode(codes: Tensor, device="cuda"):
codes = codes.to(torch.int32) codes = codes.to(torch.int32)
kwargs = {} kwargs = {}
if USE_VOCOS: if model.backend == "vocos":
x = model.codes_to_features(codes[0]) x = model.codes_to_features(codes[0])
kwargs['bandwidth_id'] = model.bandwidth_id kwargs['bandwidth_id'] = model.bandwidth_id
else: else:
@ -107,7 +108,7 @@ def decode(codes: Tensor, device="cuda"):
wav = model.decode(x, **kwargs) wav = model.decode(x, **kwargs)
if not USE_VOCOS: if model.backend == "encodec":
wav = wav[0] wav = wav[0]
return wav, model.sample_rate return wav, model.sample_rate

View File

@ -16,8 +16,6 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .retnet import RetNetDecoder, RetNetConfig from .retnet import RetNetDecoder, RetNetConfig
from .transformer import SinusoidalEmbedding, Block as TransformerBlock from .transformer import SinusoidalEmbedding, Block as TransformerBlock
from ..utils import wrapper as ml
def _create_mask(l, device): def _create_mask(l, device):
"""1 is valid region and 0 is invalid.""" """1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)

View File

@ -5,7 +5,7 @@
# to-do: replace this # to-do: replace this
# to-do: swap out deepspeed # to-do: swap out deepspeed
from .config import Config from ..config import Config
from .distributed import fix_unset_envs from .distributed import fix_unset_envs
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device

View File

@ -1,21 +1,19 @@
# to-do: re-introduce bitsandbytes support
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ..config import cfg
Embedding = torch.nn.Embedding Embedding = torch.nn.Embedding
Linear = torch.nn.Linear Linear = torch.nn.Linear
""" if cfg.bitsandbytes.enabled:
if cfg.bitsandbytes:
import bitsandbytes as bnb import bitsandbytes as bnb
if cfg.bitsandbytes_linear: if cfg.bitsandbytes.linear:
Linear = bnb.nn.Linear8bitLt Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes_embedding: if cfg.bitsandbytes.embedding:
Embedding = bnb.nn.StableEmbedding Embedding = bnb.nn.StableEmbedding
Embedding.forward = lambda self, input: ( self.norm(F.embedding( Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input, input,
@ -26,28 +24,15 @@ if cfg.bitsandbytes:
self.scale_grad_by_freq, self.scale_grad_by_freq,
self.sparse, self.sparse,
)).to(self.weight.dtype) ) )).to(self.weight.dtype) )
"""
Adam = torch.optim.Adam Adam = torch.optim.Adam
AdamW = torch.optim.AdamW AdamW = torch.optim.AdamW
""" if cfg.bitsandbytes.enabled:
if cfg.bitsandbytes:
import bitsandbytes as bnb import bitsandbytes as bnb
Adam = bnb.optim.Adam Adam = bnb.optim.Adam
AdamW = bnb.optim.AdamW AdamW = bnb.optim.AdamW
"""
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
input = input.to(torch.int32)
return func( self, input, *args, **kwargs )
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16) # handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager @contextmanager
@ -58,3 +43,33 @@ def autocast(input, from_dtype, to_dtype):
input = input.to(from_dtype) input = input.to(from_dtype)
else: else:
yield input yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
"""
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
return func( self, input.to(torch.int32), *args, **kwargs )
return func( self, input, *args, **kwargs )
"""
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch.nn.Linear = Linear
torch.nn.Embedding = Embedding
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW