diff --git a/README.md b/README.md index 75106d1..ca7942d 100755 --- a/README.md +++ b/README.md @@ -6,16 +6,10 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder. -[Main Repo](https://git.ecker.tech/mrq/vall-e) | [GitHub Mirror](https://github.com/e-c-k-e-r/vall-e/) - > **Note** Development on this is very sporadic. Gomen. ## Requirements -* [`DeepSpeed`](https://github.com/microsoft/DeepSpeed#requirements): - - DeepSpeed training is Linux only. Installation under Windows should ignore trying to install DeepSpeed. - - If your config YAML has the training backend set to `deepspeed`, you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package. - * [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/): - For phonemizing text, this repo requires `espeak`/`espeak-ng` installed. - Linux users can consult their package managers on installing `espeak`/`espeak-ng`. @@ -24,7 +18,7 @@ An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), ## Install -Simply run `pip install git+https://git.ecker.tech/mrq/vall-e`. +Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`. I've tested this repo under Python versions `3.10.9` and `3.11.3`. @@ -68,7 +62,7 @@ A script to setup a proper environment and train can be invoked with `./scripts/ If you're interested in creating an HDF5 copy of your dataset, simply invoke: `python -m vall_e.data --action='hdf5' yaml='./data/config.yaml'` -5. Train the AR and NAR models using the following scripts: `python -m vall_e.train yaml=./data/config.yaml` +5. Train the model using the following scripts: `python -m vall_e.train yaml=./data/config.yaml` * If distributing your training (for example, multi-GPU), use `deepspeed --module vall_e.train yaml="./data/config.yaml"` You may quit your training any time by just entering `quit` in your CLI. The latest checkpoint will be automatically saved. @@ -93,18 +87,23 @@ You can specify what X and Y labels you want to plot against by passing `--xs to #### Training Under Windows -As training under `deepspeed` and Windows is not supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend. +As training under `deepspeed` and Windows is not (easily) supported, under your `config.yaml`, simply change `trainer.backend` to `local` to use the local training backend. -Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment. +Keep in mind that creature comforts like distributed training or `float16` training cannot be verified as working at the moment with the local trainer. #### Training on Low-VRAM Cards -During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM) with `trainer.deepspeed.compression_training` enabled with both the AR and NAR at a batch size of 16, albeit I feel this is mostly snakeoil. Better VRAM savings can be had with use of BitsAndBytes and their respective flags (specifically its AdamW implementation). - -VRAM use is also predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful. +During experimentation, I've found I can comfortably train on a 4070Ti (12GiB VRAM). Howver, VRAM use is predicated on your dataset; a mix of large and small utterances will cause VRAM usage to spike and can trigger OOM conditions during the backwards pass if you are not careful. Additionally, under Windows, I managed to finetune the AR on my 2060 (6GiB VRAM) with a batch size of 8 (although, with the card as a secondary GPU). +#### Training Caveats + +Unfortunately, efforts to train a *good* foundational model seems entirely predicated on a good dataset. My dataset might be too fouled with: +* too short utterances: trying to extrapolate longer contexts seems to utterly fall apart from just the `text` being too long. +* too tightly trimmed utterances: there being little to no space at the start and end might harm associating `` and `` tokens with empty utterances. +* a poorly mapped phoneme mapping: I naively crafted my own phoneme mapping, where a HuggingFace tokenizer might supply a better token mapping. + #### Backend Architectures As the core of VALL-E makes use of a language model, various LLM architectures can be supported and slotted in. Currently supported: @@ -112,6 +111,8 @@ As the core of VALL-E makes use of a language model, various LLM architectures c * `transformer`: a basic attention-based transformer implementation, with attention heads + feed forwards. * `retnet`: using [TorchScale's RetNet](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py) implementation, a retention-based approach can be used instead. - Its implementation for MoE can also be utilized. +* `retnet-hf`: using [syncdoth/RetNet/](https://github.com/syncdoth/RetNet/) with a HuggingFace-compatible RetNet model + - inferencing cost is about 0.5x, and MoE is not implemented. * `llama`: using HF transformer's LLaMa implementation for its attention-based transformer, boasting RoPE and other improvements. * `mixtral`: using HF transformer's Mixtral implementation for its attention-based transformer, also utilizing its MoE implementation. * `bitnet`: using [this](https://github.com/kyegomez/BitNet/) implementation of BitNet's transformer. @@ -121,11 +122,11 @@ As the core of VALL-E makes use of a language model, various LLM architectures c To export the models, run: `python -m vall_e.export yaml=./data/config.yaml`. -This will export the latest checkpoints, for example, under `./data/ckpt/ar-retnet-2/fp32.pth` and `./data/ckpt/nar-retnet-2/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats. +This will export the latest checkpoints, for example, under `./data/ckpt/ar+nar-retnet-8/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats. ## Synthesis -To synthesize speech, invoke either (if exported the models): `python -m vall_e --ar-ckpt ./models/ar.pt --nar-ckpt ./models/nar.pt` or `python -m vall_e yaml=` +To synthesize speech, invoke either (if exported the models): `python -m vall_e --model-ckpt ./data/ckpt/ar+nar-retnet-8/fp32.pth` or `python -m vall_e yaml=` Some additional flags you can pass are: * `--language`: specifies the language for phonemizing the text, and helps guide inferencing when the model is trained against that language. @@ -154,7 +155,6 @@ And some experimental sampling flags you can use too (your mileage will ***defin ## To-Do * train and release a ***good*** model. - - the current model seems to require a ***long*** time of training at a very small LR rate to try and cover a wide variety of speakers of varying acoustics. * clean up the README, and document, document, document onto the wiki. * extend to ~~multiple languages ([VALL-E X](https://arxiv.org/abs/2303.03926)) and~~ addditional tasks ([SpeechX](https://arxiv.org/abs/2308.06873)). - training additional tasks needs the SpeechX implementation to be reworked. @@ -164,7 +164,7 @@ And some experimental sampling flags you can use too (your mileage will ***defin + this requires a properly trained AR, however. * work around issues with extending context past what's trained (despite RetNet's retention allegedly being able to defeat this): - "sliding" AR input, such as have the context a fixed length. - + the model may need to be trained for this with a fancy positional embedding injected. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful. + + the model may need to be trained for this with a fancy positional embedding injected OR already trained with a sliding context window in mind. Naively sliding the context window while making use of the RetNet implementation's positional embedding doesn't seem fruitful. ## Notices and Citations diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 27faaef..5ccb1cc 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -13,8 +13,7 @@ def main(): parser.add_argument("--out-path", type=Path, default=None) parser.add_argument("--yaml", type=Path, default=None) - parser.add_argument("--ar-ckpt", type=Path, default=None) - parser.add_argument("--nar-ckpt", type=Path, default=None) + parser.add_argument("--model-ckpt", type=Path, default=None) parser.add_argument("--max-ar-steps", type=int, default=6 * 75) parser.add_argument("--max-nar-levels", type=int, default=7) @@ -41,7 +40,7 @@ def main(): parser.add_argument("--dtype", type=str, default=None) args = parser.parse_args() - tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) + tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype, amp=args.amp ) tts.inference( text=args.text, references=args.references, diff --git a/vall_e/config.py b/vall_e/config.py index cef0bec..aa8d075 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -162,6 +162,9 @@ class Dataset: @dataclass() class Model: + _max_levels: int = 0 + _embeddings: str | None = None + name: str = "" # vanity name for the model version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding size: str | dict = "full" # preset string or explicitly defined dimensionality @@ -169,6 +172,7 @@ class Model: prom_levels: int = 8 # RVQ-bin levels this model accepts as an input prompt tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") langs: int = 1 # defined languages + tones: int = 1 # defined tones experts: int = 1 arch_type: str = "retnet" # or "transformer"" training: bool = True # unneeded now @@ -176,6 +180,13 @@ class Model: p_ar_level: float | str = "auto" # determines odds of selecting the AR (level 0) when training, "auto" for default behavior frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training + def get(self, name=None): + return [ self ] if not name or self.name == name else [] + + @property + def max_levels(self): + return self._max_levels if self._max_levels > 0 else self.prom_levels + @property # required for fp8 as the lengths needs to be divisible by 8 def input_alignment(self): @@ -203,7 +214,7 @@ class Model: if self.interleave: name.append("interleaved") else: - name.append(f'{cfg.models.prom_levels}') + name.append(f'{cfg.model.prom_levels}') return "-".join(name) @@ -256,58 +267,6 @@ class Model: def activation_checkpointing(self): return cfg.trainer.activation_checkpointing - -@dataclass() -class Models: - _max_levels: int = 0 - _prom_levels: int = 1 - _embeddings: str | None = None - - _models: list[Model] = field(default_factory=lambda: [ - Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False), - Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, experts=1, training=True, interleave=False), - ]) - - def get(self, name=None): - if not name: - return [ Model(**model) for model in self._models ] - - for model in self._models: - if model.name == name: - return model - - raise ValueError - - @property - def ar(self): - return self.get("ar") - - @property - def ar_nar(self): - return self.get("ar+nar") - - @property - def nar(self): - return self.get("nar") - - @property - def prom_levels(self): - prom_levels = self._prom_levels - for model in self._models: - prom_levels = max(prom_levels, model.prom_levels) - return prom_levels - - @property - def tasks(self): - tasks = 1 - for model in self._models: - tasks = max(tasks, model.tasks) - return tasks - - @property - def max_levels(self): - return self._max_levels if self._max_levels > 0 else self.prom_levels - @dataclass() class Hyperparameters: batch_size: int = 8 @@ -568,7 +527,7 @@ class Config(_Config): experimental: bool = False # So I can stop commenting out things when committing dataset: Dataset = field(default_factory=lambda: Dataset) - models: Models = field(default_factory=lambda: Models) + model: Model = field(default_factory=lambda: Model) hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) @@ -617,7 +576,7 @@ class Config(_Config): def format( self ): self.dataset = Dataset(**self.dataset) - self.models = Models(**self.models) + self.model = Model(**self.model) self.hyperparameters = Hyperparameters(**self.hyperparameters) self.evaluation = Evaluation(**self.evaluation) self.trainer = Trainer(**self.trainer) diff --git a/vall_e/data.py b/vall_e/data.py index 4472afe..8984a94 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -29,22 +29,27 @@ from tqdm.auto import tqdm _logger = logging.getLogger(__name__) +# to-do: clean up this symmap mess def get_phone_symmap(): if cfg.dataset.use_hdf5 and 'symmap' in cfg.hdf5: return json.loads( cfg.hdf5['symmap'].asstr()[()] ) - symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, 'ᵝ': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, 'oˌ': 220, 'eˈ': 221, 'ʍ': 222, 'eˌ': 223, 'uˌ': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228} - return symmap + return {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178, '”': 179, '“': 180, '“ˈ': 181, '“ˌ': 182, ';ˈ': 183, ';ˌ': 184, ':ˈ': 185, '1': 186, 'rˈ': 187, 'qˈ': 188, 'ᵻˌ': 189, 'ä': 190, '̞ˌ': 191, '̞': 192, 'ũˌ': 193, 'ʑˌ': 194, 'ᵝ': 195, 'ɽ': 196, 'ʲˌ': 197, 'ᵝˌ': 198, 'ũ': 199, 'ũˈ': 200, 'äˌ': 201, 'ɕ': 202, 'ɕˌ': 203, 'ɽˌ': 204, 'çˌ': 205, '…ˌ': 206, '̞ˈ': 207, 'äˈ': 208, 'ɽˈ': 209, 'ɸˌ': 210, 'ɴ': 211, 'ɸˈ': 212, 'ɕˈ': 213, 'ɸ': 214, 'ᵝˈ': 215, 'ʲˈ': 216, 'ĩ': 217, 'çˈ': 218, 'ĩˌ': 219, 'oˌ': 220, 'eˈ': 221, 'ʍ': 222, 'eˌ': 223, 'uˌ': 224, 'ʍˌ': 225, 'uˈ': 226, 'oˈ': 227, 'aˈ': 228} def get_lang_symmap(): - symmap = { + return { "en": 0, "ja": 1, } + +def get_tone_symmap(): + return { + "neutral": 0, + } return symmap def get_task_symmap(): - symmap = { + return { "": 0, "": 1, "": 2, @@ -54,7 +59,6 @@ def get_task_symmap(): "": 6, "": 7, } - return symmap def _replace_file_extension(path, suffix): return (path.parent / path.name.split(".")[0]).with_suffix(suffix) @@ -237,6 +241,7 @@ class Dataset(_Dataset): self.spkr_symmap = self._get_spkr_symmap() self.spkr_group_symmap = self._get_spkr_group_symmap() self.lang_symmap = self._get_lang_symmap() + self.tone_symmap = self._get_tone_symmap() self.task_symmap = self._get_task_symmap() # assert len(self.phone_symmap) < 256, "Unique token count should be [0,255] to fit within uint8" @@ -309,11 +314,14 @@ class Dataset(_Dataset): def _get_lang_symmap(self): return get_lang_symmap() + def _get_tone_symmap(self): + return get_tone_symmap() + def _get_task_symmap(self): return get_task_symmap() """ - def get_task_token( self, token, levels=cfg.models.max_levels ): + def get_task_token( self, token, levels=cfg.model.max_levels ): if not hasattr(self, "task_symmap"): self.task_symmap = self._get_task_symmap() return torch.Tensor([[ self.task_symmap[f'<{token}>'] for _ in range(levels) ]]).to(dtype=torch.int16) @@ -339,7 +347,7 @@ class Dataset(_Dataset): choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore} choices = [*choices] - # no other utterances, it'd make more sense to prune speakers with only one utterance in the validatoin step + # no other utterances, it'd make more sense to prune speakers with only one utterance in the validation step if len(choices) == 0: choices = [*set(self.paths_by_spkr_name[spkr_name])] """ @@ -622,8 +630,8 @@ class Dataset(_Dataset): """ # trim to fit to requested prom/resps levels - proms = proms[:, :cfg.models.prom_levels] - resps = resps[:, :cfg.models.prom_levels] + proms = proms[:, :cfg.model.prom_levels] + resps = resps[:, :cfg.model.prom_levels] return dict( @@ -928,7 +936,7 @@ if __name__ == "__main__": if task not in cfg.dataset.tasks_list: continue - print(text, task, cfg.models.prom_levels) + print(text, task, cfg.model.prom_levels) print( proms.shape, resps.shape ) tokens = 0 diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index a86af5e..242d852 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -21,7 +21,7 @@ except Exception as e: cfg.inference.use_vocos = False @cache -def _load_encodec_model(device="cuda", levels=cfg.models.max_levels): +def _load_encodec_model(device="cuda", levels=cfg.model.max_levels): # Instantiate a pretrained EnCodec model assert cfg.sample_rate == 24_000 @@ -44,7 +44,7 @@ def _load_encodec_model(device="cuda", levels=cfg.models.max_levels): return model @cache -def _load_vocos_model(device="cuda", levels=cfg.models.max_levels): +def _load_vocos_model(device="cuda", levels=cfg.model.max_levels): assert cfg.sample_rate == 24_000 model = Vocos.from_pretrained("charactr/vocos-encodec-24khz") @@ -66,7 +66,7 @@ def _load_vocos_model(device="cuda", levels=cfg.models.max_levels): return model @cache -def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.models.max_levels): +def _load_model(device="cuda", vocos=cfg.inference.use_vocos, levels=cfg.model.max_levels): if vocos: model = _load_vocos_model(device, levels=levels) else: @@ -80,7 +80,7 @@ def unload_model(): @torch.inference_mode() -def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels): +def decode(codes: Tensor, device="cuda", levels=cfg.model.max_levels): """ Args: codes: (b q t) @@ -117,7 +117,7 @@ def decode(codes: Tensor, device="cuda", levels=cfg.models.max_levels): return wav, model.sample_rate # huh -def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.models.max_levels): +def decode_to_wave(resps: Tensor, device="cuda", levels=cfg.model.max_levels): return decode(resps, device=device, levels=levels) def decode_to_file(resps: Tensor, path: Path, device="cuda"): @@ -131,7 +131,7 @@ def _replace_file_extension(path, suffix): @torch.inference_mode() -def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.models.max_levels): +def encode(wav: Tensor, sr: int = 24_000, device="cuda", levels=cfg.model.max_levels): """ Args: wav: (t) @@ -224,7 +224,7 @@ def repeat_extend_audio( qnt, target ): # merges two quantized audios together # I don't know if this works -def merge_audio( *args, device="cpu", scale=[], levels=cfg.models.max_levels ): +def merge_audio( *args, device="cpu", scale=[], levels=cfg.model.max_levels ): qnts = [*args] decoded = [ decode(qnt, device=device, levels=levels)[0] for qnt in qnts ] diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 2ae2fef..31e2551 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -26,7 +26,7 @@ from functools import cache @cache def load_engines(training=True): - models = get_models(cfg.models.get(), training=training) + models = get_models(cfg.model.get(), training=training) engines = dict() for name, model in models.items(): @@ -145,8 +145,8 @@ def load_engines(training=True): engine.freeze(freeze_all=False) # copy embeddings if requested - if cfg.models._embeddings is not None: - embeddings_path = cfg.relpath / cfg.models._embeddings + if cfg.model._embeddings is not None: + embeddings_path = cfg.relpath / cfg.model._embeddings if embeddings_path.exists(): embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) diff --git a/vall_e/inference.py b/vall_e/inference.py index 0b72024..e7c32b2 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -19,7 +19,7 @@ if deepspeed_available: import deepspeed class TTS(): - def __init__( self, config=None, ar_ckpt=None, nar_ckpt=None, device=None, amp=None, dtype=None ): + def __init__( self, config=None, model_ckpt=None, device=None, amp=None, dtype=None ): self.loading = True self.input_sample_rate = 24000 @@ -53,7 +53,10 @@ class TTS(): self.symmap = None - def parse( name, model, state ): + if model_ckpt: + state = torch.load(model_ckpt) + self.model = get_models(cfg.model.get(), training=False)[0] + if "userdata" in state and 'symmap' in state['userdata']: self.symmap = state['userdata']['symmap'] elif "symmap" in state: @@ -62,55 +65,26 @@ class TTS(): if "module" in state: state = state['module'] - model.load_state_dict(state) + self.model.load_state_dict(state) if cfg.inference.backend == "local" and deepspeed_available and cfg.trainer.deepspeed.inferencing: - model = deepspeed.init_inference(model=model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module - - return model - - if ar_ckpt and nar_ckpt: - self.ar_ckpt = ar_ckpt - self.nar_ckpt = nar_ckpt - - models = get_models(cfg.models.get(), training=False) - - for name, model in models.items(): - if name.startswith("ar"): - state = torch.load(self.ar_ckpt) - self.ar = parse( name, model, state ) - elif name.startswith("nar"): - state = torch.load(self.nar_ckpt) - self.nar = parse( name, model, state ) - - if name.startswith("ar+nar"): - self.nar = self.ar + self.model = deepspeed.init_inference(model=self.model, mp_size=1, replace_with_kernel_inject=True, dtype=dtype if not amp else torch.float32).module else: - self.load_models() + engines = load_engines(training=False) + for name, engine in engines.items(): + self.model = engine.module + break if self.dtype != torch.int8: - self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) - self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32) + self.model = self.model.to(self.device, dtype=self.dtype if not self.amp else torch.float32) - self.ar.eval() - self.nar.eval() + self.model.eval() if self.symmap is None: self.symmap = get_phone_symmap() self.loading = False - def load_models( self ): - engines = load_engines(training=False) - for name, engine in engines.items(): - if name.startswith("ar"): - self.ar = engine.module - elif name.startswith("nar"): - self.nar = engine.module - - if name.startswith("ar+nar"): - self.nar = self.ar - def encode_text( self, text, language="en" ): # already a tensor, return it if isinstance( text, Tensor ): @@ -193,7 +167,7 @@ class TTS(): lang = to_device(lang, self.device).to(torch.uint8) with torch.autocast("cuda", dtype=self.dtype, enabled=self.amp): - resps_list = self.ar( + resps_list = self.model( text_list=[phns], proms_list=[prom], lang_list=[lang], max_steps=max_ar_steps, max_resp_context=max_ar_context, sampling_temperature=ar_temp, sampling_min_temperature=min_ar_temp, @@ -205,7 +179,7 @@ class TTS(): sampling_mirostat_eta=mirostat_eta, ) resps_list = [r.unsqueeze(-1) for r in resps_list] - resps_list = self.nar( + resps_list = self.model( text_list=[phns], proms_list=[prom], lang_list=[lang], resps_list=resps_list, max_levels=max_nar_levels, sampling_temperature=nar_temp, diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 0e4566e..5979e44 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,19 +1,9 @@ -from .ar import AR -from .nar import NAR from .ar_nar import AR_NAR def get_model(cfg, training=True): - if cfg.name == "ar": - Model = AR - elif cfg.name == "nar": - Model = NAR - elif cfg.name == "ar+nar": - Model = AR_NAR - else: - raise f"invalid model name: {cfg.name}" name = cfg.name - model = Model( + model = AR_NAR( n_tokens=cfg.tokens, d_model=cfg.dim, n_heads=cfg.heads, diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py deleted file mode 100755 index d58ef26..0000000 --- a/vall_e/models/ar.py +++ /dev/null @@ -1,309 +0,0 @@ -from ..config import cfg -from .base import Base, list_to_tensor, Categorical - -import torch -from torch.nn.utils.rnn import pad_sequence - -from einops import rearrange -from torch import Tensor -from tqdm import trange - -class AR(Base): - @property - def causal(self): - return True - - @property - def norm_type(self): - return "ln" - - @property - def arch_type(self) -> str: - if hasattr(self, "config") and self.config: - return self.config.arch_type - return cfg.models.ar.arch_type - - @property - def n_prom_levels(self) -> int: - return cfg.models.prom_levels - - @property - def n_resp_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.resp_levels - return cfg.models.ar.resp_levels - - @property - def n_max_levels(self) -> int: - return cfg.models.max_levels - - @property - def n_tasks(self) -> int: - return cfg.models.ar.tasks - - @property - def n_langs(self) -> int: - return cfg.models.ar.langs - - @property - def recurrent_chunk_size(self) -> int: - if cfg.mode == "training": - return 0 - return cfg.inference.recurrent_chunk_size - - """ - @property - def rotary_embedding_base(self) -> float: - if hasattr(self, "config") and self.config: - return self.config.rotary_embedding_base - return cfg.models.ar.rotary_embedding_base - """ - - @property - def interleave(self) -> bool: - if hasattr(self, "config") and self.config: - return self.config.interleave - return False - - @property - def monolithic(self) -> bool: - return False - - @property - def version(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.version - return cfg.models.ar.version - - def _prune(self, l: Tensor): - indices = (l == self.stop_token).nonzero() - if len(indices) == 0: - return l - return l[: indices.min().item()] - - def _interleave( self, codes ): - if not self.interleave: - return codes - - return codes.flatten() - - def _deinterleave( self, codes, length = 0 ): - if not self.interleave: - return codes - - return torch.unflatten( codes[:codes.shape[0] // self.n_prom_levels * self.n_prom_levels], 0, ( codes.shape[0] // self.n_prom_levels, self.n_prom_levels ) ) - - @staticmethod - def _unsqueeze_list(x_list, axis=-1): - return [x.unsqueeze(dim=axis) for x in x_list] - - def forward( - self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor] | None = None, - lang_list: list[Tensor] | None = None, - max_steps: int = 1000, - max_resp_context: int = -1, - - sampling_temperature: float = 1.0, - sampling_min_temperature: float = -1.0, - sampling_top_k: int = -100, - sampling_top_p: float = 1.0, - sampling_repetition_penalty: float = 1.0, - sampling_repetition_penalty_decay: float = 0.0, - sampling_length_penalty: float = 0.0, - sampling_beam_width: int = 0, - - sampling_mirostat_tau: float = 0.0, - sampling_mirostat_eta: float = 0.1, - ): - if resps_list is not None: - if self.interleave: - resps_list = [self._interleave(r) for r in resps_list] - else: - resps_list = [r[..., 0] for r in resps_list] # guarantees we only have the first levels - - return super().forward( - text_list=text_list, - proms_list=proms_list, - resps_list=self._unsqueeze_list(resps_list), - targ_list=resps_list, - lang_list=lang_list, - quant_levels=None, - ) - - device = text_list[0].device - batch_size = len(text_list) - - sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in text_list ] - stopped = torch.zeros(batch_size, device=device).bool() - - recurrent_state = {} if cfg.inference.recurrent_forward else None - mirostat = [ - {"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} - ] * batch_size if sampling_mirostat_tau > 0.0 else None - - sampling_beam_width_use_logs = True - scores = [ 1.0 ] * sampling_beam_width - - if self.interleave: - max_steps *= self.n_prom_levels - - # get next in sequence - for n in trange(max_steps // max(1, self.recurrent_chunk_size)): - if max_resp_context > 0: - resps_list = self._unsqueeze_list([ sequence[-max_resp_context:] for sequence in sequence_list ] ) - else: - resps_list = self._unsqueeze_list(sequence_list) - - logits = super().forward( - text_list=text_list, - proms_list=proms_list, - resps_list=resps_list, - - state=recurrent_state - ) - - r = super().sample( - logits=logits, - resps_list=resps_list, - - temperature=sampling_temperature, - min_temperature=sampling_min_temperature, - top_p=sampling_top_p, - top_k=sampling_top_k, - repetition_penalty=sampling_repetition_penalty, - repetition_penalty_decay=sampling_repetition_penalty_decay, - length_penalty=sampling_length_penalty, - beam_width=sampling_beam_width, - - mirostat=mirostat, - ) - - if mirostat is not None: - # r is the state - mirostat = r - # extract token from state - r = [ state["token"] for state in mirostat ] - # we do it here because the sampler will already expand our logits list - elif sampling_beam_width > 0: - # expand tuple - r, s = r - # first step, expand batch - if batch_size == 1: - batch_size *= sampling_beam_width - text_list = text_list * sampling_beam_width - proms_list = proms_list * sampling_beam_width - sequence_list = sequence_list * sampling_beam_width - stopped = torch.zeros(batch_size, device=device).bool() - - # update scores - if sampling_beam_width_use_logs: - scores = [ (math.log(scores[i]) if scores[i] > 0 else 0) + math.log(score) for i, score in enumerate(s) ] - else: - scores = [ scores[i] * score for i, score in enumerate(s) ] - - # append tokens - for i, ri in enumerate(r): - if self.stop_token in ri: - stopped[i] = True - sequence_list[i] = torch.cat([sequence_list[i], ri.to(device)]) - - # stop token found - stopped |= r == self.stop_token - if stopped.all().item(): - break - - # pick the best scoring candidate - # desu this is always going to be candidate 0 - if sampling_beam_width and len(scores) > 0: - best_idx, best_score = (0, 0) - for idx, score in enumerate(scores): - if best_score > score: - best_idx, best_score = idx, score - - sequence_list = [sequence_list[best_idx]] - - if self.interleave: - sequence_list = [self._deinterleave(r) for r in sequence_list] - return [self._prune(r) for r in sequence_list] - - -def example_usage(): - cfg.trainer.backend = "local" - from functools import partial - - from einops import repeat - - from ..emb.qnt import decode_to_file - from ..engines import Engine - from tqdm import tqdm - from ..utils import wrapper as ml - - device = "cuda" - x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) - symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} - def tokenize(content, lang_marker="en"): - split = content.split(" ") - phones = [f""] + [ " " if not p else p for p in split ] + [f""] - return torch.tensor([*map(symmap.get, phones)]).to() - - qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) - - text_list = [ - #torch.tensor([1, 2, 3], device=device), - tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), - ] - proms_list = [ - #x8(torch.tensor([1, 2, 3], device=device)), - qnt.to(device), - ] - resps_list = [ - qnt.to(device), - ] - - text_list = text_list[:1] - proms_list = proms_list[:1] - resps_list = resps_list[:1] - - kwargs = { - 'n_tokens': 1024, - 'd_model': 1024, - 'n_heads': 16, - 'n_layers': 24, - } - - """ - try: - kwargs['config'] = cfg.models.ar - except Exception as e: - pass - """ - - model = AR(**kwargs).to(device) - steps = 500 - optimizer = ml.Prodigy(model.parameters(), lr=1.0) - engine = Engine(model=model, optimizer=optimizer) - - def sample( name, steps=600 ): - engine.eval() - out = engine(text_list, proms_list, max_steps=steps) - for i, o in enumerate(out): - wav, sr = decode_to_file(o, f"data/ar.{i}.{name}.wav", device=device) - - def train(): - engine.train() - t = trange(steps) - for i in t: - stats = {"step": i} - stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) - - tqdm.write(f"{stats}") - - sample("init", 75) - train() - sample("final") - -if __name__ == "__main__": - example_usage() diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 0ad074e..fe0fa5e 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -25,29 +25,33 @@ class AR_NAR(Base): def arch_type(self) -> str: if hasattr(self, "config") and self.config: return self.config.arch_type - return cfg.models.ar_nar.arch_type + return cfg.model.arch_type @property def n_prom_levels(self) -> int: - return cfg.models.prom_levels + return cfg.model.prom_levels @property def n_resp_levels(self) -> int: if hasattr(self, "config") and self.config: return self.config.resp_levels - return cfg.models.ar_nar.resp_levels + return cfg.model.resp_levels @property def n_max_levels(self) -> int: - return cfg.models.max_levels + return cfg.model.max_levels @property def n_tasks(self) -> int: - return cfg.models.ar_nar.tasks - + return cfg.model.tasks + @property def n_langs(self) -> int: - return cfg.models.ar_nar.langs + return cfg.model.langs + + @property + def n_tones(self) -> int: + return cfg.model.tones @property def recurrent_chunk_size(self) -> int: @@ -58,7 +62,7 @@ class AR_NAR(Base): def rotary_embedding_base(self) -> float: if hasattr(self, "config") and self.config: return self.config.rotary_embedding_base - return cfg.models.ar_nar.rotary_embedding_base + return cfg.model.rotary_embedding_base """ @property @@ -73,7 +77,7 @@ class AR_NAR(Base): def version(self) -> int: if hasattr(self, "config") and self.config: return self.config.version - return cfg.models.ar_nar.version + return cfg.model.version def _prune(self, l: Tensor): indices = (l == self.stop_token).nonzero() @@ -92,6 +96,7 @@ class AR_NAR(Base): resps_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, max_steps: int = 1000, max_levels: int = 0, @@ -134,10 +139,10 @@ class AR_NAR(Base): else: quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) """ - if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None: + if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None: quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) else: - quant_levels = torch.Tensor([ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ]) + quant_levels = torch.Tensor([ 0 if random.random() < cfg.model.p_ar_level else random.randint(1, self.n_resp_levels) for _ in range(batch_size) ]) """ targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target) @@ -162,6 +167,7 @@ class AR_NAR(Base): resps_list=resps_list, targ_list=targ_list, lang_list=lang_list, + tone_list=tone_list, quant_levels=quant_levels, ) # is NAR @@ -182,6 +188,7 @@ class AR_NAR(Base): proms_list=proms_list, resps_list=prev_list, lang_list=lang_list, + tone_list=tone_list, quant_levels=quant_levels, ) @@ -234,6 +241,7 @@ class AR_NAR(Base): proms_list=proms_list, resps_list=resps_list, lang_list=lang_list, + tone_list=tone_list, state=recurrent_state ) else: @@ -242,6 +250,7 @@ class AR_NAR(Base): proms_list=proms_list, resps_list=resps_list, lang_list=lang_list, + tone_list=tone_list, state=recurrent_state ) @@ -312,14 +321,14 @@ def example_usage(): import re device = "cuda" - x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) + x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels) symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} def tokenize(content, lang_marker="en"): split = content.split(" ") phones = [f""] + [ " " if not p else p for p in split ] + [f""] return torch.tensor([*map(symmap.get, phones)]) - qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) + qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device) cfg.hyperparameters.gradient_accumulation_steps = 1 @@ -359,7 +368,7 @@ def example_usage(): """ try: - kwargs['config'] = cfg.models.ar_nar + kwargs['config'] = cfg.model except Exception as e: pass """ @@ -374,8 +383,8 @@ def example_usage(): # copy embeddings if requested """ - if cfg.models._embeddings is not None: - embeddings_path = cfg.relpath / cfg.models._embeddings + if cfg.model._embeddings is not None: + embeddings_path = cfg.relpath / cfg.model._embeddings if embeddings_path.exists(): embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index cb514a2..1103e7f 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -262,11 +262,15 @@ class Base(nn.Module): @property def n_langs(self) -> int: raise NotImplementedError - + @property def n_tasks(self) -> int: raise NotImplementedError + @property + def n_tones(self) -> int: + raise NotImplementedError + @property def recurrent_chunk_size(self) -> int: raise NotImplementedError @@ -343,6 +347,7 @@ class Base(nn.Module): self.text_emb = Embedding(n_tokens, d_model) self.langs_emb = None + self.tones_emb = None self.tasks_emb = None if self.version == 1: # legacy @@ -359,6 +364,9 @@ class Base(nn.Module): if self.version >= 3: self.langs_emb = Embedding(self.n_langs, d_model) if self.n_langs > 0 else None self.tasks_emb = Embedding(self.n_tasks, d_model) if self.n_tasks > 0 else None + + if self.version >= 4: + self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None self.sep = nn.Parameter(torch.randn(d_model)) @@ -522,53 +530,15 @@ class Base(nn.Module): ignore_index=self.ignore_index, ) - def forward( + def _forward( self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor], - targ_list: list[Tensor] | None = None, - - lang_list: list[Tensor] | None = None, - - quant_levels: Tensor | None = None, - state: dict | list | None = None, + inputs, + mask = None, + state = None, ): - batch_size = len(text_list) - - if self.langs_emb is None: - lang_list = None - - x_list = self._samplewise_merge_tensors( - self.text_emb(text_list), - self.langs_emb(lang_list) if lang_list is not None else None, - self.proms_emb(proms_list), - self.resps_emb(resps_list, quant_levels), - sep=self.sep, - ) - - - x, m = list_to_tensor(x_list) + x = inputs + m = mask.squeeze(-1).int() aux_loss = None - - device = x.device - - # pad our input and mask, but retain the original length by doing it after - if self.l_padding and x.shape[1] % self.l_padding != 0: - # pad input - shape = list(x.shape) - shape[1] = self.l_padding - shape[1] % self.l_padding - - padding = torch.zeros(shape, dtype=x.dtype, device=x.device) - x = torch.cat([x, padding], dim=1) - - # pad mask - shape[2] = 1 - padding = torch.zeros(shape, dtype=x.dtype, device=x.device) - m = torch.cat([m, padding], dim=1) - - # for simplicity - mask = m.squeeze(-1).int() """ # Broken @@ -587,7 +557,7 @@ class Base(nn.Module): xi = x[:, n, :].unsqueeze(1) kwargs = dict( - attention_mask=mask, + attention_mask=m, inputs_embeds=xi, past_key_values=state, use_cache=True, @@ -603,9 +573,9 @@ class Base(nn.Module): """ # HF transformer derived model - if self.arch_type == "llama" or self.arch_type == "mistral" or self.arch_type == "mixtral": + if self.arch_type in ["llama", "mistral", "mixtral"]: kwargs = dict( - attention_mask=mask, + attention_mask=m, inputs_embeds=x, past_key_values=state, use_cache=True, @@ -632,7 +602,7 @@ class Base(nn.Module): x = self.sin_emb.add_pe(x) # pass our inputs through the transformer for block in self.blocks: - x = block(x, mask, l) + x = block(x, m, l) elif self.arch_type == "retnet": # pass our inputs through the RetNet x, _ = self.model(x, incremental_state=state, token_embeddings=x, features_only=True) @@ -642,7 +612,7 @@ class Base(nn.Module): first = state is None or len(state) == 0 kwargs = dict( - attention_mask=mask, + attention_mask=m, inputs_embeds=x if first else x[:, -1, :].unsqueeze(1), past_key_values=None if first else state, use_cache=True, @@ -659,8 +629,76 @@ class Base(nn.Module): x = self.model(x) # output projection layer with masking + x = self.classifier(x) * mask - x = self.classifier(x) * m + return x, state, aux_loss + + def forward( + self, + text_list: list[Tensor], + proms_list: list[Tensor], + resps_list: list[Tensor], + targ_list: list[Tensor] | None = None, + + lang_list: list[Tensor] | None = None, + tone_list: list[Tensor] | None = None, + + quant_levels: Tensor | None = None, + state: dict | list | None = None, + ): + device = text_list[0].device + batch_size = len(text_list) + + # silently ignore languages if model does not have it + if self.langs_emb is None: + lang_list = None + # inject default language + elif lang_list is None: + lang_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ] + + # silently ignore tones if model does not have it + if self.tones_emb is None: + tone_list = None + # inject default tone + elif tone_list is None: + tone_list = [ torch.Tensor([ 0 ]).to(dtype=torch.uint8, device=device) for _ in range(batch_size) ] + + """ + # Typical sequence format + # To-do: integrate tasks again + + """ + x_list = self._samplewise_merge_tensors( + self.text_emb(text_list), + self.langs_emb(lang_list) if lang_list is not None else None, + self.proms_emb(proms_list), + self.tones_emb(tone_list) if tone_list is not None else None, + self.resps_emb(resps_list, quant_levels), + sep=self.sep, + ) + + x, m = list_to_tensor(x_list) + + # pad our input and mask, but retain the original length by doing it after + if self.l_padding and x.shape[1] % self.l_padding != 0: + # pad input + shape = list(x.shape) + shape[1] = self.l_padding - shape[1] % self.l_padding + + padding = torch.zeros(shape, dtype=x.dtype, device=x.device) + x = torch.cat([x, padding], dim=1) + + # pad mask + shape[2] = 1 + padding = torch.zeros(shape, dtype=x.dtype, device=x.device) + m = torch.cat([m, padding], dim=1) + + + x, state, aux_loss = self._forward( + inputs=x, + mask=m, + state=state, + ) # Remove padding logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ] @@ -790,7 +828,7 @@ def example_usage(): from .nar import NAR device = "cuda" - x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) + x8 = partial(repeat, pattern="t -> t l", l=cfg.model.prom_levels) symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} def tokenize(content, lang_marker="en"): split = content.split(" ") @@ -812,7 +850,7 @@ def example_usage(): train = True - qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) + qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.model.prom_levels].to(device) text_list = [ tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), #tokenize("ˌ ɔ n ɡˌ o ʊ ɪ ŋ hˈ o ʊ m ð ə tˈ uː f ɹˈ ɛ n d z fˈ a ʊ n d ɐ lˈ ɛ ɾ ɚ f ɹ ʌ m ˈ æ θ o ʊ z , hˌ uː d ɪ zˈ a ɪ ɚ d ðˌ ɛ m t ə mˈ iː t hˌ ɪ m æ t ð ə ɡ ɹˈ æ n d t ʃˈ ɑː ɹ l ɪ mˌ æ ɡ n i ɔ n ð ə fˈ ɑː l o ʊ ɪ ŋ dˈ e ɪ .").to(device), diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py deleted file mode 100755 index bb18f3a..0000000 --- a/vall_e/models/nar.py +++ /dev/null @@ -1,235 +0,0 @@ -from ..config import cfg -from .base import Base - -import torch - -from torch import Tensor -from tqdm import trange - -class NAR(Base): - @property - def causal(self): - return False - - @property - def arch_type(self) -> str: - if hasattr(self, "config") and self.config: - return self.config.arch_type - return cfg.models.nar.arch_type - - @property - def norm_type(self): - return "ln" if self.n_resp_levels == 1 else "adaln" - - @property - def n_prom_levels(self) -> int: - return cfg.models.prom_levels - - @property - def n_resp_levels(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.resp_levels - return cfg.models.nar.resp_levels - - @property - def n_max_levels(self) -> int: - return cfg.models.max_levels - - @property - def n_tasks(self) -> int: - return cfg.models.nar.tasks - - @property - def n_langs(self) -> int: - return cfg.models.nar.langs - - @property - def version(self) -> int: - if hasattr(self, "config") and self.config: - return self.config.version - return cfg.models.nar.version - - @property - def recurrent_chunk_size(self) -> int: - return 0 - - """ - @property - def rotary_embedding_base(self) -> float: - if hasattr(self, "config") and self.config: - return self.config.rotary_embedding_base - return cfg.models.nar.rotary_embedding_base - """ - - @property - def interleave(self) -> bool: - return False - - @property - def monolithic(self) -> bool: - return False - - def forward( - self, - text_list: list[Tensor], - proms_list: list[Tensor], - resps_list: list[Tensor], - lang_list: list[Tensor] | None = None, - max_levels: int = 0, - sampling_temperature: float = 0.2, - sampling_min_temperature: float = -1.0, - sampling_top_k: int = -100, - sampling_top_p: float = 1.0, - sampling_repetition_penalty: float = 1.0, - sampling_repetition_penalty_decay: float = 0.0, - sampling_length_penalty: float = 0.0, # unused - sampling_beam_width: int = 0, # unused - sampling_mirostat_tau: float = 0.0, # unused - ): - """ - Args: - text_list: [t] * b - proms_list: [t' l] * b, l=8 - resps_list: [t'' l] * b, l=1 or 8, 1 for testing and 8 for training. - Returns: - [t'' l], l=8 if testing. empty list will be returned during training. - """ - - n_levels_set = {r.shape[-1] for r in resps_list} - - if len(n_levels_set) > 1: - raise ValueError(f"Please give only one level, got {n_levels_set}.") - - n_levels = next(iter(n_levels_set)) - - device = text_list[0].device - - if n_levels == self.n_resp_levels + 1: - assert resps_list is not None - - quant_levels = torch.randint(0, self.n_resp_levels, (len(resps_list),)) - - prev_list = [o[..., : l + 1] for o, l in zip(resps_list, quant_levels)] - targ_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)] - - #quant_levels = quant_levels.to(device=device) - - logits = super().forward( - text_list=text_list, - proms_list=proms_list, - resps_list=prev_list, - targ_list=targ_list, - lang_list=lang_list, - quant_levels=quant_levels, - ) - - prev_list = [] - else: - prev_list = resps_list - if max_levels == 0: - max_levels = self.n_resp_levels - - while True: - level = prev_list[0].shape[-1] - 1 - - if level >= max_levels: # min(max_levels, self.n_resp_levels): # commented out to experiment with exceeding trained levels - break - - quant_levels = torch.full((len(text_list),), level, device=device) - - logits = super().forward( - text_list=text_list, - proms_list=proms_list, - resps_list=prev_list, - lang_list=lang_list, - quant_levels=quant_levels, - ) - - resps_list = super().sample( - logits=logits, - resps_list=prev_list, - quant_levels=quant_levels, - - temperature=sampling_temperature, - min_temperature=sampling_min_temperature, - top_p=sampling_top_p, - top_k=sampling_top_k, - repetition_penalty=sampling_repetition_penalty, - repetition_penalty_decay=sampling_repetition_penalty_decay, - #length_penalty=sampling_length_penalty, - #beam_width=sampling_beam_width, - #mirostat_tau=sampling_mirostat_tau, - #mirostat_state=mirostat_state, - ) - - prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] - - return prev_list - -def example_usage(): - cfg.trainer.backend = "local" - from functools import partial - - from einops import repeat - - from ..emb.qnt import decode_to_file - from ..engines import Engine - from tqdm import tqdm - from ..utils import wrapper as ml - - device = "cuda" - x8 = partial(repeat, pattern="t -> t l", l=cfg.models.prom_levels) - symmap = {'': 1, '': 2, ' ': 3, '.': 4, ',': 5, '!': 6, '?': 7, 'p': 7, 'iː': 8, 'ɚ': 9, 'ˌ': 10, 'dˌ': 11, 'mˌ': 12, 'd': 13, 'ɹ': 14, 'tˈ': 15, 'pˌ': 16, 'uː': 17, 'l': 18, 'æ': 19, 'ɛ': 20, 'ɪ': 21, 'j': 22, 'ʊ': 23, 't': 24, 'n': 25, 'v': 26, 'a': 27, 'o': 28, 'ŋ': 29, 'w': 30, 'ʌ': 31, 'hˈ': 32, 'ɡˈ': 33, 'ə': 34, 'θˈ': 35, 'dˈ': 36, 'wˌ': 37, 'h': 38, 'z': 39, 'k': 40, 'ð': 41, 'ɡˌ': 42, 'ˈ': 43, 'fˈ': 44, 'i': 45, 's': 46, 'ʃ': 47, 'wˈ': 48, 'ðˈ': 49, 'ɹˈ': 50, 'lˈ': 51, 'ɡ': 52, 'oː': 53, 'mˈ': 54, 'e': 55, 'ɑː': 56, 'nˈ': 57, 'm': 58, 'θˌ': 59, 'sˈ': 60, 'f': 61, 'ɔː': 62, 'hˌ': 63, 'b': 64, 'jˈ': 65, 'ɐ': 66, 'ʒˈ': 67, 'θ': 68, 'bˈ': 69, 'ɾ': 70, 'ɜː': 71, 'ʌˈ': 72, 'ʃˌ': 73, 'bˌ': 74, 'kˈ': 75, 'ɔ': 76, 'zˈ': 77, 'ᵻ': 78, 'kˌ': 79, 'vˈ': 80, 'fˌ': 81, 'ʒ': 82, 'ʃˈ': 83, 'ɹˌ': 84, 'tˌ': 85, 'pˈ': 86, 'ðˌ': 87, 'sˌ': 88, 'nˌ': 89, 'lˌ': 90, '̩': 91, 'ʔ': 92, 'vˌ': 93, 'ɪˈ': 94, '"': 95, 'ɪˌ': 96, 'ʒˌ': 97, 'uːˌ': 98, 'ʊˈ': 99, 'jˌ': 100, 'uːˈ': 101, 'iːˈ': 102, 'zˌ': 103, '.ˈ': 104, '…': 105, 'ŋˌ': 106, 'ɐˌ': 107, '—ˈ': 108, 'iˌ': 109, 'iːˌ': 110, 'ɛː': 111, ')': 112, ')ˈ': 113, '(': 114, 'u': 115, '-': 116, 'ɖˈ': 117, 'iˈ': 118, 'ʰˈ': 119, 'ɟˈ': 120, '̃': 121, 'eː': 122, 'ɾˈ': 123, 'r': 124, 'ʰ': 125, '-ˌ': 126, 'ɫ': 127, 'q': 128, '—': 129, 'ʊˌ': 130, 'aː': 131, 'cˈ': 132, '…ˈ': 133, 'c': 134, 'ɳ': 135, 'ɐˈ': 136, 'x': 137, 'ʔˌ': 138, '.ˌ': 139, 'ɑ': 140, '?ˈ': 141, '̩ˈ': 142, '"ˈ': 143, ',ˈ': 144, 'ŋˈ': 145, 'əˌ': 146, '!ˈ': 147, '"ˌ': 148, '?ˌ': 149, ',ˌ': 150, '—ˌ': 151, '̩ˌ': 152, 'əˈ': 153, '!ˌ': 154, 'ɬ': 155, 'ʲ': 156, '¡': 157, 'ɯ': 158, 'qˌ': 159, 'ʑ': 160, 'ʑˈ': 161, '¿': 162, 'ɑːˈ': 163, 'iːː': 164, 'ɛˈ': 165, '¡ˈ': 166, 'æˈ': 167, 'ç': 168, 'ɾˌ': 169, 'ᵻˈ': 170, 'xˈ': 171, 'ɔːˈ': 172, ';': 173, 'ɬˌ': 174, ':': 175, 'ʔˈ': 176, 'ɑːˌ': 177, 'ɬˈ': 178} - def tokenize(content, lang_marker="en"): - split = content.split(" ") - phones = [f""] + [ " " if not p else p for p in split ] + [f""] - return torch.tensor([*map(symmap.get, phones)]).to() - - # to-do: unmangle this and the resp shit - qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device) - - text_list = [ - #torch.tensor([1, 2, 3], device=device), - tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device), - ] - - proms_list = [ - x8(torch.tensor([2, 3], device=device)), - ] - - resps_list = [ - qnt.to(device), - ] - - kwargs = { - 'n_tokens': 1024, - 'd_model': 1024, - 'n_heads': 16, - 'n_layers': 12, - } - model = NAR(**kwargs).to(device) - steps = 500 - optimizer = ml.Prodigy(model.parameters(), lr=1.0) - engine = Engine(model=model, optimizer=optimizer) - - def sample( name ): - engine.eval() - codes = engine( text_list, proms_list, resps_list=[r[..., 0].unsqueeze(-1) for r in resps_list], sampling_temperature=0.2 ) - decode_to_file( codes[0], f"data/nar.{name}.wav", device ) - - def train(): - engine.train() - t = trange(steps) - for i in t: - stats = {"step": i} - stats |= engine.traverse(text_list=text_list, proms_list=proms_list, resps_list=resps_list) - - tqdm.write(f"{stats}") - - sample("init") - train() - sample("final") - - -if __name__ == "__main__": - example_usage() diff --git a/vall_e/plot.py b/vall_e/plot.py index 294a256..0c1b61c 100644 --- a/vall_e/plot.py +++ b/vall_e/plot.py @@ -109,7 +109,7 @@ if __name__ == "__main__": path = cfg.relpath / "logs" paths = path.rglob(f"./*/{args.filename}") - args.models = [ model for model in cfg.models.get() if model.training and (args.model == "*" or model.name in args.model) ] + args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ] if args.ys == "": args.ys = ["loss"] diff --git a/vall_e/webui.py b/vall_e/webui.py index a613ef3..580c331 100644 --- a/vall_e/webui.py +++ b/vall_e/webui.py @@ -54,14 +54,13 @@ def init_tts(restart=False): parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--yaml", type=Path, default=os.environ.get('VALLE_YAML', None)) # os environ so it can be specified in a HuggingFace Space too - parser.add_argument("--ar-ckpt", type=Path, default=None) - parser.add_argument("--nar-ckpt", type=Path, default=None) + parser.add_argument("--model-ckpt", type=Path, default=None) parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--amp", action="store_true") parser.add_argument("--dtype", type=str, default="auto") args, unknown = parser.parse_known_args() - tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) + tts = TTS( config=args.yaml, model_ckpt=args.model_ckpt, device=args.device, dtype=args.dtype if args.dtype != "auto" else None, amp=args.amp ) return tts @gradio_wrapper(inputs=layout["inference"]["inputs"].keys())