moved prints to use logger, edited readme (fused_attn doesnt seem stable for training)

This commit is contained in:
mrq 2024-08-29 13:27:16 -05:00
parent d423bc03c2
commit 32287710a2
21 changed files with 111 additions and 87 deletions

View File

@ -161,7 +161,7 @@ For audio backends:
* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention * `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper) * `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently) * `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
* `fused_attn`: uses an implementation using `triton` (only tested on my 7900XTX / Navi3 / gfx1100) * `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
* `transformers` Llama\*Attention implementations: * `transformers` Llama\*Attention implementations:
* `eager`: default `LlamaAttention` * `eager`: default `LlamaAttention`
* `sdpa`: integrated `LlamaSdpaAttention` attention model * `sdpa`: integrated `LlamaSdpaAttention` attention model

View File

@ -9,6 +9,7 @@ import time
import argparse import argparse
import yaml import yaml
import random import random
import logging
import torch import torch
import numpy as np import numpy as np
@ -163,7 +164,8 @@ class Dataset:
sample_order: str = "interleaved" # duration sample_order: str = "interleaved" # duration
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough # for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
sample_shuffle: bool = True # # for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
@ -364,6 +366,7 @@ class Model:
return dict(include=include, exclude=exclude) return dict(include=include, exclude=exclude)
# should be renamed to Adapters
@dataclass() @dataclass()
class LoRA: class LoRA:
name: str = "lora" # vanity name name: str = "lora" # vanity name
@ -638,9 +641,6 @@ class Trainer:
def scale_loss(self): def scale_loss(self):
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways) # currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
return self.dtype == torch.float16 return self.dtype == torch.float16
"""
"""
@dataclass() @dataclass()
class Inference: class Inference:
@ -670,7 +670,6 @@ class Inference:
return torch.float8_e4m3fn return torch.float8_e4m3fn
return torch.float32 return torch.float32
# should be renamed to optimizations
@dataclass() @dataclass()
class Optimizations: class Optimizations:
injects: bool = False # overwrites default torch classes (not recommended) injects: bool = False # overwrites default torch classes (not recommended)
@ -755,6 +754,7 @@ class Config(BaseConfig):
return self.models[0] if len(self.models) > 0 else None return self.models[0] if len(self.models) > 0 else None
# should be renamed to adapters
@property @property
def lora(self): def lora(self):
for i, lora in enumerate(self.loras): for i, lora in enumerate(self.loras):
@ -795,7 +795,7 @@ class Config(BaseConfig):
try: try:
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e: except Exception as e:
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e)) _logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
self.dataset.use_hdf5 = False self.dataset.use_hdf5 = False
# to-do: prune unused keys # to-do: prune unused keys
@ -923,7 +923,7 @@ class Config(BaseConfig):
cfg.tokenizer = NaiveTokenizer() cfg.tokenizer = NaiveTokenizer()
except Exception as e: except Exception as e:
cfg.tokenizer = NaiveTokenizer() cfg.tokenizer = NaiveTokenizer()
print("Error while parsing tokenizer:", e) _logger.warning(f"Error while parsing tokenizer: {str(e)}")
pass pass
@ -960,6 +960,7 @@ class NaiveTokenizer:
# tokenize # tokenize
return [*map(symmap.get, phones)] return [*map(symmap.get, phones)]
_logger = logging.getLogger(__name__)
cfg = Config.from_cli() cfg = Config.from_cli()
@ -967,7 +968,7 @@ cfg = Config.from_cli()
try: try:
cfg.format() cfg.format()
except Exception as e: except Exception as e:
print("Error while parsing config YAML:") _logger.error(f"Error while parsing config YAML: {str(e)}")
raise e # throw an error because I'm tired of silent errors messing things up for me raise e # throw an error because I'm tired of silent errors messing things up for me
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -495,7 +495,6 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
def _get_hdf5_path(path): def _get_hdf5_path(path):
# to-do: better validation # to-do: better validation
#print(path)
return str(path) return str(path)
def _get_hdf5_paths( data_dir, type="training", validate=False ): def _get_hdf5_paths( data_dir, type="training", validate=False ):
@ -1543,12 +1542,6 @@ if __name__ == "__main__":
cfg.dataset.workers = 1 cfg.dataset.workers = 1
class LoggerOveride:
def info(self, *args):
print(*args)
_logger = LoggerOveride()
if args.action == "hdf5": if args.action == "hdf5":
create_dataset_hdf5() create_dataset_hdf5()
elif args.action == "list-dataset": elif args.action == "list-dataset":
@ -1559,7 +1552,7 @@ if __name__ == "__main__":
continue continue
dataset.append(f'{group}/{name}') dataset.append(f'{group}/{name}')
print(json.dumps(dataset)) _logger.info(json.dumps(dataset))
elif args.action == "metadata": elif args.action == "metadata":
create_dataset_metadata() create_dataset_metadata()
elif args.action == "sample": elif args.action == "sample":
@ -1581,17 +1574,17 @@ if __name__ == "__main__":
try: try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" ) decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e: except Exception as e:
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e)) _logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
try: try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" ) decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e: except Exception as e:
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e)) _logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
v[i]['proms'][j] = v[i]['proms'][j].shape v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items(): for k, v in samples.items():
for i in range(len(v)): for i in range(len(v)):
print(f'{k}[{i}]:', v[i]) _logger.info(f'{k}[{i}]: {v[i]}')
elif args.action == "validate": elif args.action == "validate":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader() train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
@ -1610,11 +1603,11 @@ if __name__ == "__main__":
phone = phonemes[i] phone = phonemes[i]
print( batch['text'], batch['metadata']['phonemes'] ) _logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
missing |= set([phone]) missing |= set([phone])
print( "Missing tokens:", missing ) _logger.info( f"Missing tokens: {missing}" )
elif args.action == "tasks": elif args.action == "tasks":
@ -1628,13 +1621,13 @@ if __name__ == "__main__":
if task not in cfg.dataset.tasks_list: if task not in cfg.dataset.tasks_list:
continue continue
print(text, task, cfg.model.resp_levels) _logger.info( f'{text} {task} {cfg.model.resp_levels}')
print( proms.shape, resps.shape ) _logger.info( f'{proms.shape} {resps.shape}' )
tokens = 0 tokens = 0
tokens += sum([ text.shape[0] for text in batch["text"] ]) tokens += sum([ text.shape[0] for text in batch["text"] ])
tokens += sum([ resps.shape[0] for resps in batch["resps"] ]) tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
print( tokens ) _logger.info( f'{tokens}' )
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" ) decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" ) decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )

