adjustments

This commit is contained in:
mrq 2023-08-02 18:36:26 -05:00
parent 0f9b81de75
commit f6597e2dfe
5 changed files with 63 additions and 33 deletions

View File

@ -259,6 +259,19 @@ class Trainer:
zero_optimization_level: int = 0
use_compression_training: bool = False
@dataclass()
class Inference:
use_vocos: bool = True
@dataclass()
class BitsAndBytes:
enabled: bool = False
injects: bool = False
linear: bool = False
embedding: bool = False
@dataclass()
class Config(_Config):
device: str = "cuda"
@ -268,8 +281,9 @@ class Config(_Config):
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
use_vocos: bool = True
@property
def sample_rate(self):
@ -438,6 +452,8 @@ cfg.models = Models(**cfg.models)
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
cfg.evaluation = Evaluation(**cfg.evaluation)
cfg.trainer = Trainer(**cfg.trainer)
cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
# cached_property stopped working...
if cfg.dataset.use_hdf5:

View File

@ -15,12 +15,10 @@ from einops import rearrange
from torch import Tensor
from tqdm import tqdm
USE_VOCOS = False
try:
from vocos import Vocos
USE_VOCOS = True
except Exception as e:
USE_VOCOS = False
cfg.inference.use_vocos = False
@cache
def _load_encodec_model(device="cuda"):
@ -35,9 +33,11 @@ def _load_encodec_model(device="cuda"):
elif cfg.models.levels == 8:
bandwidth_id = 6.0
model = EncodecModel.encodec_model_24khz()
model = EncodecModel.encodec_model_24khz().to(device)
model.set_target_bandwidth(bandwidth_id)
model.to(device)
model.bandwidth_id = bandwidth_id
model.sample_rate = cfg.sample_rate
model.backend = "encodec"
return model
@ -58,11 +58,12 @@ def _load_vocos_model(device="cuda"):
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
model.sample_rate = cfg.sample_rate
model.backend = "vocos"
return model
@cache
def _load_model(device="cuda", vocos=USE_VOCOS):
def _load_model(device="cuda", vocos=cfg.inference.use_vocos):
if vocos:
model = _load_vocos_model(device)
else:
@ -99,7 +100,7 @@ def decode(codes: Tensor, device="cuda"):
codes = codes.to(torch.int32)
kwargs = {}
if USE_VOCOS:
if model.backend == "vocos":
x = model.codes_to_features(codes[0])
kwargs['bandwidth_id'] = model.bandwidth_id
else:
@ -107,7 +108,7 @@ def decode(codes: Tensor, device="cuda"):
wav = model.decode(x, **kwargs)
if not USE_VOCOS:
if model.backend == "encodec":
wav = wav[0]
return wav, model.sample_rate

View File

@ -16,8 +16,6 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
from .retnet import RetNetDecoder, RetNetConfig
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
from ..utils import wrapper as ml
def _create_mask(l, device):
"""1 is valid region and 0 is invalid."""
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)

View File

@ -5,7 +5,7 @@
# to-do: replace this
# to-do: swap out deepspeed
from .config import Config
from ..config import Config
from .distributed import fix_unset_envs
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device

View File

@ -1,21 +1,19 @@
# to-do: re-introduce bitsandbytes support
from contextlib import contextmanager
import torch
import torch.nn.functional as F
from ..config import cfg
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
"""
if cfg.bitsandbytes:
if cfg.bitsandbytes.enabled:
import bitsandbytes as bnb
if cfg.bitsandbytes_linear:
if cfg.bitsandbytes.linear:
Linear = bnb.nn.Linear8bitLt
if cfg.bitsandbytes_embedding:
if cfg.bitsandbytes.embedding:
Embedding = bnb.nn.StableEmbedding
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
input,
@ -26,28 +24,15 @@ if cfg.bitsandbytes:
self.scale_grad_by_freq,
self.sparse,
)).to(self.weight.dtype) )
"""
Adam = torch.optim.Adam
AdamW = torch.optim.AdamW
"""
if cfg.bitsandbytes:
if cfg.bitsandbytes.enabled:
import bitsandbytes as bnb
Adam = bnb.optim.Adam
AdamW = bnb.optim.AdamW
"""
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
input = input.to(torch.int32)
return func( self, input, *args, **kwargs )
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
@contextmanager
@ -57,4 +42,34 @@ def autocast(input, from_dtype, to_dtype):
yield input
input = input.to(from_dtype)
else:
yield input
yield input
@contextmanager
def autocasts(input, from_dtype, to_dtype):
if input.dtype in from_dtype:
from_dtype = input.dtype
input = input.to(to_dtype)
yield input
input = input.to(from_dtype)
else:
yield input
# handles temporarily upcasting 'index tensors' so torch will stop bitching
def autocast_forward( func ):
def wrapper( self, input, *args, **kwargs ):
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
return func( self, k, *args, **kwargs )
"""
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
return func( self, input.to(torch.int32), *args, **kwargs )
return func( self, input, *args, **kwargs )
"""
return wrapper
Embedding.forward = autocast_forward(Embedding.forward)
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
torch.nn.Linear = Linear
torch.nn.Embedding = Embedding
torch.optim.Adam = Adam
torch.optim.AdamW = AdamW