256 lines
No EOL
8 KiB
Python
256 lines
No EOL
8 KiB
Python
import torch
|
|
import numpy as np
|
|
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
|
|
|
def create_coordinate_grid(
|
|
shape: tuple[int, int, int],
|
|
device: torch.device,
|
|
dtype: torch.dtype = torch.float32
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
建立固定的 3D voxel coordinate grid
|
|
returns:
|
|
z_t, y_t, x_t with shape = (Z, Y, X)
|
|
"""
|
|
z_t = torch.arange(shape[0], device=device, dtype=dtype)
|
|
y_t = torch.arange(shape[1], device=device, dtype=dtype)
|
|
x_t = torch.arange(shape[2], device=device, dtype=dtype)
|
|
|
|
z_t, y_t, x_t = torch.meshgrid(z_t, y_t, x_t, indexing='ij')
|
|
return z_t, y_t, x_t
|
|
|
|
def snap_to_discrete_values(diameter_raw, length_raw):
|
|
"""
|
|
將連續值映射到最接近的允許離散值
|
|
|
|
Parameters:
|
|
diameter_raw: PSO 給的連續直徑值
|
|
length_raw: PSO 給的連續長度值
|
|
|
|
Returns:
|
|
diameter_discrete, length_discrete: 離散化後的值
|
|
"""
|
|
# 找最接近的 diameter
|
|
diameter_discrete = min(ALLOWED_DIAMETERS, key=lambda x: abs(x - diameter_raw))
|
|
|
|
# 找最接近的 length
|
|
length_discrete = min(ALLOWED_LENGTHS, key=lambda x: abs(x - length_raw))
|
|
|
|
return diameter_discrete, length_discrete
|
|
|
|
|
|
def generate_cylinder_n_torch(
|
|
diameter: float,
|
|
length: float,
|
|
position_z: float,
|
|
position_y: float,
|
|
position_x: float,
|
|
azimuth: float,
|
|
altitude: float,
|
|
shape: tuple[int, int, int],
|
|
spacing: list[float],
|
|
device: torch.device,
|
|
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
Generate a "forward" (positive z) cylinder mask in 3D space using PyTorch tensors.
|
|
Returns a binary mask (torch.uint8) on the specified device.
|
|
"""
|
|
|
|
if grid is None:
|
|
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
|
else:
|
|
z_t, y_t, x_t = grid
|
|
|
|
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
|
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
|
|
|
# Shift
|
|
z_t = z_t - position_z
|
|
y_t = y_t - position_y
|
|
x_t = x_t - position_x
|
|
|
|
# Apply rotation (same formula as your NumPy version, but in torch)
|
|
x_rot = (
|
|
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
|
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
|
- z_t * torch.sin(altitude_rad_t)
|
|
)
|
|
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
|
z_rot = (
|
|
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
|
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
|
+ z_t * torch.cos(altitude_rad_t)
|
|
)
|
|
|
|
# Handle spacing
|
|
# You can expand or generalize for more spacing options
|
|
if spacing == [1, 1, 1]:
|
|
radius = diameter / 2.0
|
|
mask = (
|
|
(x_rot**2 + y_rot**2 <= radius**2)
|
|
& (z_rot >= 0)
|
|
& (z_rot <= length)
|
|
)
|
|
elif spacing == [0.5, 0.5, 0.5]:
|
|
radius = (diameter / 2.0) * 2
|
|
mask = (
|
|
(x_rot**2 + y_rot**2 <= radius**2)
|
|
& (z_rot >= 0)
|
|
& (z_rot <= length * 2)
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported spacing: {spacing}")
|
|
|
|
# Convert boolean mask to uint8
|
|
cylinder_mask = mask.to(torch.uint8)
|
|
|
|
return cylinder_mask
|
|
|
|
def generate_cylinder_o_torch(
|
|
diameter: float,
|
|
length: float,
|
|
position_z: float,
|
|
position_y: float,
|
|
position_x: float,
|
|
azimuth: float,
|
|
altitude: float,
|
|
shape: tuple[int, int, int],
|
|
spacing: list[float],
|
|
device: torch.device,
|
|
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
|
|
) -> torch.Tensor:
|
|
"""
|
|
Generate an "opposite" (negative z) cylinder mask in 3D space using PyTorch tensors.
|
|
Returns a binary mask (torch.uint8) on the specified device.
|
|
"""
|
|
|
|
if grid is None:
|
|
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
|
else:
|
|
z_t, y_t, x_t = grid
|
|
|
|
# Convert angles to torch
|
|
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
|
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
|
|
|
# Shift
|
|
z_t = z_t - position_z
|
|
y_t = y_t - position_y
|
|
x_t = x_t - position_x
|
|
|
|
# Apply rotation
|
|
x_rot = (
|
|
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
|
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
|
- z_t * torch.sin(altitude_rad_t)
|
|
)
|
|
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
|
z_rot = (
|
|
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
|
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
|
+ z_t * torch.cos(altitude_rad_t)
|
|
)
|
|
|
|
# Handle spacing
|
|
if spacing == [1, 1, 1]:
|
|
radius = diameter / 2.0
|
|
mask = (
|
|
(x_rot**2 + y_rot**2 <= radius**2)
|
|
& (z_rot <= 0)
|
|
& (z_rot <= length)
|
|
)
|
|
elif spacing == [0.5, 0.5, 0.5]:
|
|
radius = (diameter / 2.0) * 2
|
|
mask = (
|
|
(x_rot**2 + y_rot**2 <= radius**2)
|
|
& (z_rot <= 0)
|
|
& (z_rot >= -length * 2)
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported spacing: {spacing}")
|
|
|
|
cylinder_mask_o = mask.to(torch.uint8)
|
|
|
|
return cylinder_mask_o
|
|
|
|
def generate_cylinder_numpy(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing):
|
|
|
|
azimuth = np.radians(azimuth)
|
|
altitude = np.radians(altitude)
|
|
|
|
cylinder_mask = np.zeros(shape, dtype=np.uint8)
|
|
|
|
z, y, x = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].astype(np.float64)
|
|
|
|
z -= float(position_z)
|
|
y -= float(position_y)
|
|
x -= float(position_x)
|
|
|
|
x_rot = x * np.cos(azimuth) * np.cos(altitude) + y * np.sin(azimuth) * np.cos(altitude) - z * np.sin(altitude)
|
|
y_rot = -x * np.sin(azimuth) + y * np.cos(azimuth)
|
|
z_rot = x * np.cos(azimuth) * np.sin(altitude) + y * np.sin(azimuth) * np.sin(altitude) + z * np.cos(altitude)
|
|
|
|
if spacing == [1, 1, 1]:
|
|
|
|
radius = diameter / 2.0
|
|
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length)
|
|
|
|
elif spacing == [0.5, 0.5, 0.5]:
|
|
radius = diameter / 2.0 * 2 # *2 is for resampling
|
|
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length * 2)
|
|
|
|
cylinder_mask[cylinder] = 1
|
|
|
|
return cylinder_mask
|
|
|
|
def generate_cylinder_tip_torch(
|
|
diameter, length,
|
|
position_z, position_y, position_x,
|
|
azimuth, altitude,
|
|
shape, spacing, device, grid=None,
|
|
tip_ratio=0.2 # 取末端 20% 當尖端
|
|
) -> torch.Tensor:
|
|
"""只生成圓柱末端的 mask"""
|
|
|
|
if grid is None:
|
|
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
|
else:
|
|
z_t, y_t, x_t = grid
|
|
|
|
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
|
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
|
|
|
z_t = z_t - position_z
|
|
y_t = y_t - position_y
|
|
x_t = x_t - position_x
|
|
|
|
x_rot = (
|
|
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
|
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
|
- z_t * torch.sin(altitude_rad_t)
|
|
)
|
|
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
|
z_rot = (
|
|
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
|
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
|
+ z_t * torch.cos(altitude_rad_t)
|
|
)
|
|
|
|
if spacing == [0.5, 0.5, 0.5]:
|
|
radius = (diameter / 2.0) * 2
|
|
total_length = length * 2
|
|
else:
|
|
radius = diameter / 2.0
|
|
total_length = length
|
|
|
|
tip_start = total_length * (1 - tip_ratio) # 末端 20% 開始的位置
|
|
|
|
mask = (
|
|
(x_rot**2 + y_rot**2 <= radius**2)
|
|
& (z_rot >= tip_start) # 只取末端
|
|
& (z_rot <= total_length)
|
|
)
|
|
|
|
return mask.to(torch.uint8) |