added automatically loading default YAML if --yaml is not profided (although I think it already does this by using defaults), default YAML will use local backend + deepspeed inferencing for speedups
This commit is contained in:
parent
f4fcc35aa8
commit
a5c21d65d2
|
@ -71,7 +71,7 @@ trainer:
|
||||||
|
|
||||||
backend: deepspeed
|
backend: deepspeed
|
||||||
deepspeed:
|
deepspeed:
|
||||||
inferencing: False
|
inferencing: True
|
||||||
zero_optimization_level: 0
|
zero_optimization_level: 0
|
||||||
use_compression_training: False
|
use_compression_training: False
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ trainer:
|
||||||
load_webui: False
|
load_webui: False
|
||||||
|
|
||||||
inference:
|
inference:
|
||||||
backend: deepspeed
|
backend: local
|
||||||
normalize: False
|
normalize: False
|
||||||
|
|
||||||
# some steps break under blanket (B)FP16 + AMP
|
# some steps break under blanket (B)FP16 + AMP
|
||||||
|
|
|
@ -22,9 +22,11 @@ from .tokenizer import VoiceBpeTokenizer
|
||||||
# Yuck
|
# Yuck
|
||||||
from transformers import PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml'
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class BaseConfig:
|
class BaseConfig:
|
||||||
yaml_path: str | None = None
|
yaml_path: str | None = DEFAULT_YAML
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cfg_path(self):
|
def cfg_path(self):
|
||||||
|
@ -554,7 +556,7 @@ class Config(BaseConfig):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def diskcache(self):
|
def diskcache(self):
|
||||||
if self.yaml_path is not None and self.dataset.cache:
|
if self.yaml_path is not None and self.yaml_path != DEFAULT_YAML and self.dataset.cache:
|
||||||
return diskcache.Cache(self.cache_dir).memoize
|
return diskcache.Cache(self.cache_dir).memoize
|
||||||
return lambda: lambda x: x
|
return lambda: lambda x: x
|
||||||
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from .emb.mel import encode_from_files as encode_mel, trim, trim_random
|
||||||
from .utils import to_device
|
from .utils import to_device
|
||||||
from .utils import wrapper as ml
|
from .utils import wrapper as ml
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg, DEFAULT_YAML
|
||||||
from .models import get_models, load_model
|
from .models import get_models, load_model
|
||||||
from .engines import load_engines, deepspeed_available
|
from .engines import load_engines, deepspeed_available
|
||||||
from .data import get_phone_symmap, tokenize
|
from .data import get_phone_symmap, tokenize
|
||||||
|
@ -31,6 +31,9 @@ class TTS():
|
||||||
self.input_sample_rate = 24000
|
self.input_sample_rate = 24000
|
||||||
self.output_sample_rate = 24000
|
self.output_sample_rate = 24000
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = DEFAULT_YAML
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
|
||||||
|
|
|
@ -92,12 +92,15 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
self.lm_head = nn.Sequential(norm, linear)
|
self.lm_head = nn.Sequential(norm, linear)
|
||||||
|
|
||||||
self.kv_cache = kv_cache
|
self.kv_cache = kv_cache
|
||||||
|
|
||||||
# Model parallel
|
|
||||||
self.model_parallel = False
|
|
||||||
self.device_map = None
|
|
||||||
self.cached_mel_emb = None
|
self.cached_mel_emb = None
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
|
"""
|
||||||
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
def parallelize(self, device_map=None):
|
def parallelize(self, device_map=None):
|
||||||
self.device_map = (
|
self.device_map = (
|
||||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||||
|
@ -115,6 +118,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
self.lm_head = self.lm_head.to("cpu")
|
self.lm_head = self.lm_head.to("cpu")
|
||||||
self.model_parallel = False
|
self.model_parallel = False
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
"""
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
|
@ -212,9 +216,11 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
# Set device for model parallelism
|
# Set device for model parallelism
|
||||||
|
"""
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
torch.cuda.set_device(self.transformer.first_device)
|
torch.cuda.set_device(self.transformer.first_device)
|
||||||
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||||
|
"""
|
||||||
|
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user