CBT_project/core/optimizer_ori.py
2026-04-10 13:25:27 +08:00

604 lines
No EOL
27 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from datetime import datetime
import SimpleITK as sitk
import torch
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
from core.objective import objective_function
from pyswarm import pso
import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
from config.constant import OVERLAP_THRESH
from visualization.res_plot_3d import res_plt_2_torch
def run_pso_torch(
label_str: str,
image1_path: str,
image2_path: str,
image3_path: str,
folder: str,
swarm_size: int,
max_iter: int,
spacing: list,
CBT: bool,
device: torch.device,
optimize_size: bool = True,
grid=None
):
"""
Main function to run PSO.
如果 optimize_size=Truediameter 和 length 也會被最佳化
如果 optimize_size=False使用預設值向後兼容
"""
start_time = time.time()
# Use global references
global image1_array, image2_array, image2_shape, image3_array
global diameter, length # 這些現在只用於非最佳化模式
global spine_tensor, cortical_tensor, spine_roi_tensor
# Load images
image1 = sitk.ReadImage(image1_path)
image2 = sitk.ReadImage(image2_path)
image3 = sitk.ReadImage(image3_path)
image1_array = sitk.GetArrayFromImage(image1)
image2_array = sitk.GetArrayFromImage(image2)
image3_array = sitk.GetArrayFromImage(image3)
image2_shape = image2_array.shape
image_shape = image2_shape
# Move arrays to torch
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
# 設定基本的 bounds
if CBT == True:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
azimuth_bounds_l = ((95-azi), (145-azi))
azimuth_bounds_r = ((50-azi), (85-azi))
altitude_bounds = ((60-alt), (75-alt))
else:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
azimuth_bounds_l = (60-azi, 90-azi)
azimuth_bounds_r = (90-azi, 120-azi)
altitude_bounds = (65-alt, 80-alt)
def eval_overlap_from_position(pos, side: str, optimize_size: bool,
spine_tensor: torch.Tensor,
image_shape, spacing):
"""
根據 PSO 給的 position 生成 cylinder mask再算 overlap ratio
side: "L" or "R" 只是方便 debug
"""
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
params_5 = pos
cyl_mask = generate_cylinder_n_torch(
d, L,
params_5[0], params_5[1], params_5[2],
params_5[3], params_5[4],
image_shape, spacing, device, grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L
if optimize_size:
# 模式 1優化 diameter 和 length
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
# 設定 diameter 和 length 的 bounds連續範圍
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
# bounds 現在有 7 個參數
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
else:
# 模式 2固定 diameter 和 length向後兼容
print("=== 固定尺寸模式:最佳化位置和角度 ===")
# 使用預設值(需要在調用時提供)
diameter = 4.5 # 或從參數傳入
length = 45 # 或從參數傳入
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]]
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]]
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
best_loss_l = float('inf')
best_loss_r = float('inf')
best_position_l = None
best_position_r = None
# Left side optimization
print("\n=== 左側 ===")
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
overlap_l, diameter_l, length_l = eval_overlap_from_position(
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
)
print(f"[LEFT] overlap: {overlap_l*100:.1f}%")
if optimize_size:
print(f"[LEFT] Position: {position_l[:5]}")
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
else:
print(f"[LEFT] Position: {position_l}")
best_position_l = position_l
best_loss_l = loss_l
best_overlap_l = overlap_l # 新增
max_retries = 10
retries = 0
# 左側 retryloss 要 <=0 且 overlap >= 0.5 才算過關
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries:
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
overlap_l, diameter_l, length_l = eval_overlap_from_position(
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
)
# 只要找到更好的 loss或你想用 loss+overlap 綜合排序也行)就更新 best
# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l
candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH)
best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)
if candidate_ok and (not best_ok or loss_l < best_loss_l):
best_position_l = candidate_pos
best_loss_l = loss_l
best_overlap_l = overlap_l
print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
elif (not best_ok) and (loss_l < best_loss_l):
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
best_position_l = candidate_pos
best_loss_l = loss_l
best_overlap_l = overlap_l
print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
else:
print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
retries += 1
# Right side optimization
print("\n=== 右側 ===")
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
overlap_r, diameter_r, length_r = eval_overlap_from_position(
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
)
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
if optimize_size:
diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
print(f"[RIGHT] Position: {position_r[:5]}")
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
print(f"[RIGHT] Loss: {loss_r}\n")
best_position_r = list(position_r[:5]) + [diameter_r, length_r]
else:
print(f"[RIGHT] Position: {position_r}")
print(f"[RIGHT] Loss: {loss_r}\n")
best_position_r = position_r
best_loss_r = loss_r
best_overlap_r = overlap_r
# 如果需要 retryloss > 0
max_retries = 10
retries = 0
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries:
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
overlap_r, diameter_r, length_r = eval_overlap_from_position(
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
)
# 只要找到更好的 loss或你想用 loss+overlap 綜合排序也行)就更新 best
# 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r
candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH)
best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)
if candidate_ok and (not best_ok or loss_r < best_loss_r):
best_position_r = candidate_pos
best_loss_r = loss_r
best_overlap_r = overlap_r
print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
elif (not best_ok) and (loss_r < best_loss_r):
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
best_position_r = candidate_pos
best_loss_r = loss_r
best_overlap_r = overlap_r
print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
else:
print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
retries += 1
end_time = time.time()
total_time = end_time - start_time
# 提取最終的 diameter 和 length
if optimize_size:
final_diameter_l = best_position_l[5]
final_length_l = best_position_l[6]
final_diameter_r = best_position_r[5]
final_length_r = best_position_r[6]
print(f"\n=== 最終結果 ===")
print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm")
print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm")
else:
final_diameter_l = diameter
final_length_l = length
final_diameter_r = diameter
final_length_r = length
res_plt_2_torch(
spine_tensor,
cortical_tensor,
image_shape,
image2_path,
'Output',
label_str,
final_diameter_l,
final_length_l,
final_diameter_r,
final_length_r,
best_position_l,
best_position_r,
swarm_size,
max_iter,
total_time,
spacing,
CBT,
device,
grid)
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
import time
import numpy as np
import SimpleITK as sitk
import torch
from scipy.optimize import differential_evolution
from scipy.optimize import minimize
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
from core.objective import objective_function
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
from config.constant import OVERLAP_THRESH
from visualization.res_plot_3d import res_plt_2_torch
def run_de_torch(
label_str: str,
image1_path: str,
image2_path: str,
image3_path: str,
folder: str,
swarm_size: int,
max_iter: int,
spacing: list,
CBT: bool,
device: torch.device,
optimize_size: bool = True,
grid=None
):
"""
使用 Differential Evolution (DE) 進行最佳化
"""
start_time = time.time()
global image1_array, image2_array, image2_shape, image3_array
global diameter, length
global spine_tensor, cortical_tensor, spine_roi_tensor
image1 = sitk.ReadImage(image1_path)
image2 = sitk.ReadImage(image2_path)
image3 = sitk.ReadImage(image3_path)
image1_array = sitk.GetArrayFromImage(image1)
image2_array = sitk.GetArrayFromImage(image2)
image3_array = sitk.GetArrayFromImage(image3)
image2_shape = image2_array.shape
image_shape = image2_shape
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
if CBT == True:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
azimuth_bounds_l = ((95-azi), (145-azi))
azimuth_bounds_r = ((50-azi), (85-azi))
altitude_bounds = ((60-alt), (75-alt))
else:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
azimuth_bounds_l = (60-azi, 90-azi)
azimuth_bounds_r = (90-azi, 120-azi)
altitude_bounds = (65-alt, 80-alt)
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
params_5 = pos
cyl_mask = generate_cylinder_n_torch(
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
image_shape, spacing, device, grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L
if optimize_size:
print("=== DE 最佳化模式:最佳化位置、角度、直徑和長度 ===")
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, diameter_bounds, length_bounds]
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, diameter_bounds, length_bounds]
else:
print("=== DE 固定尺寸模式:最佳化位置和角度 ===")
diameter = 4.5
length = 45
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
# DE 的 popsize 實際粒子數 = popsize * len(bounds)
# 為了跟 PSO 公平比較,我們讓它轉換一下
de_popsize = max(1, swarm_size // len(bounds_l))
# --- 左側最佳化 ---
print("\n=== 左側 (DE) ===")
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
position_l, loss_l = res_l.x, res_l.fun
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
best_loss_l, best_overlap_l = loss_l, overlap_l
retries = 0
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
position_l, loss_l = res_l.x, res_l.fun
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
retries += 1
# --- 右側最佳化 ---
print("\n=== 右側 (DE) ===")
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
position_r, loss_r = res_r.x, res_r.fun
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
best_loss_r, best_overlap_r = loss_r, overlap_r
retries = 0
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
position_r, loss_r = res_r.x, res_r.fun
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
retries += 1
total_time = time.time() - start_time
final_diameter_l = best_position_l[5] if optimize_size else diameter
final_length_l = best_position_l[6] if optimize_size else length
final_diameter_r = best_position_r[5] if optimize_size else diameter
final_length_r = best_position_r[6] if optimize_size else length
res_plt_2_torch(
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
)
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
def run_nm_torch(
label_str: str,
image1_path: str,
image2_path: str,
image3_path: str,
folder: str,
swarm_size: int, # NM 不用 swarm_size但保留參數以維持介面統一
max_iter: int,
spacing: list,
CBT: bool,
device: torch.device,
optimize_size: bool = True,
grid=None
):
"""
使用 Nelder-Mead 進行最佳化
"""
start_time = time.time()
global image1_array, image2_array, image2_shape, image3_array
global diameter, length
global spine_tensor, cortical_tensor, spine_roi_tensor
image1 = sitk.ReadImage(image1_path)
image2 = sitk.ReadImage(image2_path)
image3 = sitk.ReadImage(image3_path)
image1_array = sitk.GetArrayFromImage(image1)
image2_array = sitk.GetArrayFromImage(image2)
image3_array = sitk.GetArrayFromImage(image3)
image2_shape = image2_array.shape
image_shape = image2_shape
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
if CBT == True:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
azimuth_bounds_l = ((95-azi), (145-azi))
azimuth_bounds_r = ((50-azi), (85-azi))
altitude_bounds = ((60-alt), (75-alt))
else:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
azimuth_bounds_l = (60-azi, 90-azi)
azimuth_bounds_r = (90-azi, 120-azi)
altitude_bounds = (65-alt, 80-alt)
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
params_5 = pos
cyl_mask = generate_cylinder_n_torch(
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
image_shape, spacing, device, grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L
if optimize_size:
print("=== NM 最佳化模式 ===")
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds,
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds,
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
else:
print("=== NM 固定尺寸模式 ===")
diameter, length = 4.5, 45
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
def get_random_x0(bounds):
# 產生在 Bounds 內的隨機起始點
return [np.random.uniform(b[0], b[1]) for b in bounds]
# --- 左側最佳化 ---
print("\n=== 左側 (Nelder-Mead) ===")
x0_l = get_random_x0(bounds_l)
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
position_l, loss_l = res_l.x, res_l.fun
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
best_loss_l, best_overlap_l = loss_l, overlap_l
retries = 0
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
x0_l = get_random_x0(bounds_l) # 每次 retry 都換一個隨機起始點
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
position_l, loss_l = res_l.x, res_l.fun
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
retries += 1
# --- 右側最佳化 ---
print("\n=== 右側 (Nelder-Mead) ===")
x0_r = get_random_x0(bounds_r)
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
position_r, loss_r = res_r.x, res_r.fun
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
best_loss_r, best_overlap_r = loss_r, overlap_r
retries = 0
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
x0_r = get_random_x0(bounds_r)
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
position_r, loss_r = res_r.x, res_r.fun
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
retries += 1
total_time = time.time() - start_time
final_diameter_l = best_position_l[5] if optimize_size else diameter
final_length_l = best_position_l[6] if optimize_size else length
final_diameter_r = best_position_r[5] if optimize_size else diameter
final_length_r = best_position_r[6] if optimize_size else length
res_plt_2_torch(
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
)
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time