157 lines
6.8 KiB
Python
157 lines
6.8 KiB
Python
|
|
import os
|
||
|
|
import time
|
||
|
|
import pandas as pd
|
||
|
|
import torch
|
||
|
|
import argparse
|
||
|
|
import SimpleITK as sitk
|
||
|
|
import traceback
|
||
|
|
import core.objective
|
||
|
|
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
|
||
|
|
from core.cylinder import generate_cylinder_n_torch, create_coordinate_grid
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="CBT Optimization Benchmark")
|
||
|
|
parser.add_argument('--method', type=str, required=True, choices=['PSO', 'DE', 'NM'])
|
||
|
|
parser.add_argument('--gpu', type=int, default=0)
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
device = torch.device(f'cuda:{args.gpu}')
|
||
|
|
torch.cuda.set_device(device)
|
||
|
|
else:
|
||
|
|
device = torch.device('cpu')
|
||
|
|
|
||
|
|
print(f"\n🚀 啟動任務: 演算法 [{args.method}] | 運行於 [{device}]")
|
||
|
|
|
||
|
|
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}"
|
||
|
|
patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221]
|
||
|
|
|
||
|
|
runs_per_method = 3
|
||
|
|
|
||
|
|
CBT = True
|
||
|
|
optimize_size = True
|
||
|
|
spacing = [0.5, 0.5, 0.5]
|
||
|
|
|
||
|
|
if args.method == 'PSO':
|
||
|
|
method_func, swarm, iters = run_pso_torch, 70, 100
|
||
|
|
elif args.method == 'DE':
|
||
|
|
method_func, swarm, iters = run_de_torch, 70, 100
|
||
|
|
elif args.method == 'NM':
|
||
|
|
method_func, swarm, iters = run_nm_torch, 70, 7000
|
||
|
|
|
||
|
|
results_data = []
|
||
|
|
|
||
|
|
for p_id in patients:
|
||
|
|
folder_num = str(p_id)
|
||
|
|
current_dir = base_dir.format(folder_num)
|
||
|
|
|
||
|
|
cortical_path = os.path.join(current_dir, "L5_cortical.nii.gz")
|
||
|
|
binary_path = os.path.join(current_dir, "L5_binary2.nii.gz")
|
||
|
|
roi_path = os.path.join(current_dir, "L5_roi2.nii.gz")
|
||
|
|
|
||
|
|
if not all(os.path.exists(p) for p in [cortical_path, binary_path, roi_path]):
|
||
|
|
print(f"⚠️ 找不到病人 000{p_id} 的影像檔案,跳過此病人。")
|
||
|
|
continue
|
||
|
|
|
||
|
|
print(f"\n=========================================")
|
||
|
|
print(f"📍 開始處理病人 ID: 000{p_id} (L5) - {args.method}")
|
||
|
|
print(f"=========================================")
|
||
|
|
# 在每個 case 進入 run_idx 之前先做
|
||
|
|
img_b = sitk.GetArrayFromImage(sitk.ReadImage(binary_path))
|
||
|
|
image_shape = img_b.shape
|
||
|
|
grid = create_coordinate_grid(image_shape, device)
|
||
|
|
|
||
|
|
for run_idx in range(1, runs_per_method + 1):
|
||
|
|
label_str = f"Patient_000{p_id}_{args.method}_Run{run_idx}"
|
||
|
|
print(f"\n---> 執行 {args.method} (第 {run_idx}/{runs_per_method} 次)")
|
||
|
|
|
||
|
|
try:
|
||
|
|
# 1. 呼叫演算法 (只會拿到 5 個回傳值)
|
||
|
|
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = method_func(
|
||
|
|
label_str=label_str,
|
||
|
|
image1_path=cortical_path,
|
||
|
|
image2_path=binary_path,
|
||
|
|
image3_path=roi_path,
|
||
|
|
folder="Output",
|
||
|
|
swarm_size=swarm,
|
||
|
|
max_iter=iters,
|
||
|
|
spacing=spacing,
|
||
|
|
CBT=CBT,
|
||
|
|
device=device,
|
||
|
|
optimize_size=optimize_size,
|
||
|
|
grid=grid
|
||
|
|
)
|
||
|
|
|
||
|
|
# ========== 2. 在外部獨立計算分數 (絕對不會有 NameError) ==========
|
||
|
|
# 讀取影像為 Numpy 並轉為 GPU Tensor
|
||
|
|
img_c = sitk.GetArrayFromImage(sitk.ReadImage(cortical_path))
|
||
|
|
img_b = sitk.GetArrayFromImage(sitk.ReadImage(binary_path))
|
||
|
|
image_shape = img_b.shape
|
||
|
|
|
||
|
|
cortical_eval = torch.from_numpy(img_c).to(device=device, dtype=torch.uint8)
|
||
|
|
spine_eval = torch.from_numpy(img_b).to(device=device, dtype=torch.uint8)
|
||
|
|
|
||
|
|
# 解析 Diameter 與 Length
|
||
|
|
d_l, l_l = (best_pos_l[5], best_pos_l[6]) if optimize_size else (4.5, 45)
|
||
|
|
d_r, l_r = (best_pos_r[5], best_pos_r[6]) if optimize_size else (4.5, 45)
|
||
|
|
|
||
|
|
# 重新生成左側與右側的 Cylinder Mask
|
||
|
|
cyl_l = generate_cylinder_n_torch(
|
||
|
|
d_l, l_l, best_pos_l[0], best_pos_l[1], best_pos_l[2], best_pos_l[3], best_pos_l[4],
|
||
|
|
image_shape, spacing, device, grid=grid
|
||
|
|
)
|
||
|
|
cyl_r = generate_cylinder_n_torch(
|
||
|
|
d_r, l_r, best_pos_r[0], best_pos_r[1], best_pos_r[2], best_pos_r[3], best_pos_r[4],
|
||
|
|
image_shape, spacing, device, grid=grid
|
||
|
|
)
|
||
|
|
|
||
|
|
# 計算分數
|
||
|
|
cyl_points_l = torch.sum(cyl_l).item()
|
||
|
|
cyl_points_r = torch.sum(cyl_r).item()
|
||
|
|
|
||
|
|
overlap_c_l = ((cortical_eval == 1) & (cyl_l == 1)).sum().item()
|
||
|
|
overlap_c_r = ((cortical_eval == 1) & (cyl_r == 1)).sum().item()
|
||
|
|
overlap_b_l = ((spine_eval == 1) & (cyl_l == 1)).sum().item()
|
||
|
|
overlap_b_r = ((spine_eval == 1) & (cyl_r == 1)).sum().item()
|
||
|
|
|
||
|
|
score_c_l = (overlap_c_l / cyl_points_l * 100) if cyl_points_l > 0 else 0
|
||
|
|
score_c_r = (overlap_c_r / cyl_points_r * 100) if cyl_points_r > 0 else 0
|
||
|
|
score_b_l = (overlap_b_l / cyl_points_l * 100) if cyl_points_l > 0 else 0
|
||
|
|
score_b_r = (overlap_b_r / cyl_points_r * 100) if cyl_points_r > 0 else 0
|
||
|
|
|
||
|
|
cb_ratio_l = (score_c_l / score_b_l) if score_b_l > 0 else 0
|
||
|
|
cb_ratio_r = (score_c_r / score_b_r) if score_b_r > 0 else 0
|
||
|
|
# ===============================================================
|
||
|
|
|
||
|
|
# 3. 儲存結果
|
||
|
|
run_result = {
|
||
|
|
"Patient_ID": f"000{p_id}",
|
||
|
|
"Method": args.method,
|
||
|
|
"Run": run_idx,
|
||
|
|
"Total_Time_sec": total_time,
|
||
|
|
"Left_Cortical_Score": score_c_l,
|
||
|
|
"Left_Bone_Score": score_b_l,
|
||
|
|
"Left_C_B_Ratio": cb_ratio_l,
|
||
|
|
"Right_Cortical_Score": score_c_r,
|
||
|
|
"Right_Bone_Score": score_b_r,
|
||
|
|
"Right_C_B_Ratio": cb_ratio_r,
|
||
|
|
"Left_Loss": best_loss_l,
|
||
|
|
"Right_Loss": best_loss_r
|
||
|
|
}
|
||
|
|
results_data.append(run_result)
|
||
|
|
|
||
|
|
print(f"✅ 完成 | 耗時: {total_time:.2f}s | "
|
||
|
|
f"左C/B: {cb_ratio_l:.3f} | 右C/B: {cb_ratio_r:.3f}")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"❌ {args.method} Run {run_idx} 發生錯誤: {e}")
|
||
|
|
traceback.print_exc() # 印出詳細報錯足跡
|
||
|
|
|
||
|
|
if results_data:
|
||
|
|
df = pd.DataFrame(results_data)
|
||
|
|
csv_filename = f"optimization_benchmark_results_{args.method}.csv"
|
||
|
|
df.to_csv(csv_filename, index=False)
|
||
|
|
print(f"\n🎉 {args.method} 測試完成!結果已儲存至: {csv_filename}")
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|