diff --git a/data/config.yaml b/data/config.yaml index f9181ab..5aa94a7 100644 --- a/data/config.yaml +++ b/data/config.yaml @@ -71,7 +71,7 @@ trainer: backend: deepspeed deepspeed: - inferencing: False + inferencing: True zero_optimization_level: 0 use_compression_training: False @@ -80,7 +80,7 @@ trainer: load_webui: False inference: - backend: deepspeed + backend: local normalize: False # some steps break under blanket (B)FP16 + AMP diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 485f70a..3cb4aa4 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -22,9 +22,11 @@ from .tokenizer import VoiceBpeTokenizer # Yuck from transformers import PreTrainedTokenizerFast +DEFAULT_YAML = Path(__file__).parent.parent / 'data/config.yaml' + @dataclass() class BaseConfig: - yaml_path: str | None = None + yaml_path: str | None = DEFAULT_YAML @property def cfg_path(self): @@ -554,7 +556,7 @@ class Config(BaseConfig): @cached_property def diskcache(self): - if self.yaml_path is not None and self.dataset.cache: + if self.yaml_path is not None and self.yaml_path != DEFAULT_YAML and self.dataset.cache: return diskcache.Cache(self.cache_dir).memoize return lambda: lambda x: x diff --git a/tortoise_tts/inference.py b/tortoise_tts/inference.py index d117542..e19c431 100755 --- a/tortoise_tts/inference.py +++ b/tortoise_tts/inference.py @@ -10,7 +10,7 @@ from .emb.mel import encode_from_files as encode_mel, trim, trim_random from .utils import to_device from .utils import wrapper as ml -from .config import cfg +from .config import cfg, DEFAULT_YAML from .models import get_models, load_model from .engines import load_engines, deepspeed_available from .data import get_phone_symmap, tokenize @@ -31,6 +31,9 @@ class TTS(): self.input_sample_rate = 24000 self.output_sample_rate = 24000 + if config is None: + config = DEFAULT_YAML + if config: cfg.load_yaml( config ) diff --git a/tortoise_tts/models/unified_voice.py b/tortoise_tts/models/unified_voice.py index ad37413..c1d6096 100644 --- a/tortoise_tts/models/unified_voice.py +++ b/tortoise_tts/models/unified_voice.py @@ -92,12 +92,15 @@ class GPT2InferenceModel(GPT2PreTrainedModel): self.lm_head = nn.Sequential(norm, linear) self.kv_cache = kv_cache - - # Model parallel - self.model_parallel = False - self.device_map = None self.cached_mel_emb = None + # Model parallel + """ + self.model_parallel = False + self.device_map = None + """ + + """ def parallelize(self, device_map=None): self.device_map = ( get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) @@ -115,6 +118,7 @@ class GPT2InferenceModel(GPT2PreTrainedModel): self.lm_head = self.lm_head.to("cpu") self.model_parallel = False torch.cuda.empty_cache() + """ def get_output_embeddings(self): return self.lm_head @@ -212,9 +216,11 @@ class GPT2InferenceModel(GPT2PreTrainedModel): hidden_states = transformer_outputs[0] # Set device for model parallelism + """ if self.model_parallel: torch.cuda.set_device(self.transformer.first_device) hidden_states = hidden_states.to(self.lm_head.weight.device) + """ lm_logits = self.lm_head(hidden_states)