146 lines
6.2 KiB
Python
146 lines
6.2 KiB
Python
|
|
import os
|
|||
|
|
import traceback
|
|||
|
|
import argparse
|
|||
|
|
import torch
|
|||
|
|
import SimpleITK as sitk
|
|||
|
|
from datetime import datetime
|
|||
|
|
# Core
|
|||
|
|
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
|
|||
|
|
from core.objective import set_global_context
|
|||
|
|
from core.cylinder import create_coordinate_grid
|
|||
|
|
|
|||
|
|
# Imaging
|
|||
|
|
from imaging.preprocessing import process_single_image, process_dataset
|
|||
|
|
|
|||
|
|
# Visualization
|
|||
|
|
from visualization.res_cyl_to_CT import cyl_and_CT
|
|||
|
|
from visualization.res_cyl_nifti import save_cyl
|
|||
|
|
|
|||
|
|
# Config
|
|||
|
|
from config.device import get_device
|
|||
|
|
from config.constant import *
|
|||
|
|
|
|||
|
|
def build_paths(base_dir: str, patient_id: str, level: str):
|
|||
|
|
case_dir = os.path.join(base_dir, patient_id)
|
|||
|
|
img1 = os.path.join(case_dir, f"{level}_cortical.nii.gz")
|
|||
|
|
img2 = os.path.join(case_dir, f"{level}_binary2.nii.gz")
|
|||
|
|
img3 = os.path.join(case_dir, f"{level}_roi2.nii.gz")
|
|||
|
|
return img1, img2, img3
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# === 1. 設定命令列引數 (GPU 拆分設定) ===
|
|||
|
|
parser = argparse.ArgumentParser(description="Multi-GPU CBT Batch Processing")
|
|||
|
|
parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)')
|
|||
|
|
parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量')
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# === 2. 綁定 GPU ===
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
device = torch.device(f'cuda:{args.gpu}')
|
|||
|
|
torch.cuda.set_device(device)
|
|||
|
|
else:
|
|||
|
|
device = torch.device('cpu')
|
|||
|
|
print("⚠️ 找不到 CUDA,將使用 CPU")
|
|||
|
|
|
|||
|
|
print(f"\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算")
|
|||
|
|
|
|||
|
|
# === 3. 基本參數 ===
|
|||
|
|
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/"
|
|||
|
|
spacing = [0.5, 0.5, 0.5]
|
|||
|
|
levels = ["L5"]
|
|||
|
|
date = datetime.now().strftime("%Y%m%d")
|
|||
|
|
|
|||
|
|
failed = []
|
|||
|
|
skipped = []
|
|||
|
|
succeeded = []
|
|||
|
|
|
|||
|
|
# 動態獲取所有病人資料夾
|
|||
|
|
all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])
|
|||
|
|
|
|||
|
|
# === 4. 餘數分工:只挑選屬於這張 GPU 的病人 ===
|
|||
|
|
my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu]
|
|||
|
|
|
|||
|
|
print(f"📂 總資料庫有 {len(all_patients)} 個病人。")
|
|||
|
|
print(f"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...")
|
|||
|
|
|
|||
|
|
for patient_id in my_patients:
|
|||
|
|
for level in levels:
|
|||
|
|
label_str = f"{patient_id}_{level}"
|
|||
|
|
IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)
|
|||
|
|
|
|||
|
|
missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]
|
|||
|
|
if missing:
|
|||
|
|
print(f"[SKIP] ⏭️ {label_str} (缺檔)")
|
|||
|
|
skipped.append((label_str, missing))
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
print(f"\n=========================================")
|
|||
|
|
print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===")
|
|||
|
|
print(f"=========================================")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# --- CBT = True ---
|
|||
|
|
print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...")
|
|||
|
|
binary_image = sitk.ReadImage(IMAGE2)
|
|||
|
|
binary_array = sitk.GetArrayFromImage(binary_image)
|
|||
|
|
image_shape = binary_array.shape
|
|||
|
|
|
|||
|
|
grid = create_coordinate_grid(image_shape, device)
|
|||
|
|
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(
|
|||
|
|
label_str,
|
|||
|
|
image1_path=IMAGE1,
|
|||
|
|
image2_path=IMAGE2,
|
|||
|
|
image3_path=IMAGE3,
|
|||
|
|
folder=date,
|
|||
|
|
swarm_size=70,
|
|||
|
|
max_iter=100,
|
|||
|
|
spacing=spacing,
|
|||
|
|
CBT=True,
|
|||
|
|
device=device,
|
|||
|
|
optimize_size=True,
|
|||
|
|
grid=grid # 記得把 device 傳進去
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)
|
|||
|
|
cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)
|
|||
|
|
|
|||
|
|
# --- CBT = False ---
|
|||
|
|
# print(f"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...")
|
|||
|
|
# best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(
|
|||
|
|
# label_str,
|
|||
|
|
# image1_path=IMAGE2,
|
|||
|
|
# image2_path=IMAGE2,
|
|||
|
|
# image3_path=IMAGE3,
|
|||
|
|
# folder=date,
|
|||
|
|
# swarm_size=70,
|
|||
|
|
# max_iter=100,
|
|||
|
|
# spacing=spacing,
|
|||
|
|
# CBT=False,
|
|||
|
|
# device=device,
|
|||
|
|
# optimize_size=True,
|
|||
|
|
# grid=grid # 記得把 device 傳進去
|
|||
|
|
# )
|
|||
|
|
|
|||
|
|
# cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)
|
|||
|
|
# cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)
|
|||
|
|
|
|||
|
|
# succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))
|
|||
|
|
# print(f"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}")
|
|||
|
|
traceback.print_exc()
|
|||
|
|
failed.append((label_str, str(e)))
|
|||
|
|
|
|||
|
|
# --- GPU 專屬的 Summary ---
|
|||
|
|
print("\n" + "="*40)
|
|||
|
|
print(f"🏆 ==== GPU {args.gpu} 處理總結 ====")
|
|||
|
|
print("="*40)
|
|||
|
|
print(f"✅ Succeeded (成功): {len(succeeded)} 節段")
|
|||
|
|
print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段")
|
|||
|
|
print(f"❌ Failed (執行錯誤): {len(failed)} 節段")
|
|||
|
|
|
|||
|
|
if failed:
|
|||
|
|
print(f"\n⚠️ GPU {args.gpu} 失敗清單:")
|
|||
|
|
for fail_label, err_msg in failed:
|
|||
|
|
print(f" - {fail_label}: {err_msg}")
|