cleanup, putting some thoughts in comments before I forget about them
This commit is contained in:
parent
3cfc8a96bb
commit
880b4ecd1b
38
README.md
38
README.md
|
@ -8,6 +8,8 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
|
|||
|
||||
> **Note** Development on this is very sporadic. Gomen.
|
||||
|
||||
> **Note** Compatibility for existing models may break at any time while I feverishly try and work out the best way to crank out a model. Gomen.
|
||||
|
||||
## Requirements
|
||||
|
||||
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
|
||||
|
@ -24,12 +26,14 @@ I've tested this repo under Python versions `3.10.9` and `3.11.3`.
|
|||
|
||||
## Try Me
|
||||
|
||||
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`
|
||||
To quickly try it out, you can run `python -m vall_e.models.ar_nar yaml="./data/config.yaml"`.
|
||||
|
||||
Each model file has a barebones trainer and inference routine.
|
||||
A small trainer will overfit a provided utterance to ensure a model configuration works.
|
||||
|
||||
## Pre-Trained Model
|
||||
|
||||
> **Note** Pre-Trained weights aren't up to par until I finally nail the best training methodologies and model code. Gomen.
|
||||
|
||||
My pre-trained weights can be acquired from [here](https://huggingface.co/ecker/vall-e).
|
||||
|
||||
A script to setup a proper environment and download the weights can be invoked with `./scripts/setup.sh`
|
||||
|
@ -44,14 +48,14 @@ Training is very dependent on:
|
|||
|
||||
### Pre-Processed Dataset
|
||||
|
||||
> **Note** The provided dataset needs to be reprocessed to better suit a new training dataset format. Gomen.
|
||||
|
||||
A "libre" dataset utilizing EnCodec quantized audio can be found [here](https://huggingface.co/ecker/vall-e) under `data.tar.gz`.
|
||||
|
||||
A script to setup a proper environment and train can be invoked with `./scripts/setup-training.sh`
|
||||
|
||||
### Leverage Your Own Dataset
|
||||
|
||||
> **Note** Preparing a dataset is a bit messy.
|
||||
|
||||
If you already have a dataset you want, for example your own large corpus, or for finetuning, you can use your own dataset instead.
|
||||
|
||||
0. Set up a `venv` with `https://github.com/m-bain/whisperX/`.
|
||||
|
@ -79,10 +83,11 @@ If you already have a dataset you want, for example your own large corpus, or fo
|
|||
|
||||
Two dataset formats are supported:
|
||||
* the standard way:
|
||||
- for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.phn.txt` and `./training/data/{group}/{speaker}/{id}.qnt.pt`
|
||||
- for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.json` and `./training/data/{group}/{speaker}/{id}.dac`
|
||||
- for Encodec/Vocos audio backends, data is stored under `./training/data/{group}/{speaker}/{id}.enc` as a NumPy file.
|
||||
- for Descript-Audio-Codec audio backend, data is stored under `./training/data/{group}/{speaker}/{id}.dac` as a NumPy file.
|
||||
- it is *highly* recommended to generate metadata to speed up dataset pre-load with `python3 -m vall_e.data yaml="./training/config.yaml" --action=metadata`
|
||||
* using an HDF5 dataset:
|
||||
- you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"`
|
||||
- you can convert from the standard way with the following command: `python3 -m vall_e.data yaml="./training/config.yaml"` (metadata for dataset pre-load is generated alongside HDF5 creation)
|
||||
- this will shove everything into a single HDF5 file and store some metadata alongside (for now, the symbol map generated, and text/audio lengths)
|
||||
- be sure to also define `use_hdf5` in your config YAML.
|
||||
|
||||
|
@ -98,7 +103,6 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
|
|||
|
||||
The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
||||
|
||||
|
||||
### Plotting Metrics
|
||||
|
||||
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot yaml="./training/config.yaml"`
|
||||
|
@ -107,35 +111,41 @@ You can specify what X and Y labels you want to plot against by passing `--xs to
|
|||
|
||||
### Notices
|
||||
|
||||
If you're training under `float16`, it is recommended to use the `local` backend with `amp` enabled. There's something really funky with `deepspeed` as a backend that's causing issues with training.
|
||||
|
||||
#### Training Under Windows
|
||||
|
||||
As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
|
||||
|
||||
Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment with the local trainer.
|
||||
Creature comforts like `float16`, `amp`, and multi-GPU training *should* work, but extensive testing still needs to be done to ensure it all functions.
|
||||
|
||||
#### Training Caveats
|
||||
|
||||
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
|
||||
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
|
||||
+ It might help to, instead, initially train with smaller utterances, train for two epochs, then increase the each sample length.
|
||||
- This does seem to help speed up the model "learning" better.
|
||||
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
|
||||
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
|
||||
+ This seems remedied with settling for using a HuggingFace tokenizer to handle everything.
|
||||
* having a unified AR and NAR model might sound too convenient, but each task may lobotomize the other, due to the nature of things.
|
||||
+ This *might* be remedied with better sequence formatting.
|
||||
|
||||
#### Backend Architectures
|
||||
|
||||
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported LLm architectures:
|
||||
|
||||
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
|
||||
+ I aim to utilize this for the foundational model, as I get to leverage a bunch of things tailored for LLaMA (and converting to them is rather easy).
|
||||
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
|
||||
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
|
||||
- Setting `cfg.optimizers.bitnet=True` will make use of BitNet's linear implementation.
|
||||
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
|
||||
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
|
||||
- Its implementation for MoE can also be utilized.
|
||||
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
|
||||
* `retnet-hf`: using [syncdoth/RetNet](https://github.com/syncdoth/RetNet) with a HuggingFace-compatible RetNet model
|
||||
- has an inference penality, and MoE is not implemented.
|
||||
* `mamba`: using [state-spaces/mamba](https://github.com/state-spaces/mamba) (needs to mature)
|
||||
- ***really hard*** to have a unified AR and NAR model
|
||||
- inference penalty makes it a really hard sell, despite the loss already being a low 3 after a short amount of samples processed
|
||||
|
||||
For audio backends:
|
||||
|
||||
|
@ -144,7 +154,7 @@ For audio backends:
|
|||
- encoding audio will use the `encodec` backend automagically, as there's no EnCodec encoder under `vocos`
|
||||
* [`descript-audio-codec`](https://github.com/descriptinc/descript-audio-codec): boasts better compression and quality
|
||||
- **Note** models using `descript-audio-codec` at 24KHz + 8kbps will NOT converge in any manner.
|
||||
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps stops improving after a while.
|
||||
- **Note** models using `descript-audio-codec` at 44KHz + 8kbps seems harder to model its "language", but despite the loss being rather high, it sounds fine.
|
||||
|
||||
`llama`-based models also support different attention backends:
|
||||
* `math`: torch's SDPA's `math` implementation
|
||||
|
@ -155,6 +165,8 @@ For audio backends:
|
|||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||
* `flash_attention_2`: integrated `LlamaFlashAttetion2` attention model
|
||||
|
||||
The wide support for various backends is solely while I try and figure out which is the "best" for a core foundation model.
|
||||
|
||||
## Export
|
||||
|
||||
To export the models, run: `python -m vall_e.export yaml=./training/config.yaml`.
|
||||
|
|
|
@ -5,7 +5,8 @@ def get_model(cfg, training=True):
|
|||
if not cfg.experimental:
|
||||
from .ar_nar import AR_NAR
|
||||
model = AR_NAR(
|
||||
n_tokens=cfg.tokens,
|
||||
n_text_tokens=cfg.text_tokens,
|
||||
n_audio_tokens=cfg.audio_tokens,
|
||||
d_model=cfg.dim,
|
||||
n_heads=cfg.heads,
|
||||
n_layers=cfg.layers,
|
||||
|
@ -22,6 +23,9 @@ def get_model(cfg, training=True):
|
|||
else:
|
||||
from .experimental import Model as Experimental
|
||||
model = Experimental(
|
||||
n_text_tokens=cfg.text_tokens,
|
||||
n_audio_tokens=cfg.audio_tokens,
|
||||
|
||||
d_model=cfg.dim,
|
||||
n_layers=cfg.layers,
|
||||
n_heads=cfg.heads,
|
||||
|
|
|
@ -181,6 +181,11 @@ class AR_NAR(Base):
|
|||
# is NAR
|
||||
if max_levels == 0:
|
||||
max_levels = self.n_resp_levels - 1
|
||||
|
||||
# expand if given a raw 1D tensor
|
||||
for i, resp in enumerate(resps_list):
|
||||
if resp.dim() == 1:
|
||||
resps_list[i] = resp.unsqueeze(-1)
|
||||
|
||||
prev_list = resps_list
|
||||
|
||||
|
@ -377,7 +382,9 @@ def example_usage():
|
|||
|
||||
# rentet-full is the only configuration with BitNet's BitLinear that converges despite the grad_norm saying otherwise
|
||||
kwargs = {
|
||||
'n_tokens': 1024,
|
||||
'n_text_tokens': 256,
|
||||
'n_audio_tokens': 1024,
|
||||
|
||||
'd_model': 1024, # 256, # 1024, # 1536
|
||||
'n_heads': 16, # 4, # 16, # 24
|
||||
'n_layers': 12, # 32
|
||||
|
@ -475,8 +482,7 @@ def example_usage():
|
|||
else:
|
||||
resps_list = [ qnt[:, 0].to( device ) ]
|
||||
|
||||
if cfg.model.max_levels > 1:
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
if "nar" in cfg.model.capabilities:
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
||||
for i, o in enumerate(resps_list):
|
||||
|
|
|
@ -458,7 +458,7 @@ class Base(nn.Module):
|
|||
def stop_token(self):
|
||||
if not self.causal:
|
||||
raise ValueError("Not using stop token!")
|
||||
return self.n_tokens
|
||||
return self.n_audio_tokens
|
||||
|
||||
@property
|
||||
def ignore_index(self):
|
||||
|
@ -471,7 +471,10 @@ class Base(nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int = 1024,
|
||||
|
||||
n_text_tokens: int = 256,
|
||||
n_audio_tokens: int = 1024,
|
||||
|
||||
d_model: int = 512,
|
||||
n_heads: int = 8,
|
||||
n_layers: int = 12,
|
||||
|
@ -489,7 +492,9 @@ class Base(nn.Module):
|
|||
self.hyper_config = config
|
||||
self.gradient_checkpointing = self.hyper_config.gradient_checkpointing if self.hyper_config is not None else True
|
||||
|
||||
self.n_tokens = n_tokens
|
||||
self.n_text_tokens = n_text_tokens
|
||||
self.n_audio_tokens = n_audio_tokens
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
|
@ -498,10 +503,10 @@ class Base(nn.Module):
|
|||
self.l_padding = l_padding
|
||||
|
||||
# +1 to include the stop token
|
||||
n_prom_tokens = n_tokens
|
||||
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||
n_prom_tokens = n_audio_tokens
|
||||
n_resp_tokens = n_audio_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||
self.langs_emb = None
|
||||
self.tones_emb = None
|
||||
self.tasks_emb = None
|
||||
|
@ -518,24 +523,28 @@ class Base(nn.Module):
|
|||
levels=self.n_prom_levels if self.version > 3 else None,
|
||||
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True,
|
||||
)
|
||||
# [1025] + [1024] * 8
|
||||
# [1024 + STOP] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding(
|
||||
[n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model,
|
||||
levels=self.n_resp_levels if self.version > 3 else None,
|
||||
sums=self.hyper_config.audio_embedding_sums if self.hyper_config is not None else True
|
||||
)
|
||||
|
||||
|
||||
# useless since I actually removed using these with the input processing overhaul...
|
||||
if self.version >= 3:
|
||||
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
|
||||
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
|
||||
|
||||
# never actually got added... I kept forgetting to classify all my audio for speaker's tone
|
||||
if self.version >= 4:
|
||||
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
|
||||
|
||||
# mamba requires this if a model does both AR and NAR tasks
|
||||
# this *might* help for AR and NAR tasks since we explicitly specify the current RVQ level for a sequence, rather than having it "encoded" in the embeddings
|
||||
# this ***might*** let me also unify the proms_emb and resps_embedding
|
||||
if self.version >= 5:
|
||||
self.rvq_level_emb = Embedding(self.n_resp_levels, d_model)
|
||||
|
||||
# this would be nicer to be a stop token or live inside an embedding
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
# ick, there has to be a better way
|
||||
|
@ -943,7 +952,6 @@ class Base(nn.Module):
|
|||
target_list = []
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
target = []
|
||||
quant_level = quant_levels[batch_index] if quant_levels is not None else None
|
||||
for name, input in batch:
|
||||
if name == "prom":
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
|
@ -960,19 +968,36 @@ class Base(nn.Module):
|
|||
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
||||
|
||||
target = torch.cat( target_list )
|
||||
inputs = torch.cat( logits )
|
||||
# see comments for the split-loss calc cross_entropy call
|
||||
if False:
|
||||
target = torch.cat( target_list )
|
||||
inputs = torch.cat( logits )
|
||||
self.loss = dict(
|
||||
# "nll" was in the original implementation and should actually just be called something else
|
||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||
)
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
else:
|
||||
self.loss = dict(
|
||||
nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / len(batch)
|
||||
)
|
||||
self.stats = dict(
|
||||
acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / len(batch)
|
||||
)
|
||||
|
||||
self.loss = dict(
|
||||
# "nll" was in the original implementation and should actually just be called something else
|
||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||
)
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
return
|
||||
|
||||
"""
|
||||
# considerations:
|
||||
# * split losses does not maintain the entire sequence
|
||||
# * the first token is ignored for all pieces, rather than just the first text token (which is always provided)
|
||||
# + the other way at least should keep it intact this way
|
||||
# + extra logic might be required to instead offset from the end for the resp, rather than fit snuggly
|
||||
# + this might just be a spook since the odds the very first token of the AR mattering is slim (although I swear I hear a very brief audio pop sometimes)
|
||||
"""
|
||||
self.loss = dict()
|
||||
self.stats = dict(acc = dict())
|
||||
|
||||
|
@ -998,6 +1023,7 @@ class Base(nn.Module):
|
|||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
# for the AR, shift sequence so that it predicts the next token
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
if quant_level is None or quant_level == 0:
|
||||
logit = logit[..., :-1, :] # get all but the final logit
|
||||
input = input[..., 1:] # shift sequence to the right by one
|
||||
|
@ -1008,8 +1034,9 @@ class Base(nn.Module):
|
|||
"logits": [],
|
||||
}
|
||||
|
||||
info[name]["targets"].append( input ) # input.contiguous()
|
||||
info[name]["logits"].append( logit ) # logit.contiguous()
|
||||
# modeling_llama.py has some comment about requiring .contiguous() but I feel it's a spook since that incurs a memory allocation
|
||||
info[name]["targets"].append( input.long() )
|
||||
info[name]["logits"].append( logit )
|
||||
|
||||
for name, batch in info.items():
|
||||
loss_factor = self.loss_factor(name)
|
||||
|
@ -1019,15 +1046,20 @@ class Base(nn.Module):
|
|||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
targets = torch.cat( batch["targets"] ).long()
|
||||
inputs = torch.cat( batch["logits"] )
|
||||
|
||||
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
||||
try:
|
||||
# "faster" if cross_entropy has speedups for processing an entire batch, but torch.cat allocates new tensors
|
||||
# to-do: set this to a var
|
||||
if False:
|
||||
targets = torch.cat( batch["targets"] ).long()
|
||||
inputs = torch.cat( batch["logits"] )
|
||||
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
||||
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
||||
except Exception as e:
|
||||
print( name, inputs.shape, targets.shape, e )
|
||||
pass
|
||||
# probably consumes less memory due to not having to allocate memory
|
||||
# this method also opens the way to scale loss per RVQ level (although it shouldn't really be needed)
|
||||
else:
|
||||
self.loss[name] = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( batch["targets"], batch["logits"] ) ]) / len(batch)
|
||||
self.stats["acc"][name] = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( batch["targets"], batch["logits"] ) ] ) / len(batch)
|
||||
|
||||
# accuracy sometimes breaks for mamba
|
||||
|
||||
# to-do: compute loss per individual batch to scale per RVQ level
|
||||
"""
|
||||
|
@ -1049,9 +1081,8 @@ class Base(nn.Module):
|
|||
|
||||
# yes, there's a better way.
|
||||
training = False
|
||||
for b_i in range(len(inputs)):
|
||||
for i in range(len(inputs[b_i])):
|
||||
name, input = inputs[b_i][i]
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
for name, input in batch:
|
||||
if name == "targ":
|
||||
training = True
|
||||
|
||||
|
@ -1115,13 +1146,16 @@ class Base(nn.Module):
|
|||
):
|
||||
if min_temperature < 0:
|
||||
min_temperature = temperature
|
||||
|
||||
# (NAR) return the entire generated response
|
||||
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
|
||||
if quant_levels is not None:
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||
# (AR chunkwise) return the last chunkwise piece
|
||||
elif self.causal and self.recurrent_chunk_size > 0:
|
||||
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
|
||||
# (AR) return just the last code
|
||||
# Recurrent decoding relies on the last token in the logits, because each token predicts the next token in the sequence (obviously)
|
||||
else:
|
||||
logits = [ logit[-1:] for logit in logits ]
|
||||
|
||||
|
|
|
@ -95,6 +95,10 @@ else:
|
|||
class Model(LlmArchClass):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
n_text_tokens = 256,
|
||||
n_audio_tokens = 1024,
|
||||
|
||||
d_model=1024,
|
||||
n_layers=12,
|
||||
n_heads=16,
|
||||
|
@ -107,7 +111,7 @@ class Model(LlmArchClass):
|
|||
hf_attention = config.attention if config is not None else None
|
||||
gradient_checkpointing = config.gradient_checkpointing if config is not None else True
|
||||
# text_tokens + rvq levels + [audio tokens * codebooks] (prom) + [audio tokens * codebooks] (resp) + stop
|
||||
vocab_size = 256 + cfg.model.max_levels + (1024 * cfg.model.max_levels) + (1024 * cfg.model.max_levels) + 1
|
||||
vocab_size = n_text_tokens + cfg.model.max_levels + (n_audio_tokens * cfg.model.max_levels) + (n_audio_tokens * cfg.model.max_levels) + 1
|
||||
|
||||
if SELECTED_ARCH == "llama":
|
||||
super().__init__(config=LlamaConfig(
|
||||
|
|
|
@ -174,7 +174,6 @@ def run_eval(engines, eval_name, dl):
|
|||
resps_list = [ resp[:, 0] for resp in batch["resps"] ]
|
||||
|
||||
if "nar" in engine.hyper_config.capabilities:
|
||||
resps_list = [ r.unsqueeze(-1) for r in resps_list ]
|
||||
resps_list = engine(text_list=batch["text"], proms_list=batch["proms"], lang_list=batch["lang"], resps_list=resps_list, sampling_temperature=cfg.evaluation.nar_temperature)
|
||||
|
||||
process( name, batch, resps_list )
|
||||
|
|
|
@ -164,18 +164,6 @@ def train(
|
|||
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
|
||||
|
||||
"""
|
||||
stats['batch'] = {
|
||||
'size': len(batch['text']),
|
||||
'id': batch['spkr_id'],
|
||||
'index': [ index for index in batch['index'] ],
|
||||
'text_len': [ text.shape[0] for text in batch['text'] ],
|
||||
'prom_len': [ prom.shape[0] for prom in batch['proms'] ],
|
||||
'resp_len': [ resp.shape[0] for resp in batch['resps'] ],
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
try:
|
||||
metrics = json.dumps(stats)
|
||||
|
|
Loading…
Reference in New Issue
Block a user