added sampling by speaker group name (might be better to de-emphasize the LibriVox/Audiobooks that are in large numbers, and emphasize the smaller pools), log cleanup
This commit is contained in:
parent
a539f6889f
commit
09cda7d3f9
96
scripts/parse_ppp.py
Normal file
96
scripts/parse_ppp.py
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from pathlib import Path
|
||||||
|
from vall_e.emb.g2p import encode as valle_phonemize
|
||||||
|
from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
target = "in"
|
||||||
|
|
||||||
|
audio_map = {}
|
||||||
|
text_map = {}
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
for season in os.listdir(f"./{target}/"):
|
||||||
|
if not os.path.isdir(f"./{target}/{season}/"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for episode in os.listdir(f"./{target}/{season}/"):
|
||||||
|
if not os.path.isdir(f"./{target}/{season}/{episode}/"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for filename in os.listdir(f"./{target}/{season}/{episode}/"):
|
||||||
|
path = f'./{target}/{season}/{episode}/{filename}'
|
||||||
|
attrs = filename.split("_")
|
||||||
|
timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s'
|
||||||
|
|
||||||
|
key = f'{episode}_{timestamp}'
|
||||||
|
|
||||||
|
if filename[-5:] == ".flac":
|
||||||
|
name = attrs[3]
|
||||||
|
emotion = attrs[4]
|
||||||
|
quality = attrs[5]
|
||||||
|
|
||||||
|
audio_map[key] = {
|
||||||
|
"path": path,
|
||||||
|
'episode': episode,
|
||||||
|
"name": name,
|
||||||
|
"emotion": emotion,
|
||||||
|
"quality": quality,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif filename[-4:] == ".txt":
|
||||||
|
text_map[key] = open(path, encoding="utf-8").read()
|
||||||
|
txts = {}
|
||||||
|
wavs = []
|
||||||
|
|
||||||
|
for key, entry in audio_map.items():
|
||||||
|
path = entry['path']
|
||||||
|
name = entry['name']
|
||||||
|
emotion = entry['emotion']
|
||||||
|
quality = entry['quality']
|
||||||
|
episode = entry['episode']
|
||||||
|
path = entry['path']
|
||||||
|
timestamp = entry['timestamp']
|
||||||
|
transcription = text_map[key]
|
||||||
|
if name not in data:
|
||||||
|
data[name] = {}
|
||||||
|
os.makedirs(f'./training/{name}/', exist_ok=True)
|
||||||
|
os.makedirs(f'./voices/{name}/', exist_ok=True)
|
||||||
|
|
||||||
|
key = f'{episode}_{timestamp}.flac'
|
||||||
|
os.rename(path, f'./voices/{name}/{key}')
|
||||||
|
|
||||||
|
data[name][key] = {
|
||||||
|
"segments": [],
|
||||||
|
"language": "en",
|
||||||
|
"text": transcription,
|
||||||
|
"misc": {
|
||||||
|
"emotion": emotion,
|
||||||
|
"quality": quality,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"episode": episode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
path = f'./voices/{name}/{key}'
|
||||||
|
txts[path] = transcription
|
||||||
|
wavs.append(Path(path))
|
||||||
|
|
||||||
|
for name in data.keys():
|
||||||
|
open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
|
||||||
|
|
||||||
|
for key, text in tqdm(txts.items(), desc="Phonemizing..."):
|
||||||
|
path = Path(key)
|
||||||
|
phones = valle_phonemize(text)
|
||||||
|
open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
|
||||||
|
|
||||||
|
for path in tqdm(wavs, desc="Quantizing..."):
|
||||||
|
qnt = valle_quantize(path, device=device)
|
||||||
|
torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))
|
|
@ -92,11 +92,12 @@ def _load_paths_from_metadata(data_dir, type="training", validate=False):
|
||||||
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
return cfg.dataset.min_duration <= duration and duration <= cfg.dataset.max_duration and cfg.dataset.min_phones <= phones and phones <= cfg.dataset.max_phones
|
||||||
|
|
||||||
metadata_path = data_dir / "metadata.json"
|
metadata_path = data_dir / "metadata.json"
|
||||||
if not cfg.dataset.use_metadata or not metadata_path.exists():
|
metadata = {}
|
||||||
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
|
if cfg.dataset.use_metadata and metadata_path.exists():
|
||||||
|
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
||||||
|
|
||||||
speaker = cfg.get_spkr( data_dir / "dummy" )
|
if len(metadata) == 0:
|
||||||
metadata = json.loads(open( metadata_path, "r", encoding="utf-8" ).read())
|
return _fn( data_dir, type if cfg.dataset.use_hdf5 else ".qnt.pt", validate )
|
||||||
|
|
||||||
def key( dir, id ):
|
def key( dir, id ):
|
||||||
if not cfg.dataset.use_hdf5:
|
if not cfg.dataset.use_hdf5:
|
||||||
|
@ -193,6 +194,7 @@ class Dataset(_Dataset):
|
||||||
if len(self.dataset) == 0:
|
if len(self.dataset) == 0:
|
||||||
self.dataset = cfg.dataset.training
|
self.dataset = cfg.dataset.training
|
||||||
|
|
||||||
|
# dict of paths keyed by speaker names
|
||||||
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
self.paths_by_spkr_name = _load_paths(self.dataset, self.dataset_type)
|
||||||
|
|
||||||
# cull speakers if they do not have enough utterances
|
# cull speakers if they do not have enough utterances
|
||||||
|
@ -205,6 +207,23 @@ class Dataset(_Dataset):
|
||||||
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
self.paths = list(itertools.chain.from_iterable(self.paths_by_spkr_name.values()))
|
||||||
|
|
||||||
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
self.samplers = { name: Sampler( paths, keep_all=True ) for name, paths in self.paths_by_spkr_name.items() }
|
||||||
|
|
||||||
|
# dict of speakers keyed by speaker group
|
||||||
|
self.spkrs_by_spkr_group = {}
|
||||||
|
for data_dir in self.dataset:
|
||||||
|
spkr = cfg.get_spkr( data_dir / "dummy" )
|
||||||
|
spkr_group = cfg.get_spkr_group( data_dir / "dummy" )
|
||||||
|
|
||||||
|
if len(self.paths_by_spkr_name[spkr]) < cfg.dataset.min_utterances:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if spkr_group not in self.spkrs_by_spkr_group:
|
||||||
|
self.spkrs_by_spkr_group[spkr_group] = []
|
||||||
|
|
||||||
|
self.spkrs_by_spkr_group[spkr_group].append( spkr )
|
||||||
|
|
||||||
|
self.spkr_groups = list(self.spkrs_by_spkr_group.keys())
|
||||||
|
self.spkr_samplers = { name: Sampler( [*set(speakers)], keep_all=True ) for name, speakers in self.spkrs_by_spkr_group.items() }
|
||||||
|
|
||||||
if cfg.dataset.sample_type == "path":
|
if cfg.dataset.sample_type == "path":
|
||||||
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
self.paths = [*_interleaved_reorder(self.paths, self.get_speaker)]
|
||||||
|
@ -214,14 +233,15 @@ class Dataset(_Dataset):
|
||||||
|
|
||||||
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
self.phone_symmap = phone_symmap or self._get_phone_symmap()
|
||||||
self.spkr_symmap = self._get_spkr_symmap()
|
self.spkr_symmap = self._get_spkr_symmap()
|
||||||
|
self.spkr_group_symmap = self._get_spkr_group_symmap()
|
||||||
self.lang_symmap = self._get_lang_symmap()
|
self.lang_symmap = self._get_lang_symmap()
|
||||||
self.task_symmap = self._get_task_symmap()
|
self.task_symmap = self._get_task_symmap()
|
||||||
|
|
||||||
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
|
||||||
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
self.text_dtype = torch.uint8 if len(self.phone_symmap) < 256 else torch.int16
|
||||||
|
|
||||||
if len(self.paths) == 0 and training:
|
if len(self.paths) == 0:
|
||||||
raise ValueError("No valid path is found for training.")
|
raise ValueError(f"No valid path is found for {self.dataset_type}")
|
||||||
|
|
||||||
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
#self.duration = _total_durations[self.dataset_type] if self.dataset_type in _total_durations else 0
|
||||||
self.duration = _calculate_durations(self.dataset_type)
|
self.duration = _calculate_durations(self.dataset_type)
|
||||||
|
@ -281,6 +301,9 @@ class Dataset(_Dataset):
|
||||||
def _get_spkr_symmap(self):
|
def _get_spkr_symmap(self):
|
||||||
return {s: i for i, s in enumerate(self.spkrs)}
|
return {s: i for i, s in enumerate(self.spkrs)}
|
||||||
|
|
||||||
|
def _get_spkr_group_symmap(self):
|
||||||
|
return {s: i for i, s in enumerate(self.spkr_groups)}
|
||||||
|
|
||||||
def _get_lang_symmap(self):
|
def _get_lang_symmap(self):
|
||||||
return get_lang_symmap()
|
return get_lang_symmap()
|
||||||
|
|
||||||
|
@ -358,14 +381,31 @@ class Dataset(_Dataset):
|
||||||
return prom
|
return prom
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
if cfg.dataset.sample_type == "speaker":
|
if cfg.dataset.sample_type == "group":
|
||||||
|
spkr_group = self.spkr_groups[index]
|
||||||
|
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||||
|
spkr_name = self.spkr_samplers[spkr_group].sample()
|
||||||
|
if spkr_name in self.spkr_symmap:
|
||||||
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
|
else:
|
||||||
|
spkr_id = -1
|
||||||
|
try:
|
||||||
|
path = self.samplers[spkr_name].sample()
|
||||||
|
except Exception as e:
|
||||||
|
print( "ERROR", spkr_group, spkr_name )
|
||||||
|
raise e
|
||||||
|
elif cfg.dataset.sample_type == "speaker":
|
||||||
spkr_name = self.spkrs[index]
|
spkr_name = self.spkrs[index]
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
path = self.samplers[spkr_name].sample()
|
path = self.samplers[spkr_name].sample()
|
||||||
|
spkr_group = self.get_speaker_group(path)
|
||||||
|
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||||
else:
|
else:
|
||||||
path = self.paths[index]
|
path = self.paths[index]
|
||||||
spkr_name = self.get_speaker(path)
|
spkr_name = self.get_speaker(path)
|
||||||
spkr_id = self.spkr_symmap[spkr_name]
|
spkr_id = self.spkr_symmap[spkr_name]
|
||||||
|
spkr_group = self.get_speaker_group(path)
|
||||||
|
spkr_group_id = self.spkr_group_symmap[spkr_group]
|
||||||
|
|
||||||
if cfg.dataset.use_hdf5:
|
if cfg.dataset.use_hdf5:
|
||||||
key = _get_hdf5_path(path)
|
key = _get_hdf5_path(path)
|
||||||
|
@ -381,7 +421,6 @@ class Dataset(_Dataset):
|
||||||
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]).to(self.text_dtype)
|
||||||
resps = _load_quants(path)
|
resps = _load_quants(path)
|
||||||
|
|
||||||
spkr_group = self.get_speaker_group(path)
|
|
||||||
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
lang = torch.tensor([ self.lang_symmap[ self.get_language(spkr_group) ]]).to(torch.uint8)
|
||||||
|
|
||||||
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
# append additional prompts in an attempt to artifically increase lengths / offer new data
|
||||||
|
@ -611,6 +650,8 @@ class Dataset(_Dataset):
|
||||||
self.training = value
|
self.training = value
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
if cfg.dataset.sample_type == "group":
|
||||||
|
return min(len(self.spkr_groups), self._head or len(self.spkr_groups))
|
||||||
if cfg.dataset.sample_type == "speaker":
|
if cfg.dataset.sample_type == "speaker":
|
||||||
return min(len(self.spkrs), self._head or len(self.spkrs))
|
return min(len(self.spkrs), self._head or len(self.spkrs))
|
||||||
return min(len(self.paths), self._head or len(self.paths))
|
return min(len(self.paths), self._head or len(self.paths))
|
||||||
|
@ -679,6 +720,7 @@ def create_train_val_dataloader():
|
||||||
|
|
||||||
_logger.info(str(train_dataset.phone_symmap))
|
_logger.info(str(train_dataset.phone_symmap))
|
||||||
_logger.info(str(train_dataset.spkr_symmap))
|
_logger.info(str(train_dataset.spkr_symmap))
|
||||||
|
_logger.info(str(train_dataset.spkr_group_symmap))
|
||||||
|
|
||||||
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
_logger.info(f"#samples (train): {len(train_dataset)}.")
|
||||||
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
_logger.info(f"#samples (val): {len(val_dataset)}.")
|
||||||
|
@ -707,6 +749,10 @@ def create_dataset_metadata():
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
for path in tqdm(paths, desc="Parsing paths"):
|
for path in tqdm(paths, desc="Parsing paths"):
|
||||||
|
if isinstance(path, str):
|
||||||
|
print("str:", path)
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
speaker = cfg.get_spkr(path)
|
speaker = cfg.get_spkr(path)
|
||||||
if speaker not in metadata:
|
if speaker not in metadata:
|
||||||
metadata[speaker] = {}
|
metadata[speaker] = {}
|
||||||
|
|
|
@ -171,7 +171,7 @@ def encode_from_file(path, device="cuda"):
|
||||||
return encode_from_files( path, device )
|
return encode_from_files( path, device )
|
||||||
else:
|
else:
|
||||||
path = str(path)
|
path = str(path)
|
||||||
wav, sr = torchaudio.load(path, format=path[-3:])
|
wav, sr = torchaudio.load(path)
|
||||||
|
|
||||||
if wav.shape[0] == 2:
|
if wav.shape[0] == 2:
|
||||||
wav = wav[:1]
|
wav = wav[:1]
|
||||||
|
|
|
@ -469,8 +469,9 @@ class Engines(dict[str, Engine]):
|
||||||
|
|
||||||
self._update()
|
self._update()
|
||||||
|
|
||||||
stats["elapsed_time"] = total_elapsed_time
|
if len(self.keys()) > 1:
|
||||||
stats["global_step"] = self.global_step
|
stats["elapsed_time"] = total_elapsed_time
|
||||||
#stats["micro_step"] = self.micro_step
|
|
||||||
|
stats["it"] = self.global_step
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
|
@ -146,13 +146,13 @@ class AR_NAR(Base):
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
)
|
)
|
||||||
# is NAR
|
# is NAR
|
||||||
prev_list = resps_list
|
|
||||||
if max_levels == 0:
|
if max_levels == 0:
|
||||||
max_levels = self.n_resp_levels
|
max_levels = self.n_resp_levels
|
||||||
|
|
||||||
|
prev_list = resps_list
|
||||||
|
|
||||||
while True:
|
for n in trange( max_levels, desc="NAR" ):
|
||||||
level = prev_list[0].shape[-1]
|
level = prev_list[0].shape[-1]
|
||||||
|
|
||||||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -195,14 +195,13 @@ class AR_NAR(Base):
|
||||||
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||||
|
|
||||||
sampling_beam_width_use_logs = True
|
|
||||||
scores = [ 1.0 ] * sampling_beam_width
|
scores = [ 1.0 ] * sampling_beam_width
|
||||||
|
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
max_steps *= self.n_prom_levels
|
max_steps *= self.n_prom_levels
|
||||||
|
|
||||||
# get next in sequence
|
# get next in sequence
|
||||||
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
|
for n in trange(max_steps // max(1, self.recurrent_chunk_size), desc="AR"):
|
||||||
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
|
# experimental rolling response to avoid too-long perplexity hits despite RetNet allegedly fixing this.
|
||||||
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
|
# UNTESTED. In theory it would be better to also adjust the text, but there's no way of correlating text to segment of audio without something like wav2vec2
|
||||||
if max_resp_context > 0:
|
if max_resp_context > 0:
|
||||||
|
@ -245,17 +244,13 @@ class AR_NAR(Base):
|
||||||
r, s = r
|
r, s = r
|
||||||
# first step, expand batch
|
# first step, expand batch
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
batch_size *= sampling_beam_width
|
batch_size = sampling_beam_width
|
||||||
text_list = text_list * sampling_beam_width
|
text_list = text_list * sampling_beam_width
|
||||||
proms_list = proms_list * sampling_beam_width
|
proms_list = proms_list * sampling_beam_width
|
||||||
sequence_list = sequence_list * sampling_beam_width
|
sequence_list = sequence_list * sampling_beam_width
|
||||||
stopped = torch.zeros(batch_size, device=device).bool()
|
stopped = torch.zeros(batch_size, device=device).bool()
|
||||||
|
|
||||||
# update scores
|
scores = [ scores[i] + score for i, score in enumerate(s) ]
|
||||||
if sampling_beam_width_use_logs:
|
|
||||||
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
|
|
||||||
else:
|
|
||||||
scores = [ scores[i] * score for i, score in enumerate(s) ]
|
|
||||||
|
|
||||||
# append tokens
|
# append tokens
|
||||||
for i, ri in enumerate(r):
|
for i, ri in enumerate(r):
|
||||||
|
@ -270,13 +265,8 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# pick the best scoring candidate
|
# pick the best scoring candidate
|
||||||
# desu this is always going to be candidate 0
|
# desu this is always going to be candidate 0
|
||||||
if sampling_beam_width and len(scores) > 0:
|
if sampling_beam_width:
|
||||||
best_idx, best_score = (0, 0)
|
sequence_list = [ sequence_list[0] ]
|
||||||
for idx, score in enumerate(scores):
|
|
||||||
if best_score > score:
|
|
||||||
best_idx, best_score = idx, score
|
|
||||||
|
|
||||||
sequence_list = [sequence_list[best_idx]]
|
|
||||||
|
|
||||||
return [self._prune(r) for r in sequence_list]
|
return [self._prune(r) for r in sequence_list]
|
||||||
|
|
||||||
|
|
|
@ -101,7 +101,7 @@ def _non_blocking_input():
|
||||||
|
|
||||||
def _make_infinite_epochs(dl):
|
def _make_infinite_epochs(dl):
|
||||||
while True:
|
while True:
|
||||||
_logger.info("New epoch starts.")
|
#_logger.info("New epoch starts.")
|
||||||
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
yield from tqdm(dl, "Epoch progress", dynamic_ncols=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -158,10 +158,9 @@ def train(
|
||||||
|
|
||||||
#batch = to_device(batch, torch.cuda.current_device())
|
#batch = to_device(batch, torch.cuda.current_device())
|
||||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||||
|
|
||||||
stats['it'] = stats['global_step']
|
|
||||||
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths)
|
stats['epoch'] = engines.global_samples / len(train_dl.dataset.paths)
|
||||||
|
|
||||||
|
"""
|
||||||
stats['batch'] = {
|
stats['batch'] = {
|
||||||
'size': len(batch['text']),
|
'size': len(batch['text']),
|
||||||
'id': batch['spkr_id'],
|
'id': batch['spkr_id'],
|
||||||
|
@ -170,8 +169,8 @@ def train(
|
||||||
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
||||||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||||
}
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
del stats['global_step']
|
|
||||||
|
|
||||||
elapsed_time = stats.get("elapsed_time", 0)
|
elapsed_time = stats.get("elapsed_time", 0)
|
||||||
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
_logger.info(f"Training Metrics: {json.dumps(stats)}.")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user