CBT_project/run_all_cases.py
2026-04-10 13:25:27 +08:00

146 lines
No EOL
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import traceback
import argparse
import torch
import SimpleITK as sitk
from datetime import datetime
# Core
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
from core.objective import set_global_context
from core.cylinder import create_coordinate_grid
# Imaging
from imaging.preprocessing import process_single_image, process_dataset
# Visualization
from visualization.res_cyl_to_CT import cyl_and_CT
from visualization.res_cyl_nifti import save_cyl
# Config
from config.device import get_device
from config.constant import *
def build_paths(base_dir: str, patient_id: str, level: str):
case_dir = os.path.join(base_dir, patient_id)
img1 = os.path.join(case_dir, f"{level}_cortical.nii.gz")
img2 = os.path.join(case_dir, f"{level}_binary2.nii.gz")
img3 = os.path.join(case_dir, f"{level}_roi2.nii.gz")
return img1, img2, img3
if __name__ == "__main__":
# === 1. 設定命令列引數 (GPU 拆分設定) ===
parser = argparse.ArgumentParser(description="Multi-GPU CBT Batch Processing")
parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)')
parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量')
args = parser.parse_args()
# === 2. 綁定 GPU ===
if torch.cuda.is_available():
device = torch.device(f'cuda:{args.gpu}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')
print("⚠️ 找不到 CUDA將使用 CPU")
print(f"\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算")
# === 3. 基本參數 ===
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/"
spacing = [0.5, 0.5, 0.5]
levels = ["L5"]
date = datetime.now().strftime("%Y%m%d")
failed = []
skipped = []
succeeded = []
# 動態獲取所有病人資料夾
all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])
# === 4. 餘數分工:只挑選屬於這張 GPU 的病人 ===
my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu]
print(f"📂 總資料庫有 {len(all_patients)} 個病人。")
print(f"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...")
for patient_id in my_patients:
for level in levels:
label_str = f"{patient_id}_{level}"
IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)
missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]
if missing:
print(f"[SKIP] ⏭️ {label_str} (缺檔)")
skipped.append((label_str, missing))
continue
print(f"\n=========================================")
print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===")
print(f"=========================================")
try:
# --- CBT = True ---
print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...")
binary_image = sitk.ReadImage(IMAGE2)
binary_array = sitk.GetArrayFromImage(binary_image)
image_shape = binary_array.shape
grid = create_coordinate_grid(image_shape, device)
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(
label_str,
image1_path=IMAGE1,
image2_path=IMAGE2,
image3_path=IMAGE3,
folder=date,
swarm_size=70,
max_iter=100,
spacing=spacing,
CBT=True,
device=device,
optimize_size=True,
grid=grid # 記得把 device 傳進去
)
cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)
cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)
# --- CBT = False ---
# print(f"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...")
# best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(
# label_str,
# image1_path=IMAGE2,
# image2_path=IMAGE2,
# image3_path=IMAGE3,
# folder=date,
# swarm_size=70,
# max_iter=100,
# spacing=spacing,
# CBT=False,
# device=device,
# optimize_size=True,
# grid=grid # 記得把 device 傳進去
# )
# cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)
# cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)
# succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))
# print(f"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s")
except Exception as e:
print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}")
traceback.print_exc()
failed.append((label_str, str(e)))
# --- GPU 專屬的 Summary ---
print("\n" + "="*40)
print(f"🏆 ==== GPU {args.gpu} 處理總結 ====")
print("="*40)
print(f"✅ Succeeded (成功): {len(succeeded)} 節段")
print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段")
print(f"❌ Failed (執行錯誤): {len(failed)} 節段")
if failed:
print(f"\n⚠️ GPU {args.gpu} 失敗清單:")
for fail_label, err_msg in failed:
print(f" - {fail_label}: {err_msg}")