tweaks, including exporting on save/quit

This commit is contained in:
mrq 2023-08-23 16:43:03 -05:00
parent d106598403
commit 4585824cd3
5 changed files with 42 additions and 7 deletions

View File

@ -372,6 +372,10 @@ class Trainer:
save_on_oom: bool = True
save_on_quit: bool = True
export_on_save: bool = False
export_on_quit: bool = False
save_frequency: int = 100
keep_last_checkpoints: int = 0

View File

@ -10,7 +10,7 @@ import random
import torch
from .config import cfg
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
from collections import defaultdict
from functools import cache, cached_property
@ -253,7 +253,7 @@ class Dataset(_Dataset):
qnt = _load_quants(path)
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
qnt = trim_random( qnt, trim_length )
qnt = trim( qnt, trim_length )
prom_list.append(qnt)
prom_length += qnt.shape[0]
@ -264,7 +264,7 @@ class Dataset(_Dataset):
prom = torch.cat(prom_list)
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
prom = trim_random( prom, trim_length )
prom = trim( prom, trim_length )
return prom

View File

@ -180,7 +180,19 @@ def encode_from_file(path, device="cuda"):
# Helper Functions
# trims from the start, up to `target`
def trim( qnt, target ):
length = max( qnt.shape[0], qnt.shape[1] )
start = 0
end = start + target
if end >= length:
start = length - target
end = length
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
# trims a random piece of audio, up to `target`
# to-do: try and align to EnCodec window
def trim_random( qnt, target ):
length = max( qnt.shape[0], qnt.shape[1] )
start = int(length * random.random())

View File

@ -7,7 +7,7 @@ from einops import rearrange
from pathlib import Path
from .emb import g2p, qnt
from .emb.qnt import trim_random
from .emb.qnt import trim, trim_random
from .utils import to_device
from .config import cfg
@ -25,12 +25,14 @@ class TTS():
if config:
cfg.load_yaml( config )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
try:
cfg.format()
except Exception as e:
pass
self.symmap = None
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt
@ -40,6 +42,8 @@ class TTS():
if name.startswith("ar"):
self.ar = model
state = torch.load(self.ar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
@ -47,6 +51,8 @@ class TTS():
elif name.startswith("nar"):
self.nar = model
state = torch.load(self.nar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
self.nar.load_state_dict(state)
@ -54,7 +60,9 @@ class TTS():
else:
self.load_models()
self.symmap = get_phone_symmap()
if self.symmap is None:
self.symmap = get_phone_symmap()
self.ar.eval()
self.nar.eval()
@ -78,7 +86,7 @@ class TTS():
phones = [ " " if not p else p for p in content ]
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
def encode_audio( self, paths, trim=True ):
def encode_audio( self, paths, should_trim=True ):
# already a tensor, return it
if isinstance( paths, Tensor ):
return paths
@ -90,7 +98,7 @@ class TTS():
# merge inputs
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
if trim:
if should_trim:
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
return res

View File

@ -324,14 +324,24 @@ def train(
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
saving_commands = ["save"]
export_commands = ["export"]
if cfg.trainer.save_on_quit:
saving_commands.append("quit")
if cfg.trainer.export_on_quit:
export_commands.append("quit")
if cfg.trainer.export_on_save:
export_commands.append("save")
if engines.global_step != last_save_step:
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
engines.save_checkpoint()
last_save_step = engines.global_step
if command in export_commands and is_global_leader():
engines.export(userdata={"symmap": get_phone_symmap()})
if engines.global_step != last_eval_step:
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
@ -343,4 +353,5 @@ def train(
last_eval_step = engines.global_step
if command in ["quit"]:
if cfg.export_on_quit:
return