import os import traceback import argparse import torch import SimpleITK as sitk from datetime import datetime # Core from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch from core.objective import set_global_context from core.cylinder import create_coordinate_grid # Imaging from imaging.preprocessing import process_single_image, process_dataset # Visualization from visualization.res_cyl_to_CT import cyl_and_CT from visualization.res_cyl_nifti import save_cyl # Config from config.device import get_device from config.constant import * def build_paths(base_dir: str, patient_id: str, level: str): case_dir = os.path.join(base_dir, patient_id) img1 = os.path.join(case_dir, f"{level}_cortical.nii.gz") img2 = os.path.join(case_dir, f"{level}_binary2.nii.gz") img3 = os.path.join(case_dir, f"{level}_roi2.nii.gz") return img1, img2, img3 if __name__ == "__main__": # === 1. 設定命令列引數 (GPU 拆分設定) === parser = argparse.ArgumentParser(description="Multi-GPU CBT Batch Processing") parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)') parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量') args = parser.parse_args() # === 2. 綁定 GPU === if torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') torch.cuda.set_device(device) else: device = torch.device('cpu') print("⚠️ 找不到 CUDA,將使用 CPU") print(f"\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算") # === 3. 基本參數 === base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/" spacing = [0.5, 0.5, 0.5] levels = ["L5"] date = datetime.now().strftime("%Y%m%d") failed = [] skipped = [] succeeded = [] # 動態獲取所有病人資料夾 all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]) # === 4. 餘數分工:只挑選屬於這張 GPU 的病人 === my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu] print(f"📂 總資料庫有 {len(all_patients)} 個病人。") print(f"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...") for patient_id in my_patients: for level in levels: label_str = f"{patient_id}_{level}" IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level) missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)] if missing: print(f"[SKIP] ⏭️ {label_str} (缺檔)") skipped.append((label_str, missing)) continue print(f"\n=========================================") print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===") print(f"=========================================") try: # --- CBT = True --- print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...") binary_image = sitk.ReadImage(IMAGE2) binary_array = sitk.GetArrayFromImage(binary_image) image_shape = binary_array.shape grid = create_coordinate_grid(image_shape, device) best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch( label_str, image1_path=IMAGE1, image2_path=IMAGE2, image3_path=IMAGE3, folder=date, swarm_size=70, max_iter=100, spacing=spacing, CBT=True, device=device, optimize_size=True, grid=grid # 記得把 device 傳進去 ) cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True) cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True) # --- CBT = False --- # print(f"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...") # best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch( # label_str, # image1_path=IMAGE2, # image2_path=IMAGE2, # image3_path=IMAGE3, # folder=date, # swarm_size=70, # max_iter=100, # spacing=spacing, # CBT=False, # device=device, # optimize_size=True, # grid=grid # 記得把 device 傳進去 # ) # cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False) # cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False) # succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt)) # print(f"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s") except Exception as e: print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}") traceback.print_exc() failed.append((label_str, str(e))) # --- GPU 專屬的 Summary --- print("\n" + "="*40) print(f"🏆 ==== GPU {args.gpu} 處理總結 ====") print("="*40) print(f"✅ Succeeded (成功): {len(succeeded)} 節段") print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段") print(f"❌ Failed (執行錯誤): {len(failed)} 節段") if failed: print(f"\n⚠️ GPU {args.gpu} 失敗清單:") for fail_label, err_msg in failed: print(f" - {fail_label}: {err_msg}")