diff --git a/README.md b/README.md index 87fad22..6191eb6 100755 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ For audio backends: * `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention * `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper) * `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently) - * `fused_attn`: uses an implementation using `triton` (only tested on my 7900XTX / Navi3 / gfx1100) + * `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while * `transformers` Llama\*Attention implementations: * `eager`: default `LlamaAttention` * `sdpa`: integrated `LlamaSdpaAttention` attention model diff --git a/vall_e/config.py b/vall_e/config.py index 2968992..bc7b3c9 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -9,6 +9,7 @@ import time import argparse import yaml import random +import logging import torch import numpy as np @@ -163,7 +164,8 @@ class Dataset: sample_order: str = "interleaved" # duration sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable # for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough - sample_shuffle: bool = True # + # for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size) + sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0 tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes @@ -364,6 +366,7 @@ class Model: return dict(include=include, exclude=exclude) +# should be renamed to Adapters @dataclass() class LoRA: name: str = "lora" # vanity name @@ -638,9 +641,6 @@ class Trainer: def scale_loss(self): # currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways) return self.dtype == torch.float16 - """ - """ - @dataclass() class Inference: @@ -670,7 +670,6 @@ class Inference: return torch.float8_e4m3fn return torch.float32 -# should be renamed to optimizations @dataclass() class Optimizations: injects: bool = False # overwrites default torch classes (not recommended) @@ -755,6 +754,7 @@ class Config(BaseConfig): return self.models[0] if len(self.models) > 0 else None + # should be renamed to adapters @property def lora(self): for i, lora in enumerate(self.loras): @@ -795,7 +795,7 @@ class Config(BaseConfig): try: self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset except Exception as e: - print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e)) + _logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}") self.dataset.use_hdf5 = False # to-do: prune unused keys @@ -923,7 +923,7 @@ class Config(BaseConfig): cfg.tokenizer = NaiveTokenizer() except Exception as e: cfg.tokenizer = NaiveTokenizer() - print("Error while parsing tokenizer:", e) + _logger.warning(f"Error while parsing tokenizer: {str(e)}") pass @@ -960,6 +960,7 @@ class NaiveTokenizer: # tokenize return [*map(symmap.get, phones)] +_logger = logging.getLogger(__name__) cfg = Config.from_cli() @@ -967,7 +968,7 @@ cfg = Config.from_cli() try: cfg.format() except Exception as e: - print("Error while parsing config YAML:") + _logger.error(f"Error while parsing config YAML: {str(e)}") raise e # throw an error because I'm tired of silent errors messing things up for me if __name__ == "__main__": diff --git a/vall_e/data.py b/vall_e/data.py index ed9ae76..8f0d3a4 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -495,7 +495,6 @@ def _load_paths_from_metadata(group_name, type="training", validate=False): def _get_hdf5_path(path): # to-do: better validation - #print(path) return str(path) def _get_hdf5_paths( data_dir, type="training", validate=False ): @@ -1543,12 +1542,6 @@ if __name__ == "__main__": cfg.dataset.workers = 1 - class LoggerOveride: - def info(self, *args): - print(*args) - - _logger = LoggerOveride() - if args.action == "hdf5": create_dataset_hdf5() elif args.action == "list-dataset": @@ -1559,7 +1552,7 @@ if __name__ == "__main__": continue dataset.append(f'{group}/{name}') - print(json.dumps(dataset)) + _logger.info(json.dumps(dataset)) elif args.action == "metadata": create_dataset_metadata() elif args.action == "sample": @@ -1581,17 +1574,17 @@ if __name__ == "__main__": try: decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" ) except Exception as e: - print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e)) + _logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}") try: decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" ) except Exception as e: - print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e)) + _logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}") v[i]['proms'][j] = v[i]['proms'][j].shape v[i]['resps'][j] = v[i]['resps'][j].shape for k, v in samples.items(): for i in range(len(v)): - print(f'{k}[{i}]:', v[i]) + _logger.info(f'{k}[{i}]: {v[i]}') elif args.action == "validate": train_dl, subtrain_dl, val_dl = create_train_val_dataloader() @@ -1610,11 +1603,11 @@ if __name__ == "__main__": phone = phonemes[i] - print( batch['text'], batch['metadata']['phonemes'] ) + _logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" ) missing |= set([phone]) - print( "Missing tokens:", missing ) + _logger.info( f"Missing tokens: {missing}" ) elif args.action == "tasks": @@ -1628,13 +1621,13 @@ if __name__ == "__main__": if task not in cfg.dataset.tasks_list: continue - print(text, task, cfg.model.resp_levels) - print( proms.shape, resps.shape ) + _logger.info( f'{text} {task} {cfg.model.resp_levels}') + _logger.info( f'{proms.shape} {resps.shape}' ) tokens = 0 tokens += sum([ text.shape[0] for text in batch["text"] ]) tokens += sum([ resps.shape[0] for resps in batch["resps"] ]) - print( tokens ) + _logger.info( f'{tokens}' ) decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" ) decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" ) diff --git a/vall_e/demo.py b/vall_e/demo.py index 21e1ea1..5e62b36 100644 --- a/vall_e/demo.py +++ b/vall_e/demo.py @@ -18,6 +18,9 @@ Will also generate samples from a provided datset, if requested. import argparse import base64 import random +import logging + +_logger = logging.getLogger(__name__) from pathlib import Path @@ -117,9 +120,9 @@ def main(): samples_dirs["dataset"] = args.demo_dir / "dataset" - print("Loading dataloader...") + _logger.info("Loading dataloader...") dataloader = create_train_dataloader() - print("Loaded dataloader.") + _logger.info("Loaded dataloader.") num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size diff --git a/vall_e/emb/process.py b/vall_e/emb/process.py index 0f8970d..82723b3 100644 --- a/vall_e/emb/process.py +++ b/vall_e/emb/process.py @@ -9,6 +9,9 @@ import argparse import torch import torchaudio import numpy as np +import logging + +_logger = logging.getLogger(__name__) from tqdm.auto import tqdm from pathlib import Path @@ -78,7 +81,7 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ): try: process_job( outpath, waveform, sample_rate, text, language ) except Exception as e: - print(f"Failed to quantize: {outpath}:", e) + _logger.error(f"Failed to quantize: {outpath}: {str(e)}") if raise_exceptions: raise e continue @@ -128,7 +131,7 @@ def process( for group_name in sorted(os.listdir(f'./{input_audio}/')): if not os.path.isdir(f'./{input_audio}/{group_name}/'): - print("Is not dir:", f'./{input_audio}/{group_name}/') + _logger.warning(f'Is not dir:" /{input_audio}/{group_name}/') continue if group_name in ignore_groups: @@ -138,7 +141,7 @@ def process( for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"): if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'): - print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}') + _logger.warning(f'Is not dir: ./{input_audio}/{group_name}/{speaker_id}') continue if speaker_id in ignore_speakers: diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 30c94fc..f84c8e5 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -6,6 +6,9 @@ import math import torch import torchaudio import numpy as np +import logging + +_logger = logging.getLogger(__name__) from functools import cache from pathlib import Path @@ -203,7 +206,7 @@ try: except Exception as e: cfg.inference.use_dac = False - print(str(e)) + _logger.warning(str(e)) # uses https://github.com/facebookresearch/AudioDec/ # I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip @@ -213,7 +216,7 @@ try: from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model except Exception as e: cfg.inference.use_audiodec = False - print(str(e)) + _logger.warning(str(e)) """ @cache @@ -747,8 +750,8 @@ if __name__ == "__main__": if args.print: torch.set_printoptions(profile="full") - print( "Metadata:", artifact['metadata'] ) - print( "Codes:", codes.shape, codes ) + _logger.info(f"Metadata: {artifact['metadata']}" ) + _logger.info(f"Codes: {codes.shape}, {codes}" ) # encode else: args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index e012a87..2173dca 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -17,6 +17,9 @@ from ..models.lora import apply_lora, lora_load_state_dict import torch import re +import logging + +_logger = logging.getLogger(__name__) deepspeed_available = False try: @@ -55,7 +58,7 @@ def load_engines(training=True, **model_kwargs): checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): - print("Checkpoint missing, but weights found:", load_path) + _logger.warning("Checkpoint missing, but weights found:", load_path) loads_state_dict = True # load state early @@ -64,7 +67,7 @@ def load_engines(training=True, **model_kwargs): # check if config is defined in state, and re-initialize the model if "config" in state and False: - print("Model config definition in weights, re-loading...") + _logger.warning("Model config definition in weights, re-loading...") config_state = state["config"] model = get_model( config=cfg.model.__class__( *config_state ), training=training ) @@ -201,7 +204,7 @@ def load_engines(training=True, **model_kwargs): if cfg.lora is not None: lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) if lora_path.exists(): - print( "Loaded LoRA state dict:", lora_path ) + _logger.info( "Loaded LoRA state dict:", lora_path ) state = torch_load(lora_path, device=cfg.device) state = state['lora' if 'lora' in state else 'module'] diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index c62cfc0..4d80552 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -367,7 +367,7 @@ class Engines(dict[str, Engine]): state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) torch_save(state_dict, save_path) - print(f"Exported {name} to {save_path}") + _logger.info(f"Exported {name} to {save_path}") def save_checkpoint(self, tag=None): if not tag: @@ -385,7 +385,7 @@ class Engines(dict[str, Engine]): try: engine.save_checkpoint(save_dir, tag=tag) except Exception as e: - print(f'Failed to save checkpoint for engine {name}:', str(e)) + _logger.warning(f'Failed to save checkpoint for engine {name}:', str(e)) # might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None] if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader(): @@ -395,7 +395,7 @@ class Engines(dict[str, Engine]): for d in checkpoints: if not d.is_dir() or not d.exists(): continue - print("Removing", d) + _logger.info("Removing", d) for p in d.iterdir(): p.unlink() d.rmdir() @@ -490,7 +490,7 @@ class Engines(dict[str, Engine]): res = feeder( engine=engine, batch=batch ) break except RuntimeError as e: - print("Forward", str(e)) + _logger.error("Forward", str(e)) if "out of memory" not in str(e): self.save_checkpoint() @@ -532,7 +532,7 @@ class Engines(dict[str, Engine]): try: engine.backward(loss) except RuntimeError as e: - print("Backwards:", str(e)) + _logger.error("Backwards:", str(e)) if "out of memory" not in str(e): self.save_checkpoint() diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index e31da2d..4abe44f 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -117,7 +117,7 @@ class Engine(DeepSpeedEngine): else: self.optimizer.set_lr(lr) except Exception as e: - print(str(e)) + _logger.warning(str(e)) # we'll just have to live with the LoRA weights living within our main weights # they're easy to extract anyways diff --git a/vall_e/inference.py b/vall_e/inference.py index 79a91e0..8820531 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -2,6 +2,9 @@ import torch import torchaudio import soundfile import time +import logging + +_logger = logging.getLogger(__name__) from torch import Tensor from einops import rearrange @@ -31,14 +34,13 @@ class TTS(): def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ): if config: - print("Loading YAML:", config) + _logger.info(f"Loading YAML: {config}") cfg.load_yaml( config ) try: cfg.format( training=False ) cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing except Exception as e: - print("Error while parsing config YAML:") raise e # throw an error because I'm tired of silent errors messing things up for me if amp is None: @@ -73,7 +75,7 @@ class TTS(): self.engines.eval() self.symmap = get_phone_symmap() - print("Loaded model") + _logger.info("Loaded model") def encode_text( self, text, language="en" ): # already a tensor, return it diff --git a/vall_e/models/__init__.py b/vall_e/models/__init__.py index 4e37356..c3f807a 100755 --- a/vall_e/models/__init__.py +++ b/vall_e/models/__init__.py @@ -1,3 +1,6 @@ +import logging + +_logger = logging.getLogger(__name__) def get_model(config, training=True, **model_kwargs): name = config.name @@ -53,7 +56,7 @@ def get_model(config, training=True, **model_kwargs): **model_kwargs ) - print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") + _logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") return model diff --git a/vall_e/models/ar.py b/vall_e/models/ar.py index 9a861be..c600015 100644 --- a/vall_e/models/ar.py +++ b/vall_e/models/ar.py @@ -15,6 +15,9 @@ import math from einops import rearrange from torch import Tensor from tqdm import trange +import logging + +_logger = logging.getLogger(__name__) from ..emb.qnt import trim, encode_as_embedding @@ -379,7 +382,7 @@ def example_usage(): else: raise ValueError(f"Unrecognized optimizer: {optimizer}") - print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) + _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") optimizer = optimizer(model.parameters(), lr=learning_rate) @@ -392,7 +395,7 @@ def example_usage(): scheduler = None if scheduler is not None: - print("Scheduler:", scheduler) + _logger.info(f"Scheduler: {scheduler}") optimizer = scheduler( model.parameters(), lr = learning_rate ) if cfg.optimizations.replace and cfg.optimizations.linear: @@ -425,7 +428,7 @@ def example_usage(): }, f"./data/{cfg.model.arch_type}.pth" ) """ - print(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + _logger.info(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.no_grad() def sample_data(task=None): diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index da90e17..d1661f7 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -17,6 +17,9 @@ import math from einops import rearrange from torch import Tensor from tqdm import trange +import logging + +_logger = logging.getLogger(__name__) from ..emb.qnt import trim, encode_as_embedding @@ -434,7 +437,7 @@ def example_usage(): else: raise ValueError(f"Unrecognized optimizer: {optimizer}") - print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) + _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") optimizer = optimizer(model.parameters(), lr=learning_rate) @@ -447,7 +450,7 @@ def example_usage(): scheduler = None if scheduler is not None: - print("Scheduler:", scheduler) + _logger.info(f"Scheduler: {scheduler}") optimizer = scheduler( model.parameters(), lr = learning_rate ) if cfg.optimizations.replace and cfg.optimizations.linear: @@ -480,7 +483,7 @@ def example_usage(): }, f"./data/{cfg.model.arch_type}.pth" ) """ - print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + _logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.no_grad() def sample_data(task=None): diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index f486e65..1cef053 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -2,6 +2,8 @@ import math import torch +import logging + from typing import Literal, overload, Optional, Tuple from torch import Tensor, nn @@ -10,6 +12,8 @@ from transformers.cache_utils import Cache from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv +_logger = logging.getLogger(__name__) + AVAILABLE_ATTENTIONS = [] try: @@ -18,7 +22,7 @@ try: if is_flash_attn_2_available(): AVAILABLE_ATTENTIONS.append("flash_attention_2") except Exception as e: - print("Error while querying for `flash_attention_2` support", e) + _logger.warning(f"Error while querying for `flash_attention_2` support: {str(e)}") try: from .attention.fused import attention as _fused_attention @@ -27,7 +31,7 @@ try: AVAILABLE_ATTENTIONS.append("fused_attn") except Exception as e: - print("Error while querying for `fused_attn` support", e) + _logger.warning(f"Error while querying for `fused_attn` support: {str(e)}") is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count())) @@ -99,7 +103,7 @@ try: has_flash_attn_with_paged = True except Exception as e: raise e - print("Error while querying for `flash_attn` support", e) + _logger.warning(f"Error while querying for `flash_attn` support: {str(e)}") try: from xformers.ops.fmha import memory_efficient_attention @@ -107,7 +111,7 @@ try: AVAILABLE_ATTENTIONS.append("xformers") except Exception as e: - print("Error while importing `xformers`", e) + _logger.warning(f"Error while importing `xformers`: {str(e)}") # to-do: find a better way to query for if there's available kernels since these return true regardless if torch.backends.cuda.flash_sdp_enabled(): diff --git a/vall_e/models/experimental.py b/vall_e/models/experimental.py index b652b26..1f6b724 100644 --- a/vall_e/models/experimental.py +++ b/vall_e/models/experimental.py @@ -20,6 +20,9 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult import random import math +import logging + +_logger = logging.getLogger(__name__) from einops import rearrange from tqdm import trange @@ -502,7 +505,7 @@ def example_usage(): else: raise ValueError(f"Unrecognized optimizer: {optimizer}") - print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) + _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") optimizer = optimizer(model.parameters(), lr=learning_rate) @@ -515,7 +518,7 @@ def example_usage(): scheduler = None if scheduler is not None: - print("Scheduler:", scheduler) + _logger.info(f"Scheduler: {scheduler}") optimizer = scheduler( model.parameters(), lr = learning_rate ) if cfg.optimizations.replace and cfg.optimizations.linear: @@ -532,7 +535,7 @@ def example_usage(): }, f"./data/{cfg.model.arch_type}.pth" ) """ - print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + _logger.info(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.inference_mode() def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ): diff --git a/vall_e/models/nar.py b/vall_e/models/nar.py index b2ccc65..87f0fb4 100644 --- a/vall_e/models/nar.py +++ b/vall_e/models/nar.py @@ -19,6 +19,9 @@ from torch import Tensor from tqdm import trange from ..emb.qnt import trim +import logging + +_logger = logging.getLogger(__name__) class NAR(Base): def forward( @@ -361,7 +364,7 @@ def example_usage(): else: raise ValueError(f"Unrecognized optimizer: {optimizer}") - print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) + _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}") optimizer = optimizer(model.parameters(), lr=learning_rate) @@ -374,7 +377,7 @@ def example_usage(): scheduler = None if scheduler is not None: - print("Scheduler:", scheduler) + _logger.info(f"Scheduler: {scheduler}") optimizer = scheduler( model.parameters(), lr = learning_rate ) if cfg.optimizations.replace and cfg.optimizations.linear: @@ -391,7 +394,7 @@ def example_usage(): }, f"./data/{cfg.model.arch_type}.pth" ) """ - print(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + _logger.info(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") @torch.inference_mode() def sample( name, steps=1000 ): diff --git a/vall_e/train.py b/vall_e/train.py index 7d3a48e..0a60613 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -133,10 +133,7 @@ def run_eval(engines, eval_name, dl): } #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) - if cfg.trainer.no_logger: - tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.") - else: - _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") + _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") def train(): @@ -160,8 +157,8 @@ def train(): run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "val", val_dl) except Exception as e: - print("Error occurred while performing eval:", str(e)) - print(traceback.format_exc()) + _logger.warning(f"Error occurred while performing eval: {str(e)}") + _logger.warning(traceback.format_exc()) engines.train() qnt.unload_model() diff --git a/vall_e/utils/distributed.py b/vall_e/utils/distributed.py index 244c383..2400fdd 100755 --- a/vall_e/utils/distributed.py +++ b/vall_e/utils/distributed.py @@ -20,7 +20,6 @@ def get_free_port(): _distributed_initialized = False def init_distributed( fn, *args, **kwargs ): - #print("Initializing distributed...") torch.cuda.set_device(local_rank()) fn(*args, **kwargs) _distributed_initialized = True @@ -29,8 +28,6 @@ def distributed_initialized(): return _distributed_initialized def cleanup_distributed(): - #if not _distributed_initialized: - # return dist.barrier() dist.destroy_process_group() diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 599ecda..b77c9f9 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -177,10 +177,7 @@ def train( except Exception as e: metrics = str(stats) - if cfg.trainer.no_logger: - tqdm.write(f"Training Metrics: {truncate_json(metrics)}.") - else: - _logger.info(f"Training Metrics: {truncate_json(metrics)}.") + _logger.info(f"Training Metrics: {truncate_json(metrics)}.") command = _non_blocking_input() @@ -220,9 +217,9 @@ def train( rate = float(command.split(" ")[-1]) try: engines.set_lr(rate) - print("Updating LR to:", rate) + _logger.info(f"Updating LR to: {rate}") except Exception as e: - print("Failed to set LR rate to:", rate, str(e)) + _logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}") if "export" in command: train_dl.dataset.save_state_dict() diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index 5d5193d..a253848 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -14,6 +14,9 @@ import random import time import psutil import math +import logging + +_logger = logging.getLogger(__name__) from coloredlogs import ColoredFormatter from logging import StreamHandler @@ -296,7 +299,7 @@ def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ): ) if verbose: - print(f"Replacing {name}.{k} to", klass) + _logger.info(f"Replacing {name}.{k} to: {klass}") return model @@ -330,7 +333,7 @@ def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ): ) if verbose: - print(f"Replacing {name}.{k} to", klass) + _logger.info(f"Replacing {name}.{k} to: {klass}") return model @@ -360,7 +363,7 @@ def replace_attention( model, klass, target, mode="math", verbose=False ): ) if verbose: - print(f"Replacing {name}.{k} to", klass) + _logger.info(f"Replacing {name}.{k} to: {klass}") return model @@ -491,7 +494,7 @@ def get_model_offload_policy(module, policy=None): # does not fit in budget, increase device index else: device_index += 1 - print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB") + _logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB") # to-do: check that all modules are exhausted assert module_index >= len(modules) @@ -528,9 +531,9 @@ def offload_model( model, policy=None ): if not not [*module.named_children()]: continue try: - print( name, next(module.parameters()).device ) + _logger.info( name, next(module.parameters()).device ) except Exception as e: - print( name, "?" ) + _logger.info( name, "?" ) pass """ diff --git a/vall_e/utils/wrapper.py b/vall_e/utils/wrapper.py index d4fb780..2cbd14b 100755 --- a/vall_e/utils/wrapper.py +++ b/vall_e/utils/wrapper.py @@ -3,9 +3,12 @@ from contextlib import contextmanager import math import torch import torch.nn.functional as F +import logging from ..config import cfg +_logger = logging.getLogger(__name__) + Embedding = torch.nn.Embedding Linear = torch.nn.Linear @@ -95,7 +98,7 @@ if cfg.optimizations.tensorrt: import torch_tensorrt AVAILABLE_COMPILE_BACKENDS.append("tensorrt") except Exception as e: - print('Error while importing TensorRT:', str(e)) + _logger.warning(f'Error while importing TensorRT: {str(e)}') pass def compile_model(model, backend="auto"): @@ -111,14 +114,14 @@ def compile_model(model, backend="auto"): try: from prodigyopt import Prodigy except Exception as e: - print('Error while importing Prodigyopt:', str(e)) + _logger.warning(f'Error while importing Prodigyopt: {str(e)}') pass # https://github.com/facebookresearch/schedule_free/ try: import schedulefree except Exception as e: - print('Error while importing Schedule_Free:', str(e)) + _logger.warning(f'Error while importing Schedule_Free: {str(e)}') pass # backwards compat