404 lines
15 KiB
Python
404 lines
15 KiB
Python
import torch
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import os
|
||
from datetime import datetime
|
||
import csv
|
||
|
||
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values
|
||
from core.intersection import center_line_intersections_torch
|
||
from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask, cl_score_torch_xfr
|
||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||
from utils.helpers import save_with_unique_name
|
||
|
||
def set_axes_equal_3d(ax):
|
||
"""
|
||
Make axes of 3D plot have equal scale so that spheres appear as spheres,
|
||
cubes as cubes, etc.
|
||
"""
|
||
x_limits = ax.get_xlim3d()
|
||
y_limits = ax.get_ylim3d()
|
||
z_limits = ax.get_zlim3d()
|
||
|
||
x_range = abs(x_limits[1] - x_limits[0])
|
||
x_middle = np.mean(x_limits)
|
||
y_range = abs(y_limits[1] - y_limits[0])
|
||
y_middle = np.mean(y_limits)
|
||
z_range = abs(z_limits[1] - z_limits[0])
|
||
z_middle = np.mean(z_limits)
|
||
|
||
plot_radius = 0.5*max([x_range, y_range, z_range])
|
||
|
||
ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
|
||
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
|
||
ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
|
||
|
||
try:
|
||
ax.set_box_aspect([1, 1, 1])
|
||
except AttributeError:
|
||
pass
|
||
|
||
def res_plt_2_torch(
|
||
spine_tensor: torch.Tensor,
|
||
cortical_tensor: torch.Tensor,
|
||
image_shape: tuple[int, int, int],
|
||
image2_path: str,
|
||
base_folder: str,
|
||
label_str: str,
|
||
diameter_l: float,
|
||
length_l: float,
|
||
diameter_r: float,
|
||
length_r: float,
|
||
best_position_l: list[float],
|
||
best_position_r: list[float],
|
||
swarm_size: int,
|
||
max_iter: int,
|
||
total_time: float,
|
||
spacing: list[float],
|
||
CBT: bool,
|
||
device: torch.device,
|
||
grid=None
|
||
) -> None:
|
||
"""
|
||
Same plotting function as before, but it uses torch-based generation
|
||
and then moves data to CPU for matplotlib 3D scatter.
|
||
"""
|
||
cyl_l = generate_cylinder_n_torch(
|
||
diameter_l,
|
||
length_l,
|
||
best_position_l[0],
|
||
best_position_l[1],
|
||
best_position_l[2],
|
||
best_position_l[3],
|
||
best_position_l[4],
|
||
image_shape,
|
||
spacing,
|
||
device,
|
||
grid
|
||
)
|
||
|
||
cyl_lo = generate_cylinder_o_torch(
|
||
diameter_l,
|
||
length_l,
|
||
best_position_l[0],
|
||
best_position_l[1],
|
||
best_position_l[2],
|
||
best_position_l[3],
|
||
best_position_l[4],
|
||
image_shape,
|
||
spacing,
|
||
device,
|
||
grid
|
||
)
|
||
cyl_r = generate_cylinder_n_torch(
|
||
diameter_r,
|
||
length_r,
|
||
best_position_r[0],
|
||
best_position_r[1],
|
||
best_position_r[2],
|
||
best_position_r[3],
|
||
best_position_r[4],
|
||
image_shape,
|
||
spacing,
|
||
device,
|
||
grid
|
||
)
|
||
cyl_ro = generate_cylinder_o_torch(
|
||
diameter_r,
|
||
length_r,
|
||
best_position_r[0],
|
||
best_position_r[1],
|
||
best_position_r[2],
|
||
best_position_r[3],
|
||
best_position_r[4],
|
||
image_shape,
|
||
spacing,
|
||
device,
|
||
grid
|
||
)
|
||
|
||
intersections_l, line_mask_l = center_line_intersections_torch(
|
||
best_position_l[0],
|
||
best_position_l[1],
|
||
best_position_l[2],
|
||
best_position_l[3],
|
||
best_position_l[4],
|
||
int(length_l),
|
||
spine_tensor,
|
||
spacing,
|
||
device
|
||
)
|
||
loss_l = cl_score_torch(cortical_tensor, spine_tensor, cyl_l, cyl_lo, intersections_l)
|
||
|
||
intersections_r, line_mask_r = center_line_intersections_torch(
|
||
best_position_r[0],
|
||
best_position_r[1],
|
||
best_position_r[2],
|
||
best_position_r[3],
|
||
best_position_r[4],
|
||
int(length_r),
|
||
spine_tensor,
|
||
spacing,
|
||
device
|
||
)
|
||
# loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
|
||
loss_r = cl_score_torch_xfr(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
|
||
|
||
azi = azimuth_rotation(image2_path)
|
||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||
alt = res['superior']['tilt_angle_deg']
|
||
|
||
# Move data to CPU for plotting
|
||
line_mask_l_cpu = line_mask_l.cpu().numpy()
|
||
line_mask_r_cpu = line_mask_r.cpu().numpy()
|
||
cyl_l_cpu = cyl_l.cpu().numpy()
|
||
cyl_lo_cpu = cyl_lo.cpu().numpy()
|
||
cyl_r_cpu = cyl_r.cpu().numpy()
|
||
cyl_ro_cpu = cyl_ro.cpu().numpy()
|
||
spine_cpu = spine_tensor.cpu().numpy()
|
||
|
||
z_lin1, y_lin1, x_lin1 = np.where(line_mask_l_cpu == 1)
|
||
z_lin2, y_lin2, x_lin2 = np.where(line_mask_r_cpu == 1)
|
||
|
||
z_cyl_l1, y_cyl_l1, x_cyl_l1 = np.where(cyl_l_cpu == 1)
|
||
z_cyl_l2, y_cyl_l2, x_cyl_l2 = np.where(cyl_lo_cpu == 1)
|
||
z_cyl_r1, y_cyl_r1, x_cyl_r1 = np.where(cyl_r_cpu == 1)
|
||
z_cyl_r2, y_cyl_r2, x_cyl_r2 = np.where(cyl_ro_cpu == 1)
|
||
|
||
z_img, y_img, x_img = np.where(spine_cpu == 1)
|
||
|
||
fig = plt.figure(figsize=(12, 12))
|
||
|
||
ax1 = fig.add_subplot(221, projection='3d')
|
||
ax1.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
|
||
ax1.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
|
||
ax1.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
|
||
ax1.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
|
||
ax1.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
|
||
ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||
ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||
ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis')
|
||
set_axes_equal_3d(ax1)
|
||
|
||
ax2 = fig.add_subplot(222, projection='3d')
|
||
ax2.view_init(elev=90, azim=-90, roll=0)
|
||
ax2.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
|
||
ax2.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
|
||
ax2.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
|
||
ax2.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
|
||
ax2.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
|
||
ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||
ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||
ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis')
|
||
set_axes_equal_3d(ax2)
|
||
ax2.legend()
|
||
|
||
ax3 = fig.add_subplot(223, projection='3d')
|
||
ax3.view_init(elev=0, azim=90, roll=0)
|
||
ax3.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
|
||
ax3.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
|
||
ax3.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
|
||
ax3.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
|
||
ax3.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
|
||
ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||
ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||
ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis')
|
||
set_axes_equal_3d(ax3)
|
||
|
||
ax4 = fig.add_subplot(224, projection='3d')
|
||
ax4.view_init(elev=0, azim=0, roll=0)
|
||
ax4.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
|
||
ax4.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
|
||
ax4.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
|
||
ax4.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
|
||
ax4.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
|
||
ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||
ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||
ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis')
|
||
set_axes_equal_3d(ax4)
|
||
|
||
cyl_points_l = torch.sum(cyl_l).item()
|
||
cyl_points_r = torch.sum(cyl_r).item()
|
||
|
||
overlap_l = ((cortical_tensor == 1) & (cyl_l == 1)).sum().item()
|
||
overlap_r = ((cortical_tensor == 1) & (cyl_r == 1)).sum().item()
|
||
overlap_b_l = ((spine_tensor == 1) & (cyl_l == 1)).sum().item()
|
||
overlap_b_r = ((spine_tensor == 1) & (cyl_r == 1)).sum().item()
|
||
|
||
overlap_cortical_l = (overlap_l / cyl_points_l) * 100
|
||
overlap_cortical_r = (overlap_r / cyl_points_r) * 100
|
||
overlap_vertebral_l = (overlap_b_l / cyl_points_l) * 100
|
||
overlap_vertebral_r = (overlap_b_r / cyl_points_r) * 100
|
||
cb_ratio_l = overlap_cortical_l/overlap_vertebral_l
|
||
cb_ratio_r = overlap_cortical_r/overlap_vertebral_r
|
||
user_altitude_l = 90 - best_position_l[4] - alt
|
||
user_altitude_r = 90 - best_position_r[4] - alt
|
||
user_azimuth_l = 90 - best_position_l[3] - azi
|
||
user_azimuth_r = 90 - best_position_r[3] - azi
|
||
|
||
date_str = datetime.now().strftime("%Y%m%d")
|
||
patient_id = os.path.basename(os.path.dirname(image2_path))
|
||
output_folder = os.path.join(base_folder, date_str, patient_id)
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
csv_path = os.path.join(output_folder, 'output.csv')
|
||
|
||
# 檢查檔案是否存在 (決定是否寫入標題)
|
||
file_exists = os.path.isfile(csv_path)
|
||
|
||
# 欄位標題 (Header)
|
||
headers = [
|
||
'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter',
|
||
'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff',
|
||
'Intersections', 'Best_Loss', 'cyl_points', 'Overlap_Cortical', 'Overlap_Bone',
|
||
'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time'
|
||
]
|
||
|
||
try:
|
||
with open(csv_path, 'a', newline='') as csvfile:
|
||
writer = csv.writer(csvfile)
|
||
|
||
# 如果是新檔案,寫入 Header
|
||
if not file_exists:
|
||
writer.writerow(headers)
|
||
|
||
# 寫入 Left 數據
|
||
writer.writerow([
|
||
label_str,
|
||
'L',
|
||
diameter_l,
|
||
length_l,
|
||
swarm_size,
|
||
max_iter,
|
||
# f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})",
|
||
f"({best_position_l[2]:.2f}, {best_position_l[1]:.2f}, {best_position_l[0]:.2f})",
|
||
f"{best_position_l[3]:.2f}",
|
||
f"{best_position_l[3]-azi:.2f}",
|
||
f"{best_position_l[4]:.2f}",
|
||
f"{best_position_l[4]-alt:.2f}",
|
||
intersections_l,
|
||
f"{loss_l:.2f}",
|
||
cyl_points_l,
|
||
f"{overlap_cortical_l:.2f}",
|
||
f"{overlap_vertebral_l:.2f}",
|
||
f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}",
|
||
f"{user_azimuth_l:.2f}",
|
||
f"{user_altitude_l:.2f}",
|
||
f"{total_time:.2f}"
|
||
])
|
||
|
||
# 寫入 Right 數據
|
||
writer.writerow([
|
||
label_str,
|
||
'R',
|
||
diameter_r,
|
||
length_r,
|
||
swarm_size,
|
||
max_iter,
|
||
# f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})",
|
||
f"({best_position_r[2]:.2f}, {best_position_r[1]:.2f}, {best_position_r[0]:.2f})",
|
||
f"{best_position_r[3]:.2f}",
|
||
f"{best_position_r[3]-azi:.2f}",
|
||
f"{best_position_r[4]:.2f}",
|
||
f"{best_position_r[4]-alt:.2f}",
|
||
intersections_r,
|
||
f"{loss_r:.2f}",
|
||
cyl_points_r,
|
||
f"{overlap_cortical_r:.2f}",
|
||
f"{overlap_vertebral_r:.2f}",
|
||
f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}",
|
||
f"{user_azimuth_r:.2f}",
|
||
f"{user_altitude_r:.2f}",
|
||
f"{total_time:.2f}"
|
||
])
|
||
print(f"[CSV Saved] {csv_path}")
|
||
|
||
except Exception as e:
|
||
print(f"[Error] Failed to write CSV: {e}")
|
||
|
||
fig.text(0.5, 0.98, f'{label_str} Best Position', ha='center', fontsize=15)
|
||
fig.text(
|
||
0.5, 0.44,
|
||
f'L: Diameter = {diameter_l} mm, {length_l} mm, '
|
||
f'R: Diameter = {diameter_r} mm, {length_r} mm, '
|
||
f'Swarm size = {swarm_size}, Iteration = {max_iter}, Total time = {total_time:.2f} s',
|
||
ha='center', fontsize=12
|
||
)
|
||
fig.text(
|
||
0.5, 0.03,
|
||
# f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), '
|
||
f'Left : Position = ({best_position_l[2]:.2f}, {best_position_l[1]:.2f}, {best_position_l[0]:.2f}), '
|
||
f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, '
|
||
f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}',
|
||
ha='center', fontsize=9
|
||
)
|
||
fig.text(
|
||
0.5, 0.01,
|
||
# f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), '
|
||
f'Right : Position = ({best_position_r[2]:.2f}, {best_position_r[1]:.2f}, {best_position_r[0]:.2f}), '
|
||
f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, '
|
||
f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}',
|
||
ha='center', fontsize=9
|
||
)
|
||
|
||
fig.tight_layout()
|
||
|
||
date_str = datetime.now().strftime("%Y%m%d")
|
||
file_name = os.path.basename(image2_path)
|
||
level = file_name.split('_')[0]
|
||
output_folder = os.path.join(base_folder, date_str, patient_id)
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
|
||
if CBT == True:
|
||
way = 'CBT'
|
||
|
||
else:
|
||
way = 'TPS'
|
||
|
||
path = save_with_unique_name(output_folder, label_str, way,
|
||
diameter_l, length_l, diameter_r, length_r,
|
||
swarm_size, max_iter)
|
||
|
||
fig.savefig(path, dpi=200, bbox_inches="tight")
|
||
print("[Saved figure]", path)
|
||
plt.close(fig)
|
||
|
||
|
||
def eval_overlap_from_position(
|
||
pos,
|
||
optimize_size: bool,
|
||
spine_tensor: torch.Tensor,
|
||
image_shape,
|
||
spacing,
|
||
device: torch.device,
|
||
grid=None,
|
||
fixed_diameter: float | None = None,
|
||
fixed_length: float | None = None,
|
||
):
|
||
"""
|
||
根據 position 生成 cylinder mask,再算 overlap ratio
|
||
"""
|
||
|
||
if optimize_size:
|
||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||
params_5 = pos[:5]
|
||
else:
|
||
if fixed_diameter is None or fixed_length is None:
|
||
raise ValueError("fixed_diameter and fixed_length must be provided when optimize_size=False")
|
||
d, L = fixed_diameter, fixed_length
|
||
params_5 = pos
|
||
|
||
z, y, x, az, alt = params_5
|
||
|
||
cyl_mask = generate_cylinder_n_torch(
|
||
d, L,
|
||
z, y, x,
|
||
az, alt,
|
||
image_shape, spacing,
|
||
device=device,
|
||
grid=grid
|
||
)
|
||
|
||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||
return overlap, d, L
|
||
|
||
|