CBT_project/core/cylinder.py
Xiao Furen b76f0708f3 1. correct the bounding box and cortical mask
2. make the plot isometric
3. now it should work after tuning the parameters
2026-04-17 00:03:10 +08:00

277 lines
No EOL
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import numpy as np
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
def create_coordinate_grid(
shape: tuple[int, int, int],
device: torch.device,
dtype: torch.dtype = torch.float32
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
建立固定的 3D voxel coordinate grid
returns:
z_t, y_t, x_t with shape = (Z, Y, X)
"""
z_t = torch.arange(shape[0], device=device, dtype=dtype)
y_t = torch.arange(shape[1], device=device, dtype=dtype)
x_t = torch.arange(shape[2], device=device, dtype=dtype)
z_t, y_t, x_t = torch.meshgrid(z_t, y_t, x_t, indexing='ij')
return z_t, y_t, x_t
def snap_to_discrete_values(diameter_raw, length_raw):
"""
將連續值映射到最接近的允許離散值
Parameters:
diameter_raw: PSO 給的連續直徑值
length_raw: PSO 給的連續長度值
Returns:
diameter_discrete, length_discrete: 離散化後的值
"""
# 找最接近的 diameter
diameter_discrete = min(ALLOWED_DIAMETERS, key=lambda x: abs(x - diameter_raw))
# 找最接近的 length
length_discrete = min(ALLOWED_LENGTHS, key=lambda x: abs(x - length_raw))
return diameter_discrete, length_discrete
def round_down_to_discrete_values(diameter_raw, length_raw):
"""
將連續值映射到小於等於該值的最接近允許離散值 (向下取整)
Parameters:
diameter_raw: PSO 給的連續直徑值
length_raw: PSO 給的連續長度值
Returns:
diameter_discrete, length_discrete: 離散化後的值
"""
# 找小於等於 raw 的最接近 diameter如果都大於 raw 則取最小值
valid_diameters = [d for d in ALLOWED_DIAMETERS if d <= diameter_raw]
diameter_discrete = max(valid_diameters) if valid_diameters else min(ALLOWED_DIAMETERS)
# 找小於等於 raw 的最接近 length如果都大於 raw 則取最小值
valid_lengths = [l for l in ALLOWED_LENGTHS if l <= length_raw]
length_discrete = max(valid_lengths) if valid_lengths else min(ALLOWED_LENGTHS)
return diameter_discrete, length_discrete
def generate_cylinder_n_torch(
diameter: float,
length: float,
position_z: float,
position_y: float,
position_x: float,
azimuth: float,
altitude: float,
shape: tuple[int, int, int],
spacing: list[float],
device: torch.device,
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
"""
Generate a "forward" (positive z) cylinder mask in 3D space using PyTorch tensors.
Returns a binary mask (torch.uint8) on the specified device.
"""
if grid is None:
z_t, y_t, x_t = create_coordinate_grid(shape, device)
else:
z_t, y_t, x_t = grid
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
# Shift
z_t = z_t - position_z
y_t = y_t - position_y
x_t = x_t - position_x
# Apply rotation (same formula as your NumPy version, but in torch)
x_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
- z_t * torch.sin(altitude_rad_t)
)
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
z_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ z_t * torch.cos(altitude_rad_t)
)
# Handle spacing
# You can expand or generalize for more spacing options
if spacing == [1, 1, 1]:
radius = diameter / 2.0
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot >= 0)
& (z_rot <= length)
)
elif spacing == [0.5, 0.5, 0.5]:
radius = (diameter / 2.0) * 2
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot >= 0)
& (z_rot <= length * 2)
)
else:
raise ValueError(f"Unsupported spacing: {spacing}")
# Convert boolean mask to uint8
cylinder_mask = mask.to(torch.uint8)
return cylinder_mask
def generate_cylinder_o_torch(
diameter: float,
length: float,
position_z: float,
position_y: float,
position_x: float,
azimuth: float,
altitude: float,
shape: tuple[int, int, int],
spacing: list[float],
device: torch.device,
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
"""
Generate an "opposite" (negative z) cylinder mask in 3D space using PyTorch tensors.
Returns a binary mask (torch.uint8) on the specified device.
"""
if grid is None:
z_t, y_t, x_t = create_coordinate_grid(shape, device)
else:
z_t, y_t, x_t = grid
# Convert angles to torch
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
# Shift
z_t = z_t - position_z
y_t = y_t - position_y
x_t = x_t - position_x
# Apply rotation
x_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
- z_t * torch.sin(altitude_rad_t)
)
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
z_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ z_t * torch.cos(altitude_rad_t)
)
# Handle spacing
if spacing == [1, 1, 1]:
radius = diameter / 2.0
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot <= 0)
& (z_rot <= length)
)
elif spacing == [0.5, 0.5, 0.5]:
radius = (diameter / 2.0) * 2
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot <= 0)
& (z_rot >= -length * 2)
)
else:
raise ValueError(f"Unsupported spacing: {spacing}")
cylinder_mask_o = mask.to(torch.uint8)
return cylinder_mask_o
def generate_cylinder_numpy(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing):
azimuth = np.radians(azimuth)
altitude = np.radians(altitude)
cylinder_mask = np.zeros(shape, dtype=np.uint8)
z, y, x = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].astype(np.float64)
z -= float(position_z)
y -= float(position_y)
x -= float(position_x)
x_rot = x * np.cos(azimuth) * np.cos(altitude) + y * np.sin(azimuth) * np.cos(altitude) - z * np.sin(altitude)
y_rot = -x * np.sin(azimuth) + y * np.cos(azimuth)
z_rot = x * np.cos(azimuth) * np.sin(altitude) + y * np.sin(azimuth) * np.sin(altitude) + z * np.cos(altitude)
if spacing == [1, 1, 1]:
radius = diameter / 2.0
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length)
elif spacing == [0.5, 0.5, 0.5]:
radius = diameter / 2.0 * 2 # *2 is for resampling
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length * 2)
cylinder_mask[cylinder] = 1
return cylinder_mask
def generate_cylinder_tip_torch(
diameter, length,
position_z, position_y, position_x,
azimuth, altitude,
shape, spacing, device, grid=None,
tip_ratio=0.2 # 取末端 20% 當尖端
) -> torch.Tensor:
"""只生成圓柱末端的 mask"""
if grid is None:
z_t, y_t, x_t = create_coordinate_grid(shape, device)
else:
z_t, y_t, x_t = grid
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
z_t = z_t - position_z
y_t = y_t - position_y
x_t = x_t - position_x
x_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
- z_t * torch.sin(altitude_rad_t)
)
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
z_rot = (
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
+ z_t * torch.cos(altitude_rad_t)
)
if spacing == [0.5, 0.5, 0.5]:
radius = (diameter / 2.0) * 2
total_length = length * 2
else:
radius = diameter / 2.0
total_length = length
tip_start = total_length * (1 - tip_ratio) # 末端 20% 開始的位置
mask = (
(x_rot**2 + y_rot**2 <= radius**2)
& (z_rot >= tip_start) # 只取末端
& (z_rot <= total_length)
)
return mask.to(torch.uint8)