deprecate sole AR/NAR model by only keeping the AR+NAR (the beauty of no one using this is that I can break compat as much as I want), add tone token for when I classify my dataset with tone/emotion in the future, some other things

This commit is contained in:
mrq 2024-04-15 19:54:32 -05:00
parent d69a00e389
commit 545162195b
14 changed files with 196 additions and 764 deletions

View File

@ -6,16 +6,10 @@
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/)
> **Note** Development on this is very sporadic. Gomen.
## Requirements
* [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements):
- DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed.
- If your config YAML has the training backend set to `deepspeed`, you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
@ -24,7 +18,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
## Install
Simply run `pip install git+https://git.ecker.tech/mrq/vall-e`.
Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`.
I've tested this repo under Python versions `3.10.9` and `3.11.3`.
@ -68,7 +62,7 @@ A script to setup a proper environment and train can be invoked with `./scripts/
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'`
5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
5. Train the model using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
* If distributing your training (for example, multi-GPU), use `deepspeed --module vall_e.train yaml="./data/config.yaml"`
You may quit your training any time by just entering `quit` in your CLI. The latest checkpoint will be automatically saved.
@ -93,18 +87,23 @@ You can specify what X and Y labels you want to plot against by passing `--xs to
#### Training Under Windows
As training under `deepspeed` and Windows is not supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment.
Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment with the local trainer.
#### Training on Low-VRAM Cards
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation).
VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM). Howver, VRAM use is predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
#### Training Caveats
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
#### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
@ -112,6 +111,8 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
- Its implementation for MoE can also be utilized.
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet/) with a HuggingFace-compatible RetNet model
- inferencing cost is about 0.5x, and MoE is not implemented.
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
@ -121,11 +122,11 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
This will export the latest checkpoints, for example, under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats.
This will export the latest checkpoints, for example, under `./data/ckpt/ar+nar-retnet-8/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats.
## Synthesis
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --model-ckpt ./data/ckpt/ar+nar-retnet-8/fp32.pth` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
Some additional flags you can pass are:
* `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language.
@ -154,7 +155,6 @@ And some experimental sampling flags you can use too (your mileage will ***defin
## To-Do
* train and release a ***good*** model.
- the current model seems to require a ***long*** time of training at a very small LR rate to try and cover a wide variety of speakers of varying acoustics.
* clean up the README, and document, document, document onto the wiki.
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- training additional tasks needs the SpeechX implementation to be reworked.
@ -164,7 +164,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin
+ this requires a properly trained AR, however.
* work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this):
- "sliding" AR input, such as have the context a fixed length.
+ the model may need to be trained for this with a fancy positional embedding injected. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful.
+ the model may need to be trained for this with a fancy positional embedding injected OR already trained with a sliding context window in mind. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful.
## Notices and Citations

View File

