forked from mrq/ai-voice-cloning
rewrote how AIVC gets training metrics (need to clean up later)
This commit is contained in:
parent
df0edacc60
commit
7f2da0f5fb
|
@ -18,6 +18,7 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--mode', type=str, default='none', help='mode')
|
||||
args = parser.parse_args()
|
||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||
|
||||
|
@ -77,7 +78,7 @@ def train(yaml, launcher='none'):
|
|||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
||||
trainer.init(yaml, opt, launcher)
|
||||
trainer.init(yaml, opt, launcher, '')
|
||||
trainer.do_training()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
195
src/utils.py
195
src/utils.py
|
@ -594,6 +594,9 @@ class TrainingState():
|
|||
|
||||
self.it = 0
|
||||
self.its = self.config['train']['niter']
|
||||
|
||||
self.step = 0
|
||||
self.steps = 1
|
||||
|
||||
self.epoch = 0
|
||||
self.epochs = int(self.its*self.batch_size/self.dataset_size)
|
||||
|
@ -653,13 +656,8 @@ class TrainingState():
|
|||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
||||
def load_statistics(self, update=False):
|
||||
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
||||
if not os.path.isdir(f'{self.dataset_dir}/'):
|
||||
return
|
||||
try:
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
use_tensorboard = True
|
||||
except Exception as e:
|
||||
use_tensorboard = False
|
||||
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
|
||||
infos = {}
|
||||
|
@ -669,32 +667,44 @@ class TrainingState():
|
|||
self.statistics['loss'] = []
|
||||
self.statistics['lr'] = []
|
||||
|
||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
||||
for log in logs:
|
||||
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
|
||||
ea.Reload()
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
scalars = ea.Tags()['scalars']
|
||||
for line in lines:
|
||||
if line.find('INFO: Training Metrics:') >= 0:
|
||||
data = line.split("INFO: Training Metrics:")[-1]
|
||||
info = json.loads(data)
|
||||
|
||||
for key in keys:
|
||||
if key not in scalars:
|
||||
step = info['it']
|
||||
if update and step <= self.last_info_check_at:
|
||||
continue
|
||||
|
||||
try:
|
||||
scalar = ea.Scalars(key)
|
||||
for s in scalar:
|
||||
if update and s.step <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
target = 'lr' if key == "learning_rate_gpt_0" else 'loss'
|
||||
self.statistics[target].append( { "step": s.step, "value": s.value, "type": key } )
|
||||
if key == 'loss_gpt_total':
|
||||
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
||||
except Exception as e:
|
||||
pass
|
||||
if 'lr' in info:
|
||||
self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'})
|
||||
if 'loss_text_ce' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'})
|
||||
if 'loss_mel_ce' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'})
|
||||
if 'loss_gpt_total' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'})
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||
data = line.split("INFO: Validation Metrics:")[-1]
|
||||
|
||||
step = info['it']
|
||||
if update and step <= self.last_info_check_at:
|
||||
continue
|
||||
|
||||
if 'loss_text_ce' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
|
||||
if 'loss_mel_ce' in info:
|
||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
|
||||
|
||||
self.last_info_check_at = highest_step
|
||||
|
||||
|
@ -707,9 +717,8 @@ class TrainingState():
|
|||
|
||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
||||
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
||||
|
||||
remove_models = models[:-keep]
|
||||
remove_states = states[:-keep]
|
||||
remove_models = models[:-2]
|
||||
remove_states = states[:-2]
|
||||
|
||||
for d in remove_models:
|
||||
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
|
||||
|
@ -727,8 +736,10 @@ class TrainingState():
|
|||
percent = 0
|
||||
message = None
|
||||
|
||||
if line.find('Finished training') >= 0:
|
||||
self.killed = True
|
||||
# rip out iteration info
|
||||
if not self.training_started:
|
||||
elif not self.training_started:
|
||||
if line.find('Start training from epoch') >= 0:
|
||||
self.it_time_start = time.time()
|
||||
self.epoch_time_start = time.time()
|
||||
|
@ -745,83 +756,57 @@ class TrainingState():
|
|||
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
|
||||
else:
|
||||
lapsed = False
|
||||
|
||||
message = None
|
||||
if line.find('INFO: [epoch:') >= 0:
|
||||
info_line = line.split("INFO:")[-1]
|
||||
# to-do, actually validate this works, and probably kill training when it's found, the model's dead by this point
|
||||
if ': nan' in info_line and not self.nan_detected:
|
||||
self.nan_detected = self.it
|
||||
|
||||
# easily rip out our stats...
|
||||
match = re.findall(r'\b([a-z_0-9]+?)\b: *?([0-9]\.[0-9]+?e[+-]\d+|[\d,]+)\b', info_line)
|
||||
if match and len(match) > 0:
|
||||
for k, v in match:
|
||||
self.info[k] = float(v.replace(",", ""))
|
||||
|
||||
self.load_statistics(update=True)
|
||||
should_return = True
|
||||
# INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
|
||||
if line.find('INFO: Training Metrics:') >= 0:
|
||||
data = line.split("INFO: Training Metrics:")[-1]
|
||||
self.info = json.loads(data)
|
||||
|
||||
if 'epoch' in self.info:
|
||||
self.epoch = int(self.info['epoch'])
|
||||
if 'iter' in self.info:
|
||||
self.it = int(self.info['iter'])
|
||||
if 'it' in self.info:
|
||||
self.it = int(self.info['it'])
|
||||
if 'step' in self.info:
|
||||
self.step = int(self.info['step'])
|
||||
if 'steps' in self.info:
|
||||
self.steps = int(self.info['steps'])
|
||||
|
||||
elif line.find('Saving models and training states') >= 0:
|
||||
self.checkpoint = self.checkpoint + 1
|
||||
if self.step == self.steps:
|
||||
lapsed = True
|
||||
|
||||
percent = self.checkpoint / float(self.checkpoints)
|
||||
message = f'[{self.checkpoint}/{self.checkpoints}] Saving checkpoint...'
|
||||
if progress is not None:
|
||||
progress(percent, message)
|
||||
if 'lr' in self.info:
|
||||
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'})
|
||||
if 'loss_text_ce' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'})
|
||||
if 'loss_mel_ce' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'})
|
||||
if 'loss_gpt_total' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'})
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
if 'iteration_rate' in self.info:
|
||||
it_rate = self.info['iteration_rate']
|
||||
self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s'
|
||||
|
||||
print(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||
self.buffer.append(f'{"{:.3f}".format(percent*100)}% {message}')
|
||||
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
||||
if self.epochs != self.its:
|
||||
self.metrics['step'].append(f"{self.it}/{self.its}")
|
||||
if self.steps > 1:
|
||||
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
||||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||
|
||||
self.cleanup_old(keep=keep_x_past_checkpoints)
|
||||
should_return = True
|
||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||
data = line.split("INFO: Validation Metrics:")[-1]
|
||||
|
||||
if line.find('%|') > 0:
|
||||
match = re.findall(r'(\d+)%\|(.+?)\| (\d+|\?)\/(\d+|\?) \[(.+?)<(.+?), +(.+?)\]', line)
|
||||
if match and len(match) > 0:
|
||||
match = match[0]
|
||||
per_cent = int(match[0])/100.0
|
||||
progressbar = match[1]
|
||||
step = int(match[2])
|
||||
steps = int(match[3])
|
||||
elapsed = match[4]
|
||||
until = match[5]
|
||||
rate = match[6]
|
||||
|
||||
last_step = self.last_step
|
||||
self.last_step = step
|
||||
if last_step < step:
|
||||
self.it = self.it + (step - last_step)
|
||||
|
||||
if last_step == step and step == steps:
|
||||
lapsed = True
|
||||
|
||||
self.it_time_end = time.time()
|
||||
self.it_time_delta = self.it_time_end-self.it_time_start
|
||||
self.it_time_start = time.time()
|
||||
self.it_taken = self.it_taken + 1
|
||||
if self.it_time_delta:
|
||||
try:
|
||||
rate = f'{"{:.3f}".format(self.it_time_delta)}s/it' if self.it_time_delta >= 1 or self.it_time_delta == 0 else f'{"{:.3f}".format(1/self.it_time_delta)}it/s'
|
||||
self.it_rate = rate
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
||||
if self.epochs != self.its:
|
||||
self.metrics['step'].append(f"{self.it}/{self.its}")
|
||||
if steps > 1:
|
||||
self.metrics['step'].append(f"{step}/{steps}")
|
||||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||
if 'loss_text_ce' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
|
||||
if 'loss_mel_ce' in self.info:
|
||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
|
||||
should_return = True
|
||||
|
||||
if lapsed:
|
||||
self.epoch = self.epoch + 1
|
||||
self.it = int(self.epoch * (self.dataset_size / self.batch_size))
|
||||
|
||||
self.epoch_time_end = time.time()
|
||||
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
|
||||
self.epoch_time_start = time.time()
|
||||
|
@ -850,24 +835,16 @@ class TrainingState():
|
|||
eta_hhmmss = "?"
|
||||
if self.eta_hhmmss:
|
||||
eta_hhmmss = self.eta_hhmmss
|
||||
else:
|
||||
try:
|
||||
eta = (self.its - self.it) * (self.it_time_deltas / self.it_taken)
|
||||
eta = str(timedelta(seconds=int(eta)))
|
||||
eta_hhmmss = eta
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
self.metrics['loss'] = []
|
||||
|
||||
if 'learning_rate_gpt_0' in self.info:
|
||||
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["learning_rate_gpt_0"])}')
|
||||
if 'lr' in self.info:
|
||||
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
|
||||
|
||||
if len(self.losses) > 0:
|
||||
self.metrics['loss'].append(f'Loss: {"{:.3f}".format(self.losses[-1]["value"])}')
|
||||
|
||||
if len(self.losses) >= 2:
|
||||
# """riemann sum""" but not really as this is for derivatives and not integrals
|
||||
deriv = 0
|
||||
accum_length = len(self.losses)//2 # i *guess* this is fine when you think about it
|
||||
loss_value = self.losses[-1]["value"]
|
||||
|
@ -1296,10 +1273,6 @@ def optimize_training_settings( **kwargs ):
|
|||
|
||||
iterations = calc_iterations(epochs=settings['epochs'], lines=lines, batch_size=settings['batch_size'])
|
||||
|
||||
if settings['epochs'] < settings['print_rate']:
|
||||
settings['print_rate'] = settings['epochs']
|
||||
messages.append(f"Print rate is too small for the given iteration step, clamping print rate to: {settings['print_rate']}")
|
||||
|
||||
if settings['epochs'] < settings['save_rate']:
|
||||
settings['save_rate'] = settings['epochs']
|
||||
messages.append(f"Save rate is too small for the given iteration step, clamping save rate to: {settings['save_rate']}")
|
||||
|
@ -1355,14 +1328,11 @@ def save_training_settings( **kwargs ):
|
|||
|
||||
iterations_per_epoch = settings['iterations'] / settings['epochs']
|
||||
|
||||
settings['print_rate'] = int(settings['print_rate'] * iterations_per_epoch)
|
||||
settings['save_rate'] = int(settings['save_rate'] * iterations_per_epoch)
|
||||
settings['validation_rate'] = int(settings['validation_rate'] * iterations_per_epoch)
|
||||
|
||||
iterations_per_epoch = int(iterations_per_epoch)
|
||||
|
||||
if settings['print_rate'] < 1:
|
||||
settings['print_rate'] = 1
|
||||
if settings['save_rate'] < 1:
|
||||
settings['save_rate'] = 1
|
||||
if settings['validation_rate'] < 1:
|
||||
|
@ -1858,6 +1828,11 @@ def import_generate_settings(file="./config/generate.json"):
|
|||
res.update(settings)
|
||||
return res
|
||||
|
||||
def reset_generation_settings():
|
||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps({}, indent='\t') )
|
||||
return import_generate_settings()
|
||||
|
||||
def read_generate_settings(file, read_latents=True):
|
||||
j = None
|
||||
latents = None
|
||||
|
|
12
src/webui.py
12
src/webui.py
|
@ -152,14 +152,11 @@ def import_generate_settings_proxy( file=None ):
|
|||
res = []
|
||||
for k in GENERATE_SETTINGS_ARGS:
|
||||
res.append(settings[k] if k in settings else None)
|
||||
|
||||
print(GENERATE_SETTINGS_ARGS)
|
||||
print(settings)
|
||||
print(res)
|
||||
return tuple(res)
|
||||
|
||||
def reset_generation_settings_proxy():
|
||||
with open(f'./config/generate.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps({}, indent='\t') )
|
||||
return import_generate_settings_proxy()
|
||||
|
||||
def compute_latents_proxy(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
compute_latents( voice=voice, voice_latents_chunks=voice_latents_chunks, progress=progress )
|
||||
return voice
|
||||
|
@ -442,7 +439,6 @@ def setup_gradio():
|
|||
TRAINING_SETTINGS["batch_size"] = gr.Number(label="Batch Size", value=128, precision=0)
|
||||
TRAINING_SETTINGS["gradient_accumulation_size"] = gr.Number(label="Gradient Accumulation Size", value=4, precision=0)
|
||||
with gr.Row():
|
||||
TRAINING_SETTINGS["print_rate"] = gr.Number(label="Print Frequency (in epochs)", value=5, precision=0)
|
||||
TRAINING_SETTINGS["save_rate"] = gr.Number(label="Save Frequency (in epochs)", value=5, precision=0)
|
||||
TRAINING_SETTINGS["validation_rate"] = gr.Number(label="Validation Frequency (in epochs)", value=5, precision=0)
|
||||
|
||||
|
@ -665,7 +661,7 @@ def setup_gradio():
|
|||
)
|
||||
|
||||
reset_generation_settings_button.click(
|
||||
fn=reset_generation_settings_proxy,
|
||||
fn=reset_generation_settings,
|
||||
inputs=None,
|
||||
outputs=generate_settings
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user