cleanup, use deepspeed inferencing pathway if requested

This commit is contained in:
mrq 2023-10-09 15:24:04 -05:00
parent 26fbb92ec6
commit 893a610fad
7 changed files with 127 additions and 104 deletions

View File

@ -460,6 +460,7 @@ class Trainer:
@dataclass()
class Inference:
backend: str = "local"
weight_dtype: str = "float32"
amp: bool = False
@ -492,6 +493,7 @@ class BitsAndBytes:
class Config(_Config):
device: str = "cuda"
mode: str = "training" # "inferencing"
experimental: bool = False # So I can stop commenting out things when committing
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)

View File

@ -291,7 +291,9 @@ class Dataset(_Dataset):
# shuffle it up a bit
prom_length = 0
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
if cfg.experimental:
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
else:
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
for _ in range(cfg.dataset.max_prompts):

View File

@ -9,3 +9,111 @@ elif cfg.trainer.backend == "local":
from .base import Engine
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
from ..models import get_models
from ..utils import wrapper as ml
import torch
deepspeed_available = False
try:
import deepspeed
deepspeed_available = True
except Exception as e:
pass
def load_engines():
models = get_models(cfg.models.get())
engines = dict()
for name, model in models.items():
optimizer = None
lr_scheduler = None
inferencing = cfg.mode == "inferencing" or not model._cfg.training
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
loads_state_dict = cfg.trainer.load_state_dict or inferencing
engine_class = _Engine if backend == "local" or inferencing else Engine
if inferencing:
model._cfg.training = False
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None
params = {
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw":
params["betas"] = (0.9, 0.96)
params["eps"] = 1e-07
params["weight_decay"] = 0.01
optimizer_class = ml.AdamW
elif cfg.hyperparameters.optimizer.lower() == "sgd":
optimizer = ml.SGD
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
optimizer_class = ml.Prodigy
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
# set up our LR scheduler here
if inferencing:
optimizer = None
lr_scheduler = None
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
print("DeepSpeed checkpoint missing, but weights found.")
loads_state_dict = True
stats = None
if loads_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path, map_location=torch.device(cfg.device))
# state dict is not just the module, extract the extra trainer details
if "stats" in state:
stats = state["stats"]
if "module" in state:
state = state["module"]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# deepspeed inferencing
if backend == "local" and inferencing and deepspeed_available: #and sys.platform.startswith("win"):
engine_class = _Engine
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
# use base engine if requested
engines[name] = engine_class(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
_cfg=model._cfg,
stats=stats
)
engines = Engines(engines)
engines.setup()
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
# freeze requested params
for name, engine in engines.items():
engine.freeze(freeze_all=False)
#do_gc()
return engines

View File

@ -3,7 +3,7 @@ import argparse
import torch
from .data import get_phone_symmap
from .train import load_engines
from .engines import load_engines
from .config import cfg
def main():

View File

@ -12,18 +12,11 @@ from .utils import to_device
from .config import cfg
from .models import get_models
from .train import load_engines
from .engines import load_engines, deepspeed_available
from .data import get_phone_symmap, _load_quants
use_deepspeed_inference = False
# to-do: integrate this for windows
"""
try:
if deepspeed_available:
import deepspeed
use_deepspeed_inference = True
except Exception as e:
pass
"""
class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
@ -48,9 +41,9 @@ class TTS():
if device is None:
device = cfg.device
cfg.mode = "inferencing"
cfg.device = device
cfg.trainer.load_state_dict = True
cfg.mode = "inferencing"
cfg.trainer.backend = cfg.inference.backend
cfg.trainer.weight_dtype = dtype
cfg.inference.weight_dtype = dtype
@ -70,6 +63,10 @@ class TTS():
state = state['module']
model.load_state_dict(state)
if deepspeed_available:
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
return model
if ar_ckpt and nar_ckpt:
@ -94,10 +91,6 @@ class TTS():
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
if use_deepspeed_inference:
self.ar = deepspeed.init_inference(model=self.ar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
self.nar = deepspeed.init_inference(model=self.nar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
else:
self.ar.eval()
self.nar.eval()

View File

@ -5,7 +5,6 @@ from .data import create_train_val_dataloader
from .emb import qnt
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
from .utils.trainer import load_engines
import auraloss
import json

View File

@ -28,8 +28,7 @@ from .distributed import (
local_leader_only,
)
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder
from ..models import get_models
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines
from .utils import to_device, do_gc
from ..utils import wrapper as ml
@ -38,86 +37,6 @@ from ..data import get_phone_symmap # should decouple from this trainer script
_logger = logging.getLogger(__name__)
_command: str
def load_engines():
models = get_models(cfg.models.get())
engines = dict()
for name, model in models.items():
optimizer = None
lr_scheduler = None
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None
params = {
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw":
params["betas"] = (0.9, 0.96)
params["eps"] = 1e-07
params["weight_decay"] = 0.01
optimizer_class = ml.AdamW
elif cfg.hyperparameters.optimizer.lower() == "sgd":
optimizer = ml.SGD
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
optimizer_class = ml.Prodigy
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
# set up our LR scheduler here
if not model._cfg.training:
optimizer = None
lr_scheduler = None
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
if not cfg.trainer.load_state_dict and cfg.trainer.backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
print("DeepSpeed checkpoint missing, but weights found.")
cfg.trainer.load_state_dict = True
stats = None
if cfg.trainer.load_state_dict or not model._cfg.training:
load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path, map_location=torch.device(cfg.device))
# state dict is not just the module, extract the extra trainer details
if "stats" in state:
stats = state["stats"]
if "module" in state:
state = state["module"]
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# use base engine because DeepSpeed memory leaks if it's a non-training model
engines[name] = (Engine if model._cfg.training else _Engine)(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
_cfg=model._cfg,
stats=stats
)
engines = Engines(engines)
engines.setup()
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
# freeze requested params
for name, engine in engines.items():
engine.freeze(freeze_all=False)
do_gc()
return engines
class EvalFn(Protocol):
def __call__(self, *, engines: Engines):