cleanup, use deepspeed inferencing pathway if requested
This commit is contained in:
parent
26fbb92ec6
commit
893a610fad
|
@ -460,6 +460,7 @@ class Trainer:
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Inference:
|
class Inference:
|
||||||
|
backend: str = "local"
|
||||||
weight_dtype: str = "float32"
|
weight_dtype: str = "float32"
|
||||||
amp: bool = False
|
amp: bool = False
|
||||||
|
|
||||||
|
@ -492,6 +493,7 @@ class BitsAndBytes:
|
||||||
class Config(_Config):
|
class Config(_Config):
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
mode: str = "training" # "inferencing"
|
mode: str = "training" # "inferencing"
|
||||||
|
experimental: bool = False # So I can stop commenting out things when committing
|
||||||
|
|
||||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||||
models: Models = field(default_factory=lambda: Models)
|
models: Models = field(default_factory=lambda: Models)
|
||||||
|
|
|
@ -291,8 +291,10 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
# shuffle it up a bit
|
# shuffle it up a bit
|
||||||
prom_length = 0
|
prom_length = 0
|
||||||
#trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
if cfg.experimental:
|
||||||
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
trim_length = random.randint(75 * 3, 75 * 9) # [3 seconds, 9 seconds]
|
||||||
|
else:
|
||||||
|
trim_length = int(cfg.dataset.prompt_duration * 75) + random.randint(-75, 75)
|
||||||
|
|
||||||
for _ in range(cfg.dataset.max_prompts):
|
for _ in range(cfg.dataset.max_prompts):
|
||||||
path = random.choice(choices)
|
path = random.choice(choices)
|
||||||
|
|
|
@ -9,3 +9,111 @@ elif cfg.trainer.backend == "local":
|
||||||
from .base import Engine
|
from .base import Engine
|
||||||
|
|
||||||
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
|
from .base import Engines, TrainFeeder, default_feeder, Engine as _Engine
|
||||||
|
|
||||||
|
from ..models import get_models
|
||||||
|
from ..utils import wrapper as ml
|
||||||
|
import torch
|
||||||
|
|
||||||
|
deepspeed_available = False
|
||||||
|
try:
|
||||||
|
import deepspeed
|
||||||
|
deepspeed_available = True
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load_engines():
|
||||||
|
models = get_models(cfg.models.get())
|
||||||
|
engines = dict()
|
||||||
|
|
||||||
|
for name, model in models.items():
|
||||||
|
optimizer = None
|
||||||
|
lr_scheduler = None
|
||||||
|
|
||||||
|
inferencing = cfg.mode == "inferencing" or not model._cfg.training
|
||||||
|
backend = cfg.inference.backend if inferencing else cfg.trainer.backend
|
||||||
|
dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype
|
||||||
|
amp = cfg.inference.amp if inferencing else cfg.trainer.amp
|
||||||
|
loads_state_dict = cfg.trainer.load_state_dict or inferencing
|
||||||
|
|
||||||
|
engine_class = _Engine if backend == "local" or inferencing else Engine
|
||||||
|
|
||||||
|
if inferencing:
|
||||||
|
model._cfg.training = False
|
||||||
|
|
||||||
|
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||||
|
optimizer_class = None
|
||||||
|
params = {
|
||||||
|
"lr": cfg.hyperparameters.learning_rate,
|
||||||
|
}
|
||||||
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||||
|
params["betas"] = (0.9, 0.96)
|
||||||
|
params["eps"] = 1e-07
|
||||||
|
params["weight_decay"] = 0.01
|
||||||
|
|
||||||
|
optimizer_class = ml.AdamW
|
||||||
|
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||||
|
optimizer = ml.SGD
|
||||||
|
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||||
|
optimizer_class = ml.Prodigy
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||||
|
|
||||||
|
params.update(cfg.hyperparameters.optimizer_params)
|
||||||
|
optimizer = optimizer_class(
|
||||||
|
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# set up our LR scheduler here
|
||||||
|
|
||||||
|
if inferencing:
|
||||||
|
optimizer = None
|
||||||
|
lr_scheduler = None
|
||||||
|
|
||||||
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
|
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
|
||||||
|
print("DeepSpeed checkpoint missing, but weights found.")
|
||||||
|
loads_state_dict = True
|
||||||
|
|
||||||
|
stats = None
|
||||||
|
if loads_state_dict:
|
||||||
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
|
# state dict is not just the module, extract the extra trainer details
|
||||||
|
if "stats" in state:
|
||||||
|
stats = state["stats"]
|
||||||
|
|
||||||
|
if "module" in state:
|
||||||
|
state = state["module"]
|
||||||
|
|
||||||
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
|
# deepspeed inferencing
|
||||||
|
if backend == "local" and inferencing and deepspeed_available: #and sys.platform.startswith("win"):
|
||||||
|
engine_class = _Engine
|
||||||
|
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||||
|
|
||||||
|
# use base engine if requested
|
||||||
|
engines[name] = engine_class(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
|
||||||
|
_cfg=model._cfg,
|
||||||
|
stats=stats
|
||||||
|
)
|
||||||
|
|
||||||
|
engines = Engines(engines)
|
||||||
|
engines.setup()
|
||||||
|
|
||||||
|
if not cfg.trainer.load_state_dict:
|
||||||
|
engines.load_checkpoint()
|
||||||
|
|
||||||
|
# freeze requested params
|
||||||
|
for name, engine in engines.items():
|
||||||
|
engine.freeze(freeze_all=False)
|
||||||
|
|
||||||
|
#do_gc()
|
||||||
|
|
||||||
|
return engines
|
|
@ -3,7 +3,7 @@ import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .data import get_phone_symmap
|
from .data import get_phone_symmap
|
||||||
from .train import load_engines
|
from .engines import load_engines
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
@ -12,18 +12,11 @@ from .utils import to_device
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .models import get_models
|
from .models import get_models
|
||||||
from .train import load_engines
|
from .engines import load_engines, deepspeed_available
|
||||||
from .data import get_phone_symmap, _load_quants
|
from .data import get_phone_symmap, _load_quants
|
||||||
|
|
||||||
use_deepspeed_inference = False
|
if deepspeed_available:
|
||||||
# to-do: integrate this for windows
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import deepspeed
|
import deepspeed
|
||||||
use_deepspeed_inference = True
|
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
"""
|
|
||||||
|
|
||||||
class TTS():
|
class TTS():
|
||||||
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
|
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
|
||||||
|
@ -48,9 +41,9 @@ class TTS():
|
||||||
if device is None:
|
if device is None:
|
||||||
device = cfg.device
|
device = cfg.device
|
||||||
|
|
||||||
cfg.mode = "inferencing"
|
|
||||||
cfg.device = device
|
cfg.device = device
|
||||||
cfg.trainer.load_state_dict = True
|
cfg.mode = "inferencing"
|
||||||
|
cfg.trainer.backend = cfg.inference.backend
|
||||||
cfg.trainer.weight_dtype = dtype
|
cfg.trainer.weight_dtype = dtype
|
||||||
cfg.inference.weight_dtype = dtype
|
cfg.inference.weight_dtype = dtype
|
||||||
|
|
||||||
|
@ -70,6 +63,10 @@ class TTS():
|
||||||
state = state['module']
|
state = state['module']
|
||||||
|
|
||||||
model.load_state_dict(state)
|
model.load_state_dict(state)
|
||||||
|
|
||||||
|
if deepspeed_available:
|
||||||
|
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
if ar_ckpt and nar_ckpt:
|
if ar_ckpt and nar_ckpt:
|
||||||
|
@ -94,12 +91,8 @@ class TTS():
|
||||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
|
|
||||||
if use_deepspeed_inference:
|
self.ar.eval()
|
||||||
self.ar = deepspeed.init_inference(model=self.ar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
|
self.nar.eval()
|
||||||
self.nar = deepspeed.init_inference(model=self.nar, mp_size=1, replace_with_kernel_inject=True, dtype=self.dtype if not self.amp else torch.float32).module.eval()
|
|
||||||
else:
|
|
||||||
self.ar.eval()
|
|
||||||
self.nar.eval()
|
|
||||||
|
|
||||||
if self.symmap is None:
|
if self.symmap is None:
|
||||||
self.symmap = get_phone_symmap()
|
self.symmap = get_phone_symmap()
|
||||||
|
|
|
@ -5,7 +5,6 @@ from .data import create_train_val_dataloader
|
||||||
from .emb import qnt
|
from .emb import qnt
|
||||||
|
|
||||||
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
from .utils import setup_logging, to_device, trainer, flatten_dict, do_gc
|
||||||
from .utils.trainer import load_engines
|
|
||||||
|
|
||||||
import auraloss
|
import auraloss
|
||||||
import json
|
import json
|
||||||
|
|
|
@ -28,8 +28,7 @@ from .distributed import (
|
||||||
local_leader_only,
|
local_leader_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder
|
from ..engines import _Engine, Engine, Engines, TrainFeeder, default_feeder, load_engines
|
||||||
from ..models import get_models
|
|
||||||
|
|
||||||
from .utils import to_device, do_gc
|
from .utils import to_device, do_gc
|
||||||
from ..utils import wrapper as ml
|
from ..utils import wrapper as ml
|
||||||
|
@ -38,86 +37,6 @@ from ..data import get_phone_symmap # should decouple from this trainer script
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
_command: str
|
_command: str
|
||||||
|
|
||||||
def load_engines():
|
|
||||||
models = get_models(cfg.models.get())
|
|
||||||
engines = dict()
|
|
||||||
|
|
||||||
for name, model in models.items():
|
|
||||||
optimizer = None
|
|
||||||
lr_scheduler = None
|
|
||||||
|
|
||||||
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
|
||||||
optimizer_class = None
|
|
||||||
params = {
|
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
|
||||||
}
|
|
||||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
|
||||||
params["betas"] = (0.9, 0.96)
|
|
||||||
params["eps"] = 1e-07
|
|
||||||
params["weight_decay"] = 0.01
|
|
||||||
|
|
||||||
optimizer_class = ml.AdamW
|
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
|
||||||
optimizer = ml.SGD
|
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
|
||||||
optimizer_class = ml.Prodigy
|
|
||||||
else:
|
|
||||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
|
||||||
|
|
||||||
params.update(cfg.hyperparameters.optimizer_params)
|
|
||||||
optimizer = optimizer_class(
|
|
||||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
|
||||||
**params,
|
|
||||||
)
|
|
||||||
|
|
||||||
# set up our LR scheduler here
|
|
||||||
|
|
||||||
if not model._cfg.training:
|
|
||||||
optimizer = None
|
|
||||||
lr_scheduler = None
|
|
||||||
|
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
|
||||||
if not cfg.trainer.load_state_dict and cfg.trainer.backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists():
|
|
||||||
print("DeepSpeed checkpoint missing, but weights found.")
|
|
||||||
cfg.trainer.load_state_dict = True
|
|
||||||
|
|
||||||
stats = None
|
|
||||||
if cfg.trainer.load_state_dict or not model._cfg.training:
|
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
|
||||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
|
||||||
|
|
||||||
# state dict is not just the module, extract the extra trainer details
|
|
||||||
if "stats" in state:
|
|
||||||
stats = state["stats"]
|
|
||||||
|
|
||||||
if "module" in state:
|
|
||||||
state = state["module"]
|
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
|
||||||
|
|
||||||
# use base engine because DeepSpeed memory leaks if it's a non-training model
|
|
||||||
engines[name] = (Engine if model._cfg.training else _Engine)(
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
|
|
||||||
_cfg=model._cfg,
|
|
||||||
stats=stats
|
|
||||||
)
|
|
||||||
|
|
||||||
engines = Engines(engines)
|
|
||||||
engines.setup()
|
|
||||||
|
|
||||||
if not cfg.trainer.load_state_dict:
|
|
||||||
engines.load_checkpoint()
|
|
||||||
|
|
||||||
# freeze requested params
|
|
||||||
for name, engine in engines.items():
|
|
||||||
engine.freeze(freeze_all=False)
|
|
||||||
|
|
||||||
do_gc()
|
|
||||||
|
|
||||||
return engines
|
|
||||||
|
|
||||||
class EvalFn(Protocol):
|
class EvalFn(Protocol):
|
||||||
def __call__(self, *, engines: Engines):
|
def __call__(self, *, engines: Engines):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user