deprecate sole AR/NAR model by only keeping the AR+NAR (the beauty of no one using this is that I can break compat as much as I want), add tone token for when I classify my dataset with tone/emotion in the future, some other things

This commit is contained in:
mrq 2024-04-15 19:54:32 -05:00
parent d69a00e389
commit 545162195b
14 changed files with 196 additions and 764 deletions

View File

@ -6,16 +6,10 @@
An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder. An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/)
> **Note** Development on this is very sporadic. Gomen. > **Note** Development on this is very sporadic. Gomen.
## Requirements ## Requirements
* [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements):
- DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed.
- If your config YAML has the training backend set to `deepspeed`, you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
* [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/): * [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/):
- For phonemizing text, this repo requires `espeak`/`espeak-ng` installed. - For phonemizing text, this repo requires `espeak`/`espeak-ng` installed.
- Linux users can consult their package managers on installing `espeak`/`espeak-ng`. - Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
@ -24,7 +18,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/),
## Install ## Install
Simply run `pip install git+https://git.ecker.tech/mrq/vall-e`. Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`.
I've tested this repo under Python versions `3.10.9` and `3.11.3`. I've tested this repo under Python versions `3.10.9` and `3.11.3`.
@ -68,7 +62,7 @@ A script to setup a proper environment and train can be invoked with `./scripts/
If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'` If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'`
5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml` 5. Train the model using the following scripts: `python -m vall_e.train yaml=./data/config.yaml`
* If distributing your training (for example, multi-GPU), use `deepspeed --module vall_e.train yaml="./data/config.yaml"` * If distributing your training (for example, multi-GPU), use `deepspeed --module vall_e.train yaml="./data/config.yaml"`
You may quit your training any time by just entering `quit` in your CLI. The latest checkpoint will be automatically saved. You may quit your training any time by just entering `quit` in your CLI. The latest checkpoint will be automatically saved.
@ -93,18 +87,23 @@ You can specify what X and Y labels you want to plot against by passing `--xs to
#### Training Under Windows #### Training Under Windows
As training under `deepspeed` and Windows is not supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend. As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend.
Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment. Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment with the local trainer.
#### Training on Low-VRAM Cards #### Training on Low-VRAM Cards
During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation). During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM). Howver, VRAM use is predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful.
Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU). Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU).
#### Training Caveats
Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with:
* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long.
* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `<s>` and `</s>` tokens with empty utterances.
* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping.
#### Backend Architectures #### Backend Architectures
As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported: As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported:
@ -112,6 +111,8 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
* `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards. * `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards.
* `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead. * `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead.
- Its implementation for MoE can also be utilized. - Its implementation for MoE can also be utilized.
* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet/) with a HuggingFace-compatible RetNet model
- inferencing cost is about 0.5x, and MoE is not implemented.
* `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements.
* `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. * `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation.
* `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer. * `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer.
@ -121,11 +122,11 @@ As the core of VALL-E makes use of a language model, various LLM architectures c
To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`. To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`.
This will export the latest checkpoints, for example, under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats. This will export the latest checkpoints, for example, under `./data/ckpt/ar+nar-retnet-8/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats.
## Synthesis ## Synthesis
To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>` To synthesize speech, invoke either (if exported the models): `python -m vall_e <text> <ref_path> <out_path> --model-ckpt ./data/ckpt/ar+nar-retnet-8/fp32.pth` or `python -m vall_e <text> <ref_path> <out_path> yaml=<yaml_path>`
Some additional flags you can pass are: Some additional flags you can pass are:
* `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language. * `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language.
@ -154,7 +155,6 @@ And some experimental sampling flags you can use too (your mileage will ***defin
## To-Do ## To-Do
* train and release a ***good*** model. * train and release a ***good*** model.
- the current model seems to require a ***long*** time of training at a very small LR rate to try and cover a wide variety of speakers of varying acoustics.
* clean up the README, and document, document, document onto the wiki. * clean up the README, and document, document, document onto the wiki.
* extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). * extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)).
- training additional tasks needs the SpeechX implementation to be reworked. - training additional tasks needs the SpeechX implementation to be reworked.
@ -164,7 +164,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin
+ this requires a properly trained AR, however. + this requires a properly trained AR, however.
* work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this): * work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this):
- "sliding" AR input, such as have the context a fixed length. - "sliding" AR input, such as have the context a fixed length.
+ the model may need to be trained for this with a fancy positional embedding injected. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful. + the model may need to be trained for this with a fancy positional embedding injected OR already trained with a sliding context window in mind. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful.
## Notices and Citations ## Notices and Citations

View File

