1. correct the bounding box and cortical mask

2. make the plot isometric
3. now it should work after tuning the parameters
This commit is contained in:
Xiao Furen 2026-04-17 00:03:10 +08:00
parent f84d8963da
commit b76f0708f3
11 changed files with 2524 additions and 54 deletions

View file

@ -12,7 +12,17 @@ LABEL_MAP = {
15: "T8", 16: "T9", 17: "T10", 18: "T11", 19: "T12",
20: "L1", 21: "L2", 22: "L3", 23: "L4", 24: "L5"
}
ALLOWED_DIAMETERS = [3.5, 4.0, 4.5, 5.0]
ALLOWED_LENGTHS = [35, 40, 45, 50]
ALLOWED_DIAMETERS = [
3.5,
4.0,
4.5,
5.0,
]
ALLOWED_LENGTHS = [
35,
40,
45,
50,
]
OVERLAP_THRESH = 0.50
DEFAULT_SPACING = [0.5, 0.5, 0.5]

View file

@ -38,6 +38,27 @@ def snap_to_discrete_values(diameter_raw, length_raw):
return diameter_discrete, length_discrete
def round_down_to_discrete_values(diameter_raw, length_raw):
"""
將連續值映射到小於等於該值的最接近允許離散值 (向下取整)
Parameters:
diameter_raw: PSO 給的連續直徑值
length_raw: PSO 給的連續長度值
Returns:
diameter_discrete, length_discrete: 離散化後的值
"""
# 找小於等於 raw 的最接近 diameter如果都大於 raw 則取最小值
valid_diameters = [d for d in ALLOWED_DIAMETERS if d <= diameter_raw]
diameter_discrete = max(valid_diameters) if valid_diameters else min(ALLOWED_DIAMETERS)
# 找小於等於 raw 的最接近 length如果都大於 raw 則取最小值
valid_lengths = [l for l in ALLOWED_LENGTHS if l <= length_raw]
length_discrete = max(valid_lengths) if valid_lengths else min(ALLOWED_LENGTHS)
return diameter_discrete, length_discrete
def generate_cylinder_n_torch(
diameter: float,

View file

@ -1,8 +1,13 @@
import torch
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch
import random
from scipy.ndimage import map_coordinates
import numpy as np
import torch
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch
from core.intersection import center_line_intersections_torch
from core.scoring import cl_score_torch
from core.scoring import cl_score_torch, cl_score_torch_xfr
# Global variables (used in objective_function)
image1_array = None # cortical_nii.gz
@ -102,7 +107,8 @@ def cylinder_circle_line_intersection_loss_deductions_torch(
image_shape, spacing, device, grid
)
loss_value = cl_score_torch(
# loss_value = cl_score_torch(
loss_value = cl_score_torch_xfr(
cortical_tensor, spine_tensor,
cyl_fwd, cyl_opp, intersections,
cylinder_tip_torch=cyl_tip
@ -110,6 +116,42 @@ def cylinder_circle_line_intersection_loss_deductions_torch(
return loss_value
def objective_function_xfr(params: list[float], y_indices) -> float:
"""
Wrapper for the PSO objective function, calling our Torch-based loss function.
Now params includes diameter and length at the end.
params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw]
"""
# position_params = params[:5] # [z, y, x, azimuth, altitude]
# diameter_raw = params[5]
# length_raw = params[6]
z, x, azimuth, altitude, diameter_raw, length_raw = params
y = y_indices[round(z), round(x)] #+ random.uniform(-0.5, 0.5)
# coords = np.array([[z], [x]])
# result = map_coordinates(y_indices, coords, order=1)
# y= result[0]
position_params = [z, y, x, azimuth, altitude]
# 將連續值轉換為離散值
# diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
diameter_discrete, length_discrete = diameter_raw, length_raw
loss = cylinder_circle_line_intersection_loss_deductions_torch(
diameter_discrete,
length_discrete,
position_params,
image2_shape,
cortical_tensor,
spine_tensor,
spacing,
device
)
return loss
def objective_function(params: list[float]) -> float:
"""
Wrapper for the PSO objective function, calling our Torch-based loss function.
@ -122,6 +164,7 @@ def objective_function(params: list[float]) -> float:
# 將連續值轉換為離散值
diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
# diameter_discrete, length_discrete = diameter_raw, length_raw
loss = cylinder_circle_line_intersection_loss_deductions_torch(
diameter_discrete,

View file

@ -4,15 +4,37 @@ import SimpleITK as sitk
import torch
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
from core.objective import objective_function
from core.objective import objective_function, objective_function_xfr
from pyswarm import pso
import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid, round_down_to_discrete_values
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
from config.constant import OVERLAP_THRESH
from visualization.res_plot_3d import res_plt_2_torch
def run_pso_torch(
def get_first_nonzero_y(arr):
# 1. Create a boolean mask where elements are non-zero
mask = arr != 0
# 2. Find the index of the first True value along the Y axis (axis=1)
y_indices = np.argmax(mask, axis=1)
# 3. Edge Case Handling: If a whole (x, z) column is zero, argmax returns 0.
# We need to distinguish this from an actual non-zero value at index 0.
has_nonzero = np.any(mask, axis=1)
# 4. Replace indices where there were no non-zeros with a sentinel value (e.g., -1)
y_indices = np.where(has_nonzero, y_indices, -1)
y_indices = np.where(y_indices > arr.shape[1] * .4, -1, y_indices)
return y_indices.astype(np.float32)
def constraint_y(x, y_indices):
return y_indices[round(x[0]), round(x[1])]
def run_pso_torch_xfr(
label_str: str,
image1_path: str,
image2_path: str,
@ -24,7 +46,9 @@ def run_pso_torch(
CBT: bool,
device: torch.device,
optimize_size: bool = True,
grid=None
grid=None,
debug=False,
omega = 0.9,
):
"""
Main function to run PSO.
@ -86,15 +110,51 @@ def run_pso_torch(
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
y_indices = get_first_nonzero_y(image2_array)
# flat_min_index = np.argmin(y_indices)
# z_border, x_border = np.unravel_index(flat_min_index, y_indices.shape)
x_with_nonzero = np.where(np.any(image2_array[:,9,:] != 0, axis=0))[0]
x1 = x_with_nonzero[0]
x2 = x_with_nonzero[-1]
x_mid = (x1+x2)/2
x1 = x_mid-9
x2 = x_mid+9
z_sum = np.sum(image2_array, axis=(1, 2))
z_with_nonzero = np.where(z_sum > 0)[0]
z1 = z_with_nonzero[0]
z2 = z_with_nonzero[-1]
# print(x1,x2)
# exit()
# print(x_border, z_border)
# exit()
# import sys
# import numpy
# numpy.set_printoptions(threshold=sys.maxsize)
# print(y_indices)
# exit()
# 設定基本的 bounds
if CBT == True:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
z_bounds = (0, image_shape[0] / 2)
# z_bounds = (z1, (z1+z2)/2)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
azimuth_bounds_l = ((95-azi), (145-azi))
azimuth_bounds_r = ((50-azi), (85-azi))
altitude_bounds = ((60-alt), (75-alt))
x_bounds_right = (x2, image_shape[2]*.9)
x_bounds_left = (image_shape[2]*.1, x1)
# azimuth_bounds_l = ((95-azi), (145-azi))
# azimuth_bounds_r = ((50-azi), (85-azi))
# altitude_bounds = ((60-alt), (75-alt))
azimuth_bounds_l = ((95-azi), (105-azi))
azimuth_bounds_r = ((75-azi), (85-azi))
altitude_bounds = ((55-alt), (70-alt))
else:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
@ -112,7 +172,8 @@ def run_pso_torch(
side: "L" or "R" 只是方便 debug
"""
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
# d, L = snap_to_discrete_values(pos[5], pos[6])
d, L = round_down_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
@ -133,18 +194,18 @@ def run_pso_torch(
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
# 設定 diameter 和 length 的 bounds連續範圍
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)*1.01)
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS)*1.01)
# bounds 現在有 7 個參數
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
lb_l = [z_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
ub_l = [z_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
lb_r = [z_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
ub_r = [z_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
else:
@ -160,6 +221,12 @@ def run_pso_torch(
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
if debug:
print(lb_l)
print(ub_l)
print(lb_r)
print(ub_r)
best_loss_l = float('inf')
best_loss_r = float('inf')
best_position_l = None
@ -167,7 +234,18 @@ def run_pso_torch(
# Left side optimization
print("\n=== 左側 ===")
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
kwargs = {'y_indices': y_indices}
position_l, loss_l = pso(objective_function_xfr, lb_l, ub_l,
# ieqcons=[constraint_y],
kwargs=kwargs,
swarmsize=swarm_size,
omega = omega,
maxiter=max_iter, debug=debug)
z, x, azimuth, altitude, diameter, length = position_l
y = y_indices[round(z), round(x)]
position_l = z, y, x, azimuth, altitude, diameter, length
overlap_l, diameter_l, length_l = eval_overlap_from_position(
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
@ -178,6 +256,7 @@ def run_pso_torch(
print(f"[LEFT] Position: {position_l[:5]}")
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
print(f"[LEFT] Loss: {loss_l}\n")
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
else:
print(f"[LEFT] Position: {position_l}")
@ -221,14 +300,25 @@ def run_pso_torch(
# Right side optimization
print("\n=== 右側 ===")
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
position_r, loss_r = pso(objective_function_xfr, lb_r, ub_r,
# ieqcons=[constraint_y],
kwargs=kwargs,
swarmsize=swarm_size,
omega = omega,
maxiter=max_iter, debug=debug)
z, x, azimuth, altitude, diameter, length = position_r
y = y_indices[round(z), round(x)]
position_r = z, y, x, azimuth, altitude, diameter, length
overlap_r, diameter_r, length_r = eval_overlap_from_position(
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
)
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
if optimize_size:
diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
# diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
diameter_r, length_r = round_down_to_discrete_values(position_r[5], position_r[6])
print(f"[RIGHT] Position: {position_r[:5]}")
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
@ -300,7 +390,332 @@ def run_pso_torch(
cortical_tensor,
image_shape,
image2_path,
'Output',
folder,
label_str,
final_diameter_l,
final_length_l,
final_diameter_r,
final_length_r,
best_position_l,
best_position_r,
swarm_size,
max_iter,
total_time,
spacing,
CBT,
device,
grid)
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
def run_pso_torch(
label_str: str,
image1_path: str,
image2_path: str,
image3_path: str,
folder: str,
swarm_size: int,
max_iter: int,
spacing: list,
CBT: bool,
device: torch.device,
optimize_size: bool = True,
grid=None,
debug=True,
):
"""
Main function to run PSO.
如果 optimize_size=Truediameter length 也會被最佳化
如果 optimize_size=False使用預設值向後兼容
"""
start_time = time.time()
# Use global references
global image1_array, image2_array, image2_shape, image3_array
global diameter, length # 這些現在只用於非最佳化模式
global spine_tensor, cortical_tensor, spine_roi_tensor
# Load images
image1 = sitk.ReadImage(image1_path)
image2 = sitk.ReadImage(image2_path)
image3 = sitk.ReadImage(image3_path)
image1_array = sitk.GetArrayFromImage(image1)
image2_array = sitk.GetArrayFromImage(image2)
image3_array = sitk.GetArrayFromImage(image3)
image2_shape = image2_array.shape
image_shape = image2_shape
# Move arrays to torch
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
# ================= [跨檔案注入變數:終極防呆版] =================
import core.objective
# 1. 注入 Tensors
core.objective.cortical_tensor = cortical_tensor
core.objective.spine_tensor = spine_tensor
core.objective.spine_roi_tensor = spine_roi_tensor
# 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列)
core.objective.image1_array = image1_array
core.objective.image2_array = image2_array
core.objective.image3_array = image3_array
# 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!)
core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行
core.objective.image_shape = image_shape
core.objective.shape = image_shape
# 4. 注入環境變數
core.objective.spacing = spacing
core.objective.device = device
core.objective.grid = grid
# 5. 注入尺寸參數 (兼容固定尺寸模式)
if not optimize_size:
core.objective.diameter = diameter
core.objective.length = length
# ==============================================================
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
# 設定基本的 bounds
if CBT == True:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
azimuth_bounds_l = ((95-azi), (145-azi))
azimuth_bounds_r = ((50-azi), (85-azi))
altitude_bounds = ((60-alt), (75-alt))
# xfr
# z_bounds = (0, image_shape[0] - 1)
y_bounds = (0, image_shape[1]/2)
# x_bounds_right = (image_shape[2]/2, image_shape[2] - 1)
# x_bounds_left = (0, image_shape[2]/2)
# azimuth_bounds_l = (90, 135)
# azimuth_bounds_r = (45, 90)
# altitude_bounds = (0, 90)
else:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
azimuth_bounds_l = (60-azi, 90-azi)
azimuth_bounds_r = (90-azi, 120-azi)
altitude_bounds = (65-alt, 80-alt)
def eval_overlap_from_position(pos, side: str, optimize_size: bool,
spine_tensor: torch.Tensor,
image_shape, spacing):
"""
根據 PSO 給的 position 生成 cylinder mask再算 overlap ratio
side: "L" or "R" 只是方便 debug
"""
if optimize_size:
# d, L = snap_to_discrete_values(pos[5], pos[6])
d, L = round_down_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
params_5 = pos
cyl_mask = generate_cylinder_n_torch(
d, L,
params_5[0], params_5[1], params_5[2],
params_5[3], params_5[4],
image_shape, spacing, device, grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L
if optimize_size:
# 模式 1優化 diameter 和 length
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
# 設定 diameter 和 length 的 bounds連續範圍
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
# bounds 現在有 7 個參數
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
else:
# 模式 2固定 diameter 和 length向後兼容
print("=== 固定尺寸模式:最佳化位置和角度 ===")
# 使用預設值(需要在調用時提供)
diameter = 4.5 # 或從參數傳入
length = 45 # 或從參數傳入
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]]
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]]
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
if debug:
print(lb_l)
print(ub_l)
print(lb_r)
print(ub_r)
best_loss_l = float('inf')
best_loss_r = float('inf')
best_position_l = None
best_position_r = None
# Left side optimization
print("\n=== 左側 ===")
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter, debug=debug)
overlap_l, diameter_l, length_l = eval_overlap_from_position(
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
)
print(f"[LEFT] overlap: {overlap_l*100:.1f}%")
if optimize_size:
print(f"[LEFT] Position: {position_l[:5]}")
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
else:
print(f"[LEFT] Position: {position_l}")
best_position_l = position_l
best_loss_l = loss_l
best_overlap_l = overlap_l # 新增
# max_retries = 0
# retries = 0
# 左側 retryloss 要 <=0 且 overlap >= 0.5 才算過關
# while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries:
# position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
# overlap_l, diameter_l, length_l = eval_overlap_from_position(
# position_l, "L", optimize_size, spine_tensor, image_shape, spacing
# )
# 只要找到更好的 loss或你想用 loss+overlap 綜合排序也行)就更新 best
# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
# candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l
# candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH)
# best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)
# if candidate_ok and (not best_ok or loss_l < best_loss_l):
# best_position_l = candidate_pos
# best_loss_l = loss_l
# best_overlap_l = overlap_l
# print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
# elif (not best_ok) and (loss_l < best_loss_l):
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
# best_position_l = candidate_pos
# best_loss_l = loss_l
# best_overlap_l = overlap_l
# print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
# else:
# print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
# retries += 1
# Right side optimization
print("\n=== 右側 ===")
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter, debug=debug)
overlap_r, diameter_r, length_r = eval_overlap_from_position(
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
)
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
if optimize_size:
# diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
diameter_r, length_r = round_down_to_discrete_values(position_r[5], position_r[6])
print(f"[RIGHT] Position: {position_r[:5]}")
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
print(f"[RIGHT] Loss: {loss_r}\n")
best_position_r = list(position_r[:5]) + [diameter_r, length_r]
else:
print(f"[RIGHT] Position: {position_r}")
print(f"[RIGHT] Loss: {loss_r}\n")
best_position_r = position_r
best_loss_r = loss_r
best_overlap_r = overlap_r
# 如果需要 retryloss > 0
# max_retries = 10
# retries = 0
# while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries:
# position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
# overlap_r, diameter_r, length_r = eval_overlap_from_position(
# position_r, "R", optimize_size, spine_tensor, image_shape, spacing
# )
# 只要找到更好的 loss或你想用 loss+overlap 綜合排序也行)就更新 best
# 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
# candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r
# candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH)
# best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)
# if candidate_ok and (not best_ok or loss_r < best_loss_r):
# best_position_r = candidate_pos
# best_loss_r = loss_r
# best_overlap_r = overlap_r
# print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
# elif (not best_ok) and (loss_r < best_loss_r):
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
# best_position_r = candidate_pos
# best_loss_r = loss_r
# best_overlap_r = overlap_r
# print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
# else:
# print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
# retries += 1
end_time = time.time()
total_time = end_time - start_time
# 提取最終的 diameter 和 length
if optimize_size:
final_diameter_l = best_position_l[5]
final_length_l = best_position_l[6]
final_diameter_r = best_position_r[5]
final_length_r = best_position_r[6]
print(f"\n=== 最終結果 ===")
print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm")
print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm")
else:
final_diameter_l = diameter
final_length_l = length
final_diameter_r = diameter
final_length_r = length
res_plt_2_torch(
spine_tensor,
cortical_tensor,
image_shape,
image2_path,
folder,
label_str,
final_diameter_l,
final_length_l,
@ -420,7 +835,8 @@ def run_de_torch(
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
# d, L = snap_to_discrete_values(pos[5], pos[6])
d, L = round_down_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
@ -598,7 +1014,8 @@ def run_nm_torch(
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
# d, L = snap_to_discrete_values(pos[5], pos[6])
d, L = round_down_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length

View file

@ -2,6 +2,120 @@ import torch
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch
from config.constant import OVERLAP_THRESH
def cl_score_torch_xfr(
cortical_tensor: torch.Tensor,
spine_tensor: torch.Tensor,
cylinder_torch: torch.Tensor,
cylinder_o_torch: torch.Tensor,
intersections: int,
diameter: float = None,
length: float = None,
cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask
) -> float:
"""
漸進式評分優先確保找到骨頭再改善細節
"""
cyl_total = cylinder_torch.sum().item()
overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item()
null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item()
null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item()
in_bone= ((spine_tensor == 1) & (cylinder_torch == 1)).sum().item()
not_in_bone= ((spine_tensor == 0) & (cylinder_torch == 1)).sum().item()
# if cyl_total == 0:
# return float(1000*1000)
# return float(1e9) # 極差的情況
overlap_ratio = overlap / cyl_total
# if cyl_total < 1000:
# return float((1000 - cyl_total)*10000)
score = cyl_total
# if in_bone == 0:
# return float(not_in_bone*200)
score += overlap*30
score += in_bone*10
score -= not_in_bone*1000
score -= null_vox2*1000
return float(-score)
# === 階段 1首要目標是找到骨頭overlap > 0 ===
if overlap == 0:
# 完全沒有 overlap 是最糟糕的情況
score -= 500000 # 超大懲罰
# 如果連 spine 都沒穿過,更糟
if intersections == 0:
score -= 500000
return float(-score)
# === 階段 2有找到骨頭了開始改善品質 ===
# 1. Overlap 獎勵(非線性,鼓勵快速提升)
if overlap_ratio < 0.1:
# 0-10%:每增加 1% 給大量獎勵(鼓勵探索)
score += overlap * 5000 # 很高的單位獎勵
elif overlap_ratio < 0.3:
# 10-30%:中等獎勵
score += overlap * 3000
elif overlap_ratio < 0.5:
# 30-50%:正常獎勵
score += overlap * 2000
else:
# 50%+:獎勵 + 額外比例獎勵
score += overlap * 2000
score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎
# 2. Intersection 控制(稍微放寬)
if intersections == 1:
score += 20000 # 完美
elif intersections == 0:
score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好)
elif intersections == 2:
score -= 10000 # 可接受但不理想
else:
score -= intersections * 15000
# 3. Null voxel 懲罰(漸進式)
null_ratio = null_vox / cyl_total
if overlap_ratio < 0.2:
# 如果 overlap 還很少,對 null voxel 寬容一點
score -= null_vox * 300
elif overlap_ratio < 0.4:
score -= null_vox * 600
else:
# overlap 夠高了,開始嚴格要求
if null_ratio > 0.5:
score -= null_vox * 1500
else:
score -= null_vox * 800
# 4. 反向圓柱懲罰
score -= null_vox2 * 1000
# 5. 尺寸合理性(放寬)
if diameter is not None and length is not None:
if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0)
score -= 3000
if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60)
score -= 3000
# 6. 尖端 breach 懲罰
if cylinder_tip_torch is not None:
tip_total = cylinder_tip_torch.sum().item()
if tip_total > 0:
tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item()
tip_breach_ratio = tip_breach / tip_total
if tip_breach_ratio > 0:
score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多
return float(-score)
def cl_score_torch(
cortical_tensor: torch.Tensor,
spine_tensor: torch.Tensor,

View file

@ -102,6 +102,8 @@ def process_dataset(image_dir, label_dir, output_dir, labels_to_process=None):
try:
result = process_single_image(image_path, label_path, output_dir_base=output_dir)
# print(result)
# exit()
except Exception as e:
print(f"[{idx}/{total_files}] Error processing {file_name}: {e}")
file_summary["note"] = f"Error: {e}"

View file

@ -22,38 +22,71 @@ def seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_base=None,
label_name = label_map[n]
lssif = sitk.LabelShapeStatisticsImageFilter()
lssif.Execute(resampled_sitk_lbl)
# lssif = sitk.LabelShapeStatisticsImageFilter()
# lssif.Execute(resampled_sitk_lbl)
if not lssif.HasLabel(n):
raise RuntimeError(f"Label {n} not found")
# if not lssif.HasLabel(n):
# raise RuntimeError(f"Label {n} not found")
bbox2 = lssif.GetBoundingBox(n)
# bbox2 = lssif.GetBoundingBox(n)
# 1. 提取標籤 n 的二值遮罩 (將標籤 n 設為 1其餘為 0)
binary_mask = sitk.BinaryThreshold(resampled_sitk_lbl, n, n, 1, 0)
# 2. 獲取所有連通區域
# 連通區域濾波器會將 binary_mask 中的不同物體標記為 1, 2, 3...
cc_image = sitk.ConnectedComponent(binary_mask)
# 3. 根據區域大小(像素/體積)重新標記
# RelabelComponent 會按大小排序,最大的物體標籤會被設為 1
relabeled_cc = sitk.RelabelComponent(cc_image, sortByObjectSize=True)
largest_mask = relabeled_cc == 1
# 4. 計算形狀統計信息
shape_stats = sitk.LabelShapeStatisticsImageFilter()
shape_stats.Execute(relabeled_cc)
# 檢查是否有找到任何組件
if shape_stats.GetNumberOfLabels() < 1:
return None
# 5. 獲取最大組件(標籤為 1的邊界框
# 格式通常為 [x_start, y_start, z_start, x_size, y_size, z_size]
bbox2 = shape_stats.GetBoundingBox(1)
roi = sitk.RegionOfInterest(resampled_sitk_img, bbox2[3:], bbox2[:3])
label2 = sitk.RegionOfInterest(resampled_sitk_lbl, bbox2[3:], bbox2[:3])
roi_path = os.path.join(output_base, f"{label_name}_roi.nii.gz")
sitk.WriteImage(roi, roi_path)
binary = sitk.BinaryThreshold(label2, lowerThreshold=n, upperThreshold=n, outsideValue=0, insideValue=1)
# label2 = sitk.RegionOfInterest(resampled_sitk_lbl, bbox2[3:], bbox2[:3])
# binary = sitk.BinaryThreshold(label2, lowerThreshold=n, upperThreshold=n, outsideValue=0, insideValue=1)
binary = sitk.RegionOfInterest(largest_mask, bbox2[3:], bbox2[:3])
binary_path = os.path.join(output_base, f"{label_name}_binary.nii.gz")
sitk.WriteImage(binary, binary_path)
roi_pixel_type = roi.GetPixelID()
binary_cast = sitk.Cast(binary, roi_pixel_type)
roi2 = roi * binary_cast
# roi_pixel_type = roi.GetPixelID()
# binary_cast = sitk.Cast(binary, roi_pixel_type)
# roi2 = roi * binary_cast
roi2 = sitk.Mask(roi, binary)
roi2_path = os.path.join(output_base, f"{label_name}_roi2.nii.gz")
sitk.WriteImage(roi2, roi2_path)
lsif = sitk.LabelStatisticsImageFilter()
label2_int = sitk.Cast(label2, sitk.sitkUInt16)
lsif.Execute(roi2, label2_int)
labels_in_roi = lsif.GetLabels()
if n in labels_in_roi:
roi_hu = sitk.GetArrayFromImage(roi2)
threshold = np.percentile(roi_hu, 60)
else:
threshold = lsif.GetMedian(labels_in_roi[0])
# lsif = sitk.LabelStatisticsImageFilter()
# label2_int = sitk.Cast(label2, sitk.sitkUInt16)
# lsif.Execute(roi2, label2_int)
# labels_in_roi = lsif.GetLabels()
# if n in labels_in_roi:
# roi_hu = sitk.GetArrayFromImage(roi2)
# threshold = np.percentile(roi_hu, 60)
# else:
# threshold = lsif.GetMedian(labels_in_roi[0])
stats = sitk.LabelStatisticsImageFilter()
stats.UseHistogramsOn() # Required for median calculation
stats.Execute(roi, binary)
# Get the median for the region where mask == 1
threshold = stats.GetMedian(1)
cortical = sitk.BinaryThreshold(roi2, lowerThreshold=threshold, upperThreshold=10000, outsideValue=0, insideValue=1)
cortical_path = os.path.join(output_base, f"{label_name}_cortical.nii.gz")

1625
progress.json Normal file

File diff suppressed because it is too large Load diff

View file

@ -7,10 +7,37 @@ import csv
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values
from core.intersection import center_line_intersections_torch
from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask
from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask, cl_score_torch_xfr
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
from utils.helpers import save_with_unique_name
def set_axes_equal_3d(ax):
"""
Make axes of 3D plot have equal scale so that spheres appear as spheres,
cubes as cubes, etc.
"""
x_limits = ax.get_xlim3d()
y_limits = ax.get_ylim3d()
z_limits = ax.get_zlim3d()
x_range = abs(x_limits[1] - x_limits[0])
x_middle = np.mean(x_limits)
y_range = abs(y_limits[1] - y_limits[0])
y_middle = np.mean(y_limits)
z_range = abs(z_limits[1] - z_limits[0])
z_middle = np.mean(z_limits)
plot_radius = 0.5*max([x_range, y_range, z_range])
ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
try:
ax.set_box_aspect([1, 1, 1])
except AttributeError:
pass
def res_plt_2_torch(
spine_tensor: torch.Tensor,
cortical_tensor: torch.Tensor,
@ -114,7 +141,8 @@ def res_plt_2_torch(
spacing,
device
)
loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
# loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
loss_r = cl_score_torch_xfr(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
@ -150,6 +178,7 @@ def res_plt_2_torch(
ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis')
set_axes_equal_3d(ax1)
ax2 = fig.add_subplot(222, projection='3d')
ax2.view_init(elev=90, azim=-90, roll=0)
@ -161,6 +190,7 @@ def res_plt_2_torch(
ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis')
set_axes_equal_3d(ax2)
ax2.legend()
ax3 = fig.add_subplot(223, projection='3d')
@ -173,6 +203,7 @@ def res_plt_2_torch(
ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis')
set_axes_equal_3d(ax3)
ax4 = fig.add_subplot(224, projection='3d')
ax4.view_init(elev=0, azim=0, roll=0)
@ -184,6 +215,7 @@ def res_plt_2_torch(
ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis')
set_axes_equal_3d(ax4)
cyl_points_l = torch.sum(cyl_l).item()
cyl_points_r = torch.sum(cyl_r).item()
@ -217,7 +249,7 @@ def res_plt_2_torch(
headers = [
'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter',
'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff',
'Intersections', 'Best_Loss', 'Overlap_Cortical', 'Overlap_Bone',
'Intersections', 'Best_Loss', 'cyl_points', 'Overlap_Cortical', 'Overlap_Bone',
'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time'
]
@ -237,13 +269,15 @@ def res_plt_2_torch(
length_l,
swarm_size,
max_iter,
f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})",
# f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})",
f"({best_position_l[2]:.2f}, {best_position_l[1]:.2f}, {best_position_l[0]:.2f})",
f"{best_position_l[3]:.2f}",
f"{best_position_l[3]-azi:.2f}",
f"{best_position_l[4]:.2f}",
f"{best_position_l[4]-alt:.2f}",
intersections_l,
f"{loss_l:.2f}",
cyl_points_l,
f"{overlap_cortical_l:.2f}",
f"{overlap_vertebral_l:.2f}",
f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}",
@ -260,13 +294,15 @@ def res_plt_2_torch(
length_r,
swarm_size,
max_iter,
f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})",
# f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})",
f"({best_position_r[2]:.2f}, {best_position_r[1]:.2f}, {best_position_r[0]:.2f})",
f"{best_position_r[3]:.2f}",
f"{best_position_r[3]-azi:.2f}",
f"{best_position_r[4]:.2f}",
f"{best_position_r[4]-alt:.2f}",
intersections_r,
f"{loss_r:.2f}",
cyl_points_r,
f"{overlap_cortical_r:.2f}",
f"{overlap_vertebral_r:.2f}",
f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}",
@ -289,14 +325,16 @@ def res_plt_2_torch(
)
fig.text(
0.5, 0.03,
f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), '
# f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), '
f'Left : Position = ({best_position_l[2]:.2f}, {best_position_l[1]:.2f}, {best_position_l[0]:.2f}), '
f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, '
f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}',
ha='center', fontsize=9
)
fig.text(
0.5, 0.01,
f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), '
# f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), '
f'Right : Position = ({best_position_r[2]:.2f}, {best_position_r[1]:.2f}, {best_position_r[0]:.2f}), '
f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, '
f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}',
ha='center', fontsize=9

143
xfr_debug.py Normal file
View file

@ -0,0 +1,143 @@
import os
import SimpleITK as sitk
import torch
# from config.device import get_device
from core.cylinder import create_coordinate_grid
from core.objective import set_global_context
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch, run_pso_torch_xfr
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
standardized_dir = '/mnt/1248/open/cyrou/CBT/Seg/Resample/standardized-xfr/'
azimuth_rotation_dir = '/mnt/1248/open/cyrou/azimuth_rotation'
tilt_contour_dir = '/mnt/1248/open/cyrou/tilt_contour'
Output_dir = '/mnt/1248/open/cyrou/Output'
def get_device(gpu_id=None):
if torch.cuda.is_available():
if gpu_id is None:
max_free = -1
gpu_id = 0
for i in range(torch.cuda.device_count()):
free_mem, _ = torch.cuda.mem_get_info(i)
# print(f'GPU {i}: {torch.cuda.get_device_name(i)} {free_mem}')
if free_mem > max_free:
max_free = free_mem
gpu_id = i
device = torch.device(f"cuda:{gpu_id}")
print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
else:
device = torch.device("cpu")
print("CUDA not available, using CPU")
return device
def debug_orientation(volume_id, level):
volume_dir = os.path.join(standardized_dir, volume_id)
cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz')
binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz')
roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz')
# azi = azimuth_rotation(binary_path)
# res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=False, debug=False)
azi = azimuth_rotation(binary_path, show_plt=True, save_plt=True, output_path=f'{azimuth_rotation_dir}/{level}_{volume_id}.png')
res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=True, debug=False, save_plt=True, output_path=f'{tilt_contour_dir}/{level}_{volume_id}.png')
alt = res['superior']['tilt_angle_deg']
print(binary_path)
# print(f'Azimuth: {azi}, Alt: {alt}')
print(f'Alt: {alt}')
def debug_pso(volume_id, level):
# ====== PSO ======
swarm_size = 100
max_iter = 100
# ====== DEVICE ======
device = get_device()
# ====== OTHER ======
spacing = [0.5, 0.5, 0.5]
CBT = True
volume_dir = os.path.join(standardized_dir, volume_id)
cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz')
binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz')
roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz')
cortical_image = sitk.ReadImage(cortical_path)
binary_image = sitk.ReadImage(binary_path)
roi_image = sitk.ReadImage(roi_path)
cortical_array = sitk.GetArrayFromImage(cortical_image)
binary_array = sitk.GetArrayFromImage(binary_image)
roi_array = sitk.GetArrayFromImage(roi_image)
image_shape = binary_array.shape
cortical_tensor = torch.tensor(cortical_array, device=device)
binary_tensor = torch.tensor(binary_array, device=device)
grid = create_coordinate_grid(image_shape, device)
set_global_context(
cortical=cortical_tensor,
spine=binary_tensor,
shape=image_shape,
spacing_=spacing,
device_=device,
grid_=grid,
use_tip_penalty=False
)
best_l, loss_l, best_r, loss_r, total_time = run_pso_torch_xfr(
label_str=level,
image1_path=cortical_path,
image2_path=binary_path,
image3_path=roi_path,
folder=Output_dir,
swarm_size=swarm_size,
max_iter=max_iter,
spacing=spacing,
CBT=CBT,
device=device,
optimize_size=True,
grid=grid,
)
# exit()
def main():
# level = 'L1'
# volume_id = '1.3.6.1.4.1.9328.50.4.0001'
# volume_id = '1.3.6.1.4.1.9328.50.4.0003'
# # volume_id = '1.3.6.1.4.1.9328.50.4.0121'
# process_volume(volume_id, level)
# exit()
for volume_id in (
'1.3.6.1.4.1.9328.50.4.0001',
# '1.3.6.1.4.1.9328.50.4.0002',
# '1.3.6.1.4.1.9328.50.4.0003',
# '1.3.6.1.4.1.9328.50.4.0004',
# '1.3.6.1.4.1.9328.50.4.0005',
# '1.3.6.1.4.1.9328.50.4.0006',
):
# for volume_id in sorted(os.listdir(standardized_dir)):
# debug_orientation(volume_id, level)
for level in ('L1', 'L2', 'L3', 'L4', 'L5'):
# debug_orientation(volume_id, level)
debug_pso(volume_id, level)
exit()
if __name__ == '__main__':
main()

24
xfr_preprocess.py Normal file
View file

@ -0,0 +1,24 @@
import os
from imaging.preprocessing import process_dataset
data_root = '/mnt/1220/Public/dataset/Spine/CTSpine1K/data/'
label_root = '/mnt/1220/Public/dataset/Spine/CTSpine1K/label/'
output_dir = '/mnt/1248/open2/cyrou/CBT/Seg/Resample/standardized-xfr/'
label_map = {
'colon': 'conlon',
'COVID-19': 'COVID-19',
'HNSCC-3DCT-RT_neck': 'HNSCC-3DCT-RT_neck',
'liver': 'Liver',
}
def main():
for key, value in label_map.items():
data_dir = os.path.join(data_root, key)
label_dir = os.path.join(label_root, value)
process_dataset(data_dir, label_dir, output_dir)
if __name__ == '__main__':
main()