distributed training works now (hopefully)

This commit is contained in:
mrq 2023-08-13 22:07:45 -05:00
parent 2af09d0bef
commit d7deaf6def
13 changed files with 73 additions and 58 deletions

View File

@ -8,8 +8,6 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
> **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed. This README won't get much love until I truly nail out a quasi-decent model. > **Note** this is highly experimental. While I've seem to have audited and tighened down as much as I can, I'm still trying to produce a decent model out of it. You're free to train your own model if you happen to have the massive compute for it, but it's quite the beast to properly feed. This README won't get much love until I truly nail out a quasi-decent model.
> **Note** Distributed training seems broken? I'm not really sure how to test it, as my two 6800XTs have been redistributed for now, and the last time I tried using them for this, things weren't good.
> **Note** You can follow along with my pseudo-blog in an issue [here](https://git.ecker.tech/mrq/ai-voice-cloning/issues/152). I currently have a dataset clocking in at 3400+ trimmed hours. > **Note** You can follow along with my pseudo-blog in an issue [here](https://git.ecker.tech/mrq/ai-voice-cloning/issues/152). I currently have a dataset clocking in at 3400+ trimmed hours.
### Requirements ### Requirements
@ -49,6 +47,8 @@ Training is very dependent on:
#### Leverage Your Own #### Leverage Your Own
> **Note** It is highly recommended to utilize [mrq/ai-voice-cloning](https://git.ecker.tech/mrq/ai-voice-cloning) with `--tts-backend="vall-e"` to handle transcription and dataset preparations.
1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.txt`. 1. Put your data into a folder, e.g. `./data/custom`. Audio files should be named with the suffix `.wav` and text files with `.txt`.
2. Quantize the data: 2. Quantize the data:
@ -66,6 +66,8 @@ python -m vall_e.emb.g2p ./data/custom
4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`. 4. Customize your configuration and define the dataset by modifying `./data/config.yml`. Refer to `./vall_e/config.py` for details. If you want to choose between different model presets, check `./vall_e/models/__init__.py`.
> **Note** Be sure to set `distributd: True` to ensure the `DistributedSampler` is used. In the future, I'll have it automagically detect this.
If you're interested in creating an HDF5 copy of your dataset, simply invoke: If you're interested in creating an HDF5 copy of your dataset, simply invoke:
``` ```

View File

@ -1,10 +1,8 @@
dataset: dataset:
training: [ training: [
] ]
validation: [ validation: [
] ]
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'" speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
@ -15,30 +13,30 @@ dataset:
workers: 8 workers: 8
cache: True cache: True
phones_range: [4, 512] phones_range: [4, 128]
duration_range: [1.0, 24.0] duration_range: [1.0, 12.0]
random_utterance: 1.0 random_utterance: 1.0
max_prompts: 6 max_prompts: 3
prompt_duration: 6.0 prompt_duration: 3.0
models: models:
_models: _models:
- name: "ar" - name: "ar"
size: "quarter" size: "full"
resp_levels: 1 resp_levels: 1
arch_type: "retnet" arch_type: "retnet"
- name: "nar" - name: "nar"
size: "quarter" size: "full"
resp_levels: 1 resp_levels: 1
arch_type: "retnet" arch_type: "retnet"
prom_levels: 2 prom_levels: 2
hyperparameters: hyperparameters:
batch_size: 32 batch_size: 8
gradient_accumulation_steps: 4 gradient_accumulation_steps: 16
gradient_clipping: 100 gradient_clipping: 100
optimizer: Adamw optimizer: Adamw
@ -90,15 +88,18 @@ trainer:
gc_mode: None # "global_step" gc_mode: None # "global_step"
weight_dtype: bfloat16 weight_dtype: bfloat16 # float16, float32
backend: deepspeed backend: deepspeed
deepspeed: deepspeed:
zero_optimization_level: 0 zero_optimization_level: 2
use_compression_training: True use_compression_training: True
inference: inference:
use_vocos: True use_vocos: True
bitsandbytes: bitsandbytes:
enabled: false enabled: false
device: cuda
distributed: False

View File

@ -17,7 +17,7 @@ def main():
args = parser.parse_args() args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device ) tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_samples=args.max_ar_samples, ar_temp=args.ar_temp, nar_temp=args.nar_temp ) tts.inference( text=args.text, reference=args.reference, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -7,6 +7,8 @@ import subprocess
import sys import sys
import time import time
import torch
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -379,7 +381,7 @@ class Trainer:
def dtype(self): def dtype(self):
if self.weight_dtype == "float16": if self.weight_dtype == "float16":
return torch.float16 return torch.float16
if cfg.trainer.weight_dtype == "bfloat16": if self.weight_dtype == "bfloat16":
return torch.bfloat16 return torch.bfloat16
return torch.float32 return torch.float32
@ -399,6 +401,7 @@ class BitsAndBytes:
@dataclass() @dataclass()
class Config(_Config): class Config(_Config):
device: str = "cuda" device: str = "cuda"
distributed: bool = False
dataset: Dataset = field(default_factory=lambda: Dataset) dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models) models: Models = field(default_factory=lambda: Models)
@ -433,21 +436,24 @@ class Config(_Config):
cfg = Config.from_cli() cfg = Config.from_cli()
# OmegaConf doesn't actually coerce the dicts into the @dataclass decorated classes, for some god forsaken reason, so we coerce them ourselves # OmegaConf might not coerce the dicts into the @dataclass decorated classes, so we (try to) coerce them ourselves
cfg.dataset = Dataset(**cfg.dataset) try:
cfg.models = Models(**cfg.models) cfg.dataset = Dataset(**cfg.dataset)
cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters) cfg.models = Models(**cfg.models)
cfg.evaluation = Evaluation(**cfg.evaluation) cfg.hyperparameters = Hyperparameters(**cfg.hyperparameters)
cfg.trainer = Trainer(**cfg.trainer) cfg.evaluation = Evaluation(**cfg.evaluation)
cfg.inference = Inference(**cfg.inference) cfg.trainer = Trainer(**cfg.trainer)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes) cfg.inference = Inference(**cfg.inference)
cfg.bitsandbytes = BitsAndBytes(**cfg.bitsandbytes)
cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed) cfg.trainer.deepspeed = DeepSpeed(**cfg.trainer.deepspeed)
except Exception as e:
pass
# cached_property stopped working... # cached_property stopped working...
if cfg.dataset.use_hdf5: if cfg.dataset.use_hdf5:
try: try:
cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'a') cfg.hdf5 = h5py.File(f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', 'r' if cfg.distributed else 'a')
except Exception as e: except Exception as e:
print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e)) print("Error while opening HDF5 file:", f'{cfg.cfg_path}/{cfg.dataset.hdf5_name}', str(e))
cfg.dataset.use_hdf5 = False cfg.dataset.use_hdf5 = False
@ -457,4 +463,4 @@ if not cfg.dataset.use_hdf5:
cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ] cfg.dataset.validation = [ Path(dir) for dir in cfg.dataset.validation ]
if __name__ == "__main__": if __name__ == "__main__":
print(cfg) print(cfg)