@ -13,8 +13,7 @@ def main():
parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ar-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--model-ckpt", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
parser.add_argument("--max-nar-levels", type=int, default=7)
@ -41,7 +40,7 @@ def main():
parser.add_argument("--dtype", type=str, default=None)
args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
tts.inference(
text=args.text,
references=args.references,

View File

@ -162,6 +162,9 @@ class Dataset:
@dataclass()
class Model:
_max_levels: int = 0
_embeddings: str | None = None
name: str = "" # vanity name for the model
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
size: str | dict = "full" # preset string or explicitly defined dimensionality
@ -169,6 +172,7 @@ class Model:
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
langs: int = 1 # defined languages
tones: int = 1 # defined tones
experts: int = 1
arch_type: str = "retnet" # or "transformer""
training: bool = True # unneeded now
@ -176,6 +180,13 @@ class Model:
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@property
def max_levels(self):
return self._max_levels if self._max_levels > 0 else self.prom_levels
@property
# required for fp8 as the lengths needs to be divisible by 8
def input_alignment(self):
@ -203,7 +214,7 @@ class Model:
if self.interleave:
name.append("interleaved")
else:
name.append(f'{cfg.models.prom_levels}')
name.append(f'{cfg.model.prom_levels}')
return "-".join(name)
@ -256,58 +267,6 @@ class Model:
def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing
@dataclass()
class Models:
_max_levels: int = 0
_prom_levels: int = 1
_embeddings: str | None = None
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
])
def get(self, name=None):
if not name:
return [ Model(**model) for model in self._models ]
for model in self._models:
if model.name == name:
return model
raise ValueError
@property
def ar(self):
return self.get("ar")
@property
def ar_nar(self):
return self.get("ar+nar")
@property
def nar(self):
return self.get("nar")
@property
def prom_levels(self):
prom_levels = self._prom_levels
for model in self._models:
prom_levels = max(prom_levels, model.prom_levels)
return prom_levels
@property
def tasks(self):
tasks = 1
for model in self._models:
tasks = max(tasks, model.tasks)
return tasks
@property
def max_levels(self):
return self._max_levels if self._max_levels > 0 else self.prom_levels
@dataclass()
class Hyperparameters:
batch_size: int = 8
@ -568,7 +527,7 @@ class Config(_Config):
experimental: bool = False # So I can stop commenting out things when committing
dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models)
model: Model = field(default_factory=lambda: Model)
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
@ -617,7 +576,7 @@ class Config(_Config):
def format( self ):
self.dataset = Dataset(**self.dataset)
self.models = Models(**self.models)
self.model = Model(**self.model)
self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer)

View File