@ -13,8 +13,7 @@ def main():
parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--out-path", type=Path, default=None)
parser.add_argument("--yaml", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None)
parser.add_argument("--ar-ckpt", type=Path, default=None) parser.add_argument("--model-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--max-ar-steps", type=int, default=6 * 75) parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
parser.add_argument("--max-nar-levels", type=int, default=7) parser.add_argument("--max-nar-levels", type=int, default=7)
@ -41,7 +40,7 @@ def main():
parser.add_argument("--dtype", type=str, default=None) parser.add_argument("--dtype", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype, amp=args.amp )
tts.inference( tts.inference(
text=args.text, text=args.text,
references=args.references, references=args.references,

View File

@ -162,6 +162,9 @@ class Dataset:
@dataclass() @dataclass()
class Model: class Model:
_max_levels: int = 0
_embeddings: str | None = None
name: str = "" # vanity name for the model name: str = "" # vanity name for the model
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
size: str | dict = "full" # preset string or explicitly defined dimensionality size: str | dict = "full" # preset string or explicitly defined dimensionality
@ -169,6 +172,7 @@ class Model:
prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
langs: int = 1 # defined languages langs: int = 1 # defined languages
tones: int = 1 # defined tones
experts: int = 1 experts: int = 1
arch_type: str = "retnet" # or "transformer"" arch_type: str = "retnet" # or "transformer""
training: bool = True # unneeded now training: bool = True # unneeded now
@ -176,6 +180,13 @@ class Model:
p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@property
def max_levels(self):
return self._max_levels if self._max_levels > 0 else self.prom_levels
@property @property
# required for fp8 as the lengths needs to be divisible by 8 # required for fp8 as the lengths needs to be divisible by 8
def input_alignment(self): def input_alignment(self):
@ -203,7 +214,7 @@ class Model:
if self.interleave: if self.interleave:
name.append("interleaved") name.append("interleaved")
else: else:
name.append(f'{cfg.models.prom_levels}') name.append(f'{cfg.model.prom_levels}')
return "-".join(name) return "-".join(name)
@ -256,58 +267,6 @@ class Model:
def activation_checkpointing(self): def activation_checkpointing(self):
return cfg.trainer.activation_checkpointing return cfg.trainer.activation_checkpointing
@dataclass()
class Models:
_max_levels: int = 0
_prom_levels: int = 1
_embeddings: str | None = None
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False),
])
def get(self, name=None):
if not name:
return [ Model(**model) for model in self._models ]
for model in self._models:
if model.name == name:
return model
raise ValueError
@property
def ar(self):
return self.get("ar")
@property
def ar_nar(self):
return self.get("ar+nar")
@property
def nar(self):
return self.get("nar")
@property
def prom_levels(self):
prom_levels = self._prom_levels
for model in self._models:
prom_levels = max(prom_levels, model.prom_levels)
return prom_levels
@property
def tasks(self):
tasks = 1
for model in self._models:
tasks = max(tasks, model.tasks)
return tasks
@property
def max_levels(self):
return self._max_levels if self._max_levels > 0 else self.prom_levels
@dataclass() @dataclass()
class Hyperparameters: class Hyperparameters:
batch_size: int = 8 batch_size: int = 8
@ -568,7 +527,7 @@ class Config(_Config):
experimental: bool = False # So I can stop commenting out things when committing experimental: bool = False # So I can stop commenting out things when committing
dataset: Dataset = field(default_factory=lambda: Dataset) dataset: Dataset = field(default_factory=lambda: Dataset)
models: Models = field(default_factory=lambda: Models) model: Model = field(default_factory=lambda: Model)
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation) evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer) trainer: Trainer = field(default_factory=lambda: Trainer)
@ -617,7 +576,7 @@ class Config(_Config):
def format( self ): def format( self ):
self.dataset = Dataset(**self.dataset) self.dataset = Dataset(**self.dataset)
self.models = Models(**self.models) self.model = Model(**self.model)
self.hyperparameters = Hyperparameters(**self.hyperparameters) self.hyperparameters = Hyperparameters(**self.hyperparameters)
self.evaluation = Evaluation(**self.evaluation) self.evaluation = Evaluation(**self.evaluation)
self.trainer = Trainer(**self.trainer) self.trainer = Trainer(**self.trainer)

View File