View File

@ -20,6 +20,7 @@ from typing import Any
from torch import Tensor from torch import Tensor
from torch.utils.data import DataLoader, Dataset as _Dataset from torch.utils.data import DataLoader, Dataset as _Dataset
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm from tqdm.auto import tqdm
# torch.multiprocessing.set_sharing_strategy("file_system") # torch.multiprocessing.set_sharing_strategy("file_system")
@ -312,13 +313,14 @@ def _create_dataloader(dataset, training):
return DataLoader( return DataLoader(
dataset=dataset, dataset=dataset,
batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size, batch_size=cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size,
shuffle=True, # training shuffle=False if cfg.distributed else True, # training
drop_last=training, drop_last=training,
num_workers=cfg.dataset.workers, num_workers=cfg.dataset.workers,
collate_fn=collate_fn, collate_fn=collate_fn,
persistent_workers=True, persistent_workers=True,
pin_memory=False, # True, pin_memory=False, # True,
worker_init_fn=_seed_worker, worker_init_fn=_seed_worker,
sampler=DistributedSampler(dataset) if cfg.distributed else dataset.sampler
) )
def _load_dataset_paths(): def _load_dataset_paths():

View File

@ -169,7 +169,8 @@ def encode_from_file(path, device="cuda"):
if isinstance( path, list ): if isinstance( path, list ):
return encode_from_files( path, device ) return encode_from_files( path, device )
else: else:
wav, sr = torchaudio.load(str(path), format=path[-3:]) path = str(path)
wav, sr = torchaudio.load(path, format=path[-3:])
if wav.shape[0] == 2: if wav.shape[0] == 2:
wav = wav[:1] wav = wav[:1]

