277 lines
No EOL
8.8 KiB
Python
277 lines
No EOL
8.8 KiB
Python
import torch
|
||
import numpy as np
|
||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||
|
||
def create_coordinate_grid(
|
||
shape: tuple[int, int, int],
|
||
device: torch.device,
|
||
dtype: torch.dtype = torch.float32
|
||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""
|
||
建立固定的 3D voxel coordinate grid
|
||
returns:
|
||
z_t, y_t, x_t with shape = (Z, Y, X)
|
||
"""
|
||
z_t = torch.arange(shape[0], device=device, dtype=dtype)
|
||
y_t = torch.arange(shape[1], device=device, dtype=dtype)
|
||
x_t = torch.arange(shape[2], device=device, dtype=dtype)
|
||
|
||
z_t, y_t, x_t = torch.meshgrid(z_t, y_t, x_t, indexing='ij')
|
||
return z_t, y_t, x_t
|
||
|
||
def snap_to_discrete_values(diameter_raw, length_raw):
|
||
"""
|
||
將連續值映射到最接近的允許離散值
|
||
|
||
Parameters:
|
||
diameter_raw: PSO 給的連續直徑值
|
||
length_raw: PSO 給的連續長度值
|
||
|
||
Returns:
|
||
diameter_discrete, length_discrete: 離散化後的值
|
||
"""
|
||
# 找最接近的 diameter
|
||
diameter_discrete = min(ALLOWED_DIAMETERS, key=lambda x: abs(x - diameter_raw))
|
||
|
||
# 找最接近的 length
|
||
length_discrete = min(ALLOWED_LENGTHS, key=lambda x: abs(x - length_raw))
|
||
|
||
return diameter_discrete, length_discrete
|
||
|
||
def round_down_to_discrete_values(diameter_raw, length_raw):
|
||
"""
|
||
將連續值映射到小於等於該值的最接近允許離散值 (向下取整)
|
||
|
||
Parameters:
|
||
diameter_raw: PSO 給的連續直徑值
|
||
length_raw: PSO 給的連續長度值
|
||
|
||
Returns:
|
||
diameter_discrete, length_discrete: 離散化後的值
|
||
"""
|
||
# 找小於等於 raw 的最接近 diameter,如果都大於 raw 則取最小值
|
||
valid_diameters = [d for d in ALLOWED_DIAMETERS if d <= diameter_raw]
|
||
diameter_discrete = max(valid_diameters) if valid_diameters else min(ALLOWED_DIAMETERS)
|
||
|
||
# 找小於等於 raw 的最接近 length,如果都大於 raw 則取最小值
|
||
valid_lengths = [l for l in ALLOWED_LENGTHS if l <= length_raw]
|
||
length_discrete = max(valid_lengths) if valid_lengths else min(ALLOWED_LENGTHS)
|
||
|
||
return diameter_discrete, length_discrete
|
||
|
||
|
||
def generate_cylinder_n_torch(
|
||
diameter: float,
|
||
length: float,
|
||
position_z: float,
|
||
position_y: float,
|
||
position_x: float,
|
||
azimuth: float,
|
||
altitude: float,
|
||
shape: tuple[int, int, int],
|
||
spacing: list[float],
|
||
device: torch.device,
|
||
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
|
||
) -> torch.Tensor:
|
||
|
||
"""
|
||
Generate a "forward" (positive z) cylinder mask in 3D space using PyTorch tensors.
|
||
Returns a binary mask (torch.uint8) on the specified device.
|
||
"""
|
||
|
||
if grid is None:
|
||
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
||
else:
|
||
z_t, y_t, x_t = grid
|
||
|
||
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
||
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
||
|
||
# Shift
|
||
z_t = z_t - position_z
|
||
y_t = y_t - position_y
|
||
x_t = x_t - position_x
|
||
|
||
# Apply rotation (same formula as your NumPy version, but in torch)
|
||
x_rot = (
|
||
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||
- z_t * torch.sin(altitude_rad_t)
|
||
)
|
||
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
||
z_rot = (
|
||
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||
+ z_t * torch.cos(altitude_rad_t)
|
||
)
|
||
|
||
# Handle spacing
|
||
# You can expand or generalize for more spacing options
|
||
if spacing == [1, 1, 1]:
|
||
radius = diameter / 2.0
|
||
mask = (
|
||
(x_rot**2 + y_rot**2 <= radius**2)
|
||
& (z_rot >= 0)
|
||
& (z_rot <= length)
|
||
)
|
||
elif spacing == [0.5, 0.5, 0.5]:
|
||
radius = (diameter / 2.0) * 2
|
||
mask = (
|
||
(x_rot**2 + y_rot**2 <= radius**2)
|
||
& (z_rot >= 0)
|
||
& (z_rot <= length * 2)
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported spacing: {spacing}")
|
||
|
||
# Convert boolean mask to uint8
|
||
cylinder_mask = mask.to(torch.uint8)
|
||
|
||
return cylinder_mask
|
||
|
||
def generate_cylinder_o_torch(
|
||
diameter: float,
|
||
length: float,
|
||
position_z: float,
|
||
position_y: float,
|
||
position_x: float,
|
||
azimuth: float,
|
||
altitude: float,
|
||
shape: tuple[int, int, int],
|
||
spacing: list[float],
|
||
device: torch.device,
|
||
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
|
||
) -> torch.Tensor:
|
||
"""
|
||
Generate an "opposite" (negative z) cylinder mask in 3D space using PyTorch tensors.
|
||
Returns a binary mask (torch.uint8) on the specified device.
|
||
"""
|
||
|
||
if grid is None:
|
||
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
||
else:
|
||
z_t, y_t, x_t = grid
|
||
|
||
# Convert angles to torch
|
||
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
||
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
||
|
||
# Shift
|
||
z_t = z_t - position_z
|
||
y_t = y_t - position_y
|
||
x_t = x_t - position_x
|
||
|
||
# Apply rotation
|
||
x_rot = (
|
||
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||
- z_t * torch.sin(altitude_rad_t)
|
||
)
|
||
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
||
z_rot = (
|
||
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||
+ z_t * torch.cos(altitude_rad_t)
|
||
)
|
||
|
||
# Handle spacing
|
||
if spacing == [1, 1, 1]:
|
||
radius = diameter / 2.0
|
||
mask = (
|
||
(x_rot**2 + y_rot**2 <= radius**2)
|
||
& (z_rot <= 0)
|
||
& (z_rot <= length)
|
||
)
|
||
elif spacing == [0.5, 0.5, 0.5]:
|
||
radius = (diameter / 2.0) * 2
|
||
mask = (
|
||
(x_rot**2 + y_rot**2 <= radius**2)
|
||
& (z_rot <= 0)
|
||
& (z_rot >= -length * 2)
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported spacing: {spacing}")
|
||
|
||
cylinder_mask_o = mask.to(torch.uint8)
|
||
|
||
return cylinder_mask_o
|
||
|
||
def generate_cylinder_numpy(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing):
|
||
|
||
azimuth = np.radians(azimuth)
|
||
altitude = np.radians(altitude)
|
||
|
||
cylinder_mask = np.zeros(shape, dtype=np.uint8)
|
||
|
||
z, y, x = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].astype(np.float64)
|
||
|
||
z -= float(position_z)
|
||
y -= float(position_y)
|
||
x -= float(position_x)
|
||
|
||
x_rot = x * np.cos(azimuth) * np.cos(altitude) + y * np.sin(azimuth) * np.cos(altitude) - z * np.sin(altitude)
|
||
y_rot = -x * np.sin(azimuth) + y * np.cos(azimuth)
|
||
z_rot = x * np.cos(azimuth) * np.sin(altitude) + y * np.sin(azimuth) * np.sin(altitude) + z * np.cos(altitude)
|
||
|
||
if spacing == [1, 1, 1]:
|
||
|
||
radius = diameter / 2.0
|
||
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length)
|
||
|
||
elif spacing == [0.5, 0.5, 0.5]:
|
||
radius = diameter / 2.0 * 2 # *2 is for resampling
|
||
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length * 2)
|
||
|
||
cylinder_mask[cylinder] = 1
|
||
|
||
return cylinder_mask
|
||
|
||
def generate_cylinder_tip_torch(
|
||
diameter, length,
|
||
position_z, position_y, position_x,
|
||
azimuth, altitude,
|
||
shape, spacing, device, grid=None,
|
||
tip_ratio=0.2 # 取末端 20% 當尖端
|
||
) -> torch.Tensor:
|
||
"""只生成圓柱末端的 mask"""
|
||
|
||
if grid is None:
|
||
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
||
else:
|
||
z_t, y_t, x_t = grid
|
||
|
||
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
||
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
||
|
||
z_t = z_t - position_z
|
||
y_t = y_t - position_y
|
||
x_t = x_t - position_x
|
||
|
||
x_rot = (
|
||
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||
- z_t * torch.sin(altitude_rad_t)
|
||
)
|
||
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
||
z_rot = (
|
||
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||
+ z_t * torch.cos(altitude_rad_t)
|
||
)
|
||
|
||
if spacing == [0.5, 0.5, 0.5]:
|
||
radius = (diameter / 2.0) * 2
|
||
total_length = length * 2
|
||
else:
|
||
radius = diameter / 2.0
|
||
total_length = length
|
||
|
||
tip_start = total_length * (1 - tip_ratio) # 末端 20% 開始的位置
|
||
|
||
mask = (
|
||
(x_rot**2 + y_rot**2 <= radius**2)
|
||
& (z_rot >= tip_start) # 只取末端
|
||
& (z_rot <= total_length)
|
||
)
|
||
|
||
return mask.to(torch.uint8) |