added automatically loading default YAML if --yaml is not profided (although I think it already does this by using defaults), default YAML will use local backend + deepspeed inferencing for speedups
This commit is contained in:
parent
f4fcc35aa8
commit
a5c21d65d2
|
@ -71,7 +71,7 @@ trainer:
|
|||
|
||||
backend: deepspeed
|
||||
deepspeed:
|
||||
inferencing: False
|
||||
inferencing: True
|
||||
zero_optimization_level: 0
|
||||
use_compression_training: False
|
||||
|
||||
|
@ -80,7 +80,7 @@ trainer:
|
|||
load_webui: False
|
||||
|
||||
inference:
|
||||
backend: deepspeed
|
||||
backend: local
|
||||
normalize: False
|
||||
|
||||
# some steps break under blanket (B)FP16 + AMP
|
||||
|
|
|
@ -22,9 +22,11 @@ from .tokenizer import VoiceBpeTokenizer
|
|||
# Yuck
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml'
|
||||
|
||||
@dataclass()
|
||||
class BaseConfig:
|
||||
yaml_path: str | None = None
|
||||
yaml_path: str | None = DEFAULT_YAML
|
||||
|
||||
@property
|
||||
def cfg_path(self):
|
||||
|
@ -554,7 +556,7 @@ class Config(BaseConfig):
|
|||
|
||||
@cached_property
|
||||
def diskcache(self):
|
||||
if self.yaml_path is not None and self.dataset.cache:
|
||||
if self.yaml_path is not None and self.yaml_path != DEFAULT_YAML and self.dataset.cache:
|
||||
return diskcache.Cache(self.cache_dir).memoize
|
||||
return lambda: lambda x: x
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from .emb.mel import encode_from_files as encode_mel, trim, trim_random
|
|||
from .utils import to_device
|
||||
from .utils import wrapper as ml
|
||||
|
||||
from .config import cfg
|
||||
from .config import cfg, DEFAULT_YAML
|
||||
from .models import get_models, load_model
|
||||
from .engines import load_engines, deepspeed_available
|
||||
from .data import get_phone_symmap, tokenize
|
||||
|
@ -31,6 +31,9 @@ class TTS():
|
|||
self.input_sample_rate = 24000
|
||||
self.output_sample_rate = 24000
|
||||
|
||||
if config is None:
|
||||
config = DEFAULT_YAML
|
||||
|
||||
if config:
|
||||
cfg.load_yaml( config )
|
||||
|
||||
|
|
|
@ -92,12 +92,15 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
self.lm_head = nn.Sequential(norm, linear)
|
||||
|
||||
self.kv_cache = kv_cache
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
self.cached_mel_emb = None
|
||||
|
||||
# Model parallel
|
||||
"""
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
"""
|
||||
|
||||
"""
|
||||
def parallelize(self, device_map=None):
|
||||
self.device_map = (
|
||||
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
||||
|
@ -115,6 +118,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
self.lm_head = self.lm_head.to("cpu")
|
||||
self.model_parallel = False
|
||||
torch.cuda.empty_cache()
|
||||
"""
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
@ -212,9 +216,11 @@ class GPT2InferenceModel(GPT2PreTrainedModel):
|
|||
hidden_states = transformer_outputs[0]
|
||||
|
||||
# Set device for model parallelism
|
||||
"""
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(self.transformer.first_device)
|
||||
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
||||
"""
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user