View File

@ -18,6 +18,9 @@ Will also generate samples from a provided datset, if requested.
import argparse import argparse
import base64 import base64
import random import random
import logging
_logger = logging.getLogger(__name__)
from pathlib import Path from pathlib import Path
@ -117,9 +120,9 @@ def main():
samples_dirs["dataset"] = args.demo_dir / "dataset" samples_dirs["dataset"] = args.demo_dir / "dataset"
print("Loading dataloader...") _logger.info("Loading dataloader...")
dataloader = create_train_dataloader() dataloader = create_train_dataloader()
print("Loaded dataloader.") _logger.info("Loaded dataloader.")
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size

View File

@ -9,6 +9,9 @@ import argparse
import torch import torch
import torchaudio import torchaudio
import numpy as np import numpy as np
import logging
_logger = logging.getLogger(__name__)
from tqdm.auto import tqdm from tqdm.auto import tqdm
from pathlib import Path from pathlib import Path
@ -78,7 +81,7 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
try: try:
process_job( outpath, waveform, sample_rate, text, language ) process_job( outpath, waveform, sample_rate, text, language )
except Exception as e: except Exception as e:
print(f"Failed to quantize: {outpath}:", e) _logger.error(f"Failed to quantize: {outpath}: {str(e)}")
if raise_exceptions: if raise_exceptions:
raise e raise e
continue continue
@ -128,7 +131,7 @@ def process(
for group_name in sorted(os.listdir(f'./{input_audio}/')): for group_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{group_name}/'): if not os.path.isdir(f'./{input_audio}/{group_name}/'):
print("Is not dir:", f'./{input_audio}/{group_name}/') _logger.warning(f'Is not dir:" /{input_audio}/{group_name}/')
continue continue
if group_name in ignore_groups: if group_name in ignore_groups:
@ -138,7 +141,7 @@ def process(
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"): for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"):
if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'): if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}') _logger.warning(f'Is not dir: ./{input_audio}/{group_name}/{speaker_id}')
continue continue
if speaker_id in ignore_speakers: if speaker_id in ignore_speakers:

View File

