CBT_project/core/intersection.py
2026-04-10 13:25:27 +08:00

154 lines
No EOL
4.4 KiB
Python

import torch
import numpy as np
def bresenham3d(x1: float, y1: float, z1: float, x2: float, y2: float, z2: float) -> list[tuple[int, int, int]]:
"""
3D Bresenham line in Python (CPU).
We keep this CPU-based because it's discrete and typically not large enough
to warrant GPU acceleration.
"""
x1, y1, z1 = round(x1), round(y1), round(z1)
x2, y2, z2 = round(x2), round(y2), round(z2)
points = [(x1, y1, z1)]
dx = abs(x2 - x1)
dy = abs(y2 - y1)
dz = abs(z2 - z1)
xs = 1 if x2 > x1 else -1
ys = 1 if y2 > y1 else -1
zs = 1 if z2 > z1 else -1
# Driving axis X
if dx >= dy and dx >= dz:
p1 = 2 * dy - dx
p2 = 2 * dz - dx
while x1 != x2:
x1 += xs
if p1 >= 0:
y1 += ys
p1 -= 2 * dx
if p2 >= 0:
z1 += zs
p2 -= 2 * dx
p1 += 2 * dy
p2 += 2 * dz
points.append((x1, y1, z1))
# Driving axis Y
elif dy >= dx and dy >= dz:
p1 = 2 * dx - dy
p2 = 2 * dz - dy
while y1 != y2:
y1 += ys
if p1 >= 0:
x1 += xs
p1 -= 2 * dy
if p2 >= 0:
z1 += zs
p2 -= 2 * dy
p1 += 2 * dx
p2 += 2 * dz
points.append((x1, y1, z1))
# Driving axis Z
else:
p1 = 2 * dy - dz
p2 = 2 * dx - dz
while z1 != z2:
z1 += zs
if p1 >= 0:
y1 += ys
p1 -= 2 * dz
if p2 >= 0:
x1 += xs
p2 -= 2 * dz
p1 += 2 * dy
p2 += 2 * dx
points.append((x1, y1, z1))
return points
def center_line_intersections_torch(
position_z: float,
position_y: float,
position_x: float,
azimuth: float,
altitude: float,
length: float,
image_tensor: torch.Tensor,
spacing: list[float],
device:torch.device
) -> tuple[int, torch.Tensor]:
"""
Computes the number of intersections along the 3D center line in the torch-based spine array.
The line generation (Bresenham) is done on CPU, but the intersection counting is done in Torch.
Returns (intersections, line_mask_torch).
"""
azimuth_rad = np.radians(azimuth)
altitude_rad = np.radians(altitude)
if spacing == [1, 1, 1]:
length2 = length
elif spacing == [0.5, 0.5, 0.5]:
length2 = length * 2
else:
raise ValueError(f"Unsupported spacing: {spacing}")
# Direction vectors
direction_z = np.cos(altitude_rad)
direction_y = np.sin(altitude_rad) * np.sin(azimuth_rad)
direction_x = np.sin(altitude_rad) * np.cos(azimuth_rad)
# Endpoints
end_point = (
position_z + length2 * direction_z,
position_y + length2 * direction_y,
position_x + length2 * direction_x,
)
start_opposite = (
position_z - length2 * direction_z,
position_y - length2 * direction_y,
position_x - length2 * direction_x,
)
# Round endpoints
start_point = (
int(round(position_z)),
int(round(position_y)),
int(round(position_x)),
)
end_point = tuple(map(int, np.round(end_point)))
start_opposite = tuple(map(int, np.round(start_opposite)))
# Bresenham line (CPU)
line_points = bresenham3d(
start_opposite[0], start_opposite[1], start_opposite[2],
end_point[0], end_point[1], end_point[2]
)
shape = image_tensor.shape # (z, y, x)
line_mask_torch = torch.zeros(shape, device=device, dtype=torch.uint8)
# Gather values from image_tensor
point_values = []
for (z_pt, y_pt, x_pt) in line_points:
if 0 <= z_pt < shape[0] and 0 <= y_pt < shape[1] and 0 <= x_pt < shape[2]:
line_mask_torch[z_pt, y_pt, x_pt] = 1
# Convert to CPU to gather value quickly
point_values.append(image_tensor[z_pt, y_pt, x_pt].item())
# Convert to NumPy for the simple difference-based intersection counting
point_values_np = np.array(point_values, dtype=np.int32)
d1 = np.diff(point_values_np)
d2 = np.diff(d1)
d3 = np.diff(d2)
intersections = (
np.count_nonzero(np.abs(d1) == 1)
- 2 * np.count_nonzero(np.abs(d2) == 2)
+ 2 * np.count_nonzero(np.abs(d3) == 4)
)
return intersections, line_mask_torch