tweaks, including exporting on save/quit
This commit is contained in:
parent
d106598403
commit
4585824cd3
|
@ -372,6 +372,10 @@ class Trainer:
|
|||
|
||||
save_on_oom: bool = True
|
||||
save_on_quit: bool = True
|
||||
|
||||
export_on_save: bool = False
|
||||
export_on_quit: bool = False
|
||||
|
||||
save_frequency: int = 100
|
||||
|
||||
keep_last_checkpoints: int = 0
|
||||
|
|
|
@ -10,7 +10,7 @@ import random
|
|||
import torch
|
||||
|
||||
from .config import cfg
|
||||
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -253,7 +253,7 @@ class Dataset(_Dataset):
|
|||
qnt = _load_quants(path)
|
||||
|
||||
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
|
||||
qnt = trim_random( qnt, trim_length )
|
||||
qnt = trim( qnt, trim_length )
|
||||
|
||||
prom_list.append(qnt)
|
||||
prom_length += qnt.shape[0]
|
||||
|
@ -264,7 +264,7 @@ class Dataset(_Dataset):
|
|||
prom = torch.cat(prom_list)
|
||||
|
||||
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
|
||||
prom = trim_random( prom, trim_length )
|
||||
prom = trim( prom, trim_length )
|
||||
|
||||
return prom
|
||||
|
||||
|
|
|
@ -180,7 +180,19 @@ def encode_from_file(path, device="cuda"):
|
|||
|
||||
# Helper Functions
|
||||
|
||||
# trims from the start, up to `target`
|
||||
def trim( qnt, target ):
|
||||
length = max( qnt.shape[0], qnt.shape[1] )
|
||||
start = 0
|
||||
end = start + target
|
||||
if end >= length:
|
||||
start = length - target
|
||||
end = length
|
||||
|
||||
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
||||
|
||||
# trims a random piece of audio, up to `target`
|
||||
# to-do: try and align to EnCodec window
|
||||
def trim_random( qnt, target ):
|
||||
length = max( qnt.shape[0], qnt.shape[1] )
|
||||
start = int(length * random.random())
|
||||
|
|
|
@ -7,7 +7,7 @@ from einops import rearrange
|
|||
from pathlib import Path
|
||||
|
||||
from .emb import g2p, qnt
|
||||
from .emb.qnt import trim_random
|
||||
from .emb.qnt import trim, trim_random
|
||||
from .utils import to_device
|
||||
|
||||
from .config import cfg
|
||||
|
@ -25,12 +25,14 @@ class TTS():
|
|||
|
||||
if config:
|
||||
cfg.load_yaml( config )
|
||||
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
||||
|
||||
try:
|
||||
cfg.format()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
self.symmap = None
|
||||
if ar_ckpt and nar_ckpt:
|
||||
self.ar_ckpt = ar_ckpt
|
||||
self.nar_ckpt = nar_ckpt
|
||||
|
@ -40,6 +42,8 @@ class TTS():
|
|||
if name.startswith("ar"):
|
||||
self.ar = model
|
||||
state = torch.load(self.ar_ckpt)
|
||||
if "symmap" in state:
|
||||
self.symmap = state['symmap']
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
|
@ -47,6 +51,8 @@ class TTS():
|
|||
elif name.startswith("nar"):
|
||||
self.nar = model
|
||||
state = torch.load(self.nar_ckpt)
|
||||
if "symmap" in state:
|
||||
self.symmap = state['symmap']
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.nar.load_state_dict(state)
|
||||
|
@ -54,7 +60,9 @@ class TTS():
|
|||
else:
|
||||
self.load_models()
|
||||
|
||||
self.symmap = get_phone_symmap()
|
||||
if self.symmap is None:
|
||||
self.symmap = get_phone_symmap()
|
||||
|
||||
self.ar.eval()
|
||||
self.nar.eval()
|
||||
|
||||
|
@ -78,7 +86,7 @@ class TTS():
|
|||
phones = [ " " if not p else p for p in content ]
|
||||
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
||||
|
||||
def encode_audio( self, paths, trim=True ):
|
||||
def encode_audio( self, paths, should_trim=True ):
|
||||
# already a tensor, return it
|
||||
if isinstance( paths, Tensor ):
|
||||
return paths
|
||||
|
@ -90,7 +98,7 @@ class TTS():
|
|||
# merge inputs
|
||||
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
|
||||
|
||||
if trim:
|
||||
if should_trim:
|
||||
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
|
||||
|
||||
return res
|
||||
|
|
|
@ -324,14 +324,24 @@ def train(
|
|||
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||
|
||||
saving_commands = ["save"]
|
||||
export_commands = ["export"]
|
||||
|
||||
if cfg.trainer.save_on_quit:
|
||||
saving_commands.append("quit")
|
||||
|
||||
if cfg.trainer.export_on_quit:
|
||||
export_commands.append("quit")
|
||||
|
||||
if cfg.trainer.export_on_save:
|
||||
export_commands.append("save")
|
||||
|
||||
if engines.global_step != last_save_step:
|
||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||
engines.save_checkpoint()
|
||||
last_save_step = engines.global_step
|
||||
|
||||
if command in export_commands and is_global_leader():
|
||||
engines.export(userdata={"symmap": get_phone_symmap()})
|
||||
|
||||
if engines.global_step != last_eval_step:
|
||||
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
||||
|
@ -343,4 +353,5 @@ def train(
|
|||
last_eval_step = engines.global_step
|
||||
|
||||
if command in ["quit"]:
|
||||
if cfg.export_on_quit:
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue
Block a user