@ -6,6 +6,9 @@ import math
import torch import torch
import torchaudio import torchaudio
import numpy as np import numpy as np
import logging
_logger = logging.getLogger(__name__)
from functools import cache from functools import cache
from pathlib import Path from pathlib import Path
@ -203,7 +206,7 @@ try:
except Exception as e: except Exception as e:
cfg.inference.use_dac = False cfg.inference.use_dac = False
print(str(e)) _logger.warning(str(e))
# uses https://github.com/facebookresearch/AudioDec/ # uses https://github.com/facebookresearch/AudioDec/
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip # I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
@ -213,7 +216,7 @@ try:
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
except Exception as e: except Exception as e:
cfg.inference.use_audiodec = False cfg.inference.use_audiodec = False
print(str(e)) _logger.warning(str(e))
""" """
@cache @cache
@ -747,8 +750,8 @@ if __name__ == "__main__":
if args.print: if args.print:
torch.set_printoptions(profile="full") torch.set_printoptions(profile="full")
print( "Metadata:", artifact['metadata'] ) _logger.info(f"Metadata: {artifact['metadata']}" )
print( "Codes:", codes.shape, codes ) _logger.info(f"Codes: {codes.shape}, {codes}" )
# encode # encode
else: else:
args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension) args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension)

View File

@ -17,6 +17,9 @@ from ..models.lora import apply_lora, lora_load_state_dict
import torch import torch
import re import re
import logging
_logger = logging.getLogger(__name__)
deepspeed_available = False deepspeed_available = False
try: try:
@ -55,7 +58,7 @@ def load_engines(training=True, **model_kwargs):
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found:", load_path) _logger.warning("Checkpoint missing, but weights found:", load_path)
loads_state_dict = True loads_state_dict = True
# load state early # load state early
@ -64,7 +67,7 @@ def load_engines(training=True, **model_kwargs):
# check if config is defined in state, and re-initialize the model # check if config is defined in state, and re-initialize the model
if "config" in state and False: if "config" in state and False:
print("Model config definition in weights, re-loading...") _logger.warning("Model config definition in weights, re-loading...")
config_state = state["config"] config_state = state["config"]
model = get_model( config=cfg.model.__class__( *config_state ), training=training ) model = get_model( config=cfg.model.__class__( *config_state ), training=training )
@ -201,7 +204,7 @@ def load_engines(training=True, **model_kwargs):
if cfg.lora is not None: if cfg.lora is not None:
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] ) lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if lora_path.exists(): if lora_path.exists():
print( "Loaded LoRA state dict:", lora_path ) _logger.info( "Loaded LoRA state dict:", lora_path )
state = torch_load(lora_path, device=cfg.device) state = torch_load(lora_path, device=cfg.device)
state = state['lora' if 'lora' in state else 'module'] state = state['lora' if 'lora' in state else 'module']

View File

@ -367,7 +367,7 @@ class Engines(dict[str, Engine]):
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
torch_save(state_dict, save_path) torch_save(state_dict, save_path)
print(f"Exported {name} to {save_path}") _logger.info(f"Exported {name} to {save_path}")
def save_checkpoint(self, tag=None): def save_checkpoint(self, tag=None):
if not tag: if not tag:
@ -385,7 +385,7 @@ class Engines(dict[str, Engine]):
try: try:
engine.save_checkpoint(save_dir, tag=tag) engine.save_checkpoint(save_dir, tag=tag)
except Exception as e: except Exception as e:
print(f'Failed to save checkpoint for engine {name}:', str(e)) _logger.warning(f'Failed to save checkpoint for engine {name}:', str(e))
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None] # might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader(): if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
@ -395,7 +395,7 @@ class Engines(dict[str, Engine]):
for d in checkpoints: for d in checkpoints:
if not d.is_dir() or not d.exists(): if not d.is_dir() or not d.exists():
continue continue
print("Removing", d) _logger.info("Removing", d)
for p in d.iterdir(): for p in d.iterdir():
p.unlink() p.unlink()
d.rmdir() d.rmdir()
@ -490,7 +490,7 @@ class Engines(dict[str, Engine]):
res = feeder( engine=engine, batch=batch ) res = feeder( engine=engine, batch=batch )
break break
except RuntimeError as e: except RuntimeError as e:
print("Forward", str(e)) _logger.error("Forward", str(e))
if "out of memory" not in str(e): if "out of memory" not in str(e):
self.save_checkpoint() self.save_checkpoint()
@ -532,7 +532,7 @@ class Engines(dict[str, Engine]):
try: try:
engine.backward(loss) engine.backward(loss)
except RuntimeError as e: except RuntimeError as e:
print("Backwards:", str(e)) _logger.error("Backwards:", str(e))
if "out of memory" not in str(e): if "out of memory" not in str(e):
self.save_checkpoint() self.save_checkpoint()