@ -29,22 +29,27 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__)
# to-do: clean up this symmap mess
def get_phone_symmap():
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] )
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, '': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, '': 220, 'eˈ': 221, 'ʍ': 222, '': 223, '': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
return symmap
return {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, '': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, '': 220, 'eˈ': 221, 'ʍ': 222, '': 223, '': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
def get_lang_symmap():
symmap = {
return {
"en": 0,
"ja": 1,
}
def get_tone_symmap():
return {
"neutral": 0,
}
return symmap
def get_task_symmap():
symmap = {
return {
"<tts>": 0,
"<tts-c>": 1,
"<ns>": 2,
@ -54,7 +59,6 @@ def get_task_symmap():
"<mask>": 6,
"<eoe>": 7,
}
return symmap
def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
@ -237,6 +241,7 @@ class Dataset(_Dataset):
self.spkr_symmap = self._get_spkr_symmap()
self.spkr_group_symmap = self._get_spkr_group_symmap()
self.lang_symmap = self._get_lang_symmap()
self.tone_symmap = self._get_tone_symmap()
self.task_symmap = self._get_task_symmap()
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
@ -309,11 +314,14 @@ class Dataset(_Dataset):
def _get_lang_symmap(self):
return get_lang_symmap()
def _get_tone_symmap(self):
return get_tone_symmap()
def _get_task_symmap(self):
return get_task_symmap()
"""
def get_task_token( self, token, levels=cfg.models.max_levels ):
def get_task_token( self, token, levels=cfg.model.max_levels ):
if not hasattr(self, "task_symmap"):
self.task_symmap = self._get_task_symmap()
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
@ -339,7 +347,7 @@ class Dataset(_Dataset):
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
choices = [*choices]
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
if len(choices) == 0:
choices = [*set(self.paths_by_spkr_name[spkr_name])]
"""
@ -622,8 +630,8 @@ class Dataset(_Dataset):
"""
# trim to fit to requested prom/resps levels
proms = proms[:, :cfg.models.prom_levels]
resps = resps[:, :cfg.models.prom_levels]
proms = proms[:, :cfg.model.prom_levels]
resps = resps[:, :cfg.model.prom_levels]
return dict(
@ -928,7 +936,7 @@ if __name__ == "__main__":
if task not in cfg.dataset.tasks_list:
continue
print(text, task, cfg.models.prom_levels)
print(text, task, cfg.model.prom_levels)
print( proms.shape, resps.shape )
tokens = 0

View File

@ -21,7 +21,7 @@ except Exception as e:
cfg.inference.use_vocos = False
@cache
def _load_encodec_model(device="cuda", levels=cfg.models.max_levels):
def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
# Instantiate a pretrained EnCodec model
assert cfg.sample_rate == 24_000
@ -44,7 +44,7 @@ def _load_encodec_model(device="cuda", levels=cfg.models.max_levels):
return model
@cache
def _load_vocos_model(device="cuda", levels=cfg.models.max_levels):
def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
assert cfg.sample_rate == 24_000
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
@ -66,7 +66,7 @@ def _load_vocos_model(device="cuda", levels=cfg.models.max_levels):
return model
@cache
def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.models.max_levels):
def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.model.max_levels):
if vocos:
model = _load_vocos_model(device, levels=levels)
else:
@ -80,7 +80,7 @@ def unload_model():
@torch.inference_mode()
def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels):
def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels):
"""
Args:
codes: (b q t)
@ -117,7 +117,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels):
return wav, model.sample_rate
# huh
def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.models.max_levels):
def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.model.max_levels):
return decode(resps, device=device, levels=levels)
def decode_to_file(resps: Tensor, path: Path, device="cuda"):
@ -131,7 +131,7 @@ def _replace_file_extension(path, suffix):
@torch.inference_mode()
def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.models.max_levels):
def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.model.max_levels):
"""
Args:
wav: (t)
@ -224,7 +224,7 @@ def repeat_extend_audio( qnt, target ):
# merges two quantized audios together
# I don't know if this works
def merge_audio( *args, device="cpu", scale=[], levels=cfg.models.max_levels ):
def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
qnts = [*args]
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]

View File

@ -26,7 +26,7 @@ from functools import cache
@cache
def load_engines(training=True):
models = get_models(cfg.models.get(), training=training)
models = get_models(cfg.model.get(), training=training)
engines = dict()
for name, model in models.items():
@ -145,8 +145,8 @@ def load_engines(training=True):
engine.freeze(freeze_all=False)
# copy embeddings if requested
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
if cfg.model._embeddings is not None:
embeddings_path = cfg.relpath / cfg.model._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))

View File

@ -19,7 +19,7 @@ if deepspeed_available:
import deepspeed
class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ):
def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ):
self.loading = True
self.input_sample_rate = 24000
@ -53,7 +53,10 @@ class TTS():
self.symmap = None
def parse( name, model, state ):
if model_ckpt:
state = torch.load(model_ckpt)
self.model = get_models(cfg.model.get(), training=False)[0]
if "userdata" in state and 'symmap' in state['userdata']:
self.symmap = state['userdata']['symmap']
elif "symmap" in state:
@ -62,55 +65,26 @@ class TTS():
if "module" in state:
state = state['module']
model.load_state_dict(state)
self.model.load_state_dict(state)
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
return model
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get(), training=False)
for name, model in models.items():
if name.startswith("ar"):
state = torch.load(self.ar_ckpt)
self.ar = parse( name, model, state )
elif name.startswith("nar"):
state = torch.load(self.nar_ckpt)
self.nar = parse( name, model, state )
if name.startswith("ar+nar"):
self.nar = self.ar
self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
else:
self.load_models()
engines = load_engines(training=False)
for name, engine in engines.items():
self.model = engine.module
break
if self.dtype != torch.int8:
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.ar.eval()
self.nar.eval()
self.model.eval()
if self.symmap is None:
self.symmap = get_phone_symmap()
self.loading = False
def load_models( self ):
engines = load_engines(training=False)
for name, engine in engines.items():
if name.startswith("ar"):
self.ar = engine.module
elif name.startswith("nar"):
self.nar = engine.module
if name.startswith("ar+nar"):
self.nar = self.ar
def encode_text( self, text, language="en" ):
# already a tensor, return it
if isinstance( text, Tensor ):
@ -193,7 +167,7 @@ class TTS():
lang = to_device(lang, self.device).to(torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar(
resps_list = self.model(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp,
@ -205,7 +179,7 @@ class TTS():
sampling_mirostat_eta=mirostat_eta,
)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(
resps_list = self.model(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels,
sampling_temperature=nar_temp,

View File

@ -1,19 +1,9 @@
from .ar import AR
from .nar import NAR
from .ar_nar import AR_NAR
def get_model(cfg, training=True):
if cfg.name == "ar":
Model = AR
elif cfg.name == "nar":
Model = NAR
elif cfg.name == "ar+nar":
Model = AR_NAR
else:
raise f"invalid model name: {cfg.name}"
name = cfg.name
model = Model(
model = AR_NAR(
n_tokens=cfg.tokens,
d_model=cfg.dim,
n_heads=cfg.heads,

View File

@ -1,309 +0,0 @@
from ..config import cfg
from .base import Base, list_to_tensor, Categorical
import torch
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange
from torch import Tensor
from tqdm import trange
class AR(Base):
@property
def causal(self):
return True
@property
def norm_type(self):
return "ln"
@property
def arch_type(self) -> str:
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.models.ar.arch_type
@property
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
@property
def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.resp_levels
return cfg.models.ar.resp_levels
@property
def n_max_levels(self) -> int:
return cfg.models.max_levels
@property
def n_tasks(self) -> int:
return cfg.models.ar.tasks
@property
def n_langs(self) -> int:
return cfg.models.ar.langs
@property
def recurrent_chunk_size(self) -> int:
if cfg.mode == "training":
return 0
return cfg.inference.recurrent_chunk_size
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.ar.rotary_embedding_base
"""
@property
def interleave(self) -> bool:
if hasattr(self, "config") and self.config:
return self.config.interleave
return False
@property
def monolithic(self) -> bool:
return False
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.ar.version
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
if len(indices) == 0:
return l
return l[: indices.min().item()]
def _interleave( self, codes ):
if not self.interleave:
return codes
return codes.flatten()
def _deinterleave( self, codes, length = 0 ):
if not self.interleave:
return codes
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
@staticmethod
def _unsqueeze_list(x_list, axis=-1):
return [x.unsqueeze(dim=axis) for x in x_list]
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
max_steps: int = 1000,
max_resp_context: int = -1,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
):
if resps_list is not None:
if self.interleave:
resps_list = [self._interleave(r) for r in resps_list]
else:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
return super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=self._unsqueeze_list(resps_list),
targ_list=resps_list,
lang_list=lang_list,
quant_levels=None,
)
device = text_list[0].device
batch_size = len(text_list)
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
stopped = torch.zeros(batch_size, device=device).bool()
recurrent_state = {} if cfg.inference.recurrent_forward else None
mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
if self.interleave:
max_steps *= self.n_prom_levels
# get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
if max_resp_context > 0:
resps_list = self._unsqueeze_list([ sequence[-max_resp_context:] for sequence in sequence_list ] )
else:
resps_list = self._unsqueeze_list(sequence_list)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
state=recurrent_state
)
r = super().sample(
logits=logits,
resps_list=resps_list,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
beam_width=sampling_beam_width,
mirostat=mirostat,
)
if mirostat is not None:
# r is the state
mirostat = r
# extract token from state
r = [ state["token"] for state in mirostat ]
# we do it here because the sampler will already expand our logits list
elif sampling_beam_width > 0:
# expand tuple
r, s = r
# first step, expand batch
if batch_size == 1:
batch_size *= sampling_beam_width
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool()
# update scores
if sampling_beam_width_use_logs:
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
else:
scores = [ scores[i] * score for i, score in enumerate(s) ]
# append tokens
for i, ri in enumerate(r):
if self.stop_token in ri:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width and len(scores) > 0:
best_idx, best_score = (0, 0)
for idx, score in enumerate(scores):
if best_score > score:
best_idx, best_score = idx, score
sequence_list = [sequence_list[best_idx]]
if self.interleave:
sequence_list = [self._deinterleave(r) for r in sequence_list]
return [self._prune(r) for r in sequence_list]
def example_usage():
cfg.trainer.backend = "local"
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
from ..engines import Engine
from tqdm import tqdm
from ..utils import wrapper as ml
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
text_list = [
#torch.tensor([1, 2, 3], device=device),
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
]
proms_list = [
#x8(torch.tensor([1, 2, 3], device=device)),
qnt.to(device),
]
resps_list = [
qnt.to(device),
]
text_list = text_list[:1]
proms_list = proms_list[:1]
resps_list = resps_list[:1]
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 24,
}
"""
try:
kwargs['config'] = cfg.models.ar
except Exception as e:
pass
"""
model = AR(**kwargs).to(device)
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
engine = Engine(model=model, optimizer=optimizer)
def sample( name, steps=600 ):
engine.eval()
out = engine(text_list, proms_list, max_steps=steps)
for i, o in enumerate(out):
wav, sr = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
def train():
engine.train()
t = trange(steps)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
tqdm.write(f"{stats}")
sample("init", 75)
train()
sample("final")
if __name__ == "__main__":
example_usage()