@ -29,22 +29,27 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# to-do: clean up this symmap mess
def get_phone_symmap(): def get_phone_symmap():
if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5:
return json.loads( cfg.hdf5['symmap'].asstr()[()] ) return json.loads( cfg.hdf5['symmap'].asstr()[()] )
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, '': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, '': 220, 'eˈ': 221, 'ʍ': 222, '': 223, '': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228} return {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '': 179, '': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, '': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, '': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, '': 220, 'eˈ': 221, 'ʍ': 222, '': 223, '': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228}
return symmap
def get_lang_symmap(): def get_lang_symmap():
symmap = { return {
"en": 0, "en": 0,
"ja": 1, "ja": 1,
} }
def get_tone_symmap():
return {
"neutral": 0,
}
return symmap return symmap
def get_task_symmap(): def get_task_symmap():
symmap = { return {
"<tts>": 0, "<tts>": 0,
"<tts-c>": 1, "<tts-c>": 1,
"<ns>": 2, "<ns>": 2,
@ -54,7 +59,6 @@ def get_task_symmap():
"<mask>": 6, "<mask>": 6,
"<eoe>": 7, "<eoe>": 7,
} }
return symmap
def _replace_file_extension(path, suffix): def _replace_file_extension(path, suffix):
return (path.parent / path.name.split(".")[0]).with_suffix(suffix) return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
@ -237,6 +241,7 @@ class Dataset(_Dataset):
self.spkr_symmap = self._get_spkr_symmap() self.spkr_symmap = self._get_spkr_symmap()
self.spkr_group_symmap = self._get_spkr_group_symmap() self.spkr_group_symmap = self._get_spkr_group_symmap()
self.lang_symmap = self._get_lang_symmap() self.lang_symmap = self._get_lang_symmap()
self.tone_symmap = self._get_tone_symmap()
self.task_symmap = self._get_task_symmap() self.task_symmap = self._get_task_symmap()
# assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8" # assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8"
@ -309,11 +314,14 @@ class Dataset(_Dataset):
def _get_lang_symmap(self): def _get_lang_symmap(self):
return get_lang_symmap() return get_lang_symmap()
def _get_tone_symmap(self):
return get_tone_symmap()
def _get_task_symmap(self): def _get_task_symmap(self):
return get_task_symmap() return get_task_symmap()
""" """
def get_task_token( self, token, levels=cfg.models.max_levels ): def get_task_token( self, token, levels=cfg.model.max_levels ):
if not hasattr(self, "task_symmap"): if not hasattr(self, "task_symmap"):
self.task_symmap = self._get_task_symmap() self.task_symmap = self._get_task_symmap()
return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16) return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16)
@ -339,7 +347,7 @@ class Dataset(_Dataset):
choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore} choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
choices = [*choices] choices = [*choices]
# no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step # no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step
if len(choices) == 0: if len(choices) == 0:
choices = [*set(self.paths_by_spkr_name[spkr_name])] choices = [*set(self.paths_by_spkr_name[spkr_name])]
""" """
@ -622,8 +630,8 @@ class Dataset(_Dataset):
""" """
# trim to fit to requested prom/resps levels # trim to fit to requested prom/resps levels
proms = proms[:, :cfg.models.prom_levels] proms = proms[:, :cfg.model.prom_levels]
resps = resps[:, :cfg.models.prom_levels] resps = resps[:, :cfg.model.prom_levels]
return dict( return dict(
@ -928,7 +936,7 @@ if __name__ == "__main__":
if task not in cfg.dataset.tasks_list: if task not in cfg.dataset.tasks_list:
continue continue
print(text, task, cfg.models.prom_levels) print(text, task, cfg.model.prom_levels)
print( proms.shape, resps.shape ) print( proms.shape, resps.shape )
tokens = 0 tokens = 0

View File

@ -21,7 +21,7 @@ except Exception as e:
cfg.inference.use_vocos = False cfg.inference.use_vocos = False
@cache @cache
def _load_encodec_model(device="cuda", levels=cfg.models.max_levels): def _load_encodec_model(device="cuda", levels=cfg.model.max_levels):
# Instantiate a pretrained EnCodec model # Instantiate a pretrained EnCodec model
assert cfg.sample_rate == 24_000 assert cfg.sample_rate == 24_000
@ -44,7 +44,7 @@ def _load_encodec_model(device="cuda", levels=cfg.models.max_levels):
return model return model
@cache @cache
def _load_vocos_model(device="cuda", levels=cfg.models.max_levels): def _load_vocos_model(device="cuda", levels=cfg.model.max_levels):
assert cfg.sample_rate == 24_000 assert cfg.sample_rate == 24_000
model = Vocos.from_pretrained("charactr/vocos-encodec-24khz") model = Vocos.from_pretrained("charactr/vocos-encodec-24khz")
@ -66,7 +66,7 @@ def _load_vocos_model(device="cuda", levels=cfg.models.max_levels):
return model return model
@cache @cache
def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.models.max_levels): def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.model.max_levels):
if vocos: if vocos:
model = _load_vocos_model(device, levels=levels) model = _load_vocos_model(device, levels=levels)
else: else:
@ -80,7 +80,7 @@ def unload_model():
@torch.inference_mode() @torch.inference_mode()
def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels): def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels):
""" """
Args: Args:
codes: (b q t) codes: (b q t)
@ -117,7 +117,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels):
return wav, model.sample_rate return wav, model.sample_rate
# huh # huh
def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.models.max_levels): def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.model.max_levels):
return decode(resps, device=device, levels=levels) return decode(resps, device=device, levels=levels)
def decode_to_file(resps: Tensor, path: Path, device="cuda"): def decode_to_file(resps: Tensor, path: Path, device="cuda"):
@ -131,7 +131,7 @@ def _replace_file_extension(path, suffix):
@torch.inference_mode() @torch.inference_mode()
def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.models.max_levels): def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.model.max_levels):
""" """
Args: Args:
wav: (t) wav: (t)
@ -224,7 +224,7 @@ def repeat_extend_audio( qnt, target ):
# merges two quantized audios together # merges two quantized audios together
# I don't know if this works # I don't know if this works
def merge_audio( *args, device="cpu", scale=[], levels=cfg.models.max_levels ): def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ):
qnts = [*args] qnts = [*args]
decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ]

View File

