import os import traceback import argparse import torch import SimpleITK as sitk from datetime import datetime # Core from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch from core.objective import set_global_context from core.cylinder import create_coordinate_grid # Imaging from imaging.preprocessing import process_single_image, process_dataset # Visualization from visualization.res_cyl_to_CT import cyl_and_CT from visualization.res_cyl_nifti import save_cyl # Config from config.device import get_device from config.constant import * if __name__ == "__main__": # === 1. 設定命令列引數 (GPU 拆分設定) === parser = argparse.ArgumentParser(description="DE CBT Processing") parser.add_argument('--gpu', type=int, default=0) args = parser.parse_args() if torch.cuda.is_available(): device = torch.device(f'cuda:{args.gpu}') torch.cuda.set_device(device) else: device = torch.device('cpu') base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}" patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221] CBT = True optimize_size = True spacing = [0.5, 0.5, 0.5] levels = ["L5"] date = datetime.now().strftime("%Y%m%d") failed = [] skipped = [] succeeded = [] for p_id in patients: folder_num = str(p_id) current_dir = base_dir.format(folder_num) IMAGE1 = os.path.join(current_dir, "L5_cortical.nii.gz") IMAGE2 = os.path.join(current_dir, "L5_binary2.nii.gz") IMAGE3 = os.path.join(current_dir, "L5_roi2.nii.gz") missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)] label_str = levels[0] if missing: print(f"[SKIP] ⏭️ {label_str} (缺檔)") skipped.append((label_str, missing)) continue print(f"\n=========================================") print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===") print(f"=========================================") try: # --- CBT = True --- print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...") binary_image = sitk.ReadImage(IMAGE2) binary_array = sitk.GetArrayFromImage(binary_image) image_shape = binary_array.shape grid = create_coordinate_grid(image_shape, device) best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_nm_torch( 'NM', image1_path=IMAGE1, image2_path=IMAGE2, image3_path=IMAGE3, folder=date, swarm_size=70, max_iter=100, spacing=spacing, CBT=True, device=device, optimize_size=True, grid=grid # 記得把 device 傳進去 ) cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True) cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True) except Exception as e: print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}") traceback.print_exc() failed.append((label_str, str(e))) # --- GPU 專屬的 Summary --- print("\n" + "="*40) print(f"🏆 ==== GPU {args.gpu} 處理總結 ====") print("="*40) print(f"✅ Succeeded (成功): {len(succeeded)} 節段") print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段") print(f"❌ Failed (執行錯誤): {len(failed)} 節段") if failed: print(f"\n⚠️ GPU {args.gpu} 失敗清單:") for fail_label, err_msg in failed: print(f" - {fail_label}: {err_msg}")