tweaks, including exporting on save/quit
This commit is contained in:
parent
d106598403
commit
4585824cd3
|
@ -372,6 +372,10 @@ class Trainer:
|
||||||
|
|
||||||
save_on_oom: bool = True
|
save_on_oom: bool = True
|
||||||
save_on_quit: bool = True
|
save_on_quit: bool = True
|
||||||
|
|
||||||
|
export_on_save: bool = False
|
||||||
|
export_on_quit: bool = False
|
||||||
|
|
||||||
save_frequency: int = 100
|
save_frequency: int = 100
|
||||||
|
|
||||||
keep_last_checkpoints: int = 0
|
keep_last_checkpoints: int = 0
|
||||||
|
|
|
@ -10,7 +10,7 @@ import random
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cache, cached_property
|
from functools import cache, cached_property
|
||||||
|
@ -253,7 +253,7 @@ class Dataset(_Dataset):
|
||||||
qnt = _load_quants(path)
|
qnt = _load_quants(path)
|
||||||
|
|
||||||
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
|
if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]:
|
||||||
qnt = trim_random( qnt, trim_length )
|
qnt = trim( qnt, trim_length )
|
||||||
|
|
||||||
prom_list.append(qnt)
|
prom_list.append(qnt)
|
||||||
prom_length += qnt.shape[0]
|
prom_length += qnt.shape[0]
|
||||||
|
@ -264,7 +264,7 @@ class Dataset(_Dataset):
|
||||||
prom = torch.cat(prom_list)
|
prom = torch.cat(prom_list)
|
||||||
|
|
||||||
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
|
if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]:
|
||||||
prom = trim_random( prom, trim_length )
|
prom = trim( prom, trim_length )
|
||||||
|
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
|
|
|
@ -180,7 +180,19 @@ def encode_from_file(path, device="cuda"):
|
||||||
|
|
||||||
# Helper Functions
|
# Helper Functions
|
||||||
|
|
||||||
|
# trims from the start, up to `target`
|
||||||
|
def trim( qnt, target ):
|
||||||
|
length = max( qnt.shape[0], qnt.shape[1] )
|
||||||
|
start = 0
|
||||||
|
end = start + target
|
||||||
|
if end >= length:
|
||||||
|
start = length - target
|
||||||
|
end = length
|
||||||
|
|
||||||
|
return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end]
|
||||||
|
|
||||||
# trims a random piece of audio, up to `target`
|
# trims a random piece of audio, up to `target`
|
||||||
|
# to-do: try and align to EnCodec window
|
||||||
def trim_random( qnt, target ):
|
def trim_random( qnt, target ):
|
||||||
length = max( qnt.shape[0], qnt.shape[1] )
|
length = max( qnt.shape[0], qnt.shape[1] )
|
||||||
start = int(length * random.random())
|
start = int(length * random.random())
|
||||||
|
|
|
@ -7,7 +7,7 @@ from einops import rearrange
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .emb import g2p, qnt
|
from .emb import g2p, qnt
|
||||||
from .emb.qnt import trim_random
|
from .emb.qnt import trim, trim_random
|
||||||
from .utils import to_device
|
from .utils import to_device
|
||||||
|
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
|
@ -25,12 +25,14 @@ class TTS():
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cfg.format()
|
cfg.format()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
self.symmap = None
|
||||||
if ar_ckpt and nar_ckpt:
|
if ar_ckpt and nar_ckpt:
|
||||||
self.ar_ckpt = ar_ckpt
|
self.ar_ckpt = ar_ckpt
|
||||||
self.nar_ckpt = nar_ckpt
|
self.nar_ckpt = nar_ckpt
|
||||||
|
@ -40,6 +42,8 @@ class TTS():
|
||||||
if name.startswith("ar"):
|
if name.startswith("ar"):
|
||||||
self.ar = model
|
self.ar = model
|
||||||
state = torch.load(self.ar_ckpt)
|
state = torch.load(self.ar_ckpt)
|
||||||
|
if "symmap" in state:
|
||||||
|
self.symmap = state['symmap']
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state['module']
|
state = state['module']
|
||||||
self.ar.load_state_dict(state)
|
self.ar.load_state_dict(state)
|
||||||
|
@ -47,6 +51,8 @@ class TTS():
|
||||||
elif name.startswith("nar"):
|
elif name.startswith("nar"):
|
||||||
self.nar = model
|
self.nar = model
|
||||||
state = torch.load(self.nar_ckpt)
|
state = torch.load(self.nar_ckpt)
|
||||||
|
if "symmap" in state:
|
||||||
|
self.symmap = state['symmap']
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state['module']
|
state = state['module']
|
||||||
self.nar.load_state_dict(state)
|
self.nar.load_state_dict(state)
|
||||||
|
@ -54,7 +60,9 @@ class TTS():
|
||||||
else:
|
else:
|
||||||
self.load_models()
|
self.load_models()
|
||||||
|
|
||||||
|
if self.symmap is None:
|
||||||
self.symmap = get_phone_symmap()
|
self.symmap = get_phone_symmap()
|
||||||
|
|
||||||
self.ar.eval()
|
self.ar.eval()
|
||||||
self.nar.eval()
|
self.nar.eval()
|
||||||
|
|
||||||
|
@ -78,7 +86,7 @@ class TTS():
|
||||||
phones = [ " " if not p else p for p in content ]
|
phones = [ " " if not p else p for p in content ]
|
||||||
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ])
|
||||||
|
|
||||||
def encode_audio( self, paths, trim=True ):
|
def encode_audio( self, paths, should_trim=True ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
if isinstance( paths, Tensor ):
|
if isinstance( paths, Tensor ):
|
||||||
return paths
|
return paths
|
||||||
|
@ -90,7 +98,7 @@ class TTS():
|
||||||
# merge inputs
|
# merge inputs
|
||||||
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
|
res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths])
|
||||||
|
|
||||||
if trim:
|
if should_trim:
|
||||||
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
|
res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) )
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
|
@ -324,15 +324,25 @@ def train(
|
||||||
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency
|
||||||
|
|
||||||
saving_commands = ["save"]
|
saving_commands = ["save"]
|
||||||
|
export_commands = ["export"]
|
||||||
|
|
||||||
if cfg.trainer.save_on_quit:
|
if cfg.trainer.save_on_quit:
|
||||||
saving_commands.append("quit")
|
saving_commands.append("quit")
|
||||||
|
|
||||||
|
if cfg.trainer.export_on_quit:
|
||||||
|
export_commands.append("quit")
|
||||||
|
|
||||||
|
if cfg.trainer.export_on_save:
|
||||||
|
export_commands.append("save")
|
||||||
|
|
||||||
if engines.global_step != last_save_step:
|
if engines.global_step != last_save_step:
|
||||||
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
if engines.global_step % save_ckpt_every == 0 or command in saving_commands:
|
||||||
engines.save_checkpoint()
|
engines.save_checkpoint()
|
||||||
last_save_step = engines.global_step
|
last_save_step = engines.global_step
|
||||||
|
|
||||||
|
if command in export_commands and is_global_leader():
|
||||||
|
engines.export(userdata={"symmap": get_phone_symmap()})
|
||||||
|
|
||||||
if engines.global_step != last_eval_step:
|
if engines.global_step != last_eval_step:
|
||||||
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]:
|
||||||
do_gc()
|
do_gc()
|
||||||
|
@ -343,4 +353,5 @@ def train(
|
||||||
last_eval_step = engines.global_step
|
last_eval_step = engines.global_step
|
||||||
|
|
||||||
if command in ["quit"]:
|
if command in ["quit"]:
|
||||||
|
if cfg.export_on_quit:
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue
Block a user