From 5cb28a210e8c774c8810a5de65f86b0de723e25a Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 6 Aug 2023 15:14:05 +0000 Subject: [PATCH] fixes that a CPU-only pytorch needed --- image_classifier/config.py | 5 ++++- image_classifier/data.py | 4 ++-- image_classifier/engines/base.py | 7 +++++-- image_classifier/inference.py | 8 ++++---- image_classifier/models/base.py | 2 +- image_classifier/utils/trainer.py | 3 +-- setup.py | 5 +---- 7 files changed, 18 insertions(+), 16 deletions(-) diff --git a/image_classifier/config.py b/image_classifier/config.py index 2f4cfce..e39a22e 100755 --- a/image_classifier/config.py +++ b/image_classifier/config.py @@ -10,7 +10,7 @@ import time from dataclasses import asdict, dataclass from dataclasses import dataclass, field -from functools import cached_property +from functools import cached_property, cache from pathlib import Path from omegaconf import OmegaConf @@ -340,6 +340,9 @@ class Config(_Config): inference: Inference = field(default_factory=lambda: Inference) bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes) + def get_device(self): + return torch.cuda.current_device() if self.device == "cuda" else self.device + @property def cache_dir(self): return ".cache" / self.relpath diff --git a/image_classifier/data.py b/image_classifier/data.py index 23004f1..51fdf2c 100755 --- a/image_classifier/data.py +++ b/image_classifier/data.py @@ -4,7 +4,7 @@ import copy # import h5py import json import logging -import numpy as np +#import numpy as np import os import random import torch @@ -111,7 +111,7 @@ def collate_fn(samples: list[dict]): def _seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 - np.random.seed(worker_seed) + #np.random.seed(worker_seed) random.seed(worker_seed) diff --git a/image_classifier/engines/base.py b/image_classifier/engines/base.py index db246bd..24199ed 100755 --- a/image_classifier/engines/base.py +++ b/image_classifier/engines/base.py @@ -45,7 +45,10 @@ from .base import TrainFeeder _logger = logging.getLogger(__name__) if not distributed_initialized() and cfg.trainer.backend == "local": - init_distributed(torch.distributed.init_process_group) + def _nop(): + ... + fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group + init_distributed(fn) # A very naive engine implementation using barebones PyTorch # to-do: implement lr_sheduling @@ -276,7 +279,7 @@ class Engines(dict[str, Engine]): stats.update(flatten_dict({ name.split("-")[0]: stat })) return stats - def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()): + def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()): total_elapsed_time = 0 stats: Any = dict() diff --git a/image_classifier/inference.py b/image_classifier/inference.py index 07949f4..0e70703 100755 --- a/image_classifier/inference.py +++ b/image_classifier/inference.py @@ -7,13 +7,13 @@ from .export import load_models from .data import get_symmap, _get_symbols class Classifier(): - def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ): - self.loading = True - self.device = device - + def __init__( self, width=300, height=80, config=None, ckpt=None, device=cfg.get_device(), dtype="float32" ): if config: cfg.load_yaml( config ) + self.loading = True + self.device = device + if ckpt: self.load_model_from_ckpt( ckpt ) else: diff --git a/image_classifier/models/base.py b/image_classifier/models/base.py index 2f884c6..99b0322 100755 --- a/image_classifier/models/base.py +++ b/image_classifier/models/base.py @@ -57,7 +57,7 @@ class Model(nn.Module): self, image, - text = None, + text = None, # sampling_temperature: float = 1.0, ): diff --git a/image_classifier/utils/trainer.py b/image_classifier/utils/trainer.py index 399b329..88f6358 100755 --- a/image_classifier/utils/trainer.py +++ b/image_classifier/utils/trainer.py @@ -6,7 +6,6 @@ import humanize import json import os import logging -import numpy as np import random import selectors import sys @@ -173,7 +172,7 @@ def logger(data): def seed(seed): # Set up random seeds, after fork() random.seed(seed + global_rank()) - np.random.seed(seed + global_rank()) + #np.random.seed(seed + global_rank()) torch.manual_seed(seed + global_rank()) diff --git a/setup.py b/setup.py index 11d7c4e..c719379 100755 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ def write_version(version_core, pre_release=True): return version -with open("README.md", "r") as f: +with open("README.md", "r", encoding="utf-8") as f: long_description = f.read() setup( @@ -41,15 +41,12 @@ setup( "coloredlogs>=15.0.1", "diskcache>=5.4.0", "einops>=0.6.0", - "matplotlib>=3.6.0", - "numpy==1.23.0", "omegaconf==2.0.6", "tqdm>=4.64.1", "humanize>=4.4.0", "pandas>=1.5.0", "torch>=1.13.0", - "torchaudio>=0.13.0", "torchmetrics", ], url="https://git.ecker.tech/mrq/resnet-classifier",