154 lines
No EOL
4.4 KiB
Python
154 lines
No EOL
4.4 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
def bresenham3d(x1: float, y1: float, z1: float, x2: float, y2: float, z2: float) -> list[tuple[int, int, int]]:
|
|
"""
|
|
3D Bresenham line in Python (CPU).
|
|
We keep this CPU-based because it's discrete and typically not large enough
|
|
to warrant GPU acceleration.
|
|
"""
|
|
x1, y1, z1 = round(x1), round(y1), round(z1)
|
|
x2, y2, z2 = round(x2), round(y2), round(z2)
|
|
|
|
points = [(x1, y1, z1)]
|
|
dx = abs(x2 - x1)
|
|
dy = abs(y2 - y1)
|
|
dz = abs(z2 - z1)
|
|
|
|
xs = 1 if x2 > x1 else -1
|
|
ys = 1 if y2 > y1 else -1
|
|
zs = 1 if z2 > z1 else -1
|
|
|
|
# Driving axis X
|
|
if dx >= dy and dx >= dz:
|
|
p1 = 2 * dy - dx
|
|
p2 = 2 * dz - dx
|
|
while x1 != x2:
|
|
x1 += xs
|
|
if p1 >= 0:
|
|
y1 += ys
|
|
p1 -= 2 * dx
|
|
if p2 >= 0:
|
|
z1 += zs
|
|
p2 -= 2 * dx
|
|
p1 += 2 * dy
|
|
p2 += 2 * dz
|
|
points.append((x1, y1, z1))
|
|
|
|
# Driving axis Y
|
|
elif dy >= dx and dy >= dz:
|
|
p1 = 2 * dx - dy
|
|
p2 = 2 * dz - dy
|
|
while y1 != y2:
|
|
y1 += ys
|
|
if p1 >= 0:
|
|
x1 += xs
|
|
p1 -= 2 * dy
|
|
if p2 >= 0:
|
|
z1 += zs
|
|
p2 -= 2 * dy
|
|
p1 += 2 * dx
|
|
p2 += 2 * dz
|
|
points.append((x1, y1, z1))
|
|
|
|
# Driving axis Z
|
|
else:
|
|
p1 = 2 * dy - dz
|
|
p2 = 2 * dx - dz
|
|
while z1 != z2:
|
|
z1 += zs
|
|
if p1 >= 0:
|
|
y1 += ys
|
|
p1 -= 2 * dz
|
|
if p2 >= 0:
|
|
x1 += xs
|
|
p2 -= 2 * dz
|
|
p1 += 2 * dy
|
|
p2 += 2 * dx
|
|
points.append((x1, y1, z1))
|
|
|
|
return points
|
|
|
|
|
|
def center_line_intersections_torch(
|
|
position_z: float,
|
|
position_y: float,
|
|
position_x: float,
|
|
azimuth: float,
|
|
altitude: float,
|
|
length: float,
|
|
image_tensor: torch.Tensor,
|
|
spacing: list[float],
|
|
device:torch.device
|
|
) -> tuple[int, torch.Tensor]:
|
|
"""
|
|
Computes the number of intersections along the 3D center line in the torch-based spine array.
|
|
The line generation (Bresenham) is done on CPU, but the intersection counting is done in Torch.
|
|
Returns (intersections, line_mask_torch).
|
|
"""
|
|
azimuth_rad = np.radians(azimuth)
|
|
altitude_rad = np.radians(altitude)
|
|
|
|
if spacing == [1, 1, 1]:
|
|
length2 = length
|
|
elif spacing == [0.5, 0.5, 0.5]:
|
|
length2 = length * 2
|
|
else:
|
|
raise ValueError(f"Unsupported spacing: {spacing}")
|
|
|
|
# Direction vectors
|
|
direction_z = np.cos(altitude_rad)
|
|
direction_y = np.sin(altitude_rad) * np.sin(azimuth_rad)
|
|
direction_x = np.sin(altitude_rad) * np.cos(azimuth_rad)
|
|
|
|
# Endpoints
|
|
end_point = (
|
|
position_z + length2 * direction_z,
|
|
position_y + length2 * direction_y,
|
|
position_x + length2 * direction_x,
|
|
)
|
|
start_opposite = (
|
|
position_z - length2 * direction_z,
|
|
position_y - length2 * direction_y,
|
|
position_x - length2 * direction_x,
|
|
)
|
|
|
|
# Round endpoints
|
|
start_point = (
|
|
int(round(position_z)),
|
|
int(round(position_y)),
|
|
int(round(position_x)),
|
|
)
|
|
end_point = tuple(map(int, np.round(end_point)))
|
|
start_opposite = tuple(map(int, np.round(start_opposite)))
|
|
|
|
# Bresenham line (CPU)
|
|
line_points = bresenham3d(
|
|
start_opposite[0], start_opposite[1], start_opposite[2],
|
|
end_point[0], end_point[1], end_point[2]
|
|
)
|
|
|
|
shape = image_tensor.shape # (z, y, x)
|
|
line_mask_torch = torch.zeros(shape, device=device, dtype=torch.uint8)
|
|
|
|
# Gather values from image_tensor
|
|
point_values = []
|
|
for (z_pt, y_pt, x_pt) in line_points:
|
|
if 0 <= z_pt < shape[0] and 0 <= y_pt < shape[1] and 0 <= x_pt < shape[2]:
|
|
line_mask_torch[z_pt, y_pt, x_pt] = 1
|
|
# Convert to CPU to gather value quickly
|
|
point_values.append(image_tensor[z_pt, y_pt, x_pt].item())
|
|
|
|
# Convert to NumPy for the simple difference-based intersection counting
|
|
point_values_np = np.array(point_values, dtype=np.int32)
|
|
|
|
d1 = np.diff(point_values_np)
|
|
d2 = np.diff(d1)
|
|
d3 = np.diff(d2)
|
|
intersections = (
|
|
np.count_nonzero(np.abs(d1) == 1)
|
|
- 2 * np.count_nonzero(np.abs(d2) == 2)
|
|
+ 2 * np.count_nonzero(np.abs(d3) == 4)
|
|
)
|
|
|
|
return intersections, line_mask_torch |