View File

@ -117,7 +117,7 @@ class Engine(DeepSpeedEngine):
else: else:
self.optimizer.set_lr(lr) self.optimizer.set_lr(lr)
except Exception as e: except Exception as e:
print(str(e)) _logger.warning(str(e))
# we'll just have to live with the LoRA weights living within our main weights # we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways # they're easy to extract anyways

View File

@ -2,6 +2,9 @@ import torch
import torchaudio import torchaudio
import soundfile import soundfile
import time import time
import logging
_logger = logging.getLogger(__name__)
from torch import Tensor from torch import Tensor
from einops import rearrange from einops import rearrange
@ -31,14 +34,13 @@ class TTS():
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ): def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
if config: if config:
print("Loading YAML:", config) _logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config ) cfg.load_yaml( config )
try: try:
cfg.format( training=False ) cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
except Exception as e: except Exception as e:
print("Error while parsing config YAML:")
raise e # throw an error because I'm tired of silent errors messing things up for me raise e # throw an error because I'm tired of silent errors messing things up for me
if amp is None: if amp is None:
@ -73,7 +75,7 @@ class TTS():
self.engines.eval() self.engines.eval()
self.symmap = get_phone_symmap() self.symmap = get_phone_symmap()
print("Loaded model") _logger.info("Loaded model")
def encode_text( self, text, language="en" ): def encode_text( self, text, language="en" ):
# already a tensor, return it # already a tensor, return it

View File

@ -1,3 +1,6 @@
import logging
_logger = logging.getLogger(__name__)
def get_model(config, training=True, **model_kwargs): def get_model(config, training=True, **model_kwargs):
name = config.name name = config.name
@ -53,7 +56,7 @@ def get_model(config, training=True, **model_kwargs):
**model_kwargs **model_kwargs
) )
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") _logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
return model return model

View File

@ -15,6 +15,9 @@ import math
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from tqdm import trange from tqdm import trange
import logging
_logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding from ..emb.qnt import trim, encode_as_embedding
@ -379,7 +382,7 @@ def example_usage():
else: else:
raise ValueError(f"Unrecognized optimizer: {optimizer}") raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate) optimizer = optimizer(model.parameters(), lr=learning_rate)
@ -392,7 +395,7 @@ def example_usage():
scheduler = None scheduler = None
if scheduler is not None: if scheduler is not None:
print("Scheduler:", scheduler) _logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate ) optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear: if cfg.optimizations.replace and cfg.optimizations.linear:
@ -425,7 +428,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" ) }, f"./data/{cfg.model.arch_type}.pth" )
""" """
print(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") _logger.info(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.no_grad() @torch.no_grad()
def sample_data(task=None): def sample_data(task=None):

View File

@ -17,6 +17,9 @@ import math
from einops import rearrange from einops import rearrange
from torch import Tensor from torch import Tensor
from tqdm import trange from tqdm import trange
import logging
_logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding from ..emb.qnt import trim, encode_as_embedding
@ -434,7 +437,7 @@ def example_usage():
else: else:
raise ValueError(f"Unrecognized optimizer: {optimizer}") raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate) optimizer = optimizer(model.parameters(), lr=learning_rate)
@ -447,7 +450,7 @@ def example_usage():
scheduler = None scheduler = None
if scheduler is not None: if scheduler is not None:
print("Scheduler:", scheduler) _logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate ) optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear: if cfg.optimizations.replace and cfg.optimizations.linear:
@ -480,7 +483,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" ) }, f"./data/{cfg.model.arch_type}.pth" )
""" """
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") _logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.no_grad() @torch.no_grad()
def sample_data(task=None): def sample_data(task=None):

View File