View File

@ -25,29 +25,33 @@ class AR_NAR(Base):
def arch_type(self) -> str:
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.models.ar_nar.arch_type
return cfg.model.arch_type
@property
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
return cfg.model.prom_levels
@property
def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.resp_levels
return cfg.models.ar_nar.resp_levels
return cfg.model.resp_levels
@property
def n_max_levels(self) -> int:
return cfg.models.max_levels
return cfg.model.max_levels
@property
def n_tasks(self) -> int:
return cfg.models.ar_nar.tasks
return cfg.model.tasks
@property
def n_langs(self) -> int:
return cfg.models.ar_nar.langs
return cfg.model.langs
@property
def n_tones(self) -> int:
return cfg.model.tones
@property
def recurrent_chunk_size(self) -> int:
@ -58,7 +62,7 @@ class AR_NAR(Base):
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.ar_nar.rotary_embedding_base
return cfg.model.rotary_embedding_base
"""
@property
@ -73,7 +77,7 @@ class AR_NAR(Base):
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.ar_nar.version
return cfg.model.version
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
@ -92,6 +96,7 @@ class AR_NAR(Base):
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
max_steps: int = 1000,
max_levels: int = 0,
@ -134,10 +139,10 @@ class AR_NAR(Base):
else:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
"""
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else:
quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
quant_levels = torch.Tensor([ 0 if random.random() < cfg.model.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
"""
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
@ -162,6 +167,7 @@ class AR_NAR(Base):
resps_list=resps_list,
targ_list=targ_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
# is NAR
@ -182,6 +188,7 @@ class AR_NAR(Base):
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
@ -234,6 +241,7 @@ class AR_NAR(Base):
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
state=recurrent_state
)
else:
@ -242,6 +250,7 @@ class AR_NAR(Base):
proms_list=proms_list,
resps_list=resps_list,
lang_list=lang_list,
tone_list=tone_list,
state=recurrent_state
)
@ -312,14 +321,14 @@ def example_usage():
import re
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)])
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
cfg.hyperparameters.gradient_accumulation_steps = 1
@ -359,7 +368,7 @@ def example_usage():
"""
try:
kwargs['config'] = cfg.models.ar_nar
kwargs['config'] = cfg.model
except Exception as e:
pass
"""
@ -374,8 +383,8 @@ def example_usage():
# copy embeddings if requested
"""
if cfg.models._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings
if cfg.model._embeddings is not None:
embeddings_path = cfg.relpath / cfg.model._embeddings
if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))

