tweaks, including exporting on save/quit

This commit is contained in:
mrq 2023-08-23 16:43:03 -05:00
parent d106598403
commit 4585824cd3
5 changed files with 42 additions and 7 deletions

View File

@ -372,6 +372,10 @@ class Trainer:
save_on_oom: bool = True save_on_oom: bool = True
save_on_quit: bool = True save_on_quit: bool = True
export_on_save: bool = False
export_on_quit: bool = False
save_frequency: int = 100 save_frequency: int = 100
keep_last_checkpoints: int = 0 keep_last_checkpoints: int = 0

View File

@ -10,7 +10,7 @@ import random
import torch import torch
from .config import cfg from .config import cfg
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
from collections import defaultdict from collections import defaultdict
from functools import cache, cached_property from functools import cache, cached_property
@ -253,7 +253,7 @@ class Dataset(_Dataset):
qnt = _load_quants(path) qnt = _load_quants(path)
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]: if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
qnt = trim_random( qnt, trim_length ) qnt = trim( qnt, trim_length )
prom_list.append(qnt) prom_list.append(qnt)
prom_length += qnt.shape[0] prom_length += qnt.shape[0]
@ -264,7 +264,7 @@ class Dataset(_Dataset):
prom = torch.cat(prom_list) prom = torch.cat(prom_list)
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]: if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
prom = trim_random( prom, trim_length ) prom = trim( prom, trim_length )
return prom return prom

View File

@ -180,7 +180,19 @@ def encode_from_file(path, device="cuda"):
# Helper Functions # Helper Functions
# trims from the start, up to `target`
def trim( qnt, target ):
length = max( qnt.shape[0], qnt.shape[1] )
start = 0
end = start + target
if end >= length:
start = length - target
end = length
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
# trims a random piece of audio, up to `target` # trims a random piece of audio, up to `target`
# to-do: try and align to EnCodec window
def trim_random( qnt, target ): def trim_random( qnt, target ):
length = max( qnt.shape[0], qnt.shape[1] ) length = max( qnt.shape[0], qnt.shape[1] )
start = int(length * random.random()) start = int(length * random.random())

View File

@ -7,7 +7,7 @@ from einops import rearrange
from pathlib import Path from pathlib import Path
from .emb import g2p, qnt from .emb import g2p, qnt
from .emb.qnt import trim_random from .emb.qnt import trim, trim_random
from .utils import to_device from .utils import to_device
from .config import cfg from .config import cfg
@ -25,12 +25,14 @@ class TTS():
if config: if config:
cfg.load_yaml( config ) cfg.load_yaml( config )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
try: try:
cfg.format() cfg.format()
except Exception as e: except Exception as e:
pass pass
self.symmap = None
if ar_ckpt and nar_ckpt: if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt self.nar_ckpt = nar_ckpt
@ -40,6 +42,8 @@ class TTS():
if name.startswith("ar"): if name.startswith("ar"):
self.ar = model self.ar = model
state = torch.load(self.ar_ckpt) state = torch.load(self.ar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state: if "module" in state:
state = state['module'] state = state['module']
self.ar.load_state_dict(state) self.ar.load_state_dict(state)
@ -47,6 +51,8 @@ class TTS():
elif name.startswith("nar"): elif name.startswith("nar"):
self.nar = model self.nar = model
state = torch.load(self.nar_ckpt) state = torch.load(self.nar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state: if "module" in state:
state = state['module'] state = state['module']
self.nar.load_state_dict(state) self.nar.load_state_dict(state)
@ -54,7 +60,9 @@ class TTS():
else: else:
self.load_models() self.load_models()
if self.symmap is None:
self.symmap = get_phone_symmap() self.symmap = get_phone_symmap()
self.ar.eval() self.ar.eval()
self.nar.eval() self.nar.eval()
@ -78,7 +86,7 @@ class TTS():
phones = [ " " if not p else p for p in content ] phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ]) return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
def encode_audio( self, paths, trim=True ): def encode_audio( self, paths, should_trim=True ):
# already a tensor, return it # already a tensor, return it
if isinstance( paths, Tensor ): if isinstance( paths, Tensor ):
return paths return paths
@ -90,7 +98,7 @@ class TTS():
# merge inputs # merge inputs
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths]) res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
if trim: if should_trim:
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) ) res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
return res return res

View File

@ -324,15 +324,25 @@ def train(
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"] saving_commands = ["save"]
export_commands = ["export"]
if cfg.trainer.save_on_quit: if cfg.trainer.save_on_quit:
saving_commands.append("quit") saving_commands.append("quit")
if cfg.trainer.export_on_quit:
export_commands.append("quit")
if cfg.trainer.export_on_save:
export_commands.append("save")
if engines.global_step != last_save_step: if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands: if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
engines.save_checkpoint() engines.save_checkpoint()
last_save_step = engines.global_step last_save_step = engines.global_step
if command in export_commands and is_global_leader():
engines.export(userdata={"symmap": get_phone_symmap()})
if engines.global_step != last_eval_step: if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
do_gc() do_gc()
@ -343,4 +353,5 @@ def train(
last_eval_step = engines.global_step last_eval_step = engines.global_step
if command in ["quit"]: if command in ["quit"]:
if cfg.export_on_quit:
return return