@ -2,6 +2,8 @@
import math import math
import torch import torch
import logging
from typing import Literal, overload, Optional, Tuple from typing import Literal, overload, Optional, Tuple
from torch import Tensor, nn from torch import Tensor, nn
@ -10,6 +12,8 @@ from transformers.cache_utils import Cache
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
_logger = logging.getLogger(__name__)
AVAILABLE_ATTENTIONS = [] AVAILABLE_ATTENTIONS = []
try: try:
@ -18,7 +22,7 @@ try:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash_attention_2") AVAILABLE_ATTENTIONS.append("flash_attention_2")
except Exception as e: except Exception as e:
print("Error while querying for `flash_attention_2` support", e) _logger.warning(f"Error while querying for `flash_attention_2` support: {str(e)}")
try: try:
from .attention.fused import attention as _fused_attention from .attention.fused import attention as _fused_attention
@ -27,7 +31,7 @@ try:
AVAILABLE_ATTENTIONS.append("fused_attn") AVAILABLE_ATTENTIONS.append("fused_attn")
except Exception as e: except Exception as e:
print("Error while querying for `fused_attn` support", e) _logger.warning(f"Error while querying for `fused_attn` support: {str(e)}")
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count())) is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
@ -99,7 +103,7 @@ try:
has_flash_attn_with_paged = True has_flash_attn_with_paged = True
except Exception as e: except Exception as e:
raise e raise e
print("Error while querying for `flash_attn` support", e) _logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
try: try:
from xformers.ops.fmha import memory_efficient_attention from xformers.ops.fmha import memory_efficient_attention
@ -107,7 +111,7 @@ try:
AVAILABLE_ATTENTIONS.append("xformers") AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e: except Exception as e:
print("Error while importing `xformers`", e) _logger.warning(f"Error while importing `xformers`: {str(e)}")
# to-do: find a better way to query for if there's available kernels since these return true regardless # to-do: find a better way to query for if there's available kernels since these return true regardless
if torch.backends.cuda.flash_sdp_enabled(): if torch.backends.cuda.flash_sdp_enabled():

View File

@ -20,6 +20,9 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
import random import random
import math import math
import logging
_logger = logging.getLogger(__name__)
from einops import rearrange from einops import rearrange
from tqdm import trange from tqdm import trange
@ -502,7 +505,7 @@ def example_usage():
else: else:
raise ValueError(f"Unrecognized optimizer: {optimizer}") raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate) optimizer = optimizer(model.parameters(), lr=learning_rate)
@ -515,7 +518,7 @@ def example_usage():
scheduler = None scheduler = None
if scheduler is not None: if scheduler is not None:
print("Scheduler:", scheduler) _logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate ) optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear: if cfg.optimizations.replace and cfg.optimizations.linear:
@ -532,7 +535,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" ) }, f"./data/{cfg.model.arch_type}.pth" )
""" """
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") _logger.info(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode() @torch.inference_mode()
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ): def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):

View File

@ -19,6 +19,9 @@ from torch import Tensor
from tqdm import trange from tqdm import trange
from ..emb.qnt import trim from ..emb.qnt import trim
import logging
_logger = logging.getLogger(__name__)
class NAR(Base): class NAR(Base):
def forward( def forward(
@ -361,7 +364,7 @@ def example_usage():
else: else:
raise ValueError(f"Unrecognized optimizer: {optimizer}") raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate) _logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate) optimizer = optimizer(model.parameters(), lr=learning_rate)
@ -374,7 +377,7 @@ def example_usage():
scheduler = None scheduler = None
if scheduler is not None: if scheduler is not None:
print("Scheduler:", scheduler) _logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate ) optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear: if cfg.optimizations.replace and cfg.optimizations.linear:
@ -391,7 +394,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" ) }, f"./data/{cfg.model.arch_type}.pth" )
""" """
print(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") _logger.info(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode() @torch.inference_mode()
def sample( name, steps=1000 ): def sample( name, steps=1000 ):

View File

@ -133,9 +133,6 @@ def run_eval(engines, eval_name, dl):
} }
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl) #engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
if cfg.trainer.no_logger:
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
else:
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.") _logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
@ -160,8 +157,8 @@ def train():
run_eval(engines, "subtrain", subtrain_dl) run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl) run_eval(engines, "val", val_dl)
except Exception as e: except Exception as e:
print("Error occurred while performing eval:", str(e)) _logger.warning(f"Error occurred while performing eval: {str(e)}")
print(traceback.format_exc()) _logger.warning(traceback.format_exc())
engines.train() engines.train()
qnt.unload_model() qnt.unload_model()

View File

@ -20,7 +20,6 @@ def get_free_port():
_distributed_initialized = False _distributed_initialized = False
def init_distributed( fn, *args, **kwargs ): def init_distributed( fn, *args, **kwargs ):
#print("Initializing distributed...")
torch.cuda.set_device(local_rank()) torch.cuda.set_device(local_rank())
fn(*args, **kwargs) fn(*args, **kwargs)
_distributed_initialized = True _distributed_initialized = True
@ -29,8 +28,6 @@ def distributed_initialized():
return _distributed_initialized return _distributed_initialized
def cleanup_distributed(): def cleanup_distributed():
#if not _distributed_initialized:
# return
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()

