CBT_project/run_de.py
2026-04-10 13:25:27 +08:00

109 lines
No EOL
3.8 KiB
Python

import os
import traceback
import argparse
import torch
import SimpleITK as sitk
from datetime import datetime
# Core
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
from core.objective import set_global_context
from core.cylinder import create_coordinate_grid
# Imaging
from imaging.preprocessing import process_single_image, process_dataset
# Visualization
from visualization.res_cyl_to_CT import cyl_and_CT
from visualization.res_cyl_nifti import save_cyl
# Config
from config.device import get_device
from config.constant import *
if __name__ == "__main__":
# === 1. 設定命令列引數 (GPU 拆分設定) ===
parser = argparse.ArgumentParser(description="DE CBT Processing")
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()
if torch.cuda.is_available():
device = torch.device(f'cuda:{args.gpu}')
torch.cuda.set_device(device)
else:
device = torch.device('cpu')
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}"
patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221]
CBT = True
optimize_size = True
spacing = [0.5, 0.5, 0.5]
levels = ["L5"]
date = datetime.now().strftime("%Y%m%d")
failed = []
skipped = []
succeeded = []
for p_id in patients:
folder_num = str(p_id)
current_dir = base_dir.format(folder_num)
IMAGE1 = os.path.join(current_dir, "L5_cortical.nii.gz")
IMAGE2 = os.path.join(current_dir, "L5_binary2.nii.gz")
IMAGE3 = os.path.join(current_dir, "L5_roi2.nii.gz")
missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]
label_str = levels[0]
if missing:
print(f"[SKIP] ⏭️ {label_str} (缺檔)")
skipped.append((label_str, missing))
continue
print(f"\n=========================================")
print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===")
print(f"=========================================")
try:
# --- CBT = True ---
print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...")
binary_image = sitk.ReadImage(IMAGE2)
binary_array = sitk.GetArrayFromImage(binary_image)
image_shape = binary_array.shape
grid = create_coordinate_grid(image_shape, device)
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_nm_torch(
'NM',
image1_path=IMAGE1,
image2_path=IMAGE2,
image3_path=IMAGE3,
folder=date,
swarm_size=70,
max_iter=100,
spacing=spacing,
CBT=True,
device=device,
optimize_size=True,
grid=grid # 記得把 device 傳進去
)
cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)
cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)
except Exception as e:
print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}")
traceback.print_exc()
failed.append((label_str, str(e)))
# --- GPU 專屬的 Summary ---
print("\n" + "="*40)
print(f"🏆 ==== GPU {args.gpu} 處理總結 ====")
print("="*40)
print(f"✅ Succeeded (成功): {len(succeeded)} 節段")
print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段")
print(f"❌ Failed (執行錯誤): {len(failed)} 節段")
if failed:
print(f"\n⚠️ GPU {args.gpu} 失敗清單:")
for fail_label, err_msg in failed:
print(f" - {fail_label}: {err_msg}")