fixes that a CPU-only pytorch needed

master
mrq 2023-08-06 15:14:05 +07:00
parent 93987ea5d6
commit 5cb28a210e
7 changed files with 18 additions and 16 deletions

@ -10,7 +10,7 @@ import time
from dataclasses import asdict, dataclass
from dataclasses import dataclass, field
from functools import cached_property
from functools import cached_property, cache
from pathlib import Path
from omegaconf import OmegaConf
@ -340,6 +340,9 @@ class Config(_Config):
inference: Inference = field(default_factory=lambda: Inference)
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
def get_device(self):
return torch.cuda.current_device() if self.device == "cuda" else self.device
@property
def cache_dir(self):
return ".cache" / self.relpath

@ -4,7 +4,7 @@ import copy
# import h5py
import json
import logging
import numpy as np
#import numpy as np
import os
import random
import torch
@ -111,7 +111,7 @@ def collate_fn(samples: list[dict]):
def _seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
#np.random.seed(worker_seed)
random.seed(worker_seed)

@ -45,7 +45,10 @@ from .base import TrainFeeder
_logger = logging.getLogger(__name__)
if not distributed_initialized() and cfg.trainer.backend == "local":
init_distributed(torch.distributed.init_process_group)
def _nop():
...
fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group
init_distributed(fn)
# A very naive engine implementation using barebones PyTorch
# to-do: implement lr_sheduling
@ -276,7 +279,7 @@ class Engines(dict[str, Engine]):
stats.update(flatten_dict({ name.split("-")[0]: stat }))
return stats
def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()):
def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()):
total_elapsed_time = 0
stats: Any = dict()

@ -7,13 +7,13 @@ from .export import load_models
from .data import get_symmap, _get_symbols
class Classifier():
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
self.loading = True
self.device = device
def __init__( self, width=300, height=80, config=None, ckpt=None, device=cfg.get_device(), dtype="float32" ):
if config:
cfg.load_yaml( config )
self.loading = True
self.device = device
if ckpt:
self.load_model_from_ckpt( ckpt )
else:

@ -57,7 +57,7 @@ class Model(nn.Module):
self,
image,
text = None,
text = None, #
sampling_temperature: float = 1.0,
):

@ -6,7 +6,6 @@ import humanize
import json
import os
import logging
import numpy as np
import random
import selectors
import sys
@ -173,7 +172,7 @@ def logger(data):
def seed(seed):
# Set up random seeds, after fork()
random.seed(seed + global_rank())
np.random.seed(seed + global_rank())
#np.random.seed(seed + global_rank())
torch.manual_seed(seed + global_rank())

@ -24,7 +24,7 @@ def write_version(version_core, pre_release=True):
return version
with open("README.md", "r") as f:
with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()
setup(
@ -41,15 +41,12 @@ setup(
"coloredlogs>=15.0.1",
"diskcache>=5.4.0",
"einops>=0.6.0",
"matplotlib>=3.6.0",
"numpy==1.23.0",
"omegaconf==2.0.6",
"tqdm>=4.64.1",
"humanize>=4.4.0",
"pandas>=1.5.0",
"torch>=1.13.0",
"torchaudio>=0.13.0",
"torchmetrics",
],
url="https://git.ecker.tech/mrq/resnet-classifier",