109 lines
No EOL
3.8 KiB
Python
109 lines
No EOL
3.8 KiB
Python
import os
|
|
import traceback
|
|
import argparse
|
|
import torch
|
|
import SimpleITK as sitk
|
|
from datetime import datetime
|
|
# Core
|
|
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
|
|
from core.objective import set_global_context
|
|
from core.cylinder import create_coordinate_grid
|
|
|
|
# Imaging
|
|
from imaging.preprocessing import process_single_image, process_dataset
|
|
|
|
# Visualization
|
|
from visualization.res_cyl_to_CT import cyl_and_CT
|
|
from visualization.res_cyl_nifti import save_cyl
|
|
|
|
# Config
|
|
from config.device import get_device
|
|
from config.constant import *
|
|
|
|
if __name__ == "__main__":
|
|
# === 1. 設定命令列引數 (GPU 拆分設定) ===
|
|
parser = argparse.ArgumentParser(description="DE CBT Processing")
|
|
parser.add_argument('--gpu', type=int, default=0)
|
|
args = parser.parse_args()
|
|
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f'cuda:{args.gpu}')
|
|
torch.cuda.set_device(device)
|
|
else:
|
|
device = torch.device('cpu')
|
|
|
|
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}"
|
|
patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221]
|
|
|
|
CBT = True
|
|
optimize_size = True
|
|
spacing = [0.5, 0.5, 0.5]
|
|
levels = ["L5"]
|
|
date = datetime.now().strftime("%Y%m%d")
|
|
|
|
failed = []
|
|
skipped = []
|
|
succeeded = []
|
|
|
|
for p_id in patients:
|
|
folder_num = str(p_id)
|
|
current_dir = base_dir.format(folder_num)
|
|
|
|
IMAGE1 = os.path.join(current_dir, "L5_cortical.nii.gz")
|
|
IMAGE2 = os.path.join(current_dir, "L5_binary2.nii.gz")
|
|
IMAGE3 = os.path.join(current_dir, "L5_roi2.nii.gz")
|
|
|
|
missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]
|
|
label_str = levels[0]
|
|
if missing:
|
|
print(f"[SKIP] ⏭️ {label_str} (缺檔)")
|
|
skipped.append((label_str, missing))
|
|
continue
|
|
|
|
print(f"\n=========================================")
|
|
print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===")
|
|
print(f"=========================================")
|
|
|
|
try:
|
|
# --- CBT = True ---
|
|
print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...")
|
|
binary_image = sitk.ReadImage(IMAGE2)
|
|
binary_array = sitk.GetArrayFromImage(binary_image)
|
|
image_shape = binary_array.shape
|
|
|
|
grid = create_coordinate_grid(image_shape, device)
|
|
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_nm_torch(
|
|
'NM',
|
|
image1_path=IMAGE1,
|
|
image2_path=IMAGE2,
|
|
image3_path=IMAGE3,
|
|
folder=date,
|
|
swarm_size=70,
|
|
max_iter=100,
|
|
spacing=spacing,
|
|
CBT=True,
|
|
device=device,
|
|
optimize_size=True,
|
|
grid=grid # 記得把 device 傳進去
|
|
)
|
|
|
|
cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)
|
|
cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)
|
|
|
|
except Exception as e:
|
|
print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}")
|
|
traceback.print_exc()
|
|
failed.append((label_str, str(e)))
|
|
|
|
# --- GPU 專屬的 Summary ---
|
|
print("\n" + "="*40)
|
|
print(f"🏆 ==== GPU {args.gpu} 處理總結 ====")
|
|
print("="*40)
|
|
print(f"✅ Succeeded (成功): {len(succeeded)} 節段")
|
|
print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段")
|
|
print(f"❌ Failed (執行錯誤): {len(failed)} 節段")
|
|
|
|
if failed:
|
|
print(f"\n⚠️ GPU {args.gpu} 失敗清單:")
|
|
for fail_label, err_msg in failed:
|
|
print(f" - {fail_label}: {err_msg}") |