cleanup, use deepspeed inferencing pathway if requested
This commit is contained in:
parent
26fbb92ec6
commit
893a610fad
|
@ -460,6 +460,7 @@ class Trainer:
|
|||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
backend: str = "local"
|
||||
weight_dtype: str = "float32"
|
||||
amp: bool = False
|
||||
|
||||
|
@ -492,6 +493,7 @@ class BitsAndBytes:
|
|||
class Config(_Config):
|
||||
device: str = "cuda"
|
||||
mode: str = "training" # "inferencing"
|
||||
experimental: bool = False # So I can stop commenting out things when committing
|
||||
|
||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||
models: Models = field(default_factory=lambda: Models)
|
||||
|
|
|
@ -291,8 +291,10 @@ class Dataset(_Dataset):
|
|||
|
||||
# shuffle it up a bit
|
||||
prom_length = 0
|
||||
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
if cfg.experimental:
|
||||
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||
else:
|
||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||
|
||||
for _ in range(cfg.dataset.max_prompts):
|
||||
path = random.choice(choices)
|
||||
|
|
|
@ -9,3 +9,111 @@ elif cfg.trainer.backend == "local":
|
|||
from .base import Engine
|
||||
|
||||
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
|
||||
|
||||
from ..models import get_models
|
||||
from ..utils import wrapper as ml
|
||||
import torch
|
||||
|
||||
deepspeed_available = False
|
||||
try:
|
||||
import deepspeed
|
||||
deepspeed_available = True
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def load_engines():
|
||||
models = get_models(cfg.models.get())
|
||||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
inferencing = cfg.mode == "inferencing" or not model._cfg.training
|
||||
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||
|
||||
engine_class = _Engine if backend == "local" or inferencing else Engine
|
||||
|
||||
if inferencing:
|
||||
model._cfg.training = False
|
||||
|
||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
optimizer_class = None
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
params["betas"] = (0.9, 0.96)
|
||||
params["eps"] = 1e-07
|
||||
params["weight_decay"] = 0.01
|
||||
|
||||
optimizer_class = ml.AdamW
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
optimizer = ml.SGD
|
||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||
optimizer_class = ml.Prodigy
|
||||
else:
|
||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = optimizer_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
|
||||
# set up our LR scheduler here
|
||||
|
||||
if inferencing:
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
|
||||
print("DeepSpeed checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
stats = None
|
||||
if loads_state_dict:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
if "stats" in state:
|
||||
stats = state["stats"]
|
||||
|
||||
if "module" in state:
|
||||
state = state["module"]
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# deepspeed inferencing
|
||||
if backend == "local" and inferencing and deepspeed_available: #and sys.platform.startswith("win"):
|
||||
engine_class = _Engine
|
||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||
|
||||
# use base engine if requested
|
||||
engines[name] = engine_class(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
|
||||
_cfg=model._cfg,
|
||||
stats=stats
|
||||
)
|
||||
|
||||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
||||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
|
||||
# freeze requested params
|
||||
for name, engine in engines.items():
|
||||
engine.freeze(freeze_all=False)
|
||||
|
||||
#do_gc()
|
||||
|
||||
return engines
|
|
@ -3,7 +3,7 @@ import argparse
|
|||
import torch
|
||||
|
||||
from .data import get_phone_symmap
|
||||
from .train import load_engines
|
||||
from .engines import load_engines
|
||||
from .config import cfg
|
||||
|
||||
def main():
|
||||
|
|
|
@ -12,18 +12,11 @@ from .utils import to_device
|
|||
|
||||
from .config import cfg
|
||||
from .models import get_models
|
||||
from .train import load_engines
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, _load_quants
|
||||
|
||||
use_deepspeed_inference = False
|
||||
# to-do: integrate this for windows
|
||||
"""
|
||||
try:
|
||||
if deepspeed_available:
|
||||
import deepspeed
|
||||
use_deepspeed_inference = True
|
||||
except Exception as e:
|
||||
pass
|
||||
"""
|
||||
|
||||
class TTS():
|
||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
|
||||
|
@ -48,9 +41,9 @@ class TTS():
|
|||
if device is None:
|
||||
device = cfg.device
|
||||
|
||||
cfg.mode = "inferencing"
|
||||
cfg.device = device
|
||||
cfg.trainer.load_state_dict = True
|
||||
cfg.mode = "inferencing"
|
||||
cfg.trainer.backend = cfg.inference.backend
|
||||
cfg.trainer.weight_dtype = dtype
|
||||
cfg.inference.weight_dtype = dtype
|
||||
|
||||
|
@ -70,6 +63,10 @@ class TTS():
|
|||
state = state['module']
|
||||
|
||||
model.load_state_dict(state)
|
||||
|
||||
if deepspeed_available:
|
||||
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||
|
||||
return model
|
||||
|
||||
if ar_ckpt and nar_ckpt:
|
||||
|
@ -94,12 +91,8 @@ class TTS():
|
|||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
if use_deepspeed_inference:
|
||||
self.ar = deepspeed.init_inference(model=self.ar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
|
||||
self.nar = deepspeed.init_inference(model=self.nar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
|
||||
else:
|
||||
self.ar.eval()
|
||||
self.nar.eval()
|
||||
self.ar.eval()
|
||||
self.nar.eval()
|
||||
|
||||
if self.symmap is None:
|
||||
self.symmap = get_phone_symmap()
|
||||
|
|
|
@ -5,7 +5,6 @@ from .data import create_train_val_dataloader
|
|||
from .emb import qnt
|
||||
|
||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||
from .utils.trainer import load_engines
|
||||
|
||||
import auraloss
|
||||
import json
|
||||
|
|
|
@ -28,8 +28,7 @@ from .distributed import (
|
|||
local_leader_only,
|
||||
)
|
||||
|
||||
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder
|
||||
from ..models import get_models
|
||||
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines
|
||||
|
||||
from .utils import to_device, do_gc
|
||||
from ..utils import wrapper as ml
|
||||
|
@ -38,86 +37,6 @@ from ..data import get_phone_symmap # should decouple from this trainer script
|
|||
_logger = logging.getLogger(__name__)
|
||||
_command: str
|
||||
|
||||
def load_engines():
|
||||
models = get_models(cfg.models.get())
|
||||
engines = dict()
|
||||
|
||||
for name, model in models.items():
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
optimizer_class = None
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
params["betas"] = (0.9, 0.96)
|
||||
params["eps"] = 1e-07
|
||||
params["weight_decay"] = 0.01
|
||||
|
||||
optimizer_class = ml.AdamW
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
optimizer = ml.SGD
|
||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||
optimizer_class = ml.Prodigy
|
||||
else:
|
||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = optimizer_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
|
||||
# set up our LR scheduler here
|
||||
|
||||
if not model._cfg.training:
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
if not cfg.trainer.load_state_dict and cfg.trainer.backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
|
||||
print("DeepSpeed checkpoint missing, but weights found.")
|
||||
cfg.trainer.load_state_dict = True
|
||||
|
||||
stats = None
|
||||
if cfg.trainer.load_state_dict or not model._cfg.training:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
if "stats" in state:
|
||||
stats = state["stats"]
|
||||
|
||||
if "module" in state:
|
||||
state = state["module"]
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# use base engine because DeepSpeed memory leaks if it's a non-training model
|
||||
engines[name] = (Engine if model._cfg.training else _Engine)(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
|
||||
_cfg=model._cfg,
|
||||
stats=stats
|
||||
)
|
||||
|
||||
engines = Engines(engines)
|
||||
engines.setup()
|
||||
|
||||
if not cfg.trainer.load_state_dict:
|
||||
engines.load_checkpoint()
|
||||
|
||||
# freeze requested params
|
||||
for name, engine in engines.items():
|
||||
engine.freeze(freeze_all=False)
|
||||
|
||||
do_gc()
|
||||
|
||||
return engines
|
||||
|
||||
class EvalFn(Protocol):
|
||||
def __call__(self, *, engines: Engines):
|
||||
|
|
Loading…
Reference in New Issue
Block a user