CBT_project/core/cylinder.py
2026-04-10 13:25:27 +08:00

256 lines
No EOL
8 KiB
Python

import torch
import numpy as np
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
def create_coordinate_grid(
shape: tuple[int, int, int],
device: torch.device,
dtype: torch.dtype = torch.float32
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
建立固定的 3D voxel coordinate grid
returns:
z_t, y_t, x_t with shape = (Z, Y, X)
"""
z_t = torch.arange(shape[0], device=device, dtype=dtype)
y_t = torch.arange(shape[1], device=device, dtype=dtype)
x_t = torch.arange(shape[2], device=device, dtype=dtype)
z_t, y_t, x_t = torch.meshgrid(z_t, y_t, x_t, indexing='ij')
return z_t, y_t, x_t
def snap_to_discrete_values(diameter_raw, length_raw):
"""
將連續值映射到最接近的允許離散值
Parameters:
diameter_raw: PSO 給的連續直徑值
length_raw: PSO 給的連續長度值
Returns:
diameter_discrete, length_discrete: 離散化後的值
"""
# 找最接近的 diameter
diameter_discrete = min(ALLOWED_DIAMETERS, key=lambda x: abs(x - diameter_raw))
# 找最接近的 length
length_discrete = min(ALLOWED_LENGTHS, key=lambda x: abs(x - length_raw))
return diameter_discrete, length_discrete
def generate_cylinder_n_torch(
diameter: float,
length: float,
position_z: float,
position_y: float,
position_x: float,
azimuth: float,
altitude: float,
shape: tuple[int, int, int],
spacing: list[float],
device: torch.device,
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
"""
Generate a "forward" (positive z) cylinder mask in 3D space using PyTorch tensors.
Returns a binary mask (torch.uint8) on the specified device.
"""
if grid is None:
z_t, y_t, x_t = create_coordinate_grid(shape, device)
else:
z_t, y_t, x_t = grid
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
# Shift
z_t = z_t - position_z
y_t = y_t - position_y
x_t = x_t - position_x
# Apply rotation (same formula as your NumPy version, but in torch)
x_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
- z_t * torch.sin(altitude_rad_t)
)
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
z_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ z_t * torch.cos(altitude_rad_t)
)
# Handle spacing
# You can expand or generalize for more spacing options
if spacing == [1, 1, 1]:
radius = diameter / 2.0
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot >= 0)
& (z_rot <= length)
)
elif spacing == [0.5, 0.5, 0.5]:
radius = (diameter / 2.0) * 2
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot >= 0)
& (z_rot <= length * 2)
)
else:
raise ValueError(f"Unsupported spacing: {spacing}")
# Convert boolean mask to uint8
cylinder_mask = mask.to(torch.uint8)
return cylinder_mask
def generate_cylinder_o_torch(
diameter: float,
length: float,
position_z: float,
position_y: float,
position_x: float,
azimuth: float,
altitude: float,
shape: tuple[int, int, int],
spacing: list[float],
device: torch.device,
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
"""
Generate an "opposite" (negative z) cylinder mask in 3D space using PyTorch tensors.
Returns a binary mask (torch.uint8) on the specified device.
"""
if grid is None:
z_t, y_t, x_t = create_coordinate_grid(shape, device)
else:
z_t, y_t, x_t = grid
# Convert angles to torch
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
# Shift
z_t = z_t - position_z
y_t = y_t - position_y
x_t = x_t - position_x
# Apply rotation
x_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
- z_t * torch.sin(altitude_rad_t)
)
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
z_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ z_t * torch.cos(altitude_rad_t)
)
# Handle spacing
if spacing == [1, 1, 1]:
radius = diameter / 2.0
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot <= 0)
& (z_rot <= length)
)
elif spacing == [0.5, 0.5, 0.5]:
radius = (diameter / 2.0) * 2
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot <= 0)
& (z_rot >= -length * 2)
)
else:
raise ValueError(f"Unsupported spacing: {spacing}")
cylinder_mask_o = mask.to(torch.uint8)
return cylinder_mask_o
def generate_cylinder_numpy(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing):
azimuth = np.radians(azimuth)
altitude = np.radians(altitude)
cylinder_mask = np.zeros(shape, dtype=np.uint8)
z, y, x = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].astype(np.float64)
z -= float(position_z)
y -= float(position_y)
x -= float(position_x)
x_rot = x * np.cos(azimuth) * np.cos(altitude) + y * np.sin(azimuth) * np.cos(altitude) - z * np.sin(altitude)
y_rot = -x * np.sin(azimuth) + y * np.cos(azimuth)
z_rot = x * np.cos(azimuth) * np.sin(altitude) + y * np.sin(azimuth) * np.sin(altitude) + z * np.cos(altitude)
if spacing == [1, 1, 1]:
radius = diameter / 2.0
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length)
elif spacing == [0.5, 0.5, 0.5]:
radius = diameter / 2.0 * 2 # *2 is for resampling
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length * 2)
cylinder_mask[cylinder] = 1
return cylinder_mask
def generate_cylinder_tip_torch(
diameter, length,
position_z, position_y, position_x,
azimuth, altitude,
shape, spacing, device, grid=None,
tip_ratio=0.2 # 取末端 20% 當尖端
) -> torch.Tensor:
"""只生成圓柱末端的 mask"""
if grid is None:
z_t, y_t, x_t = create_coordinate_grid(shape, device)
else:
z_t, y_t, x_t = grid
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
z_t = z_t - position_z
y_t = y_t - position_y
x_t = x_t - position_x
x_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
- z_t * torch.sin(altitude_rad_t)
)
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
z_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ z_t * torch.cos(altitude_rad_t)
)
if spacing == [0.5, 0.5, 0.5]:
radius = (diameter / 2.0) * 2
total_length = length * 2
else:
radius = diameter / 2.0
total_length = length
tip_start = total_length * (1 - tip_ratio) # 末端 20% 開始的位置
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot >= tip_start) # 只取末端
& (z_rot <= total_length)
)
return mask.to(torch.uint8)