CBT_project/visualization/res_plot_3d.py
2026-04-10 13:25:27 +08:00

366 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime
import csv
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values
from core.intersection import center_line_intersections_torch
from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
from utils.helpers import save_with_unique_name
def res_plt_2_torch(
spine_tensor: torch.Tensor,
cortical_tensor: torch.Tensor,
image_shape: tuple[int, int, int],
image2_path: str,
base_folder: str,
label_str: str,
diameter_l: float,
length_l: float,
diameter_r: float,
length_r: float,
best_position_l: list[float],
best_position_r: list[float],
swarm_size: int,
max_iter: int,
total_time: float,
spacing: list[float],
CBT: bool,
device: torch.device,
grid=None
) -> None:
"""
Same plotting function as before, but it uses torch-based generation
and then moves data to CPU for matplotlib 3D scatter.
"""
cyl_l = generate_cylinder_n_torch(
diameter_l,
length_l,
best_position_l[0],
best_position_l[1],
best_position_l[2],
best_position_l[3],
best_position_l[4],
image_shape,
spacing,
device,
grid
)
cyl_lo = generate_cylinder_o_torch(
diameter_l,
length_l,
best_position_l[0],
best_position_l[1],
best_position_l[2],
best_position_l[3],
best_position_l[4],
image_shape,
spacing,
device,
grid
)
cyl_r = generate_cylinder_n_torch(
diameter_r,
length_r,
best_position_r[0],
best_position_r[1],
best_position_r[2],
best_position_r[3],
best_position_r[4],
image_shape,
spacing,
device,
grid
)
cyl_ro = generate_cylinder_o_torch(
diameter_r,
length_r,
best_position_r[0],
best_position_r[1],
best_position_r[2],
best_position_r[3],
best_position_r[4],
image_shape,
spacing,
device,
grid
)
intersections_l, line_mask_l = center_line_intersections_torch(
best_position_l[0],
best_position_l[1],
best_position_l[2],
best_position_l[3],
best_position_l[4],
int(length_l),
spine_tensor,
spacing,
device
)
loss_l = cl_score_torch(cortical_tensor, spine_tensor, cyl_l, cyl_lo, intersections_l)
intersections_r, line_mask_r = center_line_intersections_torch(
best_position_r[0],
best_position_r[1],
best_position_r[2],
best_position_r[3],
best_position_r[4],
int(length_r),
spine_tensor,
spacing,
device
)
loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
# Move data to CPU for plotting
line_mask_l_cpu = line_mask_l.cpu().numpy()
line_mask_r_cpu = line_mask_r.cpu().numpy()
cyl_l_cpu = cyl_l.cpu().numpy()
cyl_lo_cpu = cyl_lo.cpu().numpy()
cyl_r_cpu = cyl_r.cpu().numpy()
cyl_ro_cpu = cyl_ro.cpu().numpy()
spine_cpu = spine_tensor.cpu().numpy()
z_lin1, y_lin1, x_lin1 = np.where(line_mask_l_cpu == 1)
z_lin2, y_lin2, x_lin2 = np.where(line_mask_r_cpu == 1)
z_cyl_l1, y_cyl_l1, x_cyl_l1 = np.where(cyl_l_cpu == 1)
z_cyl_l2, y_cyl_l2, x_cyl_l2 = np.where(cyl_lo_cpu == 1)
z_cyl_r1, y_cyl_r1, x_cyl_r1 = np.where(cyl_r_cpu == 1)
z_cyl_r2, y_cyl_r2, x_cyl_r2 = np.where(cyl_ro_cpu == 1)
z_img, y_img, x_img = np.where(spine_cpu == 1)
fig = plt.figure(figsize=(12, 12))
ax1 = fig.add_subplot(221, projection='3d')
ax1.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
ax1.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
ax1.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
ax1.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
ax1.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis')
ax2 = fig.add_subplot(222, projection='3d')
ax2.view_init(elev=90, azim=-90, roll=0)
ax2.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
ax2.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
ax2.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
ax2.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
ax2.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis')
ax2.legend()
ax3 = fig.add_subplot(223, projection='3d')
ax3.view_init(elev=0, azim=90, roll=0)
ax3.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
ax3.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
ax3.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
ax3.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
ax3.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis')
ax4 = fig.add_subplot(224, projection='3d')
ax4.view_init(elev=0, azim=0, roll=0)
ax4.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
ax4.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
ax4.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
ax4.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
ax4.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis')
cyl_points_l = torch.sum(cyl_l).item()
cyl_points_r = torch.sum(cyl_r).item()
overlap_l = ((cortical_tensor == 1) & (cyl_l == 1)).sum().item()
overlap_r = ((cortical_tensor == 1) & (cyl_r == 1)).sum().item()
overlap_b_l = ((spine_tensor == 1) & (cyl_l == 1)).sum().item()
overlap_b_r = ((spine_tensor == 1) & (cyl_r == 1)).sum().item()
overlap_cortical_l = (overlap_l / cyl_points_l) * 100
overlap_cortical_r = (overlap_r / cyl_points_r) * 100
overlap_vertebral_l = (overlap_b_l / cyl_points_l) * 100
overlap_vertebral_r = (overlap_b_r / cyl_points_r) * 100
cb_ratio_l = overlap_cortical_l/overlap_vertebral_l
cb_ratio_r = overlap_cortical_r/overlap_vertebral_r
user_altitude_l = 90 - best_position_l[4] - alt
user_altitude_r = 90 - best_position_r[4] - alt
user_azimuth_l = 90 - best_position_l[3] - azi
user_azimuth_r = 90 - best_position_r[3] - azi
date_str = datetime.now().strftime("%Y%m%d")
patient_id = os.path.basename(os.path.dirname(image2_path))
output_folder = os.path.join(base_folder, date_str, patient_id)
os.makedirs(output_folder, exist_ok=True)
csv_path = os.path.join(output_folder, 'output.csv')
# 檢查檔案是否存在 (決定是否寫入標題)
file_exists = os.path.isfile(csv_path)
# 欄位標題 (Header)
headers = [
'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter',
'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff',
'Intersections', 'Best_Loss', 'Overlap_Cortical', 'Overlap_Bone',
'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time'
]
try:
with open(csv_path, 'a', newline='') as csvfile:
writer = csv.writer(csvfile)
# 如果是新檔案,寫入 Header
if not file_exists:
writer.writerow(headers)
# 寫入 Left 數據
writer.writerow([
label_str,
'L',
diameter_l,
length_l,
swarm_size,
max_iter,
f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})",
f"{best_position_l[3]:.2f}",
f"{best_position_l[3]-azi:.2f}",
f"{best_position_l[4]:.2f}",
f"{best_position_l[4]-alt:.2f}",
intersections_l,
f"{loss_l:.2f}",
f"{overlap_cortical_l:.2f}",
f"{overlap_vertebral_l:.2f}",
f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}",
f"{user_azimuth_l:.2f}",
f"{user_altitude_l:.2f}",
f"{total_time:.2f}"
])
# 寫入 Right 數據
writer.writerow([
label_str,
'R',
diameter_r,
length_r,
swarm_size,
max_iter,
f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})",
f"{best_position_r[3]:.2f}",
f"{best_position_r[3]-azi:.2f}",
f"{best_position_r[4]:.2f}",
f"{best_position_r[4]-alt:.2f}",
intersections_r,
f"{loss_r:.2f}",
f"{overlap_cortical_r:.2f}",
f"{overlap_vertebral_r:.2f}",
f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}",
f"{user_azimuth_r:.2f}",
f"{user_altitude_r:.2f}",
f"{total_time:.2f}"
])
print(f"[CSV Saved] {csv_path}")
except Exception as e:
print(f"[Error] Failed to write CSV: {e}")
fig.text(0.5, 0.98, f'{label_str} Best Position', ha='center', fontsize=15)
fig.text(
0.5, 0.44,
f'L: Diameter = {diameter_l} mm, {length_l} mm, '
f'R: Diameter = {diameter_r} mm, {length_r} mm, '
f'Swarm size = {swarm_size}, Iteration = {max_iter}, Total time = {total_time:.2f} s',
ha='center', fontsize=12
)
fig.text(
0.5, 0.03,
f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), '
f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, '
f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}',
ha='center', fontsize=9
)
fig.text(
0.5, 0.01,
f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), '
f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, '
f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}',
ha='center', fontsize=9
)
fig.tight_layout()
date_str = datetime.now().strftime("%Y%m%d")
file_name = os.path.basename(image2_path)
level = file_name.split('_')[0]
output_folder = os.path.join(base_folder, date_str, patient_id)
os.makedirs(output_folder, exist_ok=True)
if CBT == True:
way = 'CBT'
else:
way = 'TPS'
path = save_with_unique_name(output_folder, label_str, way,
diameter_l, length_l, diameter_r, length_r,
swarm_size, max_iter)
fig.savefig(path, dpi=200, bbox_inches="tight")
print("[Saved figure]", path)
plt.close(fig)
def eval_overlap_from_position(
pos,
optimize_size: bool,
spine_tensor: torch.Tensor,
image_shape,
spacing,
device: torch.device,
grid=None,
fixed_diameter: float | None = None,
fixed_length: float | None = None,
):
"""
根據 position 生成 cylinder mask再算 overlap ratio
"""
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
if fixed_diameter is None or fixed_length is None:
raise ValueError("fixed_diameter and fixed_length must be provided when optimize_size=False")
d, L = fixed_diameter, fixed_length
params_5 = pos
z, y, x, az, alt = params_5
cyl_mask = generate_cylinder_n_torch(
d, L,
z, y, x,
az, alt,
image_shape, spacing,
device=device,
grid=grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L