691 lines
No EOL
31 KiB
Python
691 lines
No EOL
31 KiB
Python
import time
|
||
from datetime import datetime
|
||
import SimpleITK as sitk
|
||
import torch
|
||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||
from core.objective import objective_function
|
||
from pyswarm import pso
|
||
import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組
|
||
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
|
||
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
|
||
from config.constant import OVERLAP_THRESH
|
||
from visualization.res_plot_3d import res_plt_2_torch
|
||
|
||
def run_pso_torch(
|
||
label_str: str,
|
||
image1_path: str,
|
||
image2_path: str,
|
||
image3_path: str,
|
||
folder: str,
|
||
swarm_size: int,
|
||
max_iter: int,
|
||
spacing: list,
|
||
CBT: bool,
|
||
device: torch.device,
|
||
optimize_size: bool = True,
|
||
grid=None
|
||
):
|
||
"""
|
||
Main function to run PSO.
|
||
如果 optimize_size=True,diameter 和 length 也會被最佳化
|
||
如果 optimize_size=False,使用預設值(向後兼容)
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# Use global references
|
||
global image1_array, image2_array, image2_shape, image3_array
|
||
global diameter, length # 這些現在只用於非最佳化模式
|
||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||
|
||
# Load images
|
||
image1 = sitk.ReadImage(image1_path)
|
||
image2 = sitk.ReadImage(image2_path)
|
||
image3 = sitk.ReadImage(image3_path)
|
||
image1_array = sitk.GetArrayFromImage(image1)
|
||
image2_array = sitk.GetArrayFromImage(image2)
|
||
image3_array = sitk.GetArrayFromImage(image3)
|
||
image2_shape = image2_array.shape
|
||
image_shape = image2_shape
|
||
|
||
# Move arrays to torch
|
||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||
|
||
# ================= [跨檔案注入變數:終極防呆版] =================
|
||
import core.objective
|
||
|
||
# 1. 注入 Tensors
|
||
core.objective.cortical_tensor = cortical_tensor
|
||
core.objective.spine_tensor = spine_tensor
|
||
core.objective.spine_roi_tensor = spine_roi_tensor
|
||
|
||
# 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列)
|
||
core.objective.image1_array = image1_array
|
||
core.objective.image2_array = image2_array
|
||
core.objective.image3_array = image3_array
|
||
|
||
# 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!)
|
||
core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行
|
||
core.objective.image_shape = image_shape
|
||
core.objective.shape = image_shape
|
||
|
||
# 4. 注入環境變數
|
||
core.objective.spacing = spacing
|
||
core.objective.device = device
|
||
core.objective.grid = grid
|
||
|
||
# 5. 注入尺寸參數 (兼容固定尺寸模式)
|
||
if not optimize_size:
|
||
core.objective.diameter = diameter
|
||
core.objective.length = length
|
||
# ==============================================================
|
||
|
||
azi = azimuth_rotation(image2_path)
|
||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||
alt = res['superior']['tilt_angle_deg']
|
||
|
||
# 設定基本的 bounds
|
||
if CBT == True:
|
||
z_bounds = (0, image_shape[0] - 1)
|
||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||
altitude_bounds = ((60-alt), (75-alt))
|
||
else:
|
||
z_bounds = (0, image_shape[0] - 1)
|
||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||
azimuth_bounds_l = (60-azi, 90-azi)
|
||
azimuth_bounds_r = (90-azi, 120-azi)
|
||
altitude_bounds = (65-alt, 80-alt)
|
||
|
||
def eval_overlap_from_position(pos, side: str, optimize_size: bool,
|
||
spine_tensor: torch.Tensor,
|
||
image_shape, spacing):
|
||
"""
|
||
根據 PSO 給的 position 生成 cylinder mask,再算 overlap ratio
|
||
side: "L" or "R" 只是方便 debug
|
||
"""
|
||
if optimize_size:
|
||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||
params_5 = pos[:5]
|
||
else:
|
||
d, L = diameter, length
|
||
params_5 = pos
|
||
|
||
cyl_mask = generate_cylinder_n_torch(
|
||
d, L,
|
||
params_5[0], params_5[1], params_5[2],
|
||
params_5[3], params_5[4],
|
||
image_shape, spacing, device, grid
|
||
)
|
||
|
||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||
return overlap, d, L
|
||
|
||
if optimize_size:
|
||
# 模式 1:優化 diameter 和 length
|
||
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
||
|
||
# 設定 diameter 和 length 的 bounds(連續範圍)
|
||
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
|
||
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
|
||
|
||
# bounds 現在有 7 個參數
|
||
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
|
||
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
|
||
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||
|
||
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
|
||
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
|
||
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||
|
||
else:
|
||
# 模式 2:固定 diameter 和 length(向後兼容)
|
||
print("=== 固定尺寸模式:最佳化位置和角度 ===")
|
||
# 使用預設值(需要在調用時提供)
|
||
diameter = 4.5 # 或從參數傳入
|
||
length = 45 # 或從參數傳入
|
||
|
||
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]]
|
||
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]]
|
||
|
||
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
|
||
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
|
||
|
||
best_loss_l = float('inf')
|
||
best_loss_r = float('inf')
|
||
best_position_l = None
|
||
best_position_r = None
|
||
|
||
# Left side optimization
|
||
print("\n=== 左側 ===")
|
||
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
|
||
|
||
overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||
)
|
||
print(f"[LEFT] overlap: {overlap_l*100:.1f}%")
|
||
|
||
if optimize_size:
|
||
print(f"[LEFT] Position: {position_l[:5]}")
|
||
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
|
||
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
|
||
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
|
||
else:
|
||
print(f"[LEFT] Position: {position_l}")
|
||
best_position_l = position_l
|
||
|
||
best_loss_l = loss_l
|
||
best_overlap_l = overlap_l # 新增
|
||
|
||
# max_retries = 0
|
||
# retries = 0
|
||
|
||
# 左側 retry:loss 要 <=0 且 overlap >= 0.5 才算過關
|
||
# while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries:
|
||
# position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
|
||
# overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||
# position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||
# )
|
||
|
||
# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best
|
||
# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
|
||
# candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l
|
||
|
||
# candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH)
|
||
# best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)
|
||
|
||
# if candidate_ok and (not best_ok or loss_l < best_loss_l):
|
||
# best_position_l = candidate_pos
|
||
# best_loss_l = loss_l
|
||
# best_overlap_l = overlap_l
|
||
# print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||
# elif (not best_ok) and (loss_l < best_loss_l):
|
||
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
|
||
# best_position_l = candidate_pos
|
||
# best_loss_l = loss_l
|
||
# best_overlap_l = overlap_l
|
||
# print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||
# else:
|
||
# print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||
|
||
# retries += 1
|
||
|
||
# Right side optimization
|
||
print("\n=== 右側 ===")
|
||
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
|
||
overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||
)
|
||
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
|
||
|
||
if optimize_size:
|
||
diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
|
||
print(f"[RIGHT] Position: {position_r[:5]}")
|
||
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
|
||
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
|
||
print(f"[RIGHT] Loss: {loss_r}\n")
|
||
|
||
best_position_r = list(position_r[:5]) + [diameter_r, length_r]
|
||
else:
|
||
print(f"[RIGHT] Position: {position_r}")
|
||
print(f"[RIGHT] Loss: {loss_r}\n")
|
||
best_position_r = position_r
|
||
|
||
best_loss_r = loss_r
|
||
best_overlap_r = overlap_r
|
||
|
||
# 如果需要 retry(loss > 0)
|
||
# max_retries = 10
|
||
# retries = 0
|
||
|
||
# while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries:
|
||
# position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
|
||
# overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||
# position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||
# )
|
||
|
||
# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best
|
||
# 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
|
||
# candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r
|
||
|
||
# candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH)
|
||
# best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)
|
||
|
||
# if candidate_ok and (not best_ok or loss_r < best_loss_r):
|
||
# best_position_r = candidate_pos
|
||
# best_loss_r = loss_r
|
||
# best_overlap_r = overlap_r
|
||
# print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||
# elif (not best_ok) and (loss_r < best_loss_r):
|
||
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
|
||
# best_position_r = candidate_pos
|
||
# best_loss_r = loss_r
|
||
# best_overlap_r = overlap_r
|
||
# print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||
# else:
|
||
# print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||
|
||
# retries += 1
|
||
|
||
end_time = time.time()
|
||
total_time = end_time - start_time
|
||
|
||
# 提取最終的 diameter 和 length
|
||
if optimize_size:
|
||
final_diameter_l = best_position_l[5]
|
||
final_length_l = best_position_l[6]
|
||
final_diameter_r = best_position_r[5]
|
||
final_length_r = best_position_r[6]
|
||
|
||
print(f"\n=== 最終結果 ===")
|
||
print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm")
|
||
print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm")
|
||
else:
|
||
final_diameter_l = diameter
|
||
final_length_l = length
|
||
final_diameter_r = diameter
|
||
final_length_r = length
|
||
|
||
res_plt_2_torch(
|
||
spine_tensor,
|
||
cortical_tensor,
|
||
image_shape,
|
||
image2_path,
|
||
'Output',
|
||
label_str,
|
||
final_diameter_l,
|
||
final_length_l,
|
||
final_diameter_r,
|
||
final_length_r,
|
||
best_position_l,
|
||
best_position_r,
|
||
swarm_size,
|
||
max_iter,
|
||
total_time,
|
||
spacing,
|
||
CBT,
|
||
device,
|
||
grid)
|
||
|
||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||
|
||
import time
|
||
import numpy as np
|
||
import SimpleITK as sitk
|
||
import torch
|
||
from scipy.optimize import differential_evolution
|
||
from scipy.optimize import minimize
|
||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||
from core.objective import objective_function
|
||
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
|
||
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
|
||
from config.constant import OVERLAP_THRESH
|
||
from visualization.res_plot_3d import res_plt_2_torch
|
||
|
||
def run_de_torch(
|
||
label_str: str,
|
||
image1_path: str,
|
||
image2_path: str,
|
||
image3_path: str,
|
||
folder: str,
|
||
swarm_size: int,
|
||
max_iter: int,
|
||
spacing: list,
|
||
CBT: bool,
|
||
device: torch.device,
|
||
optimize_size: bool = True,
|
||
grid=None
|
||
):
|
||
"""
|
||
使用 Differential Evolution (DE) 進行最佳化
|
||
"""
|
||
start_time = time.time()
|
||
|
||
global image1_array, image2_array, image2_shape, image3_array
|
||
global diameter, length
|
||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||
|
||
image1 = sitk.ReadImage(image1_path)
|
||
image2 = sitk.ReadImage(image2_path)
|
||
image3 = sitk.ReadImage(image3_path)
|
||
image1_array = sitk.GetArrayFromImage(image1)
|
||
image2_array = sitk.GetArrayFromImage(image2)
|
||
image3_array = sitk.GetArrayFromImage(image3)
|
||
image2_shape = image2_array.shape
|
||
image_shape = image2_shape
|
||
|
||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||
|
||
# ================= [跨檔案注入變數:終極防呆版] =================
|
||
import core.objective
|
||
|
||
# 1. 注入 Tensors
|
||
core.objective.cortical_tensor = cortical_tensor
|
||
core.objective.spine_tensor = spine_tensor
|
||
core.objective.spine_roi_tensor = spine_roi_tensor
|
||
|
||
# 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列)
|
||
core.objective.image1_array = image1_array
|
||
core.objective.image2_array = image2_array
|
||
core.objective.image3_array = image3_array
|
||
|
||
# 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!)
|
||
core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行
|
||
core.objective.image_shape = image_shape
|
||
core.objective.shape = image_shape
|
||
|
||
# 4. 注入環境變數
|
||
core.objective.spacing = spacing
|
||
core.objective.device = device
|
||
core.objective.grid = grid
|
||
|
||
# 5. 注入尺寸參數 (兼容固定尺寸模式)
|
||
if not optimize_size:
|
||
core.objective.diameter = diameter
|
||
core.objective.length = length
|
||
# ==============================================================
|
||
|
||
azi = azimuth_rotation(image2_path)
|
||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||
alt = res['superior']['tilt_angle_deg']
|
||
|
||
if CBT == True:
|
||
z_bounds = (0, image_shape[0] - 1)
|
||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||
altitude_bounds = ((60-alt), (75-alt))
|
||
else:
|
||
z_bounds = (0, image_shape[0] - 1)
|
||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||
azimuth_bounds_l = (60-azi, 90-azi)
|
||
azimuth_bounds_r = (90-azi, 120-azi)
|
||
altitude_bounds = (65-alt, 80-alt)
|
||
|
||
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
||
if optimize_size:
|
||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||
params_5 = pos[:5]
|
||
else:
|
||
d, L = diameter, length
|
||
params_5 = pos
|
||
|
||
cyl_mask = generate_cylinder_n_torch(
|
||
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
|
||
image_shape, spacing, device, grid
|
||
)
|
||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||
return overlap, d, L
|
||
|
||
if optimize_size:
|
||
print("=== DE 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
||
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
|
||
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
|
||
|
||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, diameter_bounds, length_bounds]
|
||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, diameter_bounds, length_bounds]
|
||
else:
|
||
print("=== DE 固定尺寸模式:最佳化位置和角度 ===")
|
||
diameter = 4.5
|
||
length = 45
|
||
|
||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
|
||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
|
||
|
||
# DE 的 popsize 實際粒子數 = popsize * len(bounds)
|
||
# 為了跟 PSO 公平比較,我們讓它轉換一下
|
||
de_popsize = max(1, swarm_size // len(bounds_l))
|
||
|
||
# --- 左側最佳化 ---
|
||
print("\n=== 左側 (DE) ===")
|
||
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
|
||
position_l, loss_l = res_l.x, res_l.fun
|
||
|
||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
|
||
best_loss_l, best_overlap_l = loss_l, overlap_l
|
||
"""
|
||
retries = 0
|
||
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
|
||
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
|
||
position_l, loss_l = res_l.x, res_l.fun
|
||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||
|
||
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
|
||
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
|
||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
|
||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||
retries += 1
|
||
"""
|
||
# --- 右側最佳化 ---
|
||
print("\n=== 右側 (DE) ===")
|
||
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
|
||
position_r, loss_r = res_r.x, res_r.fun
|
||
|
||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
|
||
best_loss_r, best_overlap_r = loss_r, overlap_r
|
||
"""
|
||
retries = 0
|
||
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
|
||
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
|
||
position_r, loss_r = res_r.x, res_r.fun
|
||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||
|
||
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
|
||
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
|
||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
|
||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||
retries += 1
|
||
"""
|
||
total_time = time.time() - start_time
|
||
|
||
final_diameter_l = best_position_l[5] if optimize_size else diameter
|
||
final_length_l = best_position_l[6] if optimize_size else length
|
||
final_diameter_r = best_position_r[5] if optimize_size else diameter
|
||
final_length_r = best_position_r[6] if optimize_size else length
|
||
|
||
res_plt_2_torch(
|
||
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
|
||
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
|
||
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
|
||
)
|
||
|
||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||
|
||
def run_nm_torch(
|
||
label_str: str,
|
||
image1_path: str,
|
||
image2_path: str,
|
||
image3_path: str,
|
||
folder: str,
|
||
swarm_size: int, # NM 不用 swarm_size,但保留參數以維持介面統一
|
||
max_iter: int,
|
||
spacing: list,
|
||
CBT: bool,
|
||
device: torch.device,
|
||
optimize_size: bool = True,
|
||
grid=None
|
||
):
|
||
"""
|
||
使用 Nelder-Mead 進行最佳化
|
||
"""
|
||
start_time = time.time()
|
||
|
||
global image1_array, image2_array, image2_shape, image3_array
|
||
global diameter, length
|
||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||
|
||
image1 = sitk.ReadImage(image1_path)
|
||
image2 = sitk.ReadImage(image2_path)
|
||
image3 = sitk.ReadImage(image3_path)
|
||
image1_array = sitk.GetArrayFromImage(image1)
|
||
image2_array = sitk.GetArrayFromImage(image2)
|
||
image3_array = sitk.GetArrayFromImage(image3)
|
||
image2_shape = image2_array.shape
|
||
image_shape = image2_shape
|
||
|
||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||
# ================= [跨檔案注入變數:終極防呆版] =================
|
||
import core.objective
|
||
|
||
# 1. 注入 Tensors
|
||
core.objective.cortical_tensor = cortical_tensor
|
||
core.objective.spine_tensor = spine_tensor
|
||
core.objective.spine_roi_tensor = spine_roi_tensor
|
||
|
||
# 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列)
|
||
core.objective.image1_array = image1_array
|
||
core.objective.image2_array = image2_array
|
||
core.objective.image3_array = image3_array
|
||
|
||
# 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!)
|
||
core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行
|
||
core.objective.image_shape = image_shape
|
||
core.objective.shape = image_shape
|
||
|
||
# 4. 注入環境變數
|
||
core.objective.spacing = spacing
|
||
core.objective.device = device
|
||
core.objective.grid = grid
|
||
|
||
# 5. 注入尺寸參數 (兼容固定尺寸模式)
|
||
if not optimize_size:
|
||
core.objective.diameter = diameter
|
||
core.objective.length = length
|
||
# ==============================================================
|
||
|
||
azi = azimuth_rotation(image2_path)
|
||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||
alt = res['superior']['tilt_angle_deg']
|
||
|
||
if CBT == True:
|
||
z_bounds = (0, image_shape[0] - 1)
|
||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||
altitude_bounds = ((60-alt), (75-alt))
|
||
else:
|
||
z_bounds = (0, image_shape[0] - 1)
|
||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||
azimuth_bounds_l = (60-azi, 90-azi)
|
||
azimuth_bounds_r = (90-azi, 120-azi)
|
||
altitude_bounds = (65-alt, 80-alt)
|
||
|
||
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
||
if optimize_size:
|
||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||
params_5 = pos[:5]
|
||
else:
|
||
d, L = diameter, length
|
||
params_5 = pos
|
||
|
||
cyl_mask = generate_cylinder_n_torch(
|
||
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
|
||
image_shape, spacing, device, grid
|
||
)
|
||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||
return overlap, d, L
|
||
|
||
if optimize_size:
|
||
print("=== NM 最佳化模式 ===")
|
||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds,
|
||
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
|
||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds,
|
||
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
|
||
else:
|
||
print("=== NM 固定尺寸模式 ===")
|
||
diameter, length = 4.5, 45
|
||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
|
||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
|
||
|
||
def get_random_x0(bounds):
|
||
# 產生在 Bounds 內的隨機起始點
|
||
return [np.random.uniform(b[0], b[1]) for b in bounds]
|
||
|
||
# --- 左側最佳化 ---
|
||
print("\n=== 左側 (Nelder-Mead) ===")
|
||
x0_l = get_random_x0(bounds_l)
|
||
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
|
||
position_l, loss_l = res_l.x, res_l.fun
|
||
|
||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
|
||
best_loss_l, best_overlap_l = loss_l, overlap_l
|
||
|
||
retries = 0
|
||
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
|
||
x0_l = get_random_x0(bounds_l) # 每次 retry 都換一個隨機起始點
|
||
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
|
||
position_l, loss_l = res_l.x, res_l.fun
|
||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||
|
||
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
|
||
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
|
||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
|
||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||
retries += 1
|
||
|
||
# --- 右側最佳化 ---
|
||
print("\n=== 右側 (Nelder-Mead) ===")
|
||
x0_r = get_random_x0(bounds_r)
|
||
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
|
||
position_r, loss_r = res_r.x, res_r.fun
|
||
|
||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
|
||
best_loss_r, best_overlap_r = loss_r, overlap_r
|
||
|
||
retries = 0
|
||
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
|
||
x0_r = get_random_x0(bounds_r)
|
||
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
|
||
position_r, loss_r = res_r.x, res_r.fun
|
||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||
|
||
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
|
||
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
|
||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
|
||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||
retries += 1
|
||
|
||
total_time = time.time() - start_time
|
||
|
||
final_diameter_l = best_position_l[5] if optimize_size else diameter
|
||
final_length_l = best_position_l[6] if optimize_size else length
|
||
final_diameter_r = best_position_r[5] if optimize_size else diameter
|
||
final_length_r = best_position_r[6] if optimize_size else length
|
||
|
||
res_plt_2_torch(
|
||
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
|
||
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
|
||
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
|
||
)
|
||
|
||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time |