import os import random from PIL import Image, ImageFilter, ImageMath from scipy import ndimage import numpy as np import torch PATCH_SIZE = 256 # JPGDIR = '/media/nfs/SRS/IMPAX/' # JPGDIR = '/shares/Public/IMPAX/' def img_frombytes(data): size = data.shape[::-1] databytes = np.packbits(data, axis=1) return Image.frombytes(mode='1', size=size, data=databytes) def getpatch(width, height): w = random.randint(0, width-1)//PATCH_SIZE * PATCH_SIZE if w > width - PATCH_SIZE: w = width - PATCH_SIZE h = random.randint(0, height-1)//PATCH_SIZE * PATCH_SIZE if h > height - PATCH_SIZE: h = height - PATCH_SIZE return w, h class IMPAXDataset(object): def __init__(self, JPGDIR): # self.root = root # self.transforms = transforms # load all image files, sorting them to # ensure that they are aligned self.ST_90 = [] self.ST_100 = [] self.ST_AN = [] self.ST_TXT = [] self.MAXSHAPE = None self.MAXSIZE = 0 self.MINSHAPE = None self.MINSIZE = 9999 * 9999 self.gets = 0 for pid in sorted(os.listdir(JPGDIR)): PATDIR = os.path.join(JPGDIR, pid) for study in sorted(os.listdir(PATDIR)): if study.endswith('_100'): ST100_DIR = os.path.join(PATDIR, study) TXT_DIR = ST100_DIR.replace('_100', '_TXT') os.makedirs(TXT_DIR, exist_ok=True) for jpg in sorted(os.listdir(ST100_DIR)): jpg_path = os.path.join(ST100_DIR, jpg) txt_path = jpg_path.replace('_100', '_TXT').replace('.jpg', '.png') self.ST_100.append(jpg_path) self.ST_90.append(jpg_path.replace('_100', '_90')) self.ST_AN.append(jpg_path.replace('_100', '_AN')) self.ST_TXT.append(txt_path) if os.path.isfile(txt_path): continue img = Image.open(jpg_path).convert('L') width, height = img.size size = width * height if self.MAXSIZE < size: self.MAXSIZE = size self.MAXSHAPE = width, height if self.MINSIZE > size: self.MINSIZE = size self.MINSHAPE = width, height if os.path.isfile(txt_path): continue jpg_ndarray = np.array(img) # CC = (0xCB <= jpg_ndarray <= 0xCD) CC = np.logical_and(jpg_ndarray >= 0xCB, jpg_ndarray <= 0xCD) C0 = (jpg_ndarray <= 0x01) MASK = np.logical_or(CC, C0) MASK = np.roll(MASK, -1, 0) MASK = np.roll(MASK, -1, 1) # MASKED = np.logical_and(CC, MASK).astype('uint8') * 255 MASKED = np.logical_and(CC, MASK).astype('uint8') FILTERD = ndimage.rank_filter(MASKED, rank=-2, size=3) FILTERD = np.minimum(MASKED, FILTERD) im = img_frombytes(FILTERD) im.save (txt_path) if self.MINSHAPE: print(self.MINSHAPE) if self.MAXSHAPE: print(self.MAXSHAPE) def __getitem__(self, idx): # self.gets += 1 # print(self.gets) st_90 = Image.open(self.ST_90[idx]).convert('L') st_AN = Image.open(self.ST_AN[idx]).convert('L') st_TX = Image.open(self.ST_TXT[idx]).convert('L') width, height = st_90.size # print(idx, ST_90[idx]) w, h = getpatch(width, height) # print(w, h) s2_90 = np.array(st_90)[np.newaxis, h:h+PATCH_SIZE, w:w+PATCH_SIZE] s2_AN = np.array(st_AN)[h:h+PATCH_SIZE, w:w+PATCH_SIZE] s2_TX = np.array(st_TX)[h:h+PATCH_SIZE, w:w+PATCH_SIZE] s2_AN_TX = np.stack( (s2_AN,s2_TX) ) # print(s2_90.shape, s2_AN_TX.shape) # exit() # print(s2_90) # exit() # return s2_90, s2_AN # return s2_90[np.newaxis, :, :], s2_AN[np.newaxis, :, :] return torch.from_numpy(s2_90).float(), torch.from_numpy(s2_AN_TX).float() # load images ad masks img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) img = Image.open(img_path).convert("RGB") # note that we haven't converted the mask to RGB, # because each color corresponds to a different instance # with 0 being background mask = Image.open(mask_path) # convert the PIL Image into a numpy array mask = np.array(mask) # instances are encoded as different colors obj_ids = np.unique(mask) # first id is the background, so remove it obj_ids = obj_ids[1:] # split the color-encoded mask into a set # of binary masks masks = mask == obj_ids[:, None, None] # get bounding box coordinates for each mask num_objs = len(obj_ids) boxes = [] for i in range(num_objs): pos = np.where(masks[i]) xmin = np.min(pos[1]) xmax = np.max(pos[1]) ymin = np.min(pos[0]) ymax = np.max(pos[0]) boxes.append([xmin, ymin, xmax, ymax]) # convert everything into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) # there is only one class labels = torch.ones((num_objs,), dtype=torch.int64) masks = torch.as_tensor(masks, dtype=torch.uint8) image_id = torch.tensor([idx]) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # suppose all instances are not crowd iscrowd = torch.zeros((num_objs,), dtype=torch.int64) target = {} target["boxes"] = boxes target["labels"] = labels target["masks"] = masks target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd if self.transforms is not None: img, target = self.transforms(img, target) return img, target def __len__(self): return len(self.ST_100)