View File

@ -262,11 +262,15 @@ class Base(nn.Module):
@property
def n_langs(self) -> int:
raise NotImplementedError
@property
def n_tasks(self) -> int:
raise NotImplementedError
@property
def n_tones(self) -> int:
raise NotImplementedError
@property
def recurrent_chunk_size(self) -> int:
raise NotImplementedError
@ -343,6 +347,7 @@ class Base(nn.Module):
self.text_emb = Embedding(n_tokens, d_model)
self.langs_emb = None
self.tones_emb = None
self.tasks_emb = None
if self.version == 1: # legacy
@ -359,6 +364,9 @@ class Base(nn.Module):
if self.version >= 3:
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
if self.version >= 4:
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
self.sep = nn.Parameter(torch.randn(d_model))
@ -522,53 +530,15 @@ class Base(nn.Module):
ignore_index=self.ignore_index,
)
def forward(
def _forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
inputs,
mask = None,
state = None,
):
batch_size = len(text_list)
if self.langs_emb is None:
lang_list = None
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.langs_emb(lang_list) if lang_list is not None else None,
self.proms_emb(proms_list),
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
x = inputs
m = mask.squeeze(-1).int()
aux_loss = None
device = x.device
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input
shape = list(x.shape)
shape[1] = self.l_padding - shape[1] % self.l_padding
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding], dim=1)
# pad mask
shape[2] = 1
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
m = torch.cat([m, padding], dim=1)
# for simplicity
mask = m.squeeze(-1).int()
"""
# Broken
@ -587,7 +557,7 @@ class Base(nn.Module):
xi = x[:, n, :].unsqueeze(1)
kwargs = dict(
attention_mask=mask,
attention_mask=m,
inputs_embeds=xi,
past_key_values=state,
use_cache=True,
@ -603,9 +573,9 @@ class Base(nn.Module):
"""
# HF transformer derived model
if self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral":
if self.arch_type in ["llama", "mistral", "mixtral"]:
kwargs = dict(
attention_mask=mask,
attention_mask=m,
inputs_embeds=x,
past_key_values=state,
use_cache=True,
@ -632,7 +602,7 @@ class Base(nn.Module):
x = self.sin_emb.add_pe(x)
# pass our inputs through the transformer
for block in self.blocks:
x = block(x, mask, l)
x = block(x, m, l)
elif self.arch_type == "retnet":
# pass our inputs through the RetNet
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
@ -642,7 +612,7 @@ class Base(nn.Module):
first = state is None or len(state) == 0
kwargs = dict(
attention_mask=mask,
attention_mask=m,
inputs_embeds=x if first else x[:, -1, :].unsqueeze(1),
past_key_values=None if first else state,
use_cache=True,
@ -659,8 +629,76 @@ class Base(nn.Module):
x = self.model(x)
# output projection layer with masking
x = self.classifier(x) * mask
x = self.classifier(x) * m
return x, state, aux_loss
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
):
device = text_list[0].device
batch_size = len(text_list)
# silently ignore languages if model does not have it
if self.langs_emb is None:
lang_list = None
# inject default language
elif lang_list is None:
lang_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
# silently ignore tones if model does not have it
if self.tones_emb is None:
tone_list = None
# inject default tone
elif tone_list is None:
tone_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
"""
# Typical sequence format
# To-do: integrate tasks again
<s><text></s><sep><lang><sep><prom><sep><tone><sep><resp><stop>
"""
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.langs_emb(lang_list) if lang_list is not None else None,
self.proms_emb(proms_list),
self.tones_emb(tone_list) if tone_list is not None else None,
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input
shape = list(x.shape)
shape[1] = self.l_padding - shape[1] % self.l_padding
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding], dim=1)
# pad mask
shape[2] = 1
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
m = torch.cat([m, padding], dim=1)
x, state, aux_loss = self._forward(
inputs=x,
mask=m,
state=state,
)
# Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
@ -790,7 +828,7 @@ def example_usage():
from .nar import NAR
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
@ -812,7 +850,7 @@ def example_usage():
train = True
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
text_list = [
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),

