CBT_project/run_all_cases.py

146 lines
6.2 KiB
Python
Raw Normal View History

2026-04-10 05:25:27 +00:00
import os
import traceback
import argparse
import torch
import SimpleITK as sitk
from datetime import datetime
# Core
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
from core.objective import set_global_context
from core.cylinder import create_coordinate_grid
# Imaging
from imaging.preprocessing import process_single_image, process_dataset
# Visualization
from visualization.res_cyl_to_CT import cyl_and_CT
from visualization.res_cyl_nifti import save_cyl
# Config
from config.device import get_device
from config.constant import *
def build_paths(base_dir: str, patient_id: str, level: str):
case_dir = os.path.join(base_dir, patient_id)
img1 = os.path.join(case_dir, f"{level}_cortical.nii.gz")
img2 = os.path.join(case_dir, f"{level}_binary2.nii.gz")
img3 = os.path.join(case_dir, f"{level}_roi2.nii.gz")
return img1, img2, img3
if __name__ == "__main__":
# === 1. 設定命令列引數 (GPU 拆分設定) ===
parser = argparse.ArgumentParser(description="Multi-GPU CBT Batch Processing")
parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)')
parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量')
args = parser.parse_args()
# === 2. 綁定 GPU ===
if torch.cuda.is_available():
device = torch.device(f'cuda:{args.gpu}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')
print("⚠️ 找不到 CUDA將使用 CPU")
print(f"\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算")
# === 3. 基本參數 ===
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/"
spacing = [0.5, 0.5, 0.5]
levels = ["L5"]
date = datetime.now().strftime("%Y%m%d")
failed = []
skipped = []
succeeded = []
# 動態獲取所有病人資料夾
all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])
# === 4. 餘數分工:只挑選屬於這張 GPU 的病人 ===
my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu]
print(f"📂 總資料庫有 {len(all_patients)} 個病人。")
print(f"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...")
for patient_id in my_patients:
for level in levels:
label_str = f"{patient_id}_{level}"
IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)
missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]
if missing:
print(f"[SKIP] ⏭️ {label_str} (缺檔)")
skipped.append((label_str, missing))
continue
print(f"\n=========================================")
print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===")
print(f"=========================================")
try:
# --- CBT = True ---
print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...")
binary_image = sitk.ReadImage(IMAGE2)
binary_array = sitk.GetArrayFromImage(binary_image)
image_shape = binary_array.shape
grid = create_coordinate_grid(image_shape, device)
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(
label_str,
image1_path=IMAGE1,
image2_path=IMAGE2,
image3_path=IMAGE3,
folder=date,
swarm_size=70,
max_iter=100,
spacing=spacing,
CBT=True,
device=device,
optimize_size=True,
grid=grid # 記得把 device 傳進去
)
cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)
cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)
# --- CBT = False ---
# print(f"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...")
# best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(
# label_str,
# image1_path=IMAGE2,
# image2_path=IMAGE2,
# image3_path=IMAGE3,
# folder=date,
# swarm_size=70,
# max_iter=100,
# spacing=spacing,
# CBT=False,
# device=device,
# optimize_size=True,
# grid=grid # 記得把 device 傳進去
# )
# cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)
# cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)
# succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))
# print(f"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s")
except Exception as e:
print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}")
traceback.print_exc()
failed.append((label_str, str(e)))
# --- GPU 專屬的 Summary ---
print("\n" + "="*40)
print(f"🏆 ==== GPU {args.gpu} 處理總結 ====")
print("="*40)
print(f"✅ Succeeded (成功): {len(succeeded)} 節段")
print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段")
print(f"❌ Failed (執行錯誤): {len(failed)} 節段")
if failed:
print(f"\n⚠️ GPU {args.gpu} 失敗清單:")
for fail_label, err_msg in failed:
print(f" - {fail_label}: {err_msg}")