adm18/IMPAX/dataset.py
2025-09-16 13:20:19 +08:00

199 lines
6.4 KiB
Python

import os
import random
from PIL import Image, ImageFilter, ImageMath
from scipy import ndimage
import numpy as np
import torch
PATCH_SIZE = 256
# JPGDIR = '/media/nfs/SRS/IMPAX/'
# JPGDIR = '/shares/Public/IMPAX/'
def img_frombytes(data):
size = data.shape[::-1]
databytes = np.packbits(data, axis=1)
return Image.frombytes(mode='1', size=size, data=databytes)
def getpatch(width, height):
w = random.randint(0, width-1)//PATCH_SIZE * PATCH_SIZE
if w > width - PATCH_SIZE:
w = width - PATCH_SIZE
h = random.randint(0, height-1)//PATCH_SIZE * PATCH_SIZE
if h > height - PATCH_SIZE:
h = height - PATCH_SIZE
return w, h
class IMPAXDataset(object):
def __init__(self, JPGDIR):
# self.root = root
# self.transforms = transforms
# load all image files, sorting them to
# ensure that they are aligned
self.ST_90 = []
self.ST_100 = []
self.ST_AN = []
self.ST_TXT = []
self.MAXSHAPE = None
self.MAXSIZE = 0
self.MINSHAPE = None
self.MINSIZE = 9999 * 9999
self.gets = 0
for pid in sorted(os.listdir(JPGDIR)):
PATDIR = os.path.join(JPGDIR, pid)
for study in sorted(os.listdir(PATDIR)):
if study.endswith('_100'):
ST100_DIR = os.path.join(PATDIR, study)
TXT_DIR = ST100_DIR.replace('_100', '_TXT')
os.makedirs(TXT_DIR, exist_ok=True)
for jpg in sorted(os.listdir(ST100_DIR)):
jpg_path = os.path.join(ST100_DIR, jpg)
txt_path = jpg_path.replace('_100', '_TXT').replace('.jpg', '.png')
self.ST_100.append(jpg_path)
self.ST_90.append(jpg_path.replace('_100', '_90'))
self.ST_AN.append(jpg_path.replace('_100', '_AN'))
self.ST_TXT.append(txt_path)
if os.path.isfile(txt_path):
continue
img = Image.open(jpg_path).convert('L')
width, height = img.size
size = width * height
if self.MAXSIZE < size:
self.MAXSIZE = size
self.MAXSHAPE = width, height
if self.MINSIZE > size:
self.MINSIZE = size
self.MINSHAPE = width, height
if os.path.isfile(txt_path):
continue
jpg_ndarray = np.array(img)
# CC = (0xCB <= jpg_ndarray <= 0xCD)
CC = np.logical_and(jpg_ndarray >= 0xCB, jpg_ndarray <= 0xCD)
C0 = (jpg_ndarray <= 0x01)
MASK = np.logical_or(CC, C0)
MASK = np.roll(MASK, -1, 0)
MASK = np.roll(MASK, -1, 1)
# MASKED = np.logical_and(CC, MASK).astype('uint8') * 255
MASKED = np.logical_and(CC, MASK).astype('uint8')
FILTERD = ndimage.rank_filter(MASKED, rank=-2, size=3)
FILTERD = np.minimum(MASKED, FILTERD)
im = img_frombytes(FILTERD)
im.save (txt_path)
if self.MINSHAPE:
print(self.MINSHAPE)
if self.MAXSHAPE:
print(self.MAXSHAPE)
def __getitem__(self, idx):
# self.gets += 1
# print(self.gets)
st_90 = Image.open(self.ST_90[idx]).convert('L')
st_AN = Image.open(self.ST_AN[idx]).convert('L')
st_TX = Image.open(self.ST_TXT[idx]).convert('L')
width, height = st_90.size
# print(idx, ST_90[idx])
w, h = getpatch(width, height)
# print(w, h)
s2_90 = np.array(st_90)[np.newaxis, h:h+PATCH_SIZE, w:w+PATCH_SIZE]
s2_AN = np.array(st_AN)[h:h+PATCH_SIZE, w:w+PATCH_SIZE]
s2_TX = np.array(st_TX)[h:h+PATCH_SIZE, w:w+PATCH_SIZE]
s2_AN_TX = np.stack( (s2_AN,s2_TX) )
# print(s2_90.shape, s2_AN_TX.shape)
# exit()
# print(s2_90)
# exit()
# return s2_90, s2_AN
# return s2_90[np.newaxis, :, :], s2_AN[np.newaxis, :, :]
return torch.from_numpy(s2_90).float(), torch.from_numpy(s2_AN_TX).float()
# load images ad masks
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
img = Image.open(img_path).convert("RGB")
# note that we haven't converted the mask to RGB,
# because each color corresponds to a different instance
# with 0 being background
mask = Image.open(mask_path)
# convert the PIL Image into a numpy array
mask = np.array(mask)
# instances are encoded as different colors
obj_ids = np.unique(mask)
# first id is the background, so remove it
obj_ids = obj_ids[1:]
# split the color-encoded mask into a set
# of binary masks
masks = mask == obj_ids[:, None, None]
# get bounding box coordinates for each mask
num_objs = len(obj_ids)
boxes = []
for i in range(num_objs):
pos = np.where(masks[i])
xmin = np.min(pos[1])
xmax = np.max(pos[1])
ymin = np.min(pos[0])
ymax = np.max(pos[0])
boxes.append([xmin, ymin, xmax, ymax])
# convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# there is only one class
labels = torch.ones((num_objs,), dtype=torch.int64)
masks = torch.as_tensor(masks, dtype=torch.uint8)
image_id = torch.tensor([idx])
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
# suppose all instances are not crowd
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
target["masks"] = masks
target["image_id"] = image_id
target["area"] = area
target["iscrowd"] = iscrowd
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target
def __len__(self):
return len(self.ST_100)