View File

@ -1,235 +0,0 @@
from ..config import cfg
from .base import Base
import torch
from torch import Tensor
from tqdm import trange
class NAR(Base):
@property
def causal(self):
return False
@property
def arch_type(self) -> str:
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.models.nar.arch_type
@property
def norm_type(self):
return "ln" if self.n_resp_levels == 1 else "adaln"
@property
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
@property
def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.resp_levels
return cfg.models.nar.resp_levels
@property
def n_max_levels(self) -> int:
return cfg.models.max_levels
@property
def n_tasks(self) -> int:
return cfg.models.nar.tasks
@property
def n_langs(self) -> int:
return cfg.models.nar.langs
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.nar.version
@property
def recurrent_chunk_size(self) -> int:
return 0
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.nar.rotary_embedding_base
"""
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return False
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
lang_list: list[Tensor] | None = None,
max_levels: int = 0,
sampling_temperature: float = 0.2,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0, # unused
sampling_beam_width: int = 0, # unused
sampling_mirostat_tau: float = 0.0, # unused
):
"""
Args:
text_list: [t] * b
proms_list: [t' l] * b, l=8
resps_list: [t'' l] * b, l=1 or 8, 1 for testing and 8 for training.
Returns:
[t'' l], l=8 if testing. empty list will be returned during training.
"""
n_levels_set = {r.shape[-1] for r in resps_list}
if len(n_levels_set) > 1:
raise ValueError(f"Please give only one level, got {n_levels_set}.")
n_levels = next(iter(n_levels_set))
device = text_list[0].device
if n_levels == self.n_resp_levels + 1:
assert resps_list is not None
quant_levels = torch.randint(0, self.n_resp_levels, (len(resps_list),))
prev_list = [o[..., : l + 1] for o, l in zip(resps_list, quant_levels)]
targ_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)]
#quant_levels = quant_levels.to(device=device)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
targ_list=targ_list,
lang_list=lang_list,
quant_levels=quant_levels,
)
prev_list = []
else:
prev_list = resps_list
if max_levels == 0:
max_levels = self.n_resp_levels
while True:
level = prev_list[0].shape[-1] - 1
if level >= max_levels: # min(max_levels, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
quant_levels = torch.full((len(text_list),), level, device=device)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
quant_levels=quant_levels,
)
resps_list = super().sample(
logits=logits,
resps_list=prev_list,
quant_levels=quant_levels,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat_tau=sampling_mirostat_tau,
#mirostat_state=mirostat_state,
)
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list
def example_usage():
cfg.trainer.backend = "local"
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
from ..engines import Engine
from tqdm import tqdm
from ..utils import wrapper as ml
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
# to-do: unmangle this and the resp shit
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
text_list = [
#torch.tensor([1, 2, 3], device=device),
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
]
proms_list = [
x8(torch.tensor([2, 3], device=device)),
]
resps_list = [
qnt.to(device),
]
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
model = NAR(**kwargs).to(device)
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
engine = Engine(model=model, optimizer=optimizer)
def sample( name ):
engine.eval()
codes = engine( text_list, proms_list, resps_list=[r[..., 0].unsqueeze(-1) for r in resps_list], sampling_temperature=0.2 )
decode_to_file( codes[0], f"data/nar.{name}.wav", device )
def train():
engine.train()
t = trange(steps)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
tqdm.write(f"{stats}")
sample("init")
train()
sample("final")
if __name__ == "__main__":
example_usage()

View File

@ -109,7 +109,7 @@ if __name__ == "__main__":
path = cfg.relpath / "logs"
paths = path.rglob(f"./*/{args.filename}")
args.models = [ model for model in cfg.models.get() if model.training and (args.model == "*" or model.name in args.model) ]
args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
if args.ys == "":
args.ys = ["loss"]

View File

@ -54,14 +54,13 @@ def init_tts(restart=False):
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--ar-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--model-ckpt", type=Path, default=None)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default="auto")
args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys())