fixes that a CPU-only pytorch needed
This commit is contained in:
parent
93987ea5d6
commit
5cb28a210e
|
@ -10,7 +10,7 @@ import time
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from functools import cached_property
|
from functools import cached_property, cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
@ -340,6 +340,9 @@ class Config(_Config):
|
||||||
inference: Inference = field(default_factory=lambda: Inference)
|
inference: Inference = field(default_factory=lambda: Inference)
|
||||||
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
bitsandbytes: BitsAndBytes = field(default_factory=lambda: BitsAndBytes)
|
||||||
|
|
||||||
|
def get_device(self):
|
||||||
|
return torch.cuda.current_device() if self.device == "cuda" else self.device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_dir(self):
|
def cache_dir(self):
|
||||||
return ".cache" / self.relpath
|
return ".cache" / self.relpath
|
||||||
|
|
|
@ -4,7 +4,7 @@ import copy
|
||||||
# import h5py
|
# import h5py
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
#import numpy as np
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
|
@ -111,7 +111,7 @@ def collate_fn(samples: list[dict]):
|
||||||
|
|
||||||
def _seed_worker(worker_id):
|
def _seed_worker(worker_id):
|
||||||
worker_seed = torch.initial_seed() % 2**32
|
worker_seed = torch.initial_seed() % 2**32
|
||||||
np.random.seed(worker_seed)
|
#np.random.seed(worker_seed)
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,10 @@ from .base import TrainFeeder
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if not distributed_initialized() and cfg.trainer.backend == "local":
|
if not distributed_initialized() and cfg.trainer.backend == "local":
|
||||||
init_distributed(torch.distributed.init_process_group)
|
def _nop():
|
||||||
|
...
|
||||||
|
fn = _nop if cfg.device == "cpu" else torch.distributed.init_process_group
|
||||||
|
init_distributed(fn)
|
||||||
|
|
||||||
# A very naive engine implementation using barebones PyTorch
|
# A very naive engine implementation using barebones PyTorch
|
||||||
# to-do: implement lr_sheduling
|
# to-do: implement lr_sheduling
|
||||||
|
@ -276,7 +279,7 @@ class Engines(dict[str, Engine]):
|
||||||
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
stats.update(flatten_dict({ name.split("-")[0]: stat }))
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def step(self, batch, feeder: TrainFeeder = default_feeder, device=torch.cuda.current_device()):
|
def step(self, batch, feeder: TrainFeeder = default_feeder, device=cfg.get_device()):
|
||||||
total_elapsed_time = 0
|
total_elapsed_time = 0
|
||||||
|
|
||||||
stats: Any = dict()
|
stats: Any = dict()
|
||||||
|
|
|
@ -7,13 +7,13 @@ from .export import load_models
|
||||||
from .data import get_symmap, _get_symbols
|
from .data import get_symmap, _get_symbols
|
||||||
|
|
||||||
class Classifier():
|
class Classifier():
|
||||||
def __init__( self, width=300, height=80, config=None, ckpt=None, device="cuda", dtype="float32" ):
|
def __init__( self, width=300, height=80, config=None, ckpt=None, device=cfg.get_device(), dtype="float32" ):
|
||||||
self.loading = True
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
if config:
|
if config:
|
||||||
cfg.load_yaml( config )
|
cfg.load_yaml( config )
|
||||||
|
|
||||||
|
self.loading = True
|
||||||
|
self.device = device
|
||||||
|
|
||||||
if ckpt:
|
if ckpt:
|
||||||
self.load_model_from_ckpt( ckpt )
|
self.load_model_from_ckpt( ckpt )
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -57,7 +57,7 @@ class Model(nn.Module):
|
||||||
self,
|
self,
|
||||||
|
|
||||||
image,
|
image,
|
||||||
text = None,
|
text = None, #
|
||||||
|
|
||||||
sampling_temperature: float = 1.0,
|
sampling_temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
|
|
|
@ -6,7 +6,6 @@ import humanize
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
|
||||||
import random
|
import random
|
||||||
import selectors
|
import selectors
|
||||||
import sys
|
import sys
|
||||||
|
@ -173,7 +172,7 @@ def logger(data):
|
||||||
def seed(seed):
|
def seed(seed):
|
||||||
# Set up random seeds, after fork()
|
# Set up random seeds, after fork()
|
||||||
random.seed(seed + global_rank())
|
random.seed(seed + global_rank())
|
||||||
np.random.seed(seed + global_rank())
|
#np.random.seed(seed + global_rank())
|
||||||
torch.manual_seed(seed + global_rank())
|
torch.manual_seed(seed + global_rank())
|
||||||
|
|
||||||
|
|
||||||
|
|
5
setup.py
5
setup.py
|
@ -24,7 +24,7 @@ def write_version(version_core, pre_release=True):
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
with open("README.md", "r") as f:
|
with open("README.md", "r", encoding="utf-8") as f:
|
||||||
long_description = f.read()
|
long_description = f.read()
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
|
@ -41,15 +41,12 @@ setup(
|
||||||
"coloredlogs>=15.0.1",
|
"coloredlogs>=15.0.1",
|
||||||
"diskcache>=5.4.0",
|
"diskcache>=5.4.0",
|
||||||
"einops>=0.6.0",
|
"einops>=0.6.0",
|
||||||
"matplotlib>=3.6.0",
|
|
||||||
"numpy==1.23.0",
|
|
||||||
"omegaconf==2.0.6",
|
"omegaconf==2.0.6",
|
||||||
"tqdm>=4.64.1",
|
"tqdm>=4.64.1",
|
||||||
"humanize>=4.4.0",
|
"humanize>=4.4.0",
|
||||||
|
|
||||||
"pandas>=1.5.0",
|
"pandas>=1.5.0",
|
||||||
"torch>=1.13.0",
|
"torch>=1.13.0",
|
||||||
"torchaudio>=0.13.0",
|
|
||||||
"torchmetrics",
|
"torchmetrics",
|
||||||
],
|
],
|
||||||
url="https://git.ecker.tech/mrq/resnet-classifier",
|
url="https://git.ecker.tech/mrq/resnet-classifier",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user