adjustments
This commit is contained in:
parent
0f9b81de75
commit
f6597e2dfe
|
@ -259,6 +259,19 @@ class Trainer:
|
|||
zero_optimization_level: int = 0
|
||||
use_compression_training: bool = False
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
use_vocos: bool = True
|
||||
|
||||
@dataclass()
|
||||
class BitsAndBytes:
|
||||
enabled: bool = False
|
||||
injects: bool = False
|
||||
|
||||
linear: bool = False
|
||||
embedding: bool = False
|
||||
|
||||
@dataclass()
|
||||
class Config(_Config):
|
||||
device: str = "cuda"
|
||||
|
@ -268,8 +281,9 @@ class Config(_Config):
|
|||
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
||||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||
inference: Inference = field(default_factory=lambda: Inference)
|
||||
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
||||
|
||||
use_vocos: bool = True
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
|
@ -438,6 +452,8 @@ cfg.models = Models(**cfg.models)
|
|||
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
|
||||
cfg.evaluation = Evaluation(**cfg.evaluation)
|
||||
cfg.trainer = Trainer(**cfg.trainer)
|
||||
cfg.inference = Inference(**cfg.inference)
|
||||
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
|
||||
|
||||
# cached_property stopped working...
|
||||
if cfg.dataset.use_hdf5:
|
||||
|
|
|
@ -15,12 +15,10 @@ from einops import rearrange
|
|||
from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
USE_VOCOS = False
|
||||
try:
|
||||
from vocos import Vocos
|
||||
USE_VOCOS = True
|
||||
except Exception as e:
|
||||
USE_VOCOS = False
|
||||
cfg.inference.use_vocos = False
|
||||
|
||||
@cache
|
||||
def _load_encodec_model(device="cuda"):
|
||||
|
@ -35,9 +33,11 @@ def _load_encodec_model(device="cuda"):
|
|||
elif cfg.models.levels == 8:
|
||||
bandwidth_id = 6.0
|
||||
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
model = EncodecModel.encodec_model_24khz().to(device)
|
||||
model.set_target_bandwidth(bandwidth_id)
|
||||
model.to(device)
|
||||
model.bandwidth_id = bandwidth_id
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.backend = "encodec"
|
||||
|
||||
return model
|
||||
|
||||
|
@ -58,11 +58,12 @@ def _load_vocos_model(device="cuda"):
|
|||
|
||||
model.bandwidth_id = torch.tensor([bandwidth_id], device=device)
|
||||
model.sample_rate = cfg.sample_rate
|
||||
model.backend = "vocos"
|
||||
|
||||
return model
|
||||
|
||||
@cache
|
||||
def _load_model(device="cuda", vocos=USE_VOCOS):
|
||||
def _load_model(device="cuda", vocos=cfg.inference.use_vocos):
|
||||
if vocos:
|
||||
model = _load_vocos_model(device)
|
||||
else:
|
||||
|
@ -99,7 +100,7 @@ def decode(codes: Tensor, device="cuda"):
|
|||
codes = codes.to(torch.int32)
|
||||
|
||||
kwargs = {}
|
||||
if USE_VOCOS:
|
||||
if model.backend == "vocos":
|
||||
x = model.codes_to_features(codes[0])
|
||||
kwargs['bandwidth_id'] = model.bandwidth_id
|
||||
else:
|
||||
|
@ -107,7 +108,7 @@ def decode(codes: Tensor, device="cuda"):
|
|||
|
||||
wav = model.decode(x, **kwargs)
|
||||
|
||||
if not USE_VOCOS:
|
||||
if model.backend == "encodec":
|
||||
wav = wav[0]
|
||||
|
||||
return wav, model.sample_rate
|
||||
|
|
|
@ -16,8 +16,6 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
|
|||
from .retnet import RetNetDecoder, RetNetConfig
|
||||
from .transformer import SinusoidalEmbedding, Block as TransformerBlock
|
||||
|
||||
from ..utils import wrapper as ml
|
||||
|
||||
def _create_mask(l, device):
|
||||
"""1 is valid region and 0 is invalid."""
|
||||
seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# to-do: replace this
|
||||
# to-do: swap out deepspeed
|
||||
|
||||
from .config import Config
|
||||
from ..config import Config
|
||||
from .distributed import fix_unset_envs
|
||||
from .utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
|
||||
|
|
|
@ -1,21 +1,19 @@
|
|||
# to-do: re-introduce bitsandbytes support
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from ..config import cfg
|
||||
|
||||
Embedding = torch.nn.Embedding
|
||||
Linear = torch.nn.Linear
|
||||
|
||||
"""
|
||||
if cfg.bitsandbytes:
|
||||
if cfg.bitsandbytes.enabled:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
if cfg.bitsandbytes_linear:
|
||||
if cfg.bitsandbytes.linear:
|
||||
Linear = bnb.nn.Linear8bitLt
|
||||
|
||||
if cfg.bitsandbytes_embedding:
|
||||
if cfg.bitsandbytes.embedding:
|
||||
Embedding = bnb.nn.StableEmbedding
|
||||
Embedding.forward = lambda self, input: ( self.norm(F.embedding(
|
||||
input,
|
||||
|
@ -26,28 +24,15 @@ if cfg.bitsandbytes:
|
|||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)).to(self.weight.dtype) )
|
||||
"""
|
||||
|
||||
Adam = torch.optim.Adam
|
||||
AdamW = torch.optim.AdamW
|
||||
|
||||
"""
|
||||
if cfg.bitsandbytes:
|
||||
if cfg.bitsandbytes.enabled:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
Adam = bnb.optim.Adam
|
||||
AdamW = bnb.optim.AdamW
|
||||
"""
|
||||
|
||||
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
||||
def autocast_forward( func ):
|
||||
def wrapper( self, input, *args, **kwargs ):
|
||||
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
|
||||
input = input.to(torch.int32)
|
||||
|
||||
return func( self, input, *args, **kwargs )
|
||||
return wrapper
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
||||
|
||||
# handles generically converting to a specific tensor type and converting back (implemented solely for bfloat16)
|
||||
@contextmanager
|
||||
|
@ -57,4 +42,34 @@ def autocast(input, from_dtype, to_dtype):
|
|||
yield input
|
||||
input = input.to(from_dtype)
|
||||
else:
|
||||
yield input
|
||||
yield input
|
||||
|
||||
@contextmanager
|
||||
def autocasts(input, from_dtype, to_dtype):
|
||||
if input.dtype in from_dtype:
|
||||
from_dtype = input.dtype
|
||||
input = input.to(to_dtype)
|
||||
yield input
|
||||
input = input.to(from_dtype)
|
||||
else:
|
||||
yield input
|
||||
|
||||
# handles temporarily upcasting 'index tensors' so torch will stop bitching
|
||||
def autocast_forward( func ):
|
||||
def wrapper( self, input, *args, **kwargs ):
|
||||
with autocasts( input, [torch.int16, torch.int8, torch.uint8], torch.int32 ) as k:
|
||||
return func( self, k, *args, **kwargs )
|
||||
"""
|
||||
if input.dtype == torch.int16 or input.dtype == torch.int8 or input.dtype == torch.uint8:
|
||||
return func( self, input.to(torch.int32), *args, **kwargs )
|
||||
return func( self, input, *args, **kwargs )
|
||||
"""
|
||||
return wrapper
|
||||
Embedding.forward = autocast_forward(Embedding.forward)
|
||||
|
||||
if cfg.bitsandbytes.injects and cfg.bitsandbytes.enabled:
|
||||
torch.nn.Linear = Linear
|
||||
torch.nn.Embedding = Embedding
|
||||
|
||||
torch.optim.Adam = Adam
|
||||
torch.optim.AdamW = AdamW
|
Loading…
Reference in New Issue
Block a user