View File

@ -275,7 +275,7 @@ class Engines(dict[str, Engine]):
stats.update(flatten_dict({ name.split("-")[0]: stat })) stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats return stats
def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()): def step(self, batch, feeder: TrainFeeder = default_feeder):
total_elapsed_time = 0 total_elapsed_time = 0
stats: Any = dict() stats: Any = dict()
@ -283,10 +283,9 @@ class Engines(dict[str, Engine]):
if cfg.trainer.gc_mode == 'step': if cfg.trainer.gc_mode == 'step':
do_gc() do_gc()
batch = to_device(batch, device)
for name, engine in self.items(): for name, engine in self.items():
#torch.cuda.synchronize() device = engine.device
if cfg.trainer.gc_mode == 'substep': if cfg.trainer.gc_mode == 'substep':
do_gc() do_gc()
@ -294,10 +293,9 @@ class Engines(dict[str, Engine]):
start_time = time.time() start_time = time.time()
tries = 4 tries = 4
n_ooms = torch.zeros([], device=cfg.device) n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations: batch = to_device(batch, device)
batch = to_device(batch, device)
if not cfg.trainer.check_for_oom: if not cfg.trainer.check_for_oom:
res = feeder( engine=engine, batch=batch ) res = feeder( engine=engine, batch=batch )
@ -336,7 +334,7 @@ class Engines(dict[str, Engine]):
loss, engine_stats = res loss, engine_stats = res
engine_stats |= self.gather_attribute("scalar") engine_stats |= self.gather_attribute("scalar")
n_ooms = torch.zeros([], device=cfg.device) n_ooms = torch.zeros([], device=device)
if cfg.trainer.aggressive_optimizations: if cfg.trainer.aggressive_optimizations:
batch = to_device(batch, 'cpu') batch = to_device(batch, 'cpu')

View File

