moved prints to use logger, edited readme (fused_attn doesnt seem stable for training)
This commit is contained in:
parent
d423bc03c2
commit
32287710a2
|
@ -161,7 +161,7 @@ For audio backends:
|
|||
* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention
|
||||
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
|
||||
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
|
||||
* `fused_attn`: uses an implementation using `triton` (only tested on my 7900XTX / Navi3 / gfx1100)
|
||||
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
|
||||
* `transformers` Llama\*Attention implementations:
|
||||
* `eager`: default `LlamaAttention`
|
||||
* `sdpa`: integrated `LlamaSdpaAttention` attention model
|
||||
|
|
|
@ -9,6 +9,7 @@ import time
|
|||
import argparse
|
||||
import yaml
|
||||
import random
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
@ -163,7 +164,8 @@ class Dataset:
|
|||
sample_order: str = "interleaved" # duration
|
||||
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
|
||||
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
|
||||
sample_shuffle: bool = True #
|
||||
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
|
||||
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
|
||||
|
||||
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
|
||||
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
|
||||
|
@ -364,6 +366,7 @@ class Model:
|
|||
|
||||
return dict(include=include, exclude=exclude)
|
||||
|
||||
# should be renamed to Adapters
|
||||
@dataclass()
|
||||
class LoRA:
|
||||
name: str = "lora" # vanity name
|
||||
|
@ -638,9 +641,6 @@ class Trainer:
|
|||
def scale_loss(self):
|
||||
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
|
||||
return self.dtype == torch.float16
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Inference:
|
||||
|
@ -670,7 +670,6 @@ class Inference:
|
|||
return torch.float8_e4m3fn
|
||||
return torch.float32
|
||||
|
||||
# should be renamed to optimizations
|
||||
@dataclass()
|
||||
class Optimizations:
|
||||
injects: bool = False # overwrites default torch classes (not recommended)
|
||||
|
@ -755,6 +754,7 @@ class Config(BaseConfig):
|
|||
|
||||
return self.models[0] if len(self.models) > 0 else None
|
||||
|
||||
# should be renamed to adapters
|
||||
@property
|
||||
def lora(self):
|
||||
for i, lora in enumerate(self.loras):
|
||||
|
@ -795,7 +795,7 @@ class Config(BaseConfig):
|
|||
try:
|
||||
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
|
||||
except Exception as e:
|
||||
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
|
||||
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
|
||||
self.dataset.use_hdf5 = False
|
||||
|
||||
# to-do: prune unused keys
|
||||
|
@ -923,7 +923,7 @@ class Config(BaseConfig):
|
|||
cfg.tokenizer = NaiveTokenizer()
|
||||
except Exception as e:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
print("Error while parsing tokenizer:", e)
|
||||
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
|
||||
pass
|
||||
|
||||
|
||||
|
@ -960,6 +960,7 @@ class NaiveTokenizer:
|
|||
# tokenize
|
||||
return [*map(symmap.get, phones)]
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
cfg = Config.from_cli()
|
||||
|
||||
|
@ -967,7 +968,7 @@ cfg = Config.from_cli()
|
|||
try:
|
||||
cfg.format()
|
||||
except Exception as e:
|
||||
print("Error while parsing config YAML:")
|
||||
_logger.error(f"Error while parsing config YAML: {str(e)}")
|
||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -495,7 +495,6 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
|
|||
|
||||
def _get_hdf5_path(path):
|
||||
# to-do: better validation
|
||||
#print(path)
|
||||
return str(path)
|
||||
|
||||
def _get_hdf5_paths( data_dir, type="training", validate=False ):
|
||||
|
@ -1543,12 +1542,6 @@ if __name__ == "__main__":
|
|||
|
||||
cfg.dataset.workers = 1
|
||||
|
||||
class LoggerOveride:
|
||||
def info(self, *args):
|
||||
print(*args)
|
||||
|
||||
_logger = LoggerOveride()
|
||||
|
||||
if args.action == "hdf5":
|
||||
create_dataset_hdf5()
|
||||
elif args.action == "list-dataset":
|
||||
|
@ -1559,7 +1552,7 @@ if __name__ == "__main__":
|
|||
continue
|
||||
dataset.append(f'{group}/{name}')
|
||||
|
||||
print(json.dumps(dataset))
|
||||
_logger.info(json.dumps(dataset))
|
||||
elif args.action == "metadata":
|
||||
create_dataset_metadata()
|
||||
elif args.action == "sample":
|
||||
|
@ -1581,17 +1574,17 @@ if __name__ == "__main__":
|
|||
try:
|
||||
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
|
||||
_logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
|
||||
try:
|
||||
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
|
||||
except Exception as e:
|
||||
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
|
||||
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
|
||||
v[i]['proms'][j] = v[i]['proms'][j].shape
|
||||
v[i]['resps'][j] = v[i]['resps'][j].shape
|
||||
|
||||
for k, v in samples.items():
|
||||
for i in range(len(v)):
|
||||
print(f'{k}[{i}]:', v[i])
|
||||
_logger.info(f'{k}[{i}]: {v[i]}')
|
||||
elif args.action == "validate":
|
||||
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
|
||||
|
||||
|
@ -1610,11 +1603,11 @@ if __name__ == "__main__":
|
|||
|
||||
phone = phonemes[i]
|
||||
|
||||
print( batch['text'], batch['metadata']['phonemes'] )
|
||||
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
|
||||
|
||||
missing |= set([phone])
|
||||
|
||||
print( "Missing tokens:", missing )
|
||||
_logger.info( f"Missing tokens: {missing}" )
|
||||
|
||||
|
||||
elif args.action == "tasks":
|
||||
|
@ -1628,13 +1621,13 @@ if __name__ == "__main__":
|
|||
if task not in cfg.dataset.tasks_list:
|
||||
continue
|
||||
|
||||
print(text, task, cfg.model.resp_levels)
|
||||
print( proms.shape, resps.shape )
|
||||
_logger.info( f'{text} {task} {cfg.model.resp_levels}')
|
||||
_logger.info( f'{proms.shape} {resps.shape}' )
|
||||
|
||||
tokens = 0
|
||||
tokens += sum([ text.shape[0] for text in batch["text"] ])
|
||||
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
print( tokens )
|
||||
_logger.info( f'{tokens}' )
|
||||
|
||||
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
|
||||
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )
|
||||
|
|
|
@ -18,6 +18,9 @@ Will also generate samples from a provided datset, if requested.
|
|||
import argparse
|
||||
import base64
|
||||
import random
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -117,9 +120,9 @@ def main():
|
|||
|
||||
samples_dirs["dataset"] = args.demo_dir / "dataset"
|
||||
|
||||
print("Loading dataloader...")
|
||||
_logger.info("Loading dataloader...")
|
||||
dataloader = create_train_dataloader()
|
||||
print("Loaded dataloader.")
|
||||
_logger.info("Loaded dataloader.")
|
||||
|
||||
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size
|
||||
|
||||
|
|
|
@ -9,6 +9,9 @@ import argparse
|
|||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from pathlib import Path
|
||||
|
@ -78,7 +81,7 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
|
|||
try:
|
||||
process_job( outpath, waveform, sample_rate, text, language )
|
||||
except Exception as e:
|
||||
print(f"Failed to quantize: {outpath}:", e)
|
||||
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
|
||||
if raise_exceptions:
|
||||
raise e
|
||||
continue
|
||||
|
@ -128,7 +131,7 @@ def process(
|
|||
|
||||
for group_name in sorted(os.listdir(f'./{input_audio}/')):
|
||||
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
|
||||
print("Is not dir:", f'./{input_audio}/{group_name}/')
|
||||
_logger.warning(f'Is not dir:" /{input_audio}/{group_name}/')
|
||||
continue
|
||||
|
||||
if group_name in ignore_groups:
|
||||
|
@ -138,7 +141,7 @@ def process(
|
|||
|
||||
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"):
|
||||
if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
|
||||
print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}')
|
||||
_logger.warning(f'Is not dir: ./{input_audio}/{group_name}/{speaker_id}')
|
||||
continue
|
||||
|
||||
if speaker_id in ignore_speakers:
|
||||
|
|
|
@ -6,6 +6,9 @@ import math
|
|||
import torch
|
||||
import torchaudio
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
@ -203,7 +206,7 @@ try:
|
|||
|
||||
except Exception as e:
|
||||
cfg.inference.use_dac = False
|
||||
print(str(e))
|
||||
_logger.warning(str(e))
|
||||
|
||||
# uses https://github.com/facebookresearch/AudioDec/
|
||||
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
|
||||
|
@ -213,7 +216,7 @@ try:
|
|||
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
|
||||
except Exception as e:
|
||||
cfg.inference.use_audiodec = False
|
||||
print(str(e))
|
||||
_logger.warning(str(e))
|
||||
"""
|
||||
|
||||
@cache
|
||||
|
@ -747,8 +750,8 @@ if __name__ == "__main__":
|
|||
if args.print:
|
||||
torch.set_printoptions(profile="full")
|
||||
|
||||
print( "Metadata:", artifact['metadata'] )
|
||||
print( "Codes:", codes.shape, codes )
|
||||
_logger.info(f"Metadata: {artifact['metadata']}" )
|
||||
_logger.info(f"Codes: {codes.shape}, {codes}" )
|
||||
# encode
|
||||
else:
|
||||
args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension)
|
||||
|
|
|
@ -17,6 +17,9 @@ from ..models.lora import apply_lora, lora_load_state_dict
|
|||
|
||||
import torch
|
||||
import re
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
deepspeed_available = False
|
||||
try:
|
||||
|
@ -55,7 +58,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found:", load_path)
|
||||
_logger.warning("Checkpoint missing, but weights found:", load_path)
|
||||
loads_state_dict = True
|
||||
|
||||
# load state early
|
||||
|
@ -64,7 +67,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
|
||||
# check if config is defined in state, and re-initialize the model
|
||||
if "config" in state and False:
|
||||
print("Model config definition in weights, re-loading...")
|
||||
_logger.warning("Model config definition in weights, re-loading...")
|
||||
config_state = state["config"]
|
||||
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
|
||||
|
||||
|
@ -201,7 +204,7 @@ def load_engines(training=True, **model_kwargs):
|
|||
if cfg.lora is not None:
|
||||
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
|
||||
if lora_path.exists():
|
||||
print( "Loaded LoRA state dict:", lora_path )
|
||||
_logger.info( "Loaded LoRA state dict:", lora_path )
|
||||
|
||||
state = torch_load(lora_path, device=cfg.device)
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
|
|
|
@ -367,7 +367,7 @@ class Engines(dict[str, Engine]):
|
|||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
||||
torch_save(state_dict, save_path)
|
||||
print(f"Exported {name} to {save_path}")
|
||||
_logger.info(f"Exported {name} to {save_path}")
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
|
@ -385,7 +385,7 @@ class Engines(dict[str, Engine]):
|
|||
try:
|
||||
engine.save_checkpoint(save_dir, tag=tag)
|
||||
except Exception as e:
|
||||
print(f'Failed to save checkpoint for engine {name}:', str(e))
|
||||
_logger.warning(f'Failed to save checkpoint for engine {name}:', str(e))
|
||||
|
||||
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
|
||||
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
|
||||
|
@ -395,7 +395,7 @@ class Engines(dict[str, Engine]):
|
|||
for d in checkpoints:
|
||||
if not d.is_dir() or not d.exists():
|
||||
continue
|
||||
print("Removing", d)
|
||||
_logger.info("Removing", d)
|
||||
for p in d.iterdir():
|
||||
p.unlink()
|
||||
d.rmdir()
|
||||
|
@ -490,7 +490,7 @@ class Engines(dict[str, Engine]):
|
|||
res = feeder( engine=engine, batch=batch )
|
||||
break
|
||||
except RuntimeError as e:
|
||||
print("Forward", str(e))
|
||||
_logger.error("Forward", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
|
@ -532,7 +532,7 @@ class Engines(dict[str, Engine]):
|
|||
try:
|
||||
engine.backward(loss)
|
||||
except RuntimeError as e:
|
||||
print("Backwards:", str(e))
|
||||
_logger.error("Backwards:", str(e))
|
||||
|
||||
if "out of memory" not in str(e):
|
||||
self.save_checkpoint()
|
||||
|
|
|
@ -117,7 +117,7 @@ class Engine(DeepSpeedEngine):
|
|||
else:
|
||||
self.optimizer.set_lr(lr)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
_logger.warning(str(e))
|
||||
|
||||
# we'll just have to live with the LoRA weights living within our main weights
|
||||
# they're easy to extract anyways
|
||||
|
|
|
@ -2,6 +2,9 @@ import torch
|
|||
import torchaudio
|
||||
import soundfile
|
||||
import time
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from torch import Tensor
|
||||
from einops import rearrange
|
||||
|
@ -31,14 +34,13 @@ class TTS():
|
|||
|
||||
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
|
||||
if config:
|
||||
print("Loading YAML:", config)
|
||||
_logger.info(f"Loading YAML: {config}")
|
||||
cfg.load_yaml( config )
|
||||
|
||||
try:
|
||||
cfg.format( training=False )
|
||||
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
|
||||
except Exception as e:
|
||||
print("Error while parsing config YAML:")
|
||||
raise e # throw an error because I'm tired of silent errors messing things up for me
|
||||
|
||||
if amp is None:
|
||||
|
@ -73,7 +75,7 @@ class TTS():
|
|||
|
||||
self.engines.eval()
|
||||
self.symmap = get_phone_symmap()
|
||||
print("Loaded model")
|
||||
_logger.info("Loaded model")
|
||||
|
||||
def encode_text( self, text, language="en" ):
|
||||
# already a tensor, return it
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
def get_model(config, training=True, **model_kwargs):
|
||||
name = config.name
|
||||
|
@ -53,7 +56,7 @@ def get_model(config, training=True, **model_kwargs):
|
|||
**model_kwargs
|
||||
)
|
||||
|
||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
@ -15,6 +15,9 @@ import math
|
|||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from ..emb.qnt import trim, encode_as_embedding
|
||||
|
||||
|
@ -379,7 +382,7 @@ def example_usage():
|
|||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
|
||||
|
@ -392,7 +395,7 @@ def example_usage():
|
|||
scheduler = None
|
||||
|
||||
if scheduler is not None:
|
||||
print("Scheduler:", scheduler)
|
||||
_logger.info(f"Scheduler: {scheduler}")
|
||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
|
@ -425,7 +428,7 @@ def example_usage():
|
|||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
_logger.info(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_data(task=None):
|
||||
|
|
|
@ -17,6 +17,9 @@ import math
|
|||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from tqdm import trange
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from ..emb.qnt import trim, encode_as_embedding
|
||||
|
||||
|
@ -434,7 +437,7 @@ def example_usage():
|
|||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
|
||||
|
@ -447,7 +450,7 @@ def example_usage():
|
|||
scheduler = None
|
||||
|
||||
if scheduler is not None:
|
||||
print("Scheduler:", scheduler)
|
||||
_logger.info(f"Scheduler: {scheduler}")
|
||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
|
@ -480,7 +483,7 @@ def example_usage():
|
|||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
_logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_data(task=None):
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
import math
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from typing import Literal, overload, Optional, Tuple
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
@ -10,6 +12,8 @@ from transformers.cache_utils import Cache
|
|||
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
AVAILABLE_ATTENTIONS = []
|
||||
|
||||
try:
|
||||
|
@ -18,7 +22,7 @@ try:
|
|||
if is_flash_attn_2_available():
|
||||
AVAILABLE_ATTENTIONS.append("flash_attention_2")
|
||||
except Exception as e:
|
||||
print("Error while querying for `flash_attention_2` support", e)
|
||||
_logger.warning(f"Error while querying for `flash_attention_2` support: {str(e)}")
|
||||
|
||||
try:
|
||||
from .attention.fused import attention as _fused_attention
|
||||
|
@ -27,7 +31,7 @@ try:
|
|||
|
||||
AVAILABLE_ATTENTIONS.append("fused_attn")
|
||||
except Exception as e:
|
||||
print("Error while querying for `fused_attn` support", e)
|
||||
_logger.warning(f"Error while querying for `fused_attn` support: {str(e)}")
|
||||
|
||||
|
||||
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
|
||||
|
@ -99,7 +103,7 @@ try:
|
|||
has_flash_attn_with_paged = True
|
||||
except Exception as e:
|
||||
raise e
|
||||
print("Error while querying for `flash_attn` support", e)
|
||||
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
|
||||
|
||||
try:
|
||||
from xformers.ops.fmha import memory_efficient_attention
|
||||
|
@ -107,7 +111,7 @@ try:
|
|||
|
||||
AVAILABLE_ATTENTIONS.append("xformers")
|
||||
except Exception as e:
|
||||
print("Error while importing `xformers`", e)
|
||||
_logger.warning(f"Error while importing `xformers`: {str(e)}")
|
||||
|
||||
# to-do: find a better way to query for if there's available kernels since these return true regardless
|
||||
if torch.backends.cuda.flash_sdp_enabled():
|
||||
|
|
|
@ -20,6 +20,9 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
|
|||
|
||||
import random
|
||||
import math
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from einops import rearrange
|
||||
from tqdm import trange
|
||||
|
@ -502,7 +505,7 @@ def example_usage():
|
|||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
|
||||
|
@ -515,7 +518,7 @@ def example_usage():
|
|||
scheduler = None
|
||||
|
||||
if scheduler is not None:
|
||||
print("Scheduler:", scheduler)
|
||||
_logger.info(f"Scheduler: {scheduler}")
|
||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
|
@ -532,7 +535,7 @@ def example_usage():
|
|||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
_logger.info(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):
|
||||
|
|
|
@ -19,6 +19,9 @@ from torch import Tensor
|
|||
from tqdm import trange
|
||||
|
||||
from ..emb.qnt import trim
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
class NAR(Base):
|
||||
def forward(
|
||||
|
@ -361,7 +364,7 @@ def example_usage():
|
|||
else:
|
||||
raise ValueError(f"Unrecognized optimizer: {optimizer}")
|
||||
|
||||
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
|
||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||
|
||||
optimizer = optimizer(model.parameters(), lr=learning_rate)
|
||||
|
||||
|
@ -374,7 +377,7 @@ def example_usage():
|
|||
scheduler = None
|
||||
|
||||
if scheduler is not None:
|
||||
print("Scheduler:", scheduler)
|
||||
_logger.info(f"Scheduler: {scheduler}")
|
||||
optimizer = scheduler( model.parameters(), lr = learning_rate )
|
||||
|
||||
if cfg.optimizations.replace and cfg.optimizations.linear:
|
||||
|
@ -391,7 +394,7 @@ def example_usage():
|
|||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
"""
|
||||
|
||||
print(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
_logger.info(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample( name, steps=1000 ):
|
||||
|
|
|
@ -133,10 +133,7 @@ def run_eval(engines, eval_name, dl):
|
|||
}
|
||||
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
|
||||
|
||||
if cfg.trainer.no_logger:
|
||||
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
else:
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
|
||||
|
||||
|
||||
def train():
|
||||
|
@ -160,8 +157,8 @@ def train():
|
|||
run_eval(engines, "subtrain", subtrain_dl)
|
||||
run_eval(engines, "val", val_dl)
|
||||
except Exception as e:
|
||||
print("Error occurred while performing eval:", str(e))
|
||||
print(traceback.format_exc())
|
||||
_logger.warning(f"Error occurred while performing eval: {str(e)}")
|
||||
_logger.warning(traceback.format_exc())
|
||||
|
||||
engines.train()
|
||||
qnt.unload_model()
|
||||
|
|
|
@ -20,7 +20,6 @@ def get_free_port():
|
|||
|
||||
_distributed_initialized = False
|
||||
def init_distributed( fn, *args, **kwargs ):
|
||||
#print("Initializing distributed...")
|
||||
torch.cuda.set_device(local_rank())
|
||||
fn(*args, **kwargs)
|
||||
_distributed_initialized = True
|
||||
|
@ -29,8 +28,6 @@ def distributed_initialized():
|
|||
return _distributed_initialized
|
||||
|
||||
def cleanup_distributed():
|
||||
#if not _distributed_initialized:
|
||||
# return
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
|
|
@ -177,10 +177,7 @@ def train(
|
|||
except Exception as e:
|
||||
metrics = str(stats)
|
||||
|
||||
if cfg.trainer.no_logger:
|
||||
tqdm.write(f"Training Metrics: {truncate_json(metrics)}.")
|
||||
else:
|
||||
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
||||
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
|
||||
|
||||
command = _non_blocking_input()
|
||||
|
||||
|
@ -220,9 +217,9 @@ def train(
|
|||
rate = float(command.split(" ")[-1])
|
||||
try:
|
||||
engines.set_lr(rate)
|
||||
print("Updating LR to:", rate)
|
||||
_logger.info(f"Updating LR to: {rate}")
|
||||
except Exception as e:
|
||||
print("Failed to set LR rate to:", rate, str(e))
|
||||
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
|
||||
|
||||
if "export" in command:
|
||||
train_dl.dataset.save_state_dict()
|
||||
|
|
|
@ -14,6 +14,9 @@ import random
|
|||
import time
|
||||
import psutil
|
||||
import math
|
||||
import logging
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from coloredlogs import ColoredFormatter
|
||||
from logging import StreamHandler
|
||||
|
@ -296,7 +299,7 @@ def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
|
|||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Replacing {name}.{k} to", klass)
|
||||
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||
|
||||
return model
|
||||
|
||||
|
@ -330,7 +333,7 @@ def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
|
|||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Replacing {name}.{k} to", klass)
|
||||
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||
|
||||
return model
|
||||
|
||||
|
@ -360,7 +363,7 @@ def replace_attention( model, klass, target, mode="math", verbose=False ):
|
|||
)
|
||||
|
||||
if verbose:
|
||||
print(f"Replacing {name}.{k} to", klass)
|
||||
_logger.info(f"Replacing {name}.{k} to: {klass}")
|
||||
|
||||
return model
|
||||
|
||||
|
@ -491,7 +494,7 @@ def get_model_offload_policy(module, policy=None):
|
|||
# does not fit in budget, increase device index
|
||||
else:
|
||||
device_index += 1
|
||||
print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
|
||||
_logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
|
||||
|
||||
# to-do: check that all modules are exhausted
|
||||
assert module_index >= len(modules)
|
||||
|
@ -528,9 +531,9 @@ def offload_model( model, policy=None ):
|
|||
if not not [*module.named_children()]:
|
||||
continue
|
||||
try:
|
||||
print( name, next(module.parameters()).device )
|
||||
_logger.info( name, next(module.parameters()).device )
|
||||
except Exception as e:
|
||||
print( name, "?" )
|
||||
_logger.info( name, "?" )
|
||||
pass
|
||||
"""
|
||||
|
||||
|
|
|
@ -3,9 +3,12 @@ from contextlib import contextmanager
|
|||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import logging
|
||||
|
||||
from ..config import cfg
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
Embedding = torch.nn.Embedding
|
||||
Linear = torch.nn.Linear
|
||||
|
||||
|
@ -95,7 +98,7 @@ if cfg.optimizations.tensorrt:
|
|||
import torch_tensorrt
|
||||
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
|
||||
except Exception as e:
|
||||
print('Error while importing TensorRT:', str(e))
|
||||
_logger.warning(f'Error while importing TensorRT: {str(e)}')
|
||||
pass
|
||||
|
||||
def compile_model(model, backend="auto"):
|
||||
|
@ -111,14 +114,14 @@ def compile_model(model, backend="auto"):
|
|||
try:
|
||||
from prodigyopt import Prodigy
|
||||
except Exception as e:
|
||||
print('Error while importing Prodigyopt:', str(e))
|
||||
_logger.warning(f'Error while importing Prodigyopt: {str(e)}')
|
||||
pass
|
||||
|
||||
# https://github.com/facebookresearch/schedule_free/
|
||||
try:
|
||||
import schedulefree
|
||||
except Exception as e:
|
||||
print('Error while importing Schedule_Free:', str(e))
|
||||
_logger.warning(f'Error while importing Schedule_Free: {str(e)}')
|
||||
pass
|
||||
|
||||
# backwards compat
|
||||
|
|
Loading…
Reference in New Issue
Block a user