moved prints to use logger, edited readme (fused_attn doesnt seem stable for training)

This commit is contained in:
mrq 2024-08-29 13:27:16 -05:00
parent d423bc03c2
commit 32287710a2
21 changed files with 111 additions and 87 deletions

View File

@ -161,7 +161,7 @@ For audio backends:
* `xformers`: [facebookresearch/xformers](https://github.com/facebookresearch/xformers/)'s memory efficient attention
* `flash_attn`: uses the available `flash_attn` package (including `flash_attn==1.0.9` through a funny wrapper)
* `flash_attn_v100`: uses [ZRayZzz/flash-attention-v100](https://github.com/ZRayZzz/flash-attention-v100/)'s Flash Attention for Volta (but doesn't work currently)
* `fused_attn`: uses an implementation using `triton` (only tested on my 7900XTX / Navi3 / gfx1100)
* `fused_attn`: uses an implementation using `triton` (tested on my 7900XTX and V100s), but seems to introduce errors when used to train after a while
* `transformers` Llama\*Attention implementations:
* `eager`: default `LlamaAttention`
* `sdpa`: integrated `LlamaSdpaAttention` attention model

View File

@ -9,6 +9,7 @@ import time
import argparse
import yaml
import random
import logging
import torch
import numpy as np
@ -163,7 +164,8 @@ class Dataset:
sample_order: str = "interleaved" # duration
sample_max_duration_batch: float = 0.0 # total number of seconds of utterances per batched, 0 to disable
# for a full sized model with 12GiB of VRAM for Encodec, 120 seconds is just enough
sample_shuffle: bool = True #
# for a full sized model with 24GiB of VRAM for Encodec, 380 seconds is 80% VRAM consumed (but it might be limited by batch size)
sample_shuffle: bool = True # i swear this is spiking the loss when sample_order = duration + sample_max_duration_batch > 0
tasks_list: list[str] = field(default_factory=lambda: ["tts"]) # list of tasks to train against
reencode_on_concat: bool = False # whether to concat audio by decode => concat => encode, or naively concat codes
@ -364,6 +366,7 @@ class Model:
return dict(include=include, exclude=exclude)
# should be renamed to Adapters
@dataclass()
class LoRA:
name: str = "lora" # vanity name
@ -638,9 +641,6 @@ class Trainer:
def scale_loss(self):
# currently cannot feasibly apply loss scaling with DeepSpeed backend (it can handle it itself anyways)
return self.dtype == torch.float16
"""
"""
@dataclass()
class Inference:
@ -670,7 +670,6 @@ class Inference:
return torch.float8_e4m3fn
return torch.float32
# should be renamed to optimizations
@dataclass()
class Optimizations:
injects: bool = False # overwrites default torch classes (not recommended)
@ -755,6 +754,7 @@ class Config(BaseConfig):
return self.models[0] if len(self.models) > 0 else None
# should be renamed to adapters
@property
def lora(self):
for i, lora in enumerate(self.loras):
@ -795,7 +795,7 @@ class Config(BaseConfig):
try:
self.hdf5 = h5py.File(f'{self.rel_path}/{self.dataset.hdf5_name}', 'a' if write else self.dataset.hdf5_flag) # to-do, have an easy to set flag that determines if training or creating the dataset
except Exception as e:
print("Error while opening HDF5 file:", f'{self.rel_path}/{self.dataset.hdf5_name}', str(e))
_logger.warning(f"Error while opening HDF5 file: {self.rel_path}/{self.dataset.hdf5_name}: {str(e)}")
self.dataset.use_hdf5 = False
# to-do: prune unused keys
@ -923,7 +923,7 @@ class Config(BaseConfig):
cfg.tokenizer = NaiveTokenizer()
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
print("Error while parsing tokenizer:", e)
_logger.warning(f"Error while parsing tokenizer: {str(e)}")
pass
@ -960,6 +960,7 @@ class NaiveTokenizer:
# tokenize
return [*map(symmap.get, phones)]
_logger = logging.getLogger(__name__)
cfg = Config.from_cli()
@ -967,7 +968,7 @@ cfg = Config.from_cli()
try:
cfg.format()
except Exception as e:
print("Error while parsing config YAML:")
_logger.error(f"Error while parsing config YAML: {str(e)}")
raise e # throw an error because I'm tired of silent errors messing things up for me
if __name__ == "__main__":

View File

@ -495,7 +495,6 @@ def _load_paths_from_metadata(group_name, type="training", validate=False):
def _get_hdf5_path(path):
# to-do: better validation
#print(path)
return str(path)
def _get_hdf5_paths( data_dir, type="training", validate=False ):
@ -1543,12 +1542,6 @@ if __name__ == "__main__":
cfg.dataset.workers = 1
class LoggerOveride:
def info(self, *args):
print(*args)
_logger = LoggerOveride()
if args.action == "hdf5":
create_dataset_hdf5()
elif args.action == "list-dataset":
@ -1559,7 +1552,7 @@ if __name__ == "__main__":
continue
dataset.append(f'{group}/{name}')
print(json.dumps(dataset))
_logger.info(json.dumps(dataset))
elif args.action == "metadata":
create_dataset_metadata()
elif args.action == "sample":
@ -1581,17 +1574,17 @@ if __name__ == "__main__":
try:
decode_to_file( v[i]['proms'][j], f"./data/sample-test/{k}.{i}.{j}.proms.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding prom {k}.{i}.{j}.wav:", str(e))
_logger.info(f"Error while decoding prom {k}.{i}.{j}.wav: {str(e)}")
try:
decode_to_file( v[i]['resps'][j], f"./data/sample-test/{k}.{i}.{j}.resps.wav", device="cpu" )
except Exception as e:
print(f"Error while decoding resp {k}.{i}.{j}.wav:", str(e))
_logger.info(f"Error while decoding resp {k}.{i}.{j}.wav: {str(e)}")
v[i]['proms'][j] = v[i]['proms'][j].shape
v[i]['resps'][j] = v[i]['resps'][j].shape
for k, v in samples.items():
for i in range(len(v)):
print(f'{k}[{i}]:', v[i])
_logger.info(f'{k}[{i}]: {v[i]}')
elif args.action == "validate":
train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
@ -1610,11 +1603,11 @@ if __name__ == "__main__":
phone = phonemes[i]
print( batch['text'], batch['metadata']['phonemes'] )
_logger.info( f"{batch['text']}: {batch['metadata']['phonemes']}" )
missing |= set([phone])
print( "Missing tokens:", missing )
_logger.info( f"Missing tokens: {missing}" )
elif args.action == "tasks":
@ -1628,13 +1621,13 @@ if __name__ == "__main__":
if task not in cfg.dataset.tasks_list:
continue
print(text, task, cfg.model.resp_levels)
print( proms.shape, resps.shape )
_logger.info( f'{text} {task} {cfg.model.resp_levels}')
_logger.info( f'{proms.shape} {resps.shape}' )
tokens = 0
tokens += sum([ text.shape[0] for text in batch["text"] ])
tokens += sum([ resps.shape[0] for resps in batch["resps"] ])
print( tokens )
_logger.info( f'{tokens}' )
decode_to_file( proms, f"./data/{task}.proms.wav", device="cpu" )
decode_to_file( resps, f"./data/{task}.resps.wav", device="cpu" )

View File

@ -18,6 +18,9 @@ Will also generate samples from a provided datset, if requested.
import argparse
import base64
import random
import logging
_logger = logging.getLogger(__name__)
from pathlib import Path
@ -117,9 +120,9 @@ def main():
samples_dirs["dataset"] = args.demo_dir / "dataset"
print("Loading dataloader...")
_logger.info("Loading dataloader...")
dataloader = create_train_dataloader()
print("Loaded dataloader.")
_logger.info("Loaded dataloader.")
num = args.dataset_samples if args.dataset_samples else cfg.evaluation.size

View File

@ -9,6 +9,9 @@ import argparse
import torch
import torchaudio
import numpy as np
import logging
_logger = logging.getLogger(__name__)
from tqdm.auto import tqdm
from pathlib import Path
@ -78,7 +81,7 @@ def process_jobs( jobs, speaker_id="", raise_exceptions=True ):
try:
process_job( outpath, waveform, sample_rate, text, language )
except Exception as e:
print(f"Failed to quantize: {outpath}:", e)
_logger.error(f"Failed to quantize: {outpath}: {str(e)}")
if raise_exceptions:
raise e
continue
@ -128,7 +131,7 @@ def process(
for group_name in sorted(os.listdir(f'./{input_audio}/')):
if not os.path.isdir(f'./{input_audio}/{group_name}/'):
print("Is not dir:", f'./{input_audio}/{group_name}/')
_logger.warning(f'Is not dir:" /{input_audio}/{group_name}/')
continue
if group_name in ignore_groups:
@ -138,7 +141,7 @@ def process(
for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"):
if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}')
_logger.warning(f'Is not dir: ./{input_audio}/{group_name}/{speaker_id}')
continue
if speaker_id in ignore_speakers:

View File

@ -6,6 +6,9 @@ import math
import torch
import torchaudio
import numpy as np
import logging
_logger = logging.getLogger(__name__)
from functools import cache
from pathlib import Path
@ -203,7 +206,7 @@ try:
except Exception as e:
cfg.inference.use_dac = False
print(str(e))
_logger.warning(str(e))
# uses https://github.com/facebookresearch/AudioDec/
# I have set up a pip-ify'd version with the caveat of having to manually handle downloading the checkpoints with a wget + unzip
@ -213,7 +216,7 @@ try:
from audiodec.utils.audiodec import AudioDec, assign_model as _audiodec_assign_model
except Exception as e:
cfg.inference.use_audiodec = False
print(str(e))
_logger.warning(str(e))
"""
@cache
@ -747,8 +750,8 @@ if __name__ == "__main__":
if args.print:
torch.set_printoptions(profile="full")
print( "Metadata:", artifact['metadata'] )
print( "Codes:", codes.shape, codes )
_logger.info(f"Metadata: {artifact['metadata']}" )
_logger.info(f"Codes: {codes.shape}, {codes}" )
# encode
else:
args.output = args.input.with_suffix(audio_extension) if not args.output else args.output.with_suffix(audio_extension)

View File

@ -17,6 +17,9 @@ from ..models.lora import apply_lora, lora_load_state_dict
import torch
import re
import logging
_logger = logging.getLogger(__name__)
deepspeed_available = False
try:
@ -55,7 +58,7 @@ def load_engines(training=True, **model_kwargs):
checkpoint_path = pick_path( checkpoint_path.parent / tag / f"state.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found:", load_path)
_logger.warning("Checkpoint missing, but weights found:", load_path)
loads_state_dict = True
# load state early
@ -64,7 +67,7 @@ def load_engines(training=True, **model_kwargs):
# check if config is defined in state, and re-initialize the model
if "config" in state and False:
print("Model config definition in weights, re-loading...")
_logger.warning("Model config definition in weights, re-loading...")
config_state = state["config"]
model = get_model( config=cfg.model.__class__( *config_state ), training=training )
@ -201,7 +204,7 @@ def load_engines(training=True, **model_kwargs):
if cfg.lora is not None:
lora_path = pick_path( cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}", *[ f'.{format}' for format in cfg.supported_weights_formats] )
if lora_path.exists():
print( "Loaded LoRA state dict:", lora_path )
_logger.info( "Loaded LoRA state dict:", lora_path )
state = torch_load(lora_path, device=cfg.device)
state = state['lora' if 'lora' in state else 'module']

View File

@ -367,7 +367,7 @@ class Engines(dict[str, Engine]):
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
torch_save(state_dict, save_path)
print(f"Exported {name} to {save_path}")
_logger.info(f"Exported {name} to {save_path}")
def save_checkpoint(self, tag=None):
if not tag:
@ -385,7 +385,7 @@ class Engines(dict[str, Engine]):
try:
engine.save_checkpoint(save_dir, tag=tag)
except Exception as e:
print(f'Failed to save checkpoint for engine {name}:', str(e))
_logger.warning(f'Failed to save checkpoint for engine {name}:', str(e))
# might be better to prune before saving for safety, but [:0] returns an empty list, but I could do [:-cfg.trainer.keep_last_checkpoints - 1 if cfg.trainer.keep_last_checkpoints > 1 else None]
if cfg.trainer.keep_last_checkpoints > 0 and is_global_leader():
@ -395,7 +395,7 @@ class Engines(dict[str, Engine]):
for d in checkpoints:
if not d.is_dir() or not d.exists():
continue
print("Removing", d)
_logger.info("Removing", d)
for p in d.iterdir():
p.unlink()
d.rmdir()
@ -490,7 +490,7 @@ class Engines(dict[str, Engine]):
res = feeder( engine=engine, batch=batch )
break
except RuntimeError as e:
print("Forward", str(e))
_logger.error("Forward", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()
@ -532,7 +532,7 @@ class Engines(dict[str, Engine]):
try:
engine.backward(loss)
except RuntimeError as e:
print("Backwards:", str(e))
_logger.error("Backwards:", str(e))
if "out of memory" not in str(e):
self.save_checkpoint()

View File

@ -117,7 +117,7 @@ class Engine(DeepSpeedEngine):
else:
self.optimizer.set_lr(lr)
except Exception as e:
print(str(e))
_logger.warning(str(e))
# we'll just have to live with the LoRA weights living within our main weights
# they're easy to extract anyways

View File

@ -2,6 +2,9 @@ import torch
import torchaudio
import soundfile
import time
import logging
_logger = logging.getLogger(__name__)
from torch import Tensor
from einops import rearrange
@ -31,14 +34,13 @@ class TTS():
def load_config( self, config=None, device=None, amp=None, dtype=None, attention=None ):
if config:
print("Loading YAML:", config)
_logger.info(f"Loading YAML: {config}")
cfg.load_yaml( config )
try:
cfg.format( training=False )
cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing
except Exception as e:
print("Error while parsing config YAML:")
raise e # throw an error because I'm tired of silent errors messing things up for me
if amp is None:
@ -73,7 +75,7 @@ class TTS():
self.engines.eval()
self.symmap = get_phone_symmap()
print("Loaded model")
_logger.info("Loaded model")
def encode_text( self, text, language="en" ):
# already a tensor, return it

View File

@ -1,3 +1,6 @@
import logging
_logger = logging.getLogger(__name__)
def get_model(config, training=True, **model_kwargs):
name = config.name
@ -53,7 +56,7 @@ def get_model(config, training=True, **model_kwargs):
**model_kwargs
)
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
_logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
return model

View File

@ -15,6 +15,9 @@ import math
from einops import rearrange
from torch import Tensor
from tqdm import trange
import logging
_logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding
@ -379,7 +382,7 @@ def example_usage():
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate)
@ -392,7 +395,7 @@ def example_usage():
scheduler = None
if scheduler is not None:
print("Scheduler:", scheduler)
_logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
@ -425,7 +428,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" )
"""
print(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
_logger.info(f"AR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.no_grad()
def sample_data(task=None):

View File

@ -17,6 +17,9 @@ import math
from einops import rearrange
from torch import Tensor
from tqdm import trange
import logging
_logger = logging.getLogger(__name__)
from ..emb.qnt import trim, encode_as_embedding
@ -434,7 +437,7 @@ def example_usage():
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate)
@ -447,7 +450,7 @@ def example_usage():
scheduler = None
if scheduler is not None:
print("Scheduler:", scheduler)
_logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
@ -480,7 +483,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" )
"""
print(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
_logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.no_grad()
def sample_data(task=None):

View File

@ -2,6 +2,8 @@
import math
import torch
import logging
from typing import Literal, overload, Optional, Tuple
from torch import Tensor, nn
@ -10,6 +12,8 @@ from transformers.cache_utils import Cache
from transformers import LlamaModel, LlamaConfig, LlamaForCausalLM
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
_logger = logging.getLogger(__name__)
AVAILABLE_ATTENTIONS = []
try:
@ -18,7 +22,7 @@ try:
if is_flash_attn_2_available():
AVAILABLE_ATTENTIONS.append("flash_attention_2")
except Exception as e:
print("Error while querying for `flash_attention_2` support", e)
_logger.warning(f"Error while querying for `flash_attention_2` support: {str(e)}")
try:
from .attention.fused import attention as _fused_attention
@ -27,7 +31,7 @@ try:
AVAILABLE_ATTENTIONS.append("fused_attn")
except Exception as e:
print("Error while querying for `fused_attn` support", e)
_logger.warning(f"Error while querying for `fused_attn` support: {str(e)}")
is_rocm = any("AMD" in torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count()))
@ -99,7 +103,7 @@ try:
has_flash_attn_with_paged = True
except Exception as e:
raise e
print("Error while querying for `flash_attn` support", e)
_logger.warning(f"Error while querying for `flash_attn` support: {str(e)}")
try:
from xformers.ops.fmha import memory_efficient_attention
@ -107,7 +111,7 @@ try:
AVAILABLE_ATTENTIONS.append("xformers")
except Exception as e:
print("Error while importing `xformers`", e)
_logger.warning(f"Error while importing `xformers`: {str(e)}")
# to-do: find a better way to query for if there's available kernels since these return true regardless
if torch.backends.cuda.flash_sdp_enabled():

View File

@ -20,6 +20,9 @@ from torchmetrics.classification import BinaryAccuracy, MulticlassAccuracy, Mult
import random
import math
import logging
_logger = logging.getLogger(__name__)
from einops import rearrange
from tqdm import trange
@ -502,7 +505,7 @@ def example_usage():
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate)
@ -515,7 +518,7 @@ def example_usage():
scheduler = None
if scheduler is not None:
print("Scheduler:", scheduler)
_logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
@ -532,7 +535,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" )
"""
print(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
_logger.info(f"{LlmArchClass} parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=cfg.model.max_levels*cfg.dataset.frames_per_second*6 ):

View File

@ -19,6 +19,9 @@ from torch import Tensor
from tqdm import trange
from ..emb.qnt import trim
import logging
_logger = logging.getLogger(__name__)
class NAR(Base):
def forward(
@ -361,7 +364,7 @@ def example_usage():
else:
raise ValueError(f"Unrecognized optimizer: {optimizer}")
print("Optimizer:", optimizer, "\tLearning rate:", learning_rate)
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
optimizer = optimizer(model.parameters(), lr=learning_rate)
@ -374,7 +377,7 @@ def example_usage():
scheduler = None
if scheduler is not None:
print("Scheduler:", scheduler)
_logger.info(f"Scheduler: {scheduler}")
optimizer = scheduler( model.parameters(), lr = learning_rate )
if cfg.optimizations.replace and cfg.optimizations.linear:
@ -391,7 +394,7 @@ def example_usage():
}, f"./data/{cfg.model.arch_type}.pth" )
"""
print(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
_logger.info(f"NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=1000 ):

View File

@ -133,10 +133,7 @@ def run_eval(engines, eval_name, dl):
}
#engines_stats['epoch'] = iteration * cfg.hyperparameters.gradient_accumulation_steps / len(dl)
if cfg.trainer.no_logger:
tqdm.write(f"Validation Metrics: {json.dumps(engines_stats)}.")
else:
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
_logger.info(f"Validation Metrics: {json.dumps(engines_stats)}.")
def train():
@ -160,8 +157,8 @@ def train():
run_eval(engines, "subtrain", subtrain_dl)
run_eval(engines, "val", val_dl)
except Exception as e:
print("Error occurred while performing eval:", str(e))
print(traceback.format_exc())
_logger.warning(f"Error occurred while performing eval: {str(e)}")
_logger.warning(traceback.format_exc())
engines.train()
qnt.unload_model()

View File

@ -20,7 +20,6 @@ def get_free_port():
_distributed_initialized = False
def init_distributed( fn, *args, **kwargs ):
#print("Initializing distributed...")
torch.cuda.set_device(local_rank())
fn(*args, **kwargs)
_distributed_initialized = True
@ -29,8 +28,6 @@ def distributed_initialized():
return _distributed_initialized
def cleanup_distributed():
#if not _distributed_initialized:
# return
dist.barrier()
dist.destroy_process_group()

View File

@ -177,10 +177,7 @@ def train(
except Exception as e:
metrics = str(stats)
if cfg.trainer.no_logger:
tqdm.write(f"Training Metrics: {truncate_json(metrics)}.")
else:
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
_logger.info(f"Training Metrics: {truncate_json(metrics)}.")
command = _non_blocking_input()
@ -220,9 +217,9 @@ def train(
rate = float(command.split(" ")[-1])
try:
engines.set_lr(rate)
print("Updating LR to:", rate)
_logger.info(f"Updating LR to: {rate}")
except Exception as e:
print("Failed to set LR rate to:", rate, str(e))
_logger.warning(f"Failed to set LR rate to: {rate}, {str(e)}")
if "export" in command:
train_dl.dataset.save_state_dict()

View File

@ -14,6 +14,9 @@ import random
import time
import psutil
import math
import logging
_logger = logging.getLogger(__name__)
from coloredlogs import ColoredFormatter
from logging import StreamHandler
@ -296,7 +299,7 @@ def replace_linear( model, klass, target=torch.nn.Linear, verbose=False ):
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
@ -330,7 +333,7 @@ def replace_embedding( model, klass, target=torch.nn.Embedding, verbose=False ):
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
@ -360,7 +363,7 @@ def replace_attention( model, klass, target, mode="math", verbose=False ):
)
if verbose:
print(f"Replacing {name}.{k} to", klass)
_logger.info(f"Replacing {name}.{k} to: {klass}")
return model
@ -491,7 +494,7 @@ def get_model_offload_policy(module, policy=None):
# does not fit in budget, increase device index
else:
device_index += 1
print(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
_logger.info(f"Over budget for device: {device['name']}, shifting to next device: {name}, {size / (1024 ** 2)}MiB")
# to-do: check that all modules are exhausted
assert module_index >= len(modules)
@ -528,9 +531,9 @@ def offload_model( model, policy=None ):
if not not [*module.named_children()]:
continue
try:
print( name, next(module.parameters()).device )
_logger.info( name, next(module.parameters()).device )
except Exception as e:
print( name, "?" )
_logger.info( name, "?" )
pass
"""

View File

@ -3,9 +3,12 @@ from contextlib import contextmanager
import math
import torch
import torch.nn.functional as F
import logging
from ..config import cfg
_logger = logging.getLogger(__name__)
Embedding = torch.nn.Embedding
Linear = torch.nn.Linear
@ -95,7 +98,7 @@ if cfg.optimizations.tensorrt:
import torch_tensorrt
AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
except Exception as e:
print('Error while importing TensorRT:', str(e))
_logger.warning(f'Error while importing TensorRT: {str(e)}')
pass
def compile_model(model, backend="auto"):
@ -111,14 +114,14 @@ def compile_model(model, backend="auto"):
try:
from prodigyopt import Prodigy
except Exception as e:
print('Error while importing Prodigyopt:', str(e))
_logger.warning(f'Error while importing Prodigyopt: {str(e)}')
pass
# https://github.com/facebookresearch/schedule_free/
try:
import schedulefree
except Exception as e:
print('Error while importing Schedule_Free:', str(e))
_logger.warning(f'Error while importing Schedule_Free: {str(e)}')
pass
# backwards compat