cleanups and fixes, fix DLAS throwing errors from '''too short of sound files''' by just culling them during transcription
This commit is contained in:
parent
7f2da0f5fb
commit
2feb6da0c0
|
@ -7,8 +7,6 @@ checkpointing_enabled: true
|
||||||
fp16: ${half_p}
|
fp16: ${half_p}
|
||||||
bitsandbytes: ${bitsandbytes}
|
bitsandbytes: ${bitsandbytes}
|
||||||
gpus: ${gpus}
|
gpus: ${gpus}
|
||||||
wandb: false
|
|
||||||
use_tb_logger: true
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
train:
|
train:
|
||||||
|
@ -135,8 +133,7 @@ eval:
|
||||||
output_state: gen
|
output_state: gen
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
print_freq: ${print_rate}
|
|
||||||
save_checkpoint_freq: ${save_rate}
|
save_checkpoint_freq: ${save_rate}
|
||||||
visuals: [gen, mel]
|
visuals: [gen, mel]
|
||||||
visual_debug_rate: ${print_rate}
|
visual_debug_rate: ${save_rate}
|
||||||
is_mel_spectrogram: true
|
is_mel_spectrogram: true
|
|
@ -1 +1 @@
|
||||||
Subproject commit bf94744514e0628c6e6ba21eda76d1fd71fb1252
|
Subproject commit 802c162ce816ac9e824bd82f64f6282019ae15d5
|
197
src/utils.py
197
src/utils.py
|
@ -609,20 +609,10 @@ class TrainingState():
|
||||||
self.open_state = False
|
self.open_state = False
|
||||||
self.training_started = False
|
self.training_started = False
|
||||||
|
|
||||||
self.info = {}
|
self.info = {}
|
||||||
|
|
||||||
self.epoch_rate = ""
|
|
||||||
self.epoch_time_start = 0
|
|
||||||
self.epoch_time_end = 0
|
|
||||||
self.epoch_time_deltas = 0
|
|
||||||
self.epoch_taken = 0
|
|
||||||
|
|
||||||
self.it_rate = ""
|
self.it_rate = ""
|
||||||
self.it_time_start = 0
|
self.it_rates = 0
|
||||||
self.it_time_end = 0
|
|
||||||
self.it_time_deltas = 0
|
|
||||||
self.it_taken = 0
|
|
||||||
self.last_step = 0
|
|
||||||
|
|
||||||
self.eta = "?"
|
self.eta = "?"
|
||||||
self.eta_hhmmss = "?"
|
self.eta_hhmmss = "?"
|
||||||
|
@ -655,11 +645,63 @@ class TrainingState():
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
|
||||||
|
def parse_metrics(self, data):
|
||||||
|
if isinstance(data, str):
|
||||||
|
if line.find('INFO: Training Metrics:') >= 0:
|
||||||
|
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||||
|
data['mode'] = "training"
|
||||||
|
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||||
|
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||||
|
data['mode'] = "validation"
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.info = data
|
||||||
|
if 'epoch' in self.info:
|
||||||
|
self.epoch = int(self.info['epoch'])
|
||||||
|
if 'it' in self.info:
|
||||||
|
self.it = int(self.info['it'])
|
||||||
|
if 'step' in self.info:
|
||||||
|
self.step = int(self.info['step'])
|
||||||
|
if 'steps' in self.info:
|
||||||
|
self.steps = int(self.info['steps'])
|
||||||
|
|
||||||
|
if 'iteration_rate' in self.info:
|
||||||
|
it_rate = self.info['iteration_rate']
|
||||||
|
self.it_rate = f'{"{:.3f}".format(1/it_rate)}it/s' if 0 < it_rate and it_rate < 1 else f'{"{:.3f}".format(it_rate)}s/it'
|
||||||
|
self.it_rates += it_rate
|
||||||
|
|
||||||
|
self.eta = (self.its - self.it) * (self.it_rates / self.its)
|
||||||
|
try:
|
||||||
|
eta = str(timedelta(seconds=int(self.eta)))
|
||||||
|
self.eta_hhmmss = eta
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
||||||
|
if self.epochs != self.its:
|
||||||
|
self.metrics['step'].append(f"{self.it}/{self.its}")
|
||||||
|
if self.steps > 1:
|
||||||
|
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
||||||
|
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||||
|
|
||||||
|
if 'lr' in self.info:
|
||||||
|
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
|
||||||
|
|
||||||
|
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
||||||
|
if k not in self.info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.statistics['loss'].append({'step': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||||
|
if k == "loss_gpt_total":
|
||||||
|
self.losses.append( self.statistics['loss'][-1] )
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def load_statistics(self, update=False):
|
def load_statistics(self, update=False):
|
||||||
if not os.path.isdir(f'{self.dataset_dir}/'):
|
if not os.path.isdir(f'{self.dataset_dir}/'):
|
||||||
return
|
return
|
||||||
|
|
||||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total', 'val_loss_text_ce', 'val_loss_mel_ce', 'learning_rate_gpt_0']
|
|
||||||
infos = {}
|
infos = {}
|
||||||
highest_step = self.last_info_check_at
|
highest_step = self.last_info_check_at
|
||||||
|
|
||||||
|
@ -677,34 +719,23 @@ class TrainingState():
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line.find('INFO: Training Metrics:') >= 0:
|
if line.find('INFO: Training Metrics:') >= 0:
|
||||||
data = line.split("INFO: Training Metrics:")[-1]
|
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||||
info = json.loads(data)
|
data['mode'] = "training"
|
||||||
|
|
||||||
step = info['it']
|
|
||||||
if update and step <= self.last_info_check_at:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if 'lr' in info:
|
|
||||||
self.statistics['lr'].append({'step': step, 'value': info['lr'], 'type': 'learning_rate_gpt_0'})
|
|
||||||
if 'loss_text_ce' in info:
|
|
||||||
self.statistics['loss'].append({'step': step, 'value': info['loss_text_ce'], 'type': 'loss_text_ce'})
|
|
||||||
if 'loss_mel_ce' in info:
|
|
||||||
self.statistics['loss'].append({'step': step, 'value': info['loss_mel_ce'], 'type': 'loss_mel_ce'})
|
|
||||||
if 'loss_gpt_total' in info:
|
|
||||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'loss_gpt_total'})
|
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
|
||||||
|
|
||||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||||
data = line.split("INFO: Validation Metrics:")[-1]
|
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||||
|
data['mode'] = "validation"
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
step = info['it']
|
if "it" not in data:
|
||||||
if update and step <= self.last_info_check_at:
|
continue
|
||||||
continue
|
|
||||||
|
|
||||||
if 'loss_text_ce' in info:
|
step = data['it']
|
||||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
|
|
||||||
if 'loss_mel_ce' in info:
|
if update and step <= self.last_info_check_at:
|
||||||
self.statistics['loss'].append({'step': step, 'value': info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
|
continue
|
||||||
|
|
||||||
|
self.parse_metrics(data)
|
||||||
|
|
||||||
self.last_info_check_at = highest_step
|
self.last_info_check_at = highest_step
|
||||||
|
|
||||||
|
@ -741,10 +772,7 @@ class TrainingState():
|
||||||
# rip out iteration info
|
# rip out iteration info
|
||||||
elif not self.training_started:
|
elif not self.training_started:
|
||||||
if line.find('Start training from epoch') >= 0:
|
if line.find('Start training from epoch') >= 0:
|
||||||
self.it_time_start = time.time()
|
|
||||||
self.epoch_time_start = time.time()
|
|
||||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||||
should_return = True
|
|
||||||
|
|
||||||
match = re.findall(r'epoch: ([\d,]+)', line)
|
match = re.findall(r'epoch: ([\d,]+)', line)
|
||||||
if match and len(match) > 0:
|
if match and len(match) > 0:
|
||||||
|
@ -754,90 +782,51 @@ class TrainingState():
|
||||||
self.it = int(match[0].replace(",", ""))
|
self.it = int(match[0].replace(",", ""))
|
||||||
|
|
||||||
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
|
self.checkpoints = int((self.its - self.it) / self.config['logger']['save_checkpoint_freq'])
|
||||||
|
|
||||||
|
should_return = True
|
||||||
else:
|
else:
|
||||||
lapsed = False
|
|
||||||
message = None
|
message = None
|
||||||
|
data = None
|
||||||
|
|
||||||
# INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
|
# INFO: Training Metrics: {"loss_text_ce": 4.308311939239502, "loss_mel_ce": 2.1610655784606934, "loss_gpt_total": 2.204148769378662, "lr": 0.0001, "it": 2, "step": 1, "steps": 1, "epoch": 1, "iteration_rate": 0.10700102965037028}
|
||||||
if line.find('INFO: Training Metrics:') >= 0:
|
if line.find('INFO: Training Metrics:') >= 0:
|
||||||
data = line.split("INFO: Training Metrics:")[-1]
|
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||||
self.info = json.loads(data)
|
data['mode'] = "training"
|
||||||
|
|
||||||
if 'epoch' in self.info:
|
|
||||||
self.epoch = int(self.info['epoch'])
|
|
||||||
if 'it' in self.info:
|
|
||||||
self.it = int(self.info['it'])
|
|
||||||
if 'step' in self.info:
|
|
||||||
self.step = int(self.info['step'])
|
|
||||||
if 'steps' in self.info:
|
|
||||||
self.steps = int(self.info['steps'])
|
|
||||||
|
|
||||||
if self.step == self.steps:
|
|
||||||
lapsed = True
|
|
||||||
|
|
||||||
if 'lr' in self.info:
|
|
||||||
self.statistics['lr'].append({'step': self.it, 'value': self.info['lr'], 'type': 'learning_rate_gpt_0'})
|
|
||||||
if 'loss_text_ce' in self.info:
|
|
||||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_text_ce'], 'type': 'loss_text_ce'})
|
|
||||||
if 'loss_mel_ce' in self.info:
|
|
||||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_mel_ce'], 'type': 'loss_mel_ce'})
|
|
||||||
if 'loss_gpt_total' in self.info:
|
|
||||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'loss_gpt_total'})
|
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
|
||||||
|
|
||||||
if 'iteration_rate' in self.info:
|
|
||||||
it_rate = self.info['iteration_rate']
|
|
||||||
self.it_rate = f'{"{:.3f}".format(it_rate)}s/it' if it_rate >= 1 or it_rate == 0 else f'{"{:.3f}".format(1/it_rate)}it/s'
|
|
||||||
|
|
||||||
self.metrics['step'] = [f"{self.epoch}/{self.epochs}"]
|
|
||||||
if self.epochs != self.its:
|
|
||||||
self.metrics['step'].append(f"{self.it}/{self.its}")
|
|
||||||
if self.steps > 1:
|
|
||||||
self.metrics['step'].append(f"{self.step}/{self.steps}")
|
|
||||||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
|
||||||
|
|
||||||
should_return = True
|
|
||||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||||
data = line.split("INFO: Validation Metrics:")[-1]
|
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||||
|
data['mode'] = "validation"
|
||||||
|
|
||||||
if 'loss_text_ce' in self.info:
|
if data is not None:
|
||||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_text_ce'})
|
self.parse_metrics( data )
|
||||||
if 'loss_mel_ce' in self.info:
|
|
||||||
self.statistics['loss'].append({'step': self.it, 'value': self.info['loss_gpt_total'], 'type': 'val_loss_mel_ce'})
|
|
||||||
should_return = True
|
should_return = True
|
||||||
|
|
||||||
if lapsed:
|
if ': nan' in line and not self.nan_detected:
|
||||||
|
self.nan_detected = self.it
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.step == self.steps and self.steps > 0:
|
||||||
self.epoch_time_end = time.time()
|
self.epoch_time_end = time.time()
|
||||||
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
|
self.epoch_time_delta = self.epoch_time_end-self.epoch_time_start
|
||||||
self.epoch_time_start = time.time()
|
self.epoch_time_start = time.time()
|
||||||
try:
|
try:
|
||||||
self.epoch_rate = f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch' if self.epoch_time_delta >= 1 or self.epoch_time_delta == 0 else f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' # I doubt anyone will have it/s rates, but its here
|
self.epoch_rate = f'{"{:.3f}".format(1/self.epoch_time_delta)}epoch/s' if 0 < self.epoch_time_delta and self.epoch_time_delta < 1 else f'{"{:.3f}".format(self.epoch_time_delta)}s/epoch'
|
||||||
except Exception as e:
|
|
||||||
pass
|
|
||||||
|
|
||||||
#self.eta = (self.epochs - self.epoch) * self.epoch_time_delta
|
|
||||||
self.epoch_time_deltas = self.epoch_time_deltas + self.epoch_time_delta
|
|
||||||
self.epoch_taken = self.epoch_taken + 1
|
|
||||||
self.eta = (self.epochs - self.epoch) * (self.epoch_time_deltas / self.epoch_taken)
|
|
||||||
try:
|
|
||||||
eta = str(timedelta(seconds=int(self.eta)))
|
|
||||||
self.eta_hhmmss = eta
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
self.metrics['rate'] = []
|
self.metrics['rate'] = []
|
||||||
|
"""
|
||||||
if self.epoch_rate:
|
if self.epoch_rate:
|
||||||
self.metrics['rate'].append(self.epoch_rate)
|
self.metrics['rate'].append(self.epoch_rate)
|
||||||
if self.it_rate and self.epoch_rate != self.it_rate:
|
if self.it_rate and self.epoch_rate != self.it_rate:
|
||||||
|
"""
|
||||||
|
if self.it_rate:
|
||||||
self.metrics['rate'].append(self.it_rate)
|
self.metrics['rate'].append(self.it_rate)
|
||||||
self.metrics['rate'] = ", ".join(self.metrics['rate'])
|
self.metrics['rate'] = ", ".join(self.metrics['rate'])
|
||||||
|
|
||||||
eta_hhmmss = "?"
|
eta_hhmmss = self.eta_hhmmss if self.eta_hhmmss else "?"
|
||||||
if self.eta_hhmmss:
|
|
||||||
eta_hhmmss = self.eta_hhmmss
|
|
||||||
|
|
||||||
self.metrics['loss'] = []
|
|
||||||
|
|
||||||
|
self.metrics['loss'] = []
|
||||||
if 'lr' in self.info:
|
if 'lr' in self.info:
|
||||||
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
|
self.metrics['loss'].append(f'LR: {"{:.3e}".format(self.info["lr"])}')
|
||||||
|
|
||||||
|
@ -1126,7 +1115,11 @@ def prepare_dataset( files, outdir, language=None, skip_existings=False, progres
|
||||||
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
sliced_name = basename.replace(".wav", f"_{pad(idx, 4)}.wav")
|
||||||
|
|
||||||
if not torch.any(sliced_waveform < 0):
|
if not torch.any(sliced_waveform < 0):
|
||||||
print(f"Error with {sliced_name}, skipping...")
|
print(f"Sound file is silent: {sliced_name}, skipping...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if sliced_waveform.shape[-1] < (.6 * sampling_rate):
|
||||||
|
print(f"Sound file is too short: {sliced_name}, skipping...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
torchaudio.save(f"{outdir}/audio/{sliced_name}", sliced_waveform, sampling_rate)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user