@ -26,7 +26,7 @@ from functools import cache
@cache @cache
def load_engines(training=True): def load_engines(training=True):
models = get_models(cfg.models.get(), training=training) models = get_models(cfg.model.get(), training=training)
engines = dict() engines = dict()
for name, model in models.items(): for name, model in models.items():
@ -145,8 +145,8 @@ def load_engines(training=True):
engine.freeze(freeze_all=False) engine.freeze(freeze_all=False)
# copy embeddings if requested # copy embeddings if requested
if cfg.models._embeddings is not None: if cfg.model._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings embeddings_path = cfg.relpath / cfg.model._embeddings
if embeddings_path.exists(): if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))

View File

@ -19,7 +19,7 @@ if deepspeed_available:
import deepspeed import deepspeed
class TTS(): class TTS():
def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ): def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ):
self.loading = True self.loading = True
self.input_sample_rate = 24000 self.input_sample_rate = 24000
@ -53,7 +53,10 @@ class TTS():
self.symmap = None self.symmap = None
def parse( name, model, state ): if model_ckpt:
state = torch.load(model_ckpt)
self.model = get_models(cfg.model.get(), training=False)[0]
if "userdata" in state and 'symmap' in state['userdata']: if "userdata" in state and 'symmap' in state['userdata']:
self.symmap = state['userdata']['symmap'] self.symmap = state['userdata']['symmap']
elif "symmap" in state: elif "symmap" in state:
@ -62,55 +65,26 @@ class TTS():
if "module" in state: if "module" in state:
state = state['module'] state = state['module']
model.load_state_dict(state) self.model.load_state_dict(state)
if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing: if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing:
model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module
return model
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get(), training=False)
for name, model in models.items():
if name.startswith("ar"):
state = torch.load(self.ar_ckpt)
self.ar = parse( name, model, state )
elif name.startswith("nar"):
state = torch.load(self.nar_ckpt)
self.nar = parse( name, model, state )
if name.startswith("ar+nar"):
self.nar = self.ar
else: else:
self.load_models() engines = load_engines(training=False)
for name, engine in engines.items():
self.model = engine.module
break
if self.dtype != torch.int8: if self.dtype != torch.int8:
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.ar.eval() self.model.eval()
self.nar.eval()
if self.symmap is None: if self.symmap is None:
self.symmap = get_phone_symmap() self.symmap = get_phone_symmap()
self.loading = False self.loading = False
def load_models( self ):
engines = load_engines(training=False)
for name, engine in engines.items():
if name.startswith("ar"):
self.ar = engine.module
elif name.startswith("nar"):
self.nar = engine.module
if name.startswith("ar+nar"):
self.nar = self.ar
def encode_text( self, text, language="en" ): def encode_text( self, text, language="en" ):
# already a tensor, return it # already a tensor, return it
if isinstance( text, Tensor ): if isinstance( text, Tensor ):
@ -193,7 +167,7 @@ class TTS():
lang = to_device(lang, self.device).to(torch.uint8) lang = to_device(lang, self.device).to(torch.uint8)
with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp):
resps_list = self.ar( resps_list = self.model(
text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context,
sampling_temperature=ar_temp, sampling_temperature=ar_temp,
sampling_min_temperature=min_ar_temp, sampling_min_temperature=min_ar_temp,
@ -205,7 +179,7 @@ class TTS():
sampling_mirostat_eta=mirostat_eta, sampling_mirostat_eta=mirostat_eta,
) )
resps_list = [r.unsqueeze(-1) for r in resps_list] resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar( resps_list = self.model(
text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list,
max_levels=max_nar_levels, max_levels=max_nar_levels,
sampling_temperature=nar_temp, sampling_temperature=nar_temp,

View File

@ -1,19 +1,9 @@
from .ar import AR
from .nar import NAR
from .ar_nar import AR_NAR from .ar_nar import AR_NAR
def get_model(cfg, training=True): def get_model(cfg, training=True):
if cfg.name == "ar":
Model = AR
elif cfg.name == "nar":
Model = NAR
elif cfg.name == "ar+nar":
Model = AR_NAR
else:
raise f"invalid model name: {cfg.name}"
name = cfg.name name = cfg.name
model = Model( model = AR_NAR(
n_tokens=cfg.tokens, n_tokens=cfg.tokens,
d_model=cfg.dim, d_model=cfg.dim,
n_heads=cfg.heads, n_heads=cfg.heads,

View File

@ -1,309 +0,0 @@
from ..config import cfg
from .base import Base, list_to_tensor, Categorical
import torch
from torch.nn.utils.rnn import pad_sequence
from einops import rearrange
from torch import Tensor
from tqdm import trange
class AR(Base):
@property
def causal(self):
return True
@property
def norm_type(self):
return "ln"
@property
def arch_type(self) -> str:
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.models.ar.arch_type
@property
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
@property
def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.resp_levels
return cfg.models.ar.resp_levels
@property
def n_max_levels(self) -> int:
return cfg.models.max_levels
@property
def n_tasks(self) -> int:
return cfg.models.ar.tasks
@property
def n_langs(self) -> int:
return cfg.models.ar.langs
@property
def recurrent_chunk_size(self) -> int:
if cfg.mode == "training":
return 0
return cfg.inference.recurrent_chunk_size
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.ar.rotary_embedding_base
"""
@property
def interleave(self) -> bool:
if hasattr(self, "config") and self.config:
return self.config.interleave
return False
@property
def monolithic(self) -> bool:
return False
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.ar.version
def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero()
if len(indices) == 0:
return l
return l[: indices.min().item()]
def _interleave( self, codes ):
if not self.interleave:
return codes
return codes.flatten()
def _deinterleave( self, codes, length = 0 ):
if not self.interleave:
return codes
return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) )
@staticmethod
def _unsqueeze_list(x_list, axis=-1):
return [x.unsqueeze(dim=axis) for x in x_list]
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
max_steps: int = 1000,
max_resp_context: int = -1,
sampling_temperature: float = 1.0,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0,
sampling_beam_width: int = 0,
sampling_mirostat_tau: float = 0.0,
sampling_mirostat_eta: float = 0.1,
):
if resps_list is not None:
if self.interleave:
resps_list = [self._interleave(r) for r in resps_list]
else:
resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels
return super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=self._unsqueeze_list(resps_list),
targ_list=resps_list,
lang_list=lang_list,
quant_levels=None,
)
device = text_list[0].device
batch_size = len(text_list)
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ]
stopped = torch.zeros(batch_size, device=device).bool()
recurrent_state = {} if cfg.inference.recurrent_forward else None
mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None
sampling_beam_width_use_logs = True
scores = [ 1.0 ] * sampling_beam_width
if self.interleave:
max_steps *= self.n_prom_levels
# get next in sequence
for n in trange(max_steps // max(1, self.recurrent_chunk_size)):
if max_resp_context > 0:
resps_list = self._unsqueeze_list([ sequence[-max_resp_context:] for sequence in sequence_list ] )
else:
resps_list = self._unsqueeze_list(sequence_list)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=resps_list,
state=recurrent_state
)
r = super().sample(
logits=logits,
resps_list=resps_list,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
length_penalty=sampling_length_penalty,
beam_width=sampling_beam_width,
mirostat=mirostat,
)
if mirostat is not None:
# r is the state
mirostat = r
# extract token from state
r = [ state["token"] for state in mirostat ]
# we do it here because the sampler will already expand our logits list
elif sampling_beam_width > 0:
# expand tuple
r, s = r
# first step, expand batch
if batch_size == 1:
batch_size *= sampling_beam_width
text_list = text_list * sampling_beam_width
proms_list = proms_list * sampling_beam_width
sequence_list = sequence_list * sampling_beam_width
stopped = torch.zeros(batch_size, device=device).bool()
# update scores
if sampling_beam_width_use_logs:
scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ]
else:
scores = [ scores[i] * score for i, score in enumerate(s) ]
# append tokens
for i, ri in enumerate(r):
if self.stop_token in ri:
stopped[i] = True
sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)])
# stop token found
stopped |= r == self.stop_token
if stopped.all().item():
break
# pick the best scoring candidate
# desu this is always going to be candidate 0
if sampling_beam_width and len(scores) > 0:
best_idx, best_score = (0, 0)
for idx, score in enumerate(scores):
if best_score > score:
best_idx, best_score = idx, score
sequence_list = [sequence_list[best_idx]]
if self.interleave:
sequence_list = [self._deinterleave(r) for r in sequence_list]
return [self._prune(r) for r in sequence_list]
def example_usage():
cfg.trainer.backend = "local"
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
from ..engines import Engine
from tqdm import tqdm
from ..utils import wrapper as ml
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
text_list = [
#torch.tensor([1, 2, 3], device=device),
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
]
proms_list = [
#x8(torch.tensor([1, 2, 3], device=device)),
qnt.to(device),
]
resps_list = [
qnt.to(device),
]
text_list = text_list[:1]
proms_list = proms_list[:1]
resps_list = resps_list[:1]
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 24,
}
"""
try:
kwargs['config'] = cfg.models.ar
except Exception as e:
pass
"""
model = AR(**kwargs).to(device)
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
engine = Engine(model=model, optimizer=optimizer)
def sample( name, steps=600 ):
engine.eval()
out = engine(text_list, proms_list, max_steps=steps)
for i, o in enumerate(out):
wav, sr = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device)
def train():
engine.train()
t = trange(steps)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
tqdm.write(f"{stats}")
sample("init", 75)
train()
sample("final")
if __name__ == "__main__":
example_usage()

View File

@ -25,29 +25,33 @@ class AR_NAR(Base):
def arch_type(self) -> str: def arch_type(self) -> str:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
return self.config.arch_type return self.config.arch_type
return cfg.models.ar_nar.arch_type return cfg.model.arch_type
@property @property
def n_prom_levels(self) -> int: def n_prom_levels(self) -> int:
return cfg.models.prom_levels return cfg.model.prom_levels
@property @property
def n_resp_levels(self) -> int: def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
return self.config.resp_levels return self.config.resp_levels
return cfg.models.ar_nar.resp_levels return cfg.model.resp_levels
@property @property
def n_max_levels(self) -> int: def n_max_levels(self) -> int:
return cfg.models.max_levels return cfg.model.max_levels
@property @property
def n_tasks(self) -> int: def n_tasks(self) -> int:
return cfg.models.ar_nar.tasks return cfg.model.tasks
@property @property
def n_langs(self) -> int: def n_langs(self) -> int:
return cfg.models.ar_nar.langs return cfg.model.langs
@property
def n_tones(self) -> int:
return cfg.model.tones
@property @property
def recurrent_chunk_size(self) -> int: def recurrent_chunk_size(self) -> int:
@ -58,7 +62,7 @@ class AR_NAR(Base):
def rotary_embedding_base(self) -> float: def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base return self.config.rotary_embedding_base
return cfg.models.ar_nar.rotary_embedding_base return cfg.model.rotary_embedding_base
""" """
@property @property
@ -73,7 +77,7 @@ class AR_NAR(Base):
def version(self) -> int: def version(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
return self.config.version return self.config.version
return cfg.models.ar_nar.version return cfg.model.version
def _prune(self, l: Tensor): def _prune(self, l: Tensor):
indices = (l == self.stop_token).nonzero() indices = (l == self.stop_token).nonzero()
@ -92,6 +96,7 @@ class AR_NAR(Base):
resps_list: list[Tensor] | None = None, resps_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
max_steps: int = 1000, max_steps: int = 1000,
max_levels: int = 0, max_levels: int = 0,
@ -134,10 +139,10 @@ class AR_NAR(Base):
else: else:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
""" """
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
else: else:
quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ]) quant_levels = torch.Tensor([ 0 if random.random() < cfg.model.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ])
""" """
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
@ -162,6 +167,7 @@ class AR_NAR(Base):
resps_list=resps_list, resps_list=resps_list,
targ_list=targ_list, targ_list=targ_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
# is NAR # is NAR
@ -182,6 +188,7 @@ class AR_NAR(Base):
proms_list=proms_list, proms_list=proms_list,
resps_list=prev_list, resps_list=prev_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels, quant_levels=quant_levels,
) )
@ -234,6 +241,7 @@ class AR_NAR(Base):
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list,
state=recurrent_state state=recurrent_state
) )
else: else:
@ -242,6 +250,7 @@ class AR_NAR(Base):
proms_list=proms_list, proms_list=proms_list,
resps_list=resps_list, resps_list=resps_list,
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list,
state=recurrent_state state=recurrent_state
) )
@ -312,14 +321,14 @@ def example_usage():
import re import re
device = "cuda" device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):
split = content.split(" ") split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"] phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]) return torch.tensor([*map(symmap.get, phones)])
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
cfg.hyperparameters.gradient_accumulation_steps = 1 cfg.hyperparameters.gradient_accumulation_steps = 1
@ -359,7 +368,7 @@ def example_usage():
""" """
try: try:
kwargs['config'] = cfg.models.ar_nar kwargs['config'] = cfg.model
except Exception as e: except Exception as e:
pass pass
""" """
@ -374,8 +383,8 @@ def example_usage():
# copy embeddings if requested # copy embeddings if requested
""" """
if cfg.models._embeddings is not None: if cfg.model._embeddings is not None:
embeddings_path = cfg.relpath / cfg.models._embeddings embeddings_path = cfg.relpath / cfg.model._embeddings
if embeddings_path.exists(): if embeddings_path.exists():
embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device))

View File

@ -267,6 +267,10 @@ class Base(nn.Module):
def n_tasks(self) -> int: def n_tasks(self) -> int:
raise NotImplementedError raise NotImplementedError
@property
def n_tones(self) -> int:
raise NotImplementedError
@property @property
def recurrent_chunk_size(self) -> int: def recurrent_chunk_size(self) -> int:
raise NotImplementedError raise NotImplementedError
@ -343,6 +347,7 @@ class Base(nn.Module):
self.text_emb = Embedding(n_tokens, d_model) self.text_emb = Embedding(n_tokens, d_model)
self.langs_emb = None self.langs_emb = None
self.tones_emb = None
self.tasks_emb = None self.tasks_emb = None
if self.version == 1: # legacy if self.version == 1: # legacy
@ -360,6 +365,9 @@ class Base(nn.Module):
self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None
self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None
if self.version >= 4:
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))
if self.arch_type == "transformer": if self.arch_type == "transformer":
@ -522,54 +530,16 @@ class Base(nn.Module):
ignore_index=self.ignore_index, ignore_index=self.ignore_index,
) )
def forward( def _forward(
self, self,
text_list: list[Tensor], inputs,
proms_list: list[Tensor], mask = None,
resps_list: list[Tensor], state = None,
targ_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
): ):
batch_size = len(text_list) x = inputs
m = mask.squeeze(-1).int()
if self.langs_emb is None:
lang_list = None
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.langs_emb(lang_list) if lang_list is not None else None,
self.proms_emb(proms_list),
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
aux_loss = None aux_loss = None
device = x.device
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input
shape = list(x.shape)
shape[1] = self.l_padding - shape[1] % self.l_padding
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding], dim=1)
# pad mask
shape[2] = 1
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
m = torch.cat([m, padding], dim=1)
# for simplicity
mask = m.squeeze(-1).int()
""" """
# Broken # Broken
if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"): if state is not None and (self.arch_type == "retnet" or self.arch_type == "retnet-hf"):
@ -587,7 +557,7 @@ class Base(nn.Module):
xi = x[:, n, :].unsqueeze(1) xi = x[:, n, :].unsqueeze(1)
kwargs = dict( kwargs = dict(
attention_mask=mask, attention_mask=m,
inputs_embeds=xi, inputs_embeds=xi,
past_key_values=state, past_key_values=state,
use_cache=True, use_cache=True,
@ -603,9 +573,9 @@ class Base(nn.Module):
""" """
# HF transformer derived model # HF transformer derived model
if self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral": if self.arch_type in ["llama", "mistral", "mixtral"]:
kwargs = dict( kwargs = dict(
attention_mask=mask, attention_mask=m,
inputs_embeds=x, inputs_embeds=x,
past_key_values=state, past_key_values=state,
use_cache=True, use_cache=True,
@ -632,7 +602,7 @@ class Base(nn.Module):
x = self.sin_emb.add_pe(x) x = self.sin_emb.add_pe(x)
# pass our inputs through the transformer # pass our inputs through the transformer
for block in self.blocks: for block in self.blocks:
x = block(x, mask, l) x = block(x, m, l)
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
# pass our inputs through the RetNet # pass our inputs through the RetNet
x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True) x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True)
@ -642,7 +612,7 @@ class Base(nn.Module):
first = state is None or len(state) == 0 first = state is None or len(state) == 0
kwargs = dict( kwargs = dict(
attention_mask=mask, attention_mask=m,
inputs_embeds=x if first else x[:, -1, :].unsqueeze(1), inputs_embeds=x if first else x[:, -1, :].unsqueeze(1),
past_key_values=None if first else state, past_key_values=None if first else state,
use_cache=True, use_cache=True,
@ -659,8 +629,76 @@ class Base(nn.Module):
x = self.model(x) x = self.model(x)
# output projection layer with masking # output projection layer with masking
x = self.classifier(x) * mask
x = self.classifier(x) * m return x, state, aux_loss
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | list | None = None,
):
device = text_list[0].device
batch_size = len(text_list)
# silently ignore languages if model does not have it
if self.langs_emb is None:
lang_list = None
# inject default language
elif lang_list is None:
lang_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
# silently ignore tones if model does not have it
if self.tones_emb is None:
tone_list = None
# inject default tone
elif tone_list is None:
tone_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ]
"""
# Typical sequence format
# To-do: integrate tasks again
<s><text></s><sep><lang><sep><prom><sep><tone><sep><resp><stop>
"""
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
self.langs_emb(lang_list) if lang_list is not None else None,
self.proms_emb(proms_list),
self.tones_emb(tone_list) if tone_list is not None else None,
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
x, m = list_to_tensor(x_list)
# pad our input and mask, but retain the original length by doing it after
if self.l_padding and x.shape[1] % self.l_padding != 0:
# pad input
shape = list(x.shape)
shape[1] = self.l_padding - shape[1] % self.l_padding
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
x = torch.cat([x, padding], dim=1)
# pad mask
shape[2] = 1
padding = torch.zeros(shape, dtype=x.dtype, device=x.device)
m = torch.cat([m, padding], dim=1)
x, state, aux_loss = self._forward(
inputs=x,
mask=m,
state=state,
)
# Remove padding # Remove padding
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
@ -790,7 +828,7 @@ def example_usage():
from .nar import NAR from .nar import NAR
device = "cuda" device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"): def tokenize(content, lang_marker="en"):
split = content.split(" ") split = content.split(" ")
@ -812,7 +850,7 @@ def example_usage():
train = True train = True
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device)
text_list = [ text_list = [
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
#tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device), #tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device),

View File

@ -1,235 +0,0 @@
from ..config import cfg
from .base import Base
import torch
from torch import Tensor
from tqdm import trange
class NAR(Base):
@property
def causal(self):
return False
@property
def arch_type(self) -> str:
if hasattr(self, "config") and self.config:
return self.config.arch_type
return cfg.models.nar.arch_type
@property
def norm_type(self):
return "ln" if self.n_resp_levels == 1 else "adaln"
@property
def n_prom_levels(self) -> int:
return cfg.models.prom_levels
@property
def n_resp_levels(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.resp_levels
return cfg.models.nar.resp_levels
@property
def n_max_levels(self) -> int:
return cfg.models.max_levels
@property
def n_tasks(self) -> int:
return cfg.models.nar.tasks
@property
def n_langs(self) -> int:
return cfg.models.nar.langs
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
return self.config.version
return cfg.models.nar.version
@property
def recurrent_chunk_size(self) -> int:
return 0
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.nar.rotary_embedding_base
"""
@property
def interleave(self) -> bool:
return False
@property
def monolithic(self) -> bool:
return False
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
lang_list: list[Tensor] | None = None,
max_levels: int = 0,
sampling_temperature: float = 0.2,
sampling_min_temperature: float = -1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_repetition_penalty_decay: float = 0.0,
sampling_length_penalty: float = 0.0, # unused
sampling_beam_width: int = 0, # unused
sampling_mirostat_tau: float = 0.0, # unused
):
"""
Args:
text_list: [t] * b
proms_list: [t' l] * b, l=8
resps_list: [t'' l] * b, l=1 or 8, 1 for testing and 8 for training.
Returns:
[t'' l], l=8 if testing. empty list will be returned during training.
"""
n_levels_set = {r.shape[-1] for r in resps_list}
if len(n_levels_set) > 1:
raise ValueError(f"Please give only one level, got {n_levels_set}.")
n_levels = next(iter(n_levels_set))
device = text_list[0].device
if n_levels == self.n_resp_levels + 1:
assert resps_list is not None
quant_levels = torch.randint(0, self.n_resp_levels, (len(resps_list),))
prev_list = [o[..., : l + 1] for o, l in zip(resps_list, quant_levels)]
targ_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)]
#quant_levels = quant_levels.to(device=device)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
targ_list=targ_list,
lang_list=lang_list,
quant_levels=quant_levels,
)
prev_list = []
else:
prev_list = resps_list
if max_levels == 0:
max_levels = self.n_resp_levels
while True:
level = prev_list[0].shape[-1] - 1
if level >= max_levels: # min(max_levels, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
quant_levels = torch.full((len(text_list),), level, device=device)
logits = super().forward(
text_list=text_list,
proms_list=proms_list,
resps_list=prev_list,
lang_list=lang_list,
quant_levels=quant_levels,
)
resps_list = super().sample(
logits=logits,
resps_list=prev_list,
quant_levels=quant_levels,
temperature=sampling_temperature,
min_temperature=sampling_min_temperature,
top_p=sampling_top_p,
top_k=sampling_top_k,
repetition_penalty=sampling_repetition_penalty,
repetition_penalty_decay=sampling_repetition_penalty_decay,
#length_penalty=sampling_length_penalty,
#beam_width=sampling_beam_width,
#mirostat_tau=sampling_mirostat_tau,
#mirostat_state=mirostat_state,
)
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
return prev_list
def example_usage():
cfg.trainer.backend = "local"
from functools import partial
from einops import repeat
from ..emb.qnt import decode_to_file
from ..engines import Engine
from tqdm import tqdm
from ..utils import wrapper as ml
device = "cuda"
x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels)
symmap = {'<s>': 1, '</s>': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, '': 11, '': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, '': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, '': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, '': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, '': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, '': 78, '': 79, 'vˈ': 80, '': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, '': 85, 'pˈ': 86, 'ðˌ': 87, '': 88, '': 89, '': 90, '̩': 91, 'ʔ': 92, '': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, '': 100, 'uːˈ': 101, 'iːˈ': 102, '': 103, '.ˈ': 104, '': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, '': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '': 126, 'ɫ': 127, 'q': 128, '': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '': 149, '': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, '': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178}
def tokenize(content, lang_marker="en"):
split = content.split(" ")
phones = [f"<s>"] + [ " " if not p else p for p in split ] + [f"</s>"]
return torch.tensor([*map(symmap.get, phones)]).to()
# to-do: unmangle this and the resp shit
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
text_list = [
#torch.tensor([1, 2, 3], device=device),
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
]
proms_list = [
x8(torch.tensor([2, 3], device=device)),
]
resps_list = [
qnt.to(device),
]
kwargs = {
'n_tokens': 1024,
'd_model': 1024,
'n_heads': 16,
'n_layers': 12,
}
model = NAR(**kwargs).to(device)
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
engine = Engine(model=model, optimizer=optimizer)
def sample( name ):
engine.eval()
codes = engine( text_list, proms_list, resps_list=[r[..., 0].unsqueeze(-1) for r in resps_list], sampling_temperature=0.2 )
decode_to_file( codes[0], f"data/nar.{name}.wav", device )
def train():
engine.train()
t = trange(steps)
for i in t:
stats = {"step": i}
stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list)
tqdm.write(f"{stats}")
sample("init")
train()
sample("final")
if __name__ == "__main__":
example_usage()

View File

@ -109,7 +109,7 @@ if __name__ == "__main__":
path = cfg.relpath / "logs" path = cfg.relpath / "logs"
paths = path.rglob(f"./*/{args.filename}") paths = path.rglob(f"./*/{args.filename}")
args.models = [ model for model in cfg.models.get() if model.training and (args.model == "*" or model.name in args.model) ] args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
if args.ys == "": if args.ys == "":
args.ys = ["loss"] args.ys = ["loss"]

View File

@ -54,14 +54,13 @@ def init_tts(restart=False):
parser = argparse.ArgumentParser(allow_abbrev=False) parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too
parser.add_argument("--ar-ckpt", type=Path, default=None) parser.add_argument("--model-ckpt", type=Path, default=None)
parser.add_argument("--nar-ckpt", type=Path, default=None)
parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--amp", action="store_true") parser.add_argument("--amp", action="store_true")
parser.add_argument("--dtype", type=str, default="auto") parser.add_argument("--dtype", type=str, default="auto")
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp )
return tts return tts
@gradio_wrapper(inputs=layout["inference"]["inputs"].keys()) @gradio_wrapper(inputs=layout["inference"]["inputs"].keys())