View File

@ -177,9 +177,6 @@ def train(
except Exception as e: except Exception as e:
metrics = str(stats) metrics = str(stats)
if cfg.trainer.no_logger:
tqdm.write(f"Training Metrics: {truncate_json(metrics)}.")
else:
_logger.info(f"Training Metrics: {truncate_json(metrics)}.") _logger.info(f"Training Metrics: {truncate_json(metrics)}.")
command = _non_blocking_input() command = _non_blocking_input()
@ -220,9 +217,9 @@ def train(
rate = float(command.split(" ")[-1]) rate = float(command.split(" ")[-1])
try: try:
engines.set_lr(rate) engines.set_lr(rate)
print("Updating LR to:", rate) _logger.info(f"Updating LR to: {rate}")
except Exception as e: except Exception as e:
print("Failed to set LR rate to:", rate, str(e)) _logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
if "export" in command: if "export" in command:
train_dl.dataset.save_state_dict() train_dl.dataset.save_state_dict()

View File

@ -14,6 +14,9 @@ import random
import time import time
import psutil import psutil
import math import math
import logging
_logger = logging.getLogger(__name__)
from coloredlogs import ColoredFormatter from coloredlogs import ColoredFormatter
from logging import StreamHandler from logging import StreamHandler
@ -296,7 +299,7 @@ def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
) )
if verbose: if verbose:
print(f"Replacing {name}.{k} to", klass) _logger.info(f"Replacing {name}.{k} to: {klass}")
return model return model
@ -330,7 +333,7 @@ def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
) )
if verbose: if verbose:
print(f"Replacing {name}.{k} to", klass) _logger.info(f"Replacing {name}.{k} to: {klass}")
return model return model
@ -360,7 +363,7 @@ def replace_attention( model, klass, target, mode="math", verbose=False ):
) )
if verbose: if verbose:
print(f"Replacing {name}.{k} to", klass) _logger.info(f"Replacing {name}.{k} to: {klass}")
return model return model
@ -491,7 +494,7 @@ def get_model_offload_policy(module, policy=None):
# does not fit in budget, increase device index # does not fit in budget, increase device index
else: else:
device_index += 1 device_index += 1
print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB") _logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
# to-do: check that all modules are exhausted # to-do: check that all modules are exhausted
assert module_index >= len(modules) assert module_index >= len(modules)
@ -528,9 +531,9 @@ def offload_model( model, policy=None ):
if not not [*module.named_children()]: if not not [*module.named_children()]:
continue continue
try: try:
print( name, next(module.parameters()).device ) _logger.info( name, next(module.parameters()).device )
except Exception as e: except Exception as e:
print( name, "?" ) _logger.info( name, "?" )
pass pass
""" """

View File

@ -3,9 +3,12 @@ from contextlib import contextmanager
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import logging
from ..config import cfg from ..config import cfg
_logger = logging.getLogger(__name__)
Embedding = torch.nn.Embedding Embedding = torch.nn.Embedding
Linear = torch.nn.Linear Linear = torch.nn.Linear
@ -95,7 +98,7 @@ if cfg.optimizations.tensorrt:
import torch_tensorrt import torch_tensorrt
AVAILABLE_COMPILE_BACKENDS.append("tensorrt") AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
except Exception as e: except Exception as e:
print('Error while importing TensorRT:', str(e)) _logger.warning(f'Error while importing TensorRT: {str(e)}')
pass pass
def compile_model(model, backend="auto"): def compile_model(model, backend="auto"):
@ -111,14 +114,14 @@ def compile_model(model, backend="auto"):
try: try:
from prodigyopt import Prodigy from prodigyopt import Prodigy
except Exception as e: except Exception as e:
print('Error while importing Prodigyopt:', str(e)) _logger.warning(f'Error while importing Prodigyopt: {str(e)}')
pass pass
# https://github.com/facebookresearch/schedule_free/ # https://github.com/facebookresearch/schedule_free/
try: try:
import schedulefree import schedulefree
except Exception as e: except Exception as e:
print('Error while importing Schedule_Free:', str(e)) _logger.warning(f'Error while importing Schedule_Free: {str(e)}')
pass pass
# backwards compat # backwards compat