DL-Art-School/dlas/models/image_generation/srflow/thops.py

53 lines
1.3 KiB
Python
Raw Normal View History

import torch
def sum(tensor, dim=None, keepdim=False):
if dim is None:
# sum up all dim
return torch.sum(tensor)
else:
if isinstance(dim, int):
dim = [dim]
dim = sorted(dim)
for d in dim:
tensor = tensor.sum(dim=d, keepdim=True)
if not keepdim:
for i, d in enumerate(dim):
tensor.squeeze_(d-i)
return tensor
def mean(tensor, dim=None, keepdim=False):
if dim is None:
# mean all dim
return torch.mean(tensor)
else:
if isinstance(dim, int):
dim = [dim]
dim = sorted(dim)
for d in dim:
tensor = tensor.mean(dim=d, keepdim=True)
if not keepdim:
for i, d in enumerate(dim):
tensor.squeeze_(d-i)
return tensor
def split_feature(tensor, type="split"):
"""
type = ["split", "cross"]
"""
C = tensor.size(1)
if type == "split":
return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
elif type == "cross":
return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
def cat_feature(tensor_a, tensor_b):
return torch.cat((tensor_a, tensor_b), dim=1)
def pixels(tensor):
return int(tensor.size(2) * tensor.size(3))