@ -27,7 +27,9 @@ def main():
model = models[name] model = models[name]
outpath = f'{args.path}/{name}.pt' outpath = f'{args.path}/{name}.pt'
torch.save(model, outpath) torch.save({
'module': model.state_dict()
}, outpath)
print(f"Exported {name} to {outpath}") print(f"Exported {name} to {outpath}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -9,6 +9,8 @@ from .utils import to_device
from .config import cfg from .config import cfg
from .export import load_models from .export import load_models
from .models import get_models
from .data import get_phone_symmap
class TTS(): class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ): def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device="cuda" ):
@ -19,11 +21,21 @@ class TTS():
self.output_sample_rate = 24000 self.output_sample_rate = 24000
if ar_ckpt and nar_ckpt: if ar_ckpt and nar_ckpt:
self.load_ar( ar_ckpt ) self.ar_ckpt = ar_ckpt
self.load_nar( nar_ckpt ) self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get())
for name, model in models.items():
if name.startswith("ar"):
self.ar = model.to(self.device, dtype=torch.float32)
self.ar.load_state_dict(torch.load(self.ar_ckpt)['module'])
elif name.startswith("nar"):
self.nar = model.to(self.device, dtype=torch.float32)
self.nar.load_state_dict(torch.load(self.nar_ckpt)['module'])
else: else:
self.load_models( config ) self.load_models( config )
self.symmap = get_phone_symmap()
self.ar.eval() self.ar.eval()
self.nar.eval() self.nar.eval()
@ -39,24 +51,13 @@ class TTS():
for name in models: for name in models:
model = models[name] model = models[name]
if name[:2] == "ar": if name[:2] == "ar":
self.ar = model.to(self.device) self.ar = model.to(self.device, dtype=torch.float32)
self.symmap = self.ar.phone_symmap self.symmap = self.ar.phone_symmap
elif name[:3] == "nar": elif name[:3] == "nar":
self.nar = model.to(self.device) self.nar = model.to(self.device, dtype=torch.float32)
else: else:
print("Unknown:", name) print("Unknown:", name)
def load_ar( self, ckpt ):
self.ar_ckpt = ckpt
self.ar = torch.load(self.ar_ckpt).to(self.device)
self.symmap = self.ar.phone_symmap
def load_nar( self, ckpt ):
self.nar_ckpt = nar_ckpt
self.nar = torch.load(self.nar_ckpt).to(self.device)
def encode_text( self, text, lang_marker="en" ): def encode_text( self, text, lang_marker="en" ):
text = g2p.encode(text) text = g2p.encode(text)
phones = [f"<{lang_marker}>"] + [ " " if not p else p for p in text ] + [f"</{lang_marker}>"] phones = [f"<{lang_marker}>"] + [ " " if not p else p for p in text ] + [f"</{lang_marker}>"]
@ -70,7 +71,7 @@ class TTS():
def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ): def inference( self, text, reference, mode="both", max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path="./.tmp.wav" ):
prom = self.encode_audio( reference ) prom = self.encode_audio( reference )
phns = self.encode_text(text) phns = self.encode_text( text )
prom = to_device(prom, self.device).to(torch.int16) prom = to_device(prom, self.device).to(torch.int16)
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16) phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)

View File

@ -114,7 +114,7 @@ def example_usage():
from ..engines import Engine from ..engines import Engine
from tqdm import tqdm from tqdm import tqdm
device = "cpu" device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=2) x8 = partial(repeat, pattern="t -> t l", l=2)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):

View File

@ -272,6 +272,7 @@ class Base(nn.Module):
""" """
batch_size = len(text_list) batch_size = len(text_list)
x_list = self._samplewise_merge_tensors( x_list = self._samplewise_merge_tensors(
self.text_emb(text_list), self.text_emb(text_list),
self.proms_emb(proms_list), self.proms_emb(proms_list),
@ -281,6 +282,7 @@ class Base(nn.Module):
x, m = list_to_tensor(x_list) x, m = list_to_tensor(x_list)
if self.arch_type == "transformer": if self.arch_type == "transformer":
x = self.sin_emb.add_pe(x) x = self.sin_emb.add_pe(x)
for block in self.blocks: for block in self.blocks:

View File

@ -19,7 +19,7 @@ from collections import defaultdict
from tqdm import tqdm from tqdm import tqdm
mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cuda") mel_stft_loss = auraloss.freq.MelSTFTLoss(24_000, device="cpu")
def center_crop(x, len): def center_crop(x, len):
start = (x.shape[-1] - len) // 2 start = (x.shape[-1] - len) // 2
@ -89,10 +89,10 @@ def run_eval(engines, eval_name, dl):
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length] ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length] hyp_audio = hyp_audio[..., 0:min_length]
try: try:
stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item()) stats['loss'].append(mel_stft_loss(hyp_audio, ref_audio).item())
except Exception as e: except Exception as e:
stats['loss'].append(0)
print(str(e)) print(str(e))
for batch in tqdm(dl): for batch in tqdm(dl):

View File

@ -80,7 +80,7 @@ def load_engines():
if cfg.trainer.load_state_dict: if cfg.trainer.load_state_dict:
load_path = cfg.ckpt_dir / name / "fp32.pth" load_path = cfg.ckpt_dir / name / "fp32.pth"
model.load_state_dict(torch.load(load_path)) model.load_state_dict(torch.load(load_path)['module'])
engines[name] = Engine( engines[name] = Engine(
model=model, model=model,