CBT_project/experiment.ipynb
2026-04-10 13:25:27 +08:00

561 lines
37 KiB
Text
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "52324a94",
"metadata": {},
"outputs": [],
"source": [
"# Core\n",
"from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch\n",
"from core.objective import set_global_context\n",
"from core.cylinder import create_coordinate_grid\n",
"\n",
"# Imaging\n",
"from imaging.preprocessing import process_single_image, process_dataset\n",
"\n",
"# Visualization\n",
"from visualization.res_cyl_to_CT import cyl_and_CT\n",
"from visualization.res_cyl_nifti import save_cyl\n",
"\n",
"# Config\n",
"from config.device import get_device\n",
"from config.constant import *\n",
"\n",
"import SimpleITK as sitk\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e680dfe",
"metadata": {},
"outputs": [],
"source": [
"# ====== SEGMENTATION: SINGLE DATA ======\n",
"\n",
"image_path = \"/home/cyrou/CBT/CTSpine1K/data/colon/1.3.6.1.4.1.9328.50.4.0783.nii.gz\"\n",
"label_path = \"/home/cyrou/CBT/CTSpine1K/label/colon/1.3.6.1.4.1.9328.50.4.0783_seg.nii.gz\"\n",
"output_dir = \"/home/cyrou/CBT/Dataset\" # set by user\n",
"\n",
"process_single_image(image_path, label_path, output_dir_base=output_dir)\n",
"# If you need to assign special labels, please use labels_to_process = [23, 24] .etc "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1961cf4c",
"metadata": {},
"outputs": [],
"source": [
"# ====== SEGMENTATION: DATASET ======\n",
"\n",
"image_dir = \"/home/cyrou/CBT/CTSpine1K/data/colon/\"\n",
"label_dir = \"/home/cyrou/CBT/CTSpine1K/label/colon/\"\n",
"output_dir = \"/home/cyrou/CBT/Dataset\" # set by user\n",
"\n",
"process_dataset(image_dir, label_dir, output_dir, labels_to_process=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2bed470b",
"metadata": {},
"outputs": [],
"source": [
"# ====== CASE 設定 ======\n",
"cortical_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_cortical.nii.gz\"\n",
"binary_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_binary.nii.gz\"\n",
"roi_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_roi2.nii.gz\"\n",
"\n",
"# ====== PSO ======\n",
"swarm_size = 70\n",
"max_iter = 100\n",
"\n",
"# ====== DEVICE ======\n",
"device = get_device(3) \n",
"\n",
"# ====== OTHER ======\n",
"spacing = [0.5, 0.5, 0.5]\n",
"CBT = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc17e584",
"metadata": {},
"outputs": [],
"source": [
"import SimpleITK as sitk\n",
"import torch\n",
"\n",
"cortical_image = sitk.ReadImage(cortical_path)\n",
"binary_image = sitk.ReadImage(binary_path)\n",
"roi_image = sitk.ReadImage(roi_path)\n",
"\n",
"cortical_array = sitk.GetArrayFromImage(cortical_image)\n",
"binary_array = sitk.GetArrayFromImage(binary_image)\n",
"roi_array = sitk.GetArrayFromImage(roi_image) \n",
"image_shape = binary_array.shape\n",
"\n",
"cortical_tensor = torch.tensor(cortical_array, device=device)\n",
"binary_tensor = torch.tensor(binary_array, device=device)\n",
"\n",
"grid = create_coordinate_grid(image_shape, device)\n",
"\n",
"set_global_context(\n",
" cortical=cortical_tensor,\n",
" spine=binary_tensor,\n",
" shape=image_shape,\n",
" spacing_=spacing,\n",
" device_=device,\n",
" grid_=grid,\n",
" use_tip_penalty=False \n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2c3a3e47",
"metadata": {},
"outputs": [],
"source": [
"best_l, loss_l, best_r, loss_r, total_time = run_de_torch(\n",
" label_str=\"L5\",\n",
" image1_path=cortical_path,\n",
" image2_path=binary_path,\n",
" image3_path=roi_path,\n",
" folder=\"Output\",\n",
" swarm_size=swarm_size,\n",
" max_iter=max_iter,\n",
" spacing=spacing,\n",
" CBT=CBT,\n",
" device=device,\n",
" optimize_size=True,\n",
" grid=grid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "94e1ee75",
"metadata": {},
"outputs": [],
"source": [
"best_l, loss_l, best_r, loss_r, total_time = run_nm_torch(\n",
" label_str=\"L5\",\n",
" image1_path=cortical_path,\n",
" image2_path=binary_path,\n",
" image3_path=roi_path,\n",
" folder=\"Output\",\n",
" swarm_size=swarm_size,\n",
" max_iter=max_iter,\n",
" spacing=spacing,\n",
" CBT=CBT,\n",
" device=device,\n",
" optimize_size=True,\n",
" grid=grid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be72d637",
"metadata": {},
"outputs": [],
"source": [
"best_l, loss_l, best_r, loss_r, total_time = run_pso_torch(\n",
" label_str=\"L5\",\n",
" image1_path=cortical_path,\n",
" image2_path=binary_path,\n",
" image3_path=roi_path,\n",
" folder=\"Output\",\n",
" swarm_size=swarm_size,\n",
" max_iter=max_iter,\n",
" spacing=spacing,\n",
" CBT=CBT,\n",
" device=device,\n",
" optimize_size=True,\n",
" grid=grid)\n",
"\n",
"cylinder_L, cylinder_R = save_cyl(best_l, best_r, spacing, roi_path, output_base='Output', CBT=True)\n",
"\n",
"cyl_and_CT(diameter_l=best_l[5],\n",
" diameter_r=best_r[5],\n",
" length_l=best_l[6],\n",
" length_r=best_r[6],\n",
" roi_image_path=roi_path,\n",
" cylinder_1=cylinder_L,\n",
" cylinder_2=cylinder_R,\n",
" base_folder=\"Output\",\n",
" CBT=CBT)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9723e65b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using GPU 1: NVIDIA A40\n",
"📂 總共偵測到 784 個病人資料夾,準備開始處理...\n",
"\n",
"=========================================\n",
"🚀 === Running 1.3.6.1.4.1.9328.50.4.0001_L1 ===\n",
"=========================================\n",
"👉 執行 CBT 軌跡最佳化...\n",
"Superior 終板傾斜角度: 6.64°\n",
"斜率: 0.1164, R²: 0.7599\n",
"=== 最佳化模式:最佳化位置、角度、直徑和長度 ===\n",
"\n",
"=== 左側 ===\n",
"Stopping search: maximum iterations reached --> 100\n",
"[LEFT] overlap: 15.3%\n",
"[LEFT] Position: [ 27.47724944 40.25868267 37.64458752 100.44943459 58.1295552 ]\n",
"[LEFT] Diameter: 5.0 mm (raw: 4.98)\n",
"[LEFT] Length: 50 mm (raw: 49.97)\n",
"Stopping search: maximum iterations reached --> 100\n",
"[LEFT][retry 1] ⚠️ not ok | loss improved=-15892000.0000, overlap=8.4%\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 62\u001b[0m\n\u001b[1;32m 59\u001b[0m image_shape \u001b[38;5;241m=\u001b[39m binary_array\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 61\u001b[0m grid \u001b[38;5;241m=\u001b[39m create_coordinate_grid(image_shape, device)\n\u001b[0;32m---> 62\u001b[0m best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time \u001b[38;5;241m=\u001b[39m \u001b[43mrun_pso_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabel_str\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage1_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage2_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 66\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage3_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE3\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 67\u001b[0m \u001b[43m \u001b[49m\u001b[43mfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 68\u001b[0m \u001b[43m \u001b[49m\u001b[43mswarm_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 69\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_iter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 70\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 71\u001b[0m \u001b[43m \u001b[49m\u001b[43mCBT\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 73\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimize_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrid\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 77\u001b[0m cylinder_L, cylinder_R \u001b[38;5;241m=\u001b[39m save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOutput\u001b[39m\u001b[38;5;124m'\u001b[39m, CBT\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 78\u001b[0m cyl_and_CT(best_pos_l[\u001b[38;5;241m5\u001b[39m], best_pos_r[\u001b[38;5;241m5\u001b[39m], best_pos_l[\u001b[38;5;241m6\u001b[39m], best_pos_r[\u001b[38;5;241m6\u001b[39m], IMAGE3, cylinder_L, cylinder_R, base_folder\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOutput\u001b[39m\u001b[38;5;124m'\u001b[39m, CBT\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/optimizer.py:194\u001b[0m, in \u001b[0;36mrun_pso_torch\u001b[0;34m(label_str, image1_path, image2_path, image3_path, folder, swarm_size, max_iter, spacing, CBT, device, optimize_size, grid)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;66;03m# 左側 retryloss 要 <=0 且 overlap >= 0.5 才算過關\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m (best_loss_l \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m best_overlap_l \u001b[38;5;241m<\u001b[39m OVERLAP_THRESH) \u001b[38;5;129;01mand\u001b[39;00m retries \u001b[38;5;241m<\u001b[39m max_retries:\n\u001b[0;32m--> 194\u001b[0m position_l, loss_l \u001b[38;5;241m=\u001b[39m \u001b[43mpso\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobjective_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlb_l\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mub_l\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mswarmsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mswarm_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmaxiter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_iter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 195\u001b[0m overlap_l, diameter_l, length_l \u001b[38;5;241m=\u001b[39m eval_overlap_from_position(\n\u001b[1;32m 196\u001b[0m position_l, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mL\u001b[39m\u001b[38;5;124m\"\u001b[39m, optimize_size, spine_tensor, image_shape, spacing\n\u001b[1;32m 197\u001b[0m )\n\u001b[1;32m 199\u001b[0m \u001b[38;5;66;03m# 只要找到更好的 loss或你想用 loss+overlap 綜合排序也行)就更新 best\u001b[39;00m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;66;03m# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案\u001b[39;00m\n",
"File \u001b[0;32m~/.conda/envs/envpy/lib/python3.13/site-packages/pyswarm/pso.py:145\u001b[0m, in \u001b[0;36mpso\u001b[0;34m(func, lb, ub, ieqcons, f_ieqcons, args, kwargs, swarmsize, omega, phip, phig, maxiter, minstep, minfunc, debug)\u001b[0m\n\u001b[1;32m 143\u001b[0m x[i, mark1] \u001b[38;5;241m=\u001b[39m lb[mark1]\n\u001b[1;32m 144\u001b[0m x[i, mark2] \u001b[38;5;241m=\u001b[39m ub[mark2]\n\u001b[0;32m--> 145\u001b[0m fx \u001b[38;5;241m=\u001b[39m \u001b[43mobj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;66;03m# Compare particle's best position (if constraints are satisfied)\u001b[39;00m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fx\u001b[38;5;241m<\u001b[39mfp[i] \u001b[38;5;129;01mand\u001b[39;00m is_feasible(x[i, :]):\n",
"File \u001b[0;32m~/.conda/envs/envpy/lib/python3.13/site-packages/pyswarm/pso.py:74\u001b[0m, in \u001b[0;36mpso.<locals>.<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 71\u001b[0m vlow \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mvhigh\n\u001b[1;32m 73\u001b[0m \u001b[38;5;66;03m# Check for constraint function(s) #########################################\u001b[39;00m\n\u001b[0;32m---> 74\u001b[0m obj \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m f_ieqcons \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(ieqcons):\n",
"File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/objective.py:126\u001b[0m, in \u001b[0;36mobjective_function\u001b[0;34m(params)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;66;03m# 將連續值轉換為離散值\u001b[39;00m\n\u001b[1;32m 124\u001b[0m diameter_discrete, length_discrete \u001b[38;5;241m=\u001b[39m snap_to_discrete_values(diameter_raw, length_raw)\n\u001b[0;32m--> 126\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mcylinder_circle_line_intersection_loss_deductions_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 127\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiameter_discrete\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 128\u001b[0m \u001b[43m \u001b[49m\u001b[43mlength_discrete\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 129\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 130\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage2_shape\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[43m \u001b[49m\u001b[43mcortical_tensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 132\u001b[0m \u001b[43m \u001b[49m\u001b[43mspine_tensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 133\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 134\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\n\u001b[1;32m 135\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
"File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/objective.py:54\u001b[0m, in \u001b[0;36mcylinder_circle_line_intersection_loss_deductions_torch\u001b[0;34m(diameter, length, params, image_shape, cortical_tensor, spine_tensor, spacing, device)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;124;03mComputes the loss for a given set of cylinder params in PyTorch, \u001b[39;00m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;124;03mreturning a Python float for PSO consumption.\u001b[39;00m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 52\u001b[0m position_z, position_y, position_x, azimuth, altitude \u001b[38;5;241m=\u001b[39m params\n\u001b[0;32m---> 54\u001b[0m cyl_fwd \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_cylinder_n_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 55\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiameter\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[43m \u001b[49m\u001b[43mlength\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 57\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_z\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_y\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_x\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mazimuth\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43maltitude\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage_shape\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\n\u001b[1;32m 66\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 68\u001b[0m cyl_opp \u001b[38;5;241m=\u001b[39m generate_cylinder_o_torch(\n\u001b[1;32m 69\u001b[0m diameter,\n\u001b[1;32m 70\u001b[0m length,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 79\u001b[0m grid\n\u001b[1;32m 80\u001b[0m )\n\u001b[1;32m 83\u001b[0m \u001b[38;5;66;03m# We call the center_line_intersections in Torch mode\u001b[39;00m\n",
"File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/cylinder.py:83\u001b[0m, in \u001b[0;36mgenerate_cylinder_n_torch\u001b[0;34m(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing, device, grid)\u001b[0m\n\u001b[1;32m 75\u001b[0m x_rot \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 76\u001b[0m x_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 77\u001b[0m \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;241m-\u001b[39m z_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[1;32m 79\u001b[0m )\n\u001b[1;32m 80\u001b[0m y_rot \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mx_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(azimuth_rad_t) \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t)\n\u001b[1;32m 81\u001b[0m z_rot \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 82\u001b[0m x_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[0;32m---> 83\u001b[0m \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mazimuth_rad_t\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[1;32m 84\u001b[0m \u001b[38;5;241m+\u001b[39m z_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 85\u001b[0m )\n\u001b[1;32m 87\u001b[0m \u001b[38;5;66;03m# Handle spacing\u001b[39;00m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;66;03m# You can expand or generalize for more spacing options\u001b[39;00m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m spacing \u001b[38;5;241m==\u001b[39m [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m]:\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"import os\n",
"import traceback\n",
"from datetime import datetime\n",
"from core.optimizer import run_pso_torch\n",
"from visualization.res_cyl_nifti import save_cyl\n",
"from visualization.res_cyl_to_CT import cyl_and_CT\n",
"\n",
"device = get_device(1)\n",
"\n",
"# 只需要傳入 base_dir, patient_id, level 即可自動組裝路徑\n",
"def build_paths(base_dir: str, patient_id: str, level: str):\n",
" case_dir = os.path.join(base_dir, patient_id)\n",
" img1 = os.path.join(case_dir, f\"{level}_cortical.nii.gz\")\n",
" img2 = os.path.join(case_dir, f\"{level}_binary2.nii.gz\")\n",
" img3 = os.path.join(case_dir, f\"{level}_roi2.nii.gz\")\n",
" return img1, img2, img3\n",
"\n",
"if __name__ == \"__main__\":\n",
" base_dir = \"/home/cyrou/CBT/Seg/Resample/standardized/\" # 你的 case 資料夾根目錄\n",
" spacing = [0.5, 0.5, 0.5]\n",
" levels = [\"L1\", \"L2\", \"L3\", \"L4\", \"L5\"] # 鎖定要跑的脊椎節段\n",
" date = datetime.now().strftime(\"%Y%m%d\")\n",
" \n",
" failed = []\n",
" skipped = []\n",
" succeeded = []\n",
"\n",
" # 1. 動態獲取 base_dir 下的所有項目,並篩選出「資料夾」\n",
" # 使用 sorted 讓它照字母/數字順序跑,強迫症看了比較舒服\n",
" all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])\n",
" \n",
" print(f\"📂 總共偵測到 {len(all_patients)} 個病人資料夾,準備開始處理...\")\n",
"\n",
" # 2. 外層迴圈:遍歷每一個病人資料夾\n",
" for patient_id in all_patients:\n",
" \n",
" # 3. 內層迴圈:遍歷 L1 到 L5\n",
" for level in levels:\n",
" label_str = f\"{patient_id}_{level}\"\n",
" IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)\n",
"\n",
" # 檢查這個 level 的檔案是否齊全\n",
" missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]\n",
" if missing:\n",
" # 缺檔直接跳過,不印出太多落落長的警告,保持終端機乾淨\n",
" print(f\"[SKIP] ⏭️ {label_str} (缺少 {len(missing)} 個檔案)\")\n",
" skipped.append((label_str, missing))\n",
" continue\n",
"\n",
" print(f\"\\n=========================================\")\n",
" print(f\"🚀 === Running {label_str} ===\")\n",
" print(f\"=========================================\")\n",
"\n",
" try:\n",
" # --- 第一階段CBT = True ---\n",
" print(f\"👉 執行 CBT 軌跡最佳化...\")\n",
" binary_image = sitk.ReadImage(IMAGE2)\n",
" binary_array = sitk.GetArrayFromImage(binary_image)\n",
" image_shape = binary_array.shape\n",
"\n",
" grid = create_coordinate_grid(image_shape, device)\n",
" best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(\n",
" label_str,\n",
" image1_path=IMAGE1,\n",
" image2_path=IMAGE2,\n",
" image3_path=IMAGE3,\n",
" folder=date,\n",
" swarm_size=70,\n",
" max_iter=100,\n",
" spacing=spacing,\n",
" CBT=True,\n",
" device=device,\n",
" optimize_size=True,\n",
" grid=grid\n",
" )\n",
" \n",
" cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)\n",
" cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)\n",
" \n",
" # --- 第二階段CBT = False (傳統軌跡 TT) ---\n",
" print(f\"👉 執行傳統軌跡 (TT) 最佳化...\")\n",
" best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(\n",
" label_str,\n",
" image1_path=IMAGE2, # 注意:你原本的扣這裡是 IMAGE2如果有改過記得確認\n",
" image2_path=IMAGE2,\n",
" image3_path=IMAGE3,\n",
" folder=date,\n",
" swarm_size=70,\n",
" max_iter=100,\n",
" spacing=spacing,\n",
" CBT=False,\n",
" device=device,\n",
" optimize_size=True,\n",
" grid=grid\n",
" )\n",
" \n",
" cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)\n",
" cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)\n",
"\n",
" # 記錄成功 (把兩次的耗時加起來)\n",
" succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))\n",
" print(f\"[OK] ✅ {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s\")\n",
"\n",
" except Exception as e:\n",
" print(f\"[FAIL] ❌ {label_str} 發生錯誤: {e}\")\n",
" traceback.print_exc() # 保留剛學到的神技,印出詳細報錯\n",
" failed.append((label_str, str(e)))\n",
"\n",
" # 4. 最終報告\n",
" print(\"\\n\" + \"=\"*40)\n",
" print(\"🏆 ==== 全自動批次處理總結 ====\")\n",
" print(\"=\"*40)\n",
" print(f\"✅ Succeeded (成功): {len(succeeded)} 節段\")\n",
" print(f\"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段\")\n",
" print(f\"❌ Failed (執行錯誤): {len(failed)} 節段\")\n",
" \n",
" if failed:\n",
" print(\"\\n⚠ 失敗清單:\")\n",
" for fail_label, err_msg in failed:\n",
" print(f\" - {fail_label}: {err_msg}\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "bcf55fef",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"usage: ipykernel_launcher.py [-h] --gpu GPU [--total_gpus TOTAL_GPUS]\n",
"ipykernel_launcher.py: error: the following arguments are required: --gpu\n"
]
},
{
"ename": "SystemExit",
"evalue": "2",
"output_type": "error",
"traceback": [
"An exception has occurred, use %tb to see the full traceback.\n",
"\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cyrou/.conda/envs/envpy/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3585: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n",
" warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n"
]
}
],
"source": [
"import os\n",
"import traceback\n",
"import argparse\n",
"import torch\n",
"from datetime import datetime\n",
"\n",
"# 請確保你有 import run_pso_torch, save_cyl, cyl_and_CT 等自訂函數\n",
"\n",
"def build_paths(base_dir: str, patient_id: str, level: str):\n",
" case_dir = os.path.join(base_dir, patient_id)\n",
" img1 = os.path.join(case_dir, f\"{level}_cortical.nii.gz\")\n",
" img2 = os.path.join(case_dir, f\"{level}_binary2.nii.gz\")\n",
" img3 = os.path.join(case_dir, f\"{level}_roi2.nii.gz\")\n",
" return img1, img2, img3\n",
"\n",
"if __name__ == \"__main__\":\n",
" # === 1. 設定命令列引數 (GPU 拆分設定) ===\n",
" parser = argparse.ArgumentParser(description=\"Multi-GPU CBT Batch Processing\")\n",
" parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)')\n",
" parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量')\n",
" args = parser.parse_args()\n",
"\n",
" # === 2. 綁定 GPU ===\n",
" if torch.cuda.is_available():\n",
" device = torch.device(f'cuda:{args.gpu}')\n",
" torch.cuda.set_device(device)\n",
" else:\n",
" device = torch.device('cpu')\n",
" print(\"⚠️ 找不到 CUDA將使用 CPU\")\n",
"\n",
" print(f\"\\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算\")\n",
"\n",
" # === 3. 基本參數 ===\n",
" base_dir = \"Seg/Resample/standardized\" \n",
" spacing = [0.5, 0.5, 0.5]\n",
" levels = [\"L1\", \"L2\", \"L3\", \"L4\", \"L5\"] \n",
" date = datetime.now().strftime(\"%Y%m%d\")\n",
" \n",
" failed = []\n",
" skipped = []\n",
" succeeded = []\n",
"\n",
" # 動態獲取所有病人資料夾\n",
" all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])\n",
" \n",
" # === 4. 餘數分工:只挑選屬於這張 GPU 的病人 ===\n",
" my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu]\n",
" \n",
" print(f\"📂 總資料庫有 {len(all_patients)} 個病人。\")\n",
" print(f\"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...\")\n",
"\n",
" for patient_id in my_patients:\n",
" for level in levels:\n",
" label_str = f\"{patient_id}_{level}\"\n",
" IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)\n",
"\n",
" missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]\n",
" if missing:\n",
" print(f\"[SKIP] ⏭️ {label_str} (缺檔)\")\n",
" skipped.append((label_str, missing))\n",
" continue\n",
"\n",
" print(f\"\\n=========================================\")\n",
" print(f\"🔥 [GPU {args.gpu}] === Running {label_str} ===\")\n",
" print(f\"=========================================\")\n",
"\n",
" try:\n",
" # --- CBT = True ---\n",
" print(f\"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...\")\n",
" binary_image = sitk.ReadImage(IMAGE2)\n",
" binary_array = sitk.GetArrayFromImage(binary_image)\n",
" image_shape = binary_array.shape\n",
"\n",
" grid = create_coordinate_grid(image_shape, device)\n",
" best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(\n",
" label_str,\n",
" image1_path=IMAGE1,\n",
" image2_path=IMAGE2,\n",
" image3_path=IMAGE3,\n",
" folder=date, \n",
" swarm_size=70,\n",
" max_iter=100,\n",
" spacing=spacing,\n",
" CBT=True,\n",
" device=device,\n",
" optimize_size=True,\n",
" grid=grid # 記得把 device 傳進去\n",
" )\n",
" \n",
" cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)\n",
" cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)\n",
" \n",
" # --- CBT = False ---\n",
" print(f\"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...\")\n",
" best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(\n",
" label_str,\n",
" image1_path=IMAGE2, \n",
" image2_path=IMAGE2,\n",
" image3_path=IMAGE3,\n",
" folder=date, \n",
" swarm_size=70,\n",
" max_iter=100,\n",
" spacing=spacing,\n",
" CBT=False,\n",
" device=device,\n",
" optimize_size=True,\n",
" grid=grid # 記得把 device 傳進去\n",
" )\n",
" \n",
" cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)\n",
" cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)\n",
"\n",
" succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))\n",
" print(f\"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s\")\n",
"\n",
" except Exception as e:\n",
" print(f\"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}\")\n",
" traceback.print_exc()\n",
" failed.append((label_str, str(e)))\n",
"\n",
" # --- GPU 專屬的 Summary ---\n",
" print(\"\\n\" + \"=\"*40)\n",
" print(f\"🏆 ==== GPU {args.gpu} 處理總結 ====\")\n",
" print(\"=\"*40)\n",
" print(f\"✅ Succeeded (成功): {len(succeeded)} 節段\")\n",
" print(f\"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段\")\n",
" print(f\"❌ Failed (執行錯誤): {len(failed)} 節段\")\n",
" \n",
" if failed:\n",
" print(f\"\\n⚠ GPU {args.gpu} 失敗清單:\")\n",
" for fail_label, err_msg in failed:\n",
" print(f\" - {fail_label}: {err_msg}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "envpy",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}