import torch import numpy as np import matplotlib.pyplot as plt import os from datetime import datetime import csv from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values from core.intersection import center_line_intersections_torch from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour from utils.helpers import save_with_unique_name def res_plt_2_torch( spine_tensor: torch.Tensor, cortical_tensor: torch.Tensor, image_shape: tuple[int, int, int], image2_path: str, base_folder: str, label_str: str, diameter_l: float, length_l: float, diameter_r: float, length_r: float, best_position_l: list[float], best_position_r: list[float], swarm_size: int, max_iter: int, total_time: float, spacing: list[float], CBT: bool, device: torch.device, grid=None ) -> None: """ Same plotting function as before, but it uses torch-based generation and then moves data to CPU for matplotlib 3D scatter. """ cyl_l = generate_cylinder_n_torch( diameter_l, length_l, best_position_l[0], best_position_l[1], best_position_l[2], best_position_l[3], best_position_l[4], image_shape, spacing, device, grid ) cyl_lo = generate_cylinder_o_torch( diameter_l, length_l, best_position_l[0], best_position_l[1], best_position_l[2], best_position_l[3], best_position_l[4], image_shape, spacing, device, grid ) cyl_r = generate_cylinder_n_torch( diameter_r, length_r, best_position_r[0], best_position_r[1], best_position_r[2], best_position_r[3], best_position_r[4], image_shape, spacing, device, grid ) cyl_ro = generate_cylinder_o_torch( diameter_r, length_r, best_position_r[0], best_position_r[1], best_position_r[2], best_position_r[3], best_position_r[4], image_shape, spacing, device, grid ) intersections_l, line_mask_l = center_line_intersections_torch( best_position_l[0], best_position_l[1], best_position_l[2], best_position_l[3], best_position_l[4], int(length_l), spine_tensor, spacing, device ) loss_l = cl_score_torch(cortical_tensor, spine_tensor, cyl_l, cyl_lo, intersections_l) intersections_r, line_mask_r = center_line_intersections_torch( best_position_r[0], best_position_r[1], best_position_r[2], best_position_r[3], best_position_r[4], int(length_r), spine_tensor, spacing, device ) loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r) azi = azimuth_rotation(image2_path) res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False) alt = res['superior']['tilt_angle_deg'] # Move data to CPU for plotting line_mask_l_cpu = line_mask_l.cpu().numpy() line_mask_r_cpu = line_mask_r.cpu().numpy() cyl_l_cpu = cyl_l.cpu().numpy() cyl_lo_cpu = cyl_lo.cpu().numpy() cyl_r_cpu = cyl_r.cpu().numpy() cyl_ro_cpu = cyl_ro.cpu().numpy() spine_cpu = spine_tensor.cpu().numpy() z_lin1, y_lin1, x_lin1 = np.where(line_mask_l_cpu == 1) z_lin2, y_lin2, x_lin2 = np.where(line_mask_r_cpu == 1) z_cyl_l1, y_cyl_l1, x_cyl_l1 = np.where(cyl_l_cpu == 1) z_cyl_l2, y_cyl_l2, x_cyl_l2 = np.where(cyl_lo_cpu == 1) z_cyl_r1, y_cyl_r1, x_cyl_r1 = np.where(cyl_r_cpu == 1) z_cyl_r2, y_cyl_r2, x_cyl_r2 = np.where(cyl_ro_cpu == 1) z_img, y_img, x_img = np.where(spine_cpu == 1) fig = plt.figure(figsize=(12, 12)) ax1 = fig.add_subplot(221, projection='3d') ax1.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1) ax1.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1) ax1.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)') ax1.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o') ax1.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)') ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o') ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine') ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis') ax2 = fig.add_subplot(222, projection='3d') ax2.view_init(elev=90, azim=-90, roll=0) ax2.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1) ax2.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1) ax2.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)') ax2.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o') ax2.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)') ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o') ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine') ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis') ax2.legend() ax3 = fig.add_subplot(223, projection='3d') ax3.view_init(elev=0, azim=90, roll=0) ax3.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1) ax3.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1) ax3.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)') ax3.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o') ax3.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)') ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o') ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine') ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis') ax4 = fig.add_subplot(224, projection='3d') ax4.view_init(elev=0, azim=0, roll=0) ax4.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1) ax4.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1) ax4.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)') ax4.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o') ax4.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)') ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o') ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine') ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis') cyl_points_l = torch.sum(cyl_l).item() cyl_points_r = torch.sum(cyl_r).item() overlap_l = ((cortical_tensor == 1) & (cyl_l == 1)).sum().item() overlap_r = ((cortical_tensor == 1) & (cyl_r == 1)).sum().item() overlap_b_l = ((spine_tensor == 1) & (cyl_l == 1)).sum().item() overlap_b_r = ((spine_tensor == 1) & (cyl_r == 1)).sum().item() overlap_cortical_l = (overlap_l / cyl_points_l) * 100 overlap_cortical_r = (overlap_r / cyl_points_r) * 100 overlap_vertebral_l = (overlap_b_l / cyl_points_l) * 100 overlap_vertebral_r = (overlap_b_r / cyl_points_r) * 100 cb_ratio_l = overlap_cortical_l/overlap_vertebral_l cb_ratio_r = overlap_cortical_r/overlap_vertebral_r user_altitude_l = 90 - best_position_l[4] - alt user_altitude_r = 90 - best_position_r[4] - alt user_azimuth_l = 90 - best_position_l[3] - azi user_azimuth_r = 90 - best_position_r[3] - azi date_str = datetime.now().strftime("%Y%m%d") patient_id = os.path.basename(os.path.dirname(image2_path)) output_folder = os.path.join(base_folder, date_str, patient_id) os.makedirs(output_folder, exist_ok=True) csv_path = os.path.join(output_folder, 'output.csv') # 檢查檔案是否存在 (決定是否寫入標題) file_exists = os.path.isfile(csv_path) # 欄位標題 (Header) headers = [ 'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter', 'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff', 'Intersections', 'Best_Loss', 'Overlap_Cortical', 'Overlap_Bone', 'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time' ] try: with open(csv_path, 'a', newline='') as csvfile: writer = csv.writer(csvfile) # 如果是新檔案,寫入 Header if not file_exists: writer.writerow(headers) # 寫入 Left 數據 writer.writerow([ label_str, 'L', diameter_l, length_l, swarm_size, max_iter, f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})", f"{best_position_l[3]:.2f}", f"{best_position_l[3]-azi:.2f}", f"{best_position_l[4]:.2f}", f"{best_position_l[4]-alt:.2f}", intersections_l, f"{loss_l:.2f}", f"{overlap_cortical_l:.2f}", f"{overlap_vertebral_l:.2f}", f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}", f"{user_azimuth_l:.2f}", f"{user_altitude_l:.2f}", f"{total_time:.2f}" ]) # 寫入 Right 數據 writer.writerow([ label_str, 'R', diameter_r, length_r, swarm_size, max_iter, f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})", f"{best_position_r[3]:.2f}", f"{best_position_r[3]-azi:.2f}", f"{best_position_r[4]:.2f}", f"{best_position_r[4]-alt:.2f}", intersections_r, f"{loss_r:.2f}", f"{overlap_cortical_r:.2f}", f"{overlap_vertebral_r:.2f}", f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}", f"{user_azimuth_r:.2f}", f"{user_altitude_r:.2f}", f"{total_time:.2f}" ]) print(f"[CSV Saved] {csv_path}") except Exception as e: print(f"[Error] Failed to write CSV: {e}") fig.text(0.5, 0.98, f'{label_str} Best Position', ha='center', fontsize=15) fig.text( 0.5, 0.44, f'L: Diameter = {diameter_l} mm, {length_l} mm, ' f'R: Diameter = {diameter_r} mm, {length_r} mm, ' f'Swarm size = {swarm_size}, Iteration = {max_iter}, Total time = {total_time:.2f} s', ha='center', fontsize=12 ) fig.text( 0.5, 0.03, f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), ' f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, ' f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}', ha='center', fontsize=9 ) fig.text( 0.5, 0.01, f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), ' f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, ' f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}', ha='center', fontsize=9 ) fig.tight_layout() date_str = datetime.now().strftime("%Y%m%d") file_name = os.path.basename(image2_path) level = file_name.split('_')[0] output_folder = os.path.join(base_folder, date_str, patient_id) os.makedirs(output_folder, exist_ok=True) if CBT == True: way = 'CBT' else: way = 'TPS' path = save_with_unique_name(output_folder, label_str, way, diameter_l, length_l, diameter_r, length_r, swarm_size, max_iter) fig.savefig(path, dpi=200, bbox_inches="tight") print("[Saved figure]", path) plt.close(fig) def eval_overlap_from_position( pos, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing, device: torch.device, grid=None, fixed_diameter: float | None = None, fixed_length: float | None = None, ): """ 根據 position 生成 cylinder mask,再算 overlap ratio """ if optimize_size: d, L = snap_to_discrete_values(pos[5], pos[6]) params_5 = pos[:5] else: if fixed_diameter is None or fixed_length is None: raise ValueError("fixed_diameter and fixed_length must be provided when optimize_size=False") d, L = fixed_diameter, fixed_length params_5 = pos z, y, x, az, alt = params_5 cyl_mask = generate_cylinder_n_torch( d, L, z, y, x, az, alt, image_shape, spacing, device=device, grid=grid ) overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor) return overlap, d, L