import torch import numpy as np from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS def create_coordinate_grid( shape: tuple[int, int, int], device: torch.device, dtype: torch.dtype = torch.float32 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ 建立固定的 3D voxel coordinate grid returns: z_t, y_t, x_t with shape = (Z, Y, X) """ z_t = torch.arange(shape[0], device=device, dtype=dtype) y_t = torch.arange(shape[1], device=device, dtype=dtype) x_t = torch.arange(shape[2], device=device, dtype=dtype) z_t, y_t, x_t = torch.meshgrid(z_t, y_t, x_t, indexing='ij') return z_t, y_t, x_t def snap_to_discrete_values(diameter_raw, length_raw): """ 將連續值映射到最接近的允許離散值 Parameters: diameter_raw: PSO 給的連續直徑值 length_raw: PSO 給的連續長度值 Returns: diameter_discrete, length_discrete: 離散化後的值 """ # 找最接近的 diameter diameter_discrete = min(ALLOWED_DIAMETERS, key=lambda x: abs(x - diameter_raw)) # 找最接近的 length length_discrete = min(ALLOWED_LENGTHS, key=lambda x: abs(x - length_raw)) return diameter_discrete, length_discrete def generate_cylinder_n_torch( diameter: float, length: float, position_z: float, position_y: float, position_x: float, azimuth: float, altitude: float, shape: tuple[int, int, int], spacing: list[float], device: torch.device, grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None ) -> torch.Tensor: """ Generate a "forward" (positive z) cylinder mask in 3D space using PyTorch tensors. Returns a binary mask (torch.uint8) on the specified device. """ if grid is None: z_t, y_t, x_t = create_coordinate_grid(shape, device) else: z_t, y_t, x_t = grid azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32)) altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32)) # Shift z_t = z_t - position_z y_t = y_t - position_y x_t = x_t - position_x # Apply rotation (same formula as your NumPy version, but in torch) x_rot = ( x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t) + y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t) - z_t * torch.sin(altitude_rad_t) ) y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t) z_rot = ( x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t) + y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t) + z_t * torch.cos(altitude_rad_t) ) # Handle spacing # You can expand or generalize for more spacing options if spacing == [1, 1, 1]: radius = diameter / 2.0 mask = ( (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length) ) elif spacing == [0.5, 0.5, 0.5]: radius = (diameter / 2.0) * 2 mask = ( (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length * 2) ) else: raise ValueError(f"Unsupported spacing: {spacing}") # Convert boolean mask to uint8 cylinder_mask = mask.to(torch.uint8) return cylinder_mask def generate_cylinder_o_torch( diameter: float, length: float, position_z: float, position_y: float, position_x: float, azimuth: float, altitude: float, shape: tuple[int, int, int], spacing: list[float], device: torch.device, grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None ) -> torch.Tensor: """ Generate an "opposite" (negative z) cylinder mask in 3D space using PyTorch tensors. Returns a binary mask (torch.uint8) on the specified device. """ if grid is None: z_t, y_t, x_t = create_coordinate_grid(shape, device) else: z_t, y_t, x_t = grid # Convert angles to torch azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32)) altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32)) # Shift z_t = z_t - position_z y_t = y_t - position_y x_t = x_t - position_x # Apply rotation x_rot = ( x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t) + y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t) - z_t * torch.sin(altitude_rad_t) ) y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t) z_rot = ( x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t) + y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t) + z_t * torch.cos(altitude_rad_t) ) # Handle spacing if spacing == [1, 1, 1]: radius = diameter / 2.0 mask = ( (x_rot**2 + y_rot**2 <= radius**2) & (z_rot <= 0) & (z_rot <= length) ) elif spacing == [0.5, 0.5, 0.5]: radius = (diameter / 2.0) * 2 mask = ( (x_rot**2 + y_rot**2 <= radius**2) & (z_rot <= 0) & (z_rot >= -length * 2) ) else: raise ValueError(f"Unsupported spacing: {spacing}") cylinder_mask_o = mask.to(torch.uint8) return cylinder_mask_o def generate_cylinder_numpy(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing): azimuth = np.radians(azimuth) altitude = np.radians(altitude) cylinder_mask = np.zeros(shape, dtype=np.uint8) z, y, x = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].astype(np.float64) z -= float(position_z) y -= float(position_y) x -= float(position_x) x_rot = x * np.cos(azimuth) * np.cos(altitude) + y * np.sin(azimuth) * np.cos(altitude) - z * np.sin(altitude) y_rot = -x * np.sin(azimuth) + y * np.cos(azimuth) z_rot = x * np.cos(azimuth) * np.sin(altitude) + y * np.sin(azimuth) * np.sin(altitude) + z * np.cos(altitude) if spacing == [1, 1, 1]: radius = diameter / 2.0 cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length) elif spacing == [0.5, 0.5, 0.5]: radius = diameter / 2.0 * 2 # *2 is for resampling cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length * 2) cylinder_mask[cylinder] = 1 return cylinder_mask def generate_cylinder_tip_torch( diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing, device, grid=None, tip_ratio=0.2 # 取末端 20% 當尖端 ) -> torch.Tensor: """只生成圓柱末端的 mask""" if grid is None: z_t, y_t, x_t = create_coordinate_grid(shape, device) else: z_t, y_t, x_t = grid azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32)) altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32)) z_t = z_t - position_z y_t = y_t - position_y x_t = x_t - position_x x_rot = ( x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t) + y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t) - z_t * torch.sin(altitude_rad_t) ) y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t) z_rot = ( x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t) + y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t) + z_t * torch.cos(altitude_rad_t) ) if spacing == [0.5, 0.5, 0.5]: radius = (diameter / 2.0) * 2 total_length = length * 2 else: radius = diameter / 2.0 total_length = length tip_start = total_length * (1 - tip_ratio) # 末端 20% 開始的位置 mask = ( (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= tip_start) # 只取末端 & (z_rot <= total_length) ) return mask.to(torch.uint8)