1. correct the bounding box and cortical mask
2. make the plot isometric 3. now it should work after tuning the parameters
This commit is contained in:
parent
f84d8963da
commit
b76f0708f3
11 changed files with 2524 additions and 54 deletions
|
|
@ -12,7 +12,17 @@ LABEL_MAP = {
|
||||||
15: "T8", 16: "T9", 17: "T10", 18: "T11", 19: "T12",
|
15: "T8", 16: "T9", 17: "T10", 18: "T11", 19: "T12",
|
||||||
20: "L1", 21: "L2", 22: "L3", 23: "L4", 24: "L5"
|
20: "L1", 21: "L2", 22: "L3", 23: "L4", 24: "L5"
|
||||||
}
|
}
|
||||||
ALLOWED_DIAMETERS = [3.5, 4.0, 4.5, 5.0]
|
ALLOWED_DIAMETERS = [
|
||||||
ALLOWED_LENGTHS = [35, 40, 45, 50]
|
3.5,
|
||||||
|
4.0,
|
||||||
|
4.5,
|
||||||
|
5.0,
|
||||||
|
]
|
||||||
|
ALLOWED_LENGTHS = [
|
||||||
|
35,
|
||||||
|
40,
|
||||||
|
45,
|
||||||
|
50,
|
||||||
|
]
|
||||||
OVERLAP_THRESH = 0.50
|
OVERLAP_THRESH = 0.50
|
||||||
DEFAULT_SPACING = [0.5, 0.5, 0.5]
|
DEFAULT_SPACING = [0.5, 0.5, 0.5]
|
||||||
|
|
@ -38,6 +38,27 @@ def snap_to_discrete_values(diameter_raw, length_raw):
|
||||||
|
|
||||||
return diameter_discrete, length_discrete
|
return diameter_discrete, length_discrete
|
||||||
|
|
||||||
|
def round_down_to_discrete_values(diameter_raw, length_raw):
|
||||||
|
"""
|
||||||
|
將連續值映射到小於等於該值的最接近允許離散值 (向下取整)
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
diameter_raw: PSO 給的連續直徑值
|
||||||
|
length_raw: PSO 給的連續長度值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
diameter_discrete, length_discrete: 離散化後的值
|
||||||
|
"""
|
||||||
|
# 找小於等於 raw 的最接近 diameter,如果都大於 raw 則取最小值
|
||||||
|
valid_diameters = [d for d in ALLOWED_DIAMETERS if d <= diameter_raw]
|
||||||
|
diameter_discrete = max(valid_diameters) if valid_diameters else min(ALLOWED_DIAMETERS)
|
||||||
|
|
||||||
|
# 找小於等於 raw 的最接近 length,如果都大於 raw 則取最小值
|
||||||
|
valid_lengths = [l for l in ALLOWED_LENGTHS if l <= length_raw]
|
||||||
|
length_discrete = max(valid_lengths) if valid_lengths else min(ALLOWED_LENGTHS)
|
||||||
|
|
||||||
|
return diameter_discrete, length_discrete
|
||||||
|
|
||||||
|
|
||||||
def generate_cylinder_n_torch(
|
def generate_cylinder_n_torch(
|
||||||
diameter: float,
|
diameter: float,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,13 @@
|
||||||
import torch
|
import random
|
||||||
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch
|
|
||||||
|
|
||||||
|
from scipy.ndimage import map_coordinates
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch
|
||||||
from core.intersection import center_line_intersections_torch
|
from core.intersection import center_line_intersections_torch
|
||||||
from core.scoring import cl_score_torch
|
from core.scoring import cl_score_torch, cl_score_torch_xfr
|
||||||
|
|
||||||
# Global variables (used in objective_function)
|
# Global variables (used in objective_function)
|
||||||
image1_array = None # cortical_nii.gz
|
image1_array = None # cortical_nii.gz
|
||||||
|
|
@ -102,7 +107,8 @@ def cylinder_circle_line_intersection_loss_deductions_torch(
|
||||||
image_shape, spacing, device, grid
|
image_shape, spacing, device, grid
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value = cl_score_torch(
|
# loss_value = cl_score_torch(
|
||||||
|
loss_value = cl_score_torch_xfr(
|
||||||
cortical_tensor, spine_tensor,
|
cortical_tensor, spine_tensor,
|
||||||
cyl_fwd, cyl_opp, intersections,
|
cyl_fwd, cyl_opp, intersections,
|
||||||
cylinder_tip_torch=cyl_tip
|
cylinder_tip_torch=cyl_tip
|
||||||
|
|
@ -110,6 +116,42 @@ def cylinder_circle_line_intersection_loss_deductions_torch(
|
||||||
|
|
||||||
return loss_value
|
return loss_value
|
||||||
|
|
||||||
|
def objective_function_xfr(params: list[float], y_indices) -> float:
|
||||||
|
"""
|
||||||
|
Wrapper for the PSO objective function, calling our Torch-based loss function.
|
||||||
|
Now params includes diameter and length at the end.
|
||||||
|
params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# position_params = params[:5] # [z, y, x, azimuth, altitude]
|
||||||
|
# diameter_raw = params[5]
|
||||||
|
# length_raw = params[6]
|
||||||
|
|
||||||
|
z, x, azimuth, altitude, diameter_raw, length_raw = params
|
||||||
|
y = y_indices[round(z), round(x)] #+ random.uniform(-0.5, 0.5)
|
||||||
|
|
||||||
|
# coords = np.array([[z], [x]])
|
||||||
|
# result = map_coordinates(y_indices, coords, order=1)
|
||||||
|
# y= result[0]
|
||||||
|
|
||||||
|
position_params = [z, y, x, azimuth, altitude]
|
||||||
|
|
||||||
|
# 將連續值轉換為離散值
|
||||||
|
# diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
|
||||||
|
diameter_discrete, length_discrete = diameter_raw, length_raw
|
||||||
|
|
||||||
|
loss = cylinder_circle_line_intersection_loss_deductions_torch(
|
||||||
|
diameter_discrete,
|
||||||
|
length_discrete,
|
||||||
|
position_params,
|
||||||
|
image2_shape,
|
||||||
|
cortical_tensor,
|
||||||
|
spine_tensor,
|
||||||
|
spacing,
|
||||||
|
device
|
||||||
|
)
|
||||||
|
return loss
|
||||||
|
|
||||||
def objective_function(params: list[float]) -> float:
|
def objective_function(params: list[float]) -> float:
|
||||||
"""
|
"""
|
||||||
Wrapper for the PSO objective function, calling our Torch-based loss function.
|
Wrapper for the PSO objective function, calling our Torch-based loss function.
|
||||||
|
|
@ -122,6 +164,7 @@ def objective_function(params: list[float]) -> float:
|
||||||
|
|
||||||
# 將連續值轉換為離散值
|
# 將連續值轉換為離散值
|
||||||
diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
|
diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
|
||||||
|
# diameter_discrete, length_discrete = diameter_raw, length_raw
|
||||||
|
|
||||||
loss = cylinder_circle_line_intersection_loss_deductions_torch(
|
loss = cylinder_circle_line_intersection_loss_deductions_torch(
|
||||||
diameter_discrete,
|
diameter_discrete,
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,37 @@ import SimpleITK as sitk
|
||||||
import torch
|
import torch
|
||||||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||||||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||||||
from core.objective import objective_function
|
from core.objective import objective_function, objective_function_xfr
|
||||||
from pyswarm import pso
|
from pyswarm import pso
|
||||||
import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組
|
import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組
|
||||||
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
|
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid, round_down_to_discrete_values
|
||||||
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
|
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
|
||||||
from config.constant import OVERLAP_THRESH
|
from config.constant import OVERLAP_THRESH
|
||||||
from visualization.res_plot_3d import res_plt_2_torch
|
from visualization.res_plot_3d import res_plt_2_torch
|
||||||
|
|
||||||
def run_pso_torch(
|
|
||||||
|
def get_first_nonzero_y(arr):
|
||||||
|
# 1. Create a boolean mask where elements are non-zero
|
||||||
|
mask = arr != 0
|
||||||
|
|
||||||
|
# 2. Find the index of the first True value along the Y axis (axis=1)
|
||||||
|
y_indices = np.argmax(mask, axis=1)
|
||||||
|
|
||||||
|
# 3. Edge Case Handling: If a whole (x, z) column is zero, argmax returns 0.
|
||||||
|
# We need to distinguish this from an actual non-zero value at index 0.
|
||||||
|
has_nonzero = np.any(mask, axis=1)
|
||||||
|
|
||||||
|
# 4. Replace indices where there were no non-zeros with a sentinel value (e.g., -1)
|
||||||
|
y_indices = np.where(has_nonzero, y_indices, -1)
|
||||||
|
|
||||||
|
y_indices = np.where(y_indices > arr.shape[1] * .4, -1, y_indices)
|
||||||
|
|
||||||
|
return y_indices.astype(np.float32)
|
||||||
|
|
||||||
|
def constraint_y(x, y_indices):
|
||||||
|
return y_indices[round(x[0]), round(x[1])]
|
||||||
|
|
||||||
|
def run_pso_torch_xfr(
|
||||||
label_str: str,
|
label_str: str,
|
||||||
image1_path: str,
|
image1_path: str,
|
||||||
image2_path: str,
|
image2_path: str,
|
||||||
|
|
@ -24,7 +46,9 @@ def run_pso_torch(
|
||||||
CBT: bool,
|
CBT: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
optimize_size: bool = True,
|
optimize_size: bool = True,
|
||||||
grid=None
|
grid=None,
|
||||||
|
debug=False,
|
||||||
|
omega = 0.9,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Main function to run PSO.
|
Main function to run PSO.
|
||||||
|
|
@ -86,15 +110,51 @@ def run_pso_torch(
|
||||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||||
alt = res['superior']['tilt_angle_deg']
|
alt = res['superior']['tilt_angle_deg']
|
||||||
|
|
||||||
|
y_indices = get_first_nonzero_y(image2_array)
|
||||||
|
|
||||||
|
# flat_min_index = np.argmin(y_indices)
|
||||||
|
# z_border, x_border = np.unravel_index(flat_min_index, y_indices.shape)
|
||||||
|
|
||||||
|
x_with_nonzero = np.where(np.any(image2_array[:,9,:] != 0, axis=0))[0]
|
||||||
|
x1 = x_with_nonzero[0]
|
||||||
|
x2 = x_with_nonzero[-1]
|
||||||
|
|
||||||
|
x_mid = (x1+x2)/2
|
||||||
|
x1 = x_mid-9
|
||||||
|
x2 = x_mid+9
|
||||||
|
|
||||||
|
z_sum = np.sum(image2_array, axis=(1, 2))
|
||||||
|
z_with_nonzero = np.where(z_sum > 0)[0]
|
||||||
|
z1 = z_with_nonzero[0]
|
||||||
|
z2 = z_with_nonzero[-1]
|
||||||
|
|
||||||
|
# print(x1,x2)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
# print(x_border, z_border)
|
||||||
|
# exit()
|
||||||
|
# import sys
|
||||||
|
# import numpy
|
||||||
|
# numpy.set_printoptions(threshold=sys.maxsize)
|
||||||
|
# print(y_indices)
|
||||||
|
# exit()
|
||||||
|
|
||||||
# 設定基本的 bounds
|
# 設定基本的 bounds
|
||||||
if CBT == True:
|
if CBT == True:
|
||||||
z_bounds = (0, image_shape[0] - 1)
|
z_bounds = (0, image_shape[0] / 2)
|
||||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
# z_bounds = (z1, (z1+z2)/2)
|
||||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||||
azimuth_bounds_l = ((95-azi), (145-azi))
|
x_bounds_right = (x2, image_shape[2]*.9)
|
||||||
azimuth_bounds_r = ((50-azi), (85-azi))
|
x_bounds_left = (image_shape[2]*.1, x1)
|
||||||
altitude_bounds = ((60-alt), (75-alt))
|
|
||||||
|
# azimuth_bounds_l = ((95-azi), (145-azi))
|
||||||
|
# azimuth_bounds_r = ((50-azi), (85-azi))
|
||||||
|
# altitude_bounds = ((60-alt), (75-alt))
|
||||||
|
azimuth_bounds_l = ((95-azi), (105-azi))
|
||||||
|
azimuth_bounds_r = ((75-azi), (85-azi))
|
||||||
|
altitude_bounds = ((55-alt), (70-alt))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
z_bounds = (0, image_shape[0] - 1)
|
z_bounds = (0, image_shape[0] - 1)
|
||||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||||
|
|
@ -112,7 +172,8 @@ def run_pso_torch(
|
||||||
side: "L" or "R" 只是方便 debug
|
side: "L" or "R" 只是方便 debug
|
||||||
"""
|
"""
|
||||||
if optimize_size:
|
if optimize_size:
|
||||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
# d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||||
|
d, L = round_down_to_discrete_values(pos[5], pos[6])
|
||||||
params_5 = pos[:5]
|
params_5 = pos[:5]
|
||||||
else:
|
else:
|
||||||
d, L = diameter, length
|
d, L = diameter, length
|
||||||
|
|
@ -133,18 +194,18 @@ def run_pso_torch(
|
||||||
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
||||||
|
|
||||||
# 設定 diameter 和 length 的 bounds(連續範圍)
|
# 設定 diameter 和 length 的 bounds(連續範圍)
|
||||||
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
|
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)*1.01)
|
||||||
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
|
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS)*1.01)
|
||||||
|
|
||||||
# bounds 現在有 7 個參數
|
# bounds 現在有 7 個參數
|
||||||
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
|
lb_l = [z_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
|
||||||
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||||||
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
|
ub_l = [z_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
|
||||||
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||||||
|
|
||||||
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
|
lb_r = [z_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
|
||||||
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||||||
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
|
ub_r = [z_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
|
||||||
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -160,6 +221,12 @@ def run_pso_torch(
|
||||||
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
|
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
|
||||||
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
|
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(lb_l)
|
||||||
|
print(ub_l)
|
||||||
|
print(lb_r)
|
||||||
|
print(ub_r)
|
||||||
|
|
||||||
best_loss_l = float('inf')
|
best_loss_l = float('inf')
|
||||||
best_loss_r = float('inf')
|
best_loss_r = float('inf')
|
||||||
best_position_l = None
|
best_position_l = None
|
||||||
|
|
@ -167,7 +234,18 @@ def run_pso_torch(
|
||||||
|
|
||||||
# Left side optimization
|
# Left side optimization
|
||||||
print("\n=== 左側 ===")
|
print("\n=== 左側 ===")
|
||||||
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
|
kwargs = {'y_indices': y_indices}
|
||||||
|
position_l, loss_l = pso(objective_function_xfr, lb_l, ub_l,
|
||||||
|
|
||||||
|
# ieqcons=[constraint_y],
|
||||||
|
kwargs=kwargs,
|
||||||
|
swarmsize=swarm_size,
|
||||||
|
omega = omega,
|
||||||
|
maxiter=max_iter, debug=debug)
|
||||||
|
|
||||||
|
z, x, azimuth, altitude, diameter, length = position_l
|
||||||
|
y = y_indices[round(z), round(x)]
|
||||||
|
position_l = z, y, x, azimuth, altitude, diameter, length
|
||||||
|
|
||||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||||||
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||||||
|
|
@ -178,6 +256,7 @@ def run_pso_torch(
|
||||||
print(f"[LEFT] Position: {position_l[:5]}")
|
print(f"[LEFT] Position: {position_l[:5]}")
|
||||||
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
|
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
|
||||||
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
|
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
|
||||||
|
print(f"[LEFT] Loss: {loss_l}\n")
|
||||||
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
|
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
|
||||||
else:
|
else:
|
||||||
print(f"[LEFT] Position: {position_l}")
|
print(f"[LEFT] Position: {position_l}")
|
||||||
|
|
@ -221,14 +300,25 @@ def run_pso_torch(
|
||||||
|
|
||||||
# Right side optimization
|
# Right side optimization
|
||||||
print("\n=== 右側 ===")
|
print("\n=== 右側 ===")
|
||||||
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
|
position_r, loss_r = pso(objective_function_xfr, lb_r, ub_r,
|
||||||
|
# ieqcons=[constraint_y],
|
||||||
|
kwargs=kwargs,
|
||||||
|
swarmsize=swarm_size,
|
||||||
|
omega = omega,
|
||||||
|
maxiter=max_iter, debug=debug)
|
||||||
|
|
||||||
|
z, x, azimuth, altitude, diameter, length = position_r
|
||||||
|
y = y_indices[round(z), round(x)]
|
||||||
|
position_r = z, y, x, azimuth, altitude, diameter, length
|
||||||
|
|
||||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||||||
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||||||
)
|
)
|
||||||
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
|
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
|
||||||
|
|
||||||
if optimize_size:
|
if optimize_size:
|
||||||
diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
|
# diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
|
||||||
|
diameter_r, length_r = round_down_to_discrete_values(position_r[5], position_r[6])
|
||||||
print(f"[RIGHT] Position: {position_r[:5]}")
|
print(f"[RIGHT] Position: {position_r[:5]}")
|
||||||
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
|
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
|
||||||
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
|
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
|
||||||
|
|
@ -300,7 +390,332 @@ def run_pso_torch(
|
||||||
cortical_tensor,
|
cortical_tensor,
|
||||||
image_shape,
|
image_shape,
|
||||||
image2_path,
|
image2_path,
|
||||||
'Output',
|
folder,
|
||||||
|
label_str,
|
||||||
|
final_diameter_l,
|
||||||
|
final_length_l,
|
||||||
|
final_diameter_r,
|
||||||
|
final_length_r,
|
||||||
|
best_position_l,
|
||||||
|
best_position_r,
|
||||||
|
swarm_size,
|
||||||
|
max_iter,
|
||||||
|
total_time,
|
||||||
|
spacing,
|
||||||
|
CBT,
|
||||||
|
device,
|
||||||
|
grid)
|
||||||
|
|
||||||
|
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||||||
|
|
||||||
|
|
||||||
|
def run_pso_torch(
|
||||||
|
label_str: str,
|
||||||
|
image1_path: str,
|
||||||
|
image2_path: str,
|
||||||
|
image3_path: str,
|
||||||
|
folder: str,
|
||||||
|
swarm_size: int,
|
||||||
|
max_iter: int,
|
||||||
|
spacing: list,
|
||||||
|
CBT: bool,
|
||||||
|
device: torch.device,
|
||||||
|
optimize_size: bool = True,
|
||||||
|
grid=None,
|
||||||
|
debug=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Main function to run PSO.
|
||||||
|
如果 optimize_size=True,diameter 和 length 也會被最佳化
|
||||||
|
如果 optimize_size=False,使用預設值(向後兼容)
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Use global references
|
||||||
|
global image1_array, image2_array, image2_shape, image3_array
|
||||||
|
global diameter, length # 這些現在只用於非最佳化模式
|
||||||
|
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
image1 = sitk.ReadImage(image1_path)
|
||||||
|
image2 = sitk.ReadImage(image2_path)
|
||||||
|
image3 = sitk.ReadImage(image3_path)
|
||||||
|
image1_array = sitk.GetArrayFromImage(image1)
|
||||||
|
image2_array = sitk.GetArrayFromImage(image2)
|
||||||
|
image3_array = sitk.GetArrayFromImage(image3)
|
||||||
|
image2_shape = image2_array.shape
|
||||||
|
image_shape = image2_shape
|
||||||
|
|
||||||
|
# Move arrays to torch
|
||||||
|
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||||||
|
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||||||
|
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
# ================= [跨檔案注入變數:終極防呆版] =================
|
||||||
|
import core.objective
|
||||||
|
|
||||||
|
# 1. 注入 Tensors
|
||||||
|
core.objective.cortical_tensor = cortical_tensor
|
||||||
|
core.objective.spine_tensor = spine_tensor
|
||||||
|
core.objective.spine_roi_tensor = spine_roi_tensor
|
||||||
|
|
||||||
|
# 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列)
|
||||||
|
core.objective.image1_array = image1_array
|
||||||
|
core.objective.image2_array = image2_array
|
||||||
|
core.objective.image3_array = image3_array
|
||||||
|
|
||||||
|
# 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!)
|
||||||
|
core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行
|
||||||
|
core.objective.image_shape = image_shape
|
||||||
|
core.objective.shape = image_shape
|
||||||
|
|
||||||
|
# 4. 注入環境變數
|
||||||
|
core.objective.spacing = spacing
|
||||||
|
core.objective.device = device
|
||||||
|
core.objective.grid = grid
|
||||||
|
|
||||||
|
# 5. 注入尺寸參數 (兼容固定尺寸模式)
|
||||||
|
if not optimize_size:
|
||||||
|
core.objective.diameter = diameter
|
||||||
|
core.objective.length = length
|
||||||
|
# ==============================================================
|
||||||
|
|
||||||
|
azi = azimuth_rotation(image2_path)
|
||||||
|
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||||
|
alt = res['superior']['tilt_angle_deg']
|
||||||
|
|
||||||
|
# 設定基本的 bounds
|
||||||
|
if CBT == True:
|
||||||
|
z_bounds = (0, image_shape[0] - 1)
|
||||||
|
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||||
|
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||||
|
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||||
|
azimuth_bounds_l = ((95-azi), (145-azi))
|
||||||
|
azimuth_bounds_r = ((50-azi), (85-azi))
|
||||||
|
altitude_bounds = ((60-alt), (75-alt))
|
||||||
|
|
||||||
|
# xfr
|
||||||
|
# z_bounds = (0, image_shape[0] - 1)
|
||||||
|
y_bounds = (0, image_shape[1]/2)
|
||||||
|
# x_bounds_right = (image_shape[2]/2, image_shape[2] - 1)
|
||||||
|
# x_bounds_left = (0, image_shape[2]/2)
|
||||||
|
# azimuth_bounds_l = (90, 135)
|
||||||
|
# azimuth_bounds_r = (45, 90)
|
||||||
|
# altitude_bounds = (0, 90)
|
||||||
|
else:
|
||||||
|
z_bounds = (0, image_shape[0] - 1)
|
||||||
|
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||||
|
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||||
|
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||||
|
azimuth_bounds_l = (60-azi, 90-azi)
|
||||||
|
azimuth_bounds_r = (90-azi, 120-azi)
|
||||||
|
altitude_bounds = (65-alt, 80-alt)
|
||||||
|
|
||||||
|
def eval_overlap_from_position(pos, side: str, optimize_size: bool,
|
||||||
|
spine_tensor: torch.Tensor,
|
||||||
|
image_shape, spacing):
|
||||||
|
"""
|
||||||
|
根據 PSO 給的 position 生成 cylinder mask,再算 overlap ratio
|
||||||
|
side: "L" or "R" 只是方便 debug
|
||||||
|
"""
|
||||||
|
if optimize_size:
|
||||||
|
# d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||||
|
d, L = round_down_to_discrete_values(pos[5], pos[6])
|
||||||
|
params_5 = pos[:5]
|
||||||
|
else:
|
||||||
|
d, L = diameter, length
|
||||||
|
params_5 = pos
|
||||||
|
|
||||||
|
cyl_mask = generate_cylinder_n_torch(
|
||||||
|
d, L,
|
||||||
|
params_5[0], params_5[1], params_5[2],
|
||||||
|
params_5[3], params_5[4],
|
||||||
|
image_shape, spacing, device, grid
|
||||||
|
)
|
||||||
|
|
||||||
|
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||||||
|
return overlap, d, L
|
||||||
|
|
||||||
|
if optimize_size:
|
||||||
|
# 模式 1:優化 diameter 和 length
|
||||||
|
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
||||||
|
|
||||||
|
# 設定 diameter 和 length 的 bounds(連續範圍)
|
||||||
|
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
|
||||||
|
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
|
||||||
|
|
||||||
|
# bounds 現在有 7 個參數
|
||||||
|
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
|
||||||
|
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||||||
|
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
|
||||||
|
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||||||
|
|
||||||
|
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
|
||||||
|
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||||||
|
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
|
||||||
|
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 模式 2:固定 diameter 和 length(向後兼容)
|
||||||
|
print("=== 固定尺寸模式:最佳化位置和角度 ===")
|
||||||
|
# 使用預設值(需要在調用時提供)
|
||||||
|
diameter = 4.5 # 或從參數傳入
|
||||||
|
length = 45 # 或從參數傳入
|
||||||
|
|
||||||
|
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]]
|
||||||
|
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]]
|
||||||
|
|
||||||
|
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
|
||||||
|
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(lb_l)
|
||||||
|
print(ub_l)
|
||||||
|
print(lb_r)
|
||||||
|
print(ub_r)
|
||||||
|
|
||||||
|
best_loss_l = float('inf')
|
||||||
|
best_loss_r = float('inf')
|
||||||
|
best_position_l = None
|
||||||
|
best_position_r = None
|
||||||
|
|
||||||
|
# Left side optimization
|
||||||
|
print("\n=== 左側 ===")
|
||||||
|
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter, debug=debug)
|
||||||
|
|
||||||
|
overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||||||
|
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||||||
|
)
|
||||||
|
print(f"[LEFT] overlap: {overlap_l*100:.1f}%")
|
||||||
|
|
||||||
|
if optimize_size:
|
||||||
|
print(f"[LEFT] Position: {position_l[:5]}")
|
||||||
|
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
|
||||||
|
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
|
||||||
|
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
|
||||||
|
else:
|
||||||
|
print(f"[LEFT] Position: {position_l}")
|
||||||
|
best_position_l = position_l
|
||||||
|
|
||||||
|
best_loss_l = loss_l
|
||||||
|
best_overlap_l = overlap_l # 新增
|
||||||
|
|
||||||
|
# max_retries = 0
|
||||||
|
# retries = 0
|
||||||
|
|
||||||
|
# 左側 retry:loss 要 <=0 且 overlap >= 0.5 才算過關
|
||||||
|
# while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries:
|
||||||
|
# position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
|
||||||
|
# overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||||||
|
# position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best
|
||||||
|
# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
|
||||||
|
# candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l
|
||||||
|
|
||||||
|
# candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH)
|
||||||
|
# best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)
|
||||||
|
|
||||||
|
# if candidate_ok and (not best_ok or loss_l < best_loss_l):
|
||||||
|
# best_position_l = candidate_pos
|
||||||
|
# best_loss_l = loss_l
|
||||||
|
# best_overlap_l = overlap_l
|
||||||
|
# print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||||
|
# elif (not best_ok) and (loss_l < best_loss_l):
|
||||||
|
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
|
||||||
|
# best_position_l = candidate_pos
|
||||||
|
# best_loss_l = loss_l
|
||||||
|
# best_overlap_l = overlap_l
|
||||||
|
# print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||||
|
# else:
|
||||||
|
# print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||||
|
|
||||||
|
# retries += 1
|
||||||
|
|
||||||
|
# Right side optimization
|
||||||
|
print("\n=== 右側 ===")
|
||||||
|
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter, debug=debug)
|
||||||
|
overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||||||
|
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||||||
|
)
|
||||||
|
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
|
||||||
|
|
||||||
|
if optimize_size:
|
||||||
|
# diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
|
||||||
|
diameter_r, length_r = round_down_to_discrete_values(position_r[5], position_r[6])
|
||||||
|
print(f"[RIGHT] Position: {position_r[:5]}")
|
||||||
|
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
|
||||||
|
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
|
||||||
|
print(f"[RIGHT] Loss: {loss_r}\n")
|
||||||
|
|
||||||
|
best_position_r = list(position_r[:5]) + [diameter_r, length_r]
|
||||||
|
else:
|
||||||
|
print(f"[RIGHT] Position: {position_r}")
|
||||||
|
print(f"[RIGHT] Loss: {loss_r}\n")
|
||||||
|
best_position_r = position_r
|
||||||
|
|
||||||
|
best_loss_r = loss_r
|
||||||
|
best_overlap_r = overlap_r
|
||||||
|
|
||||||
|
# 如果需要 retry(loss > 0)
|
||||||
|
# max_retries = 10
|
||||||
|
# retries = 0
|
||||||
|
|
||||||
|
# while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries:
|
||||||
|
# position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
|
||||||
|
# overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||||||
|
# position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||||||
|
# )
|
||||||
|
|
||||||
|
# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best
|
||||||
|
# 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
|
||||||
|
# candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r
|
||||||
|
|
||||||
|
# candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH)
|
||||||
|
# best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)
|
||||||
|
|
||||||
|
# if candidate_ok and (not best_ok or loss_r < best_loss_r):
|
||||||
|
# best_position_r = candidate_pos
|
||||||
|
# best_loss_r = loss_r
|
||||||
|
# best_overlap_r = overlap_r
|
||||||
|
# print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||||
|
# elif (not best_ok) and (loss_r < best_loss_r):
|
||||||
|
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
|
||||||
|
# best_position_r = candidate_pos
|
||||||
|
# best_loss_r = loss_r
|
||||||
|
# best_overlap_r = overlap_r
|
||||||
|
# print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||||
|
# else:
|
||||||
|
# print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||||
|
|
||||||
|
# retries += 1
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
total_time = end_time - start_time
|
||||||
|
|
||||||
|
# 提取最終的 diameter 和 length
|
||||||
|
if optimize_size:
|
||||||
|
final_diameter_l = best_position_l[5]
|
||||||
|
final_length_l = best_position_l[6]
|
||||||
|
final_diameter_r = best_position_r[5]
|
||||||
|
final_length_r = best_position_r[6]
|
||||||
|
|
||||||
|
print(f"\n=== 最終結果 ===")
|
||||||
|
print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm")
|
||||||
|
print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm")
|
||||||
|
else:
|
||||||
|
final_diameter_l = diameter
|
||||||
|
final_length_l = length
|
||||||
|
final_diameter_r = diameter
|
||||||
|
final_length_r = length
|
||||||
|
|
||||||
|
res_plt_2_torch(
|
||||||
|
spine_tensor,
|
||||||
|
cortical_tensor,
|
||||||
|
image_shape,
|
||||||
|
image2_path,
|
||||||
|
folder,
|
||||||
label_str,
|
label_str,
|
||||||
final_diameter_l,
|
final_diameter_l,
|
||||||
final_length_l,
|
final_length_l,
|
||||||
|
|
@ -420,7 +835,8 @@ def run_de_torch(
|
||||||
|
|
||||||
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
||||||
if optimize_size:
|
if optimize_size:
|
||||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
# d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||||
|
d, L = round_down_to_discrete_values(pos[5], pos[6])
|
||||||
params_5 = pos[:5]
|
params_5 = pos[:5]
|
||||||
else:
|
else:
|
||||||
d, L = diameter, length
|
d, L = diameter, length
|
||||||
|
|
@ -598,7 +1014,8 @@ def run_nm_torch(
|
||||||
|
|
||||||
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
||||||
if optimize_size:
|
if optimize_size:
|
||||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
# d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||||
|
d, L = round_down_to_discrete_values(pos[5], pos[6])
|
||||||
params_5 = pos[:5]
|
params_5 = pos[:5]
|
||||||
else:
|
else:
|
||||||
d, L = diameter, length
|
d, L = diameter, length
|
||||||
|
|
|
||||||
114
core/scoring.py
114
core/scoring.py
|
|
@ -2,6 +2,120 @@ import torch
|
||||||
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch
|
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch
|
||||||
from config.constant import OVERLAP_THRESH
|
from config.constant import OVERLAP_THRESH
|
||||||
|
|
||||||
|
def cl_score_torch_xfr(
|
||||||
|
cortical_tensor: torch.Tensor,
|
||||||
|
spine_tensor: torch.Tensor,
|
||||||
|
cylinder_torch: torch.Tensor,
|
||||||
|
cylinder_o_torch: torch.Tensor,
|
||||||
|
intersections: int,
|
||||||
|
diameter: float = None,
|
||||||
|
length: float = None,
|
||||||
|
cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
漸進式評分:優先確保找到骨頭,再改善細節
|
||||||
|
"""
|
||||||
|
cyl_total = cylinder_torch.sum().item()
|
||||||
|
overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item()
|
||||||
|
null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item()
|
||||||
|
null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item()
|
||||||
|
|
||||||
|
in_bone= ((spine_tensor == 1) & (cylinder_torch == 1)).sum().item()
|
||||||
|
not_in_bone= ((spine_tensor == 0) & (cylinder_torch == 1)).sum().item()
|
||||||
|
|
||||||
|
# if cyl_total == 0:
|
||||||
|
# return float(1000*1000)
|
||||||
|
# return float(1e9) # 極差的情況
|
||||||
|
|
||||||
|
overlap_ratio = overlap / cyl_total
|
||||||
|
|
||||||
|
# if cyl_total < 1000:
|
||||||
|
# return float((1000 - cyl_total)*10000)
|
||||||
|
|
||||||
|
score = cyl_total
|
||||||
|
|
||||||
|
# if in_bone == 0:
|
||||||
|
# return float(not_in_bone*200)
|
||||||
|
|
||||||
|
score += overlap*30
|
||||||
|
score += in_bone*10
|
||||||
|
score -= not_in_bone*1000
|
||||||
|
score -= null_vox2*1000
|
||||||
|
|
||||||
|
return float(-score)
|
||||||
|
|
||||||
|
# === 階段 1:首要目標是找到骨頭(overlap > 0) ===
|
||||||
|
if overlap == 0:
|
||||||
|
# 完全沒有 overlap 是最糟糕的情況
|
||||||
|
score -= 500000 # 超大懲罰
|
||||||
|
# 如果連 spine 都沒穿過,更糟
|
||||||
|
if intersections == 0:
|
||||||
|
score -= 500000
|
||||||
|
return float(-score)
|
||||||
|
|
||||||
|
# === 階段 2:有找到骨頭了,開始改善品質 ===
|
||||||
|
|
||||||
|
# 1. Overlap 獎勵(非線性,鼓勵快速提升)
|
||||||
|
if overlap_ratio < 0.1:
|
||||||
|
# 0-10%:每增加 1% 給大量獎勵(鼓勵探索)
|
||||||
|
score += overlap * 5000 # 很高的單位獎勵
|
||||||
|
elif overlap_ratio < 0.3:
|
||||||
|
# 10-30%:中等獎勵
|
||||||
|
score += overlap * 3000
|
||||||
|
elif overlap_ratio < 0.5:
|
||||||
|
# 30-50%:正常獎勵
|
||||||
|
score += overlap * 2000
|
||||||
|
else:
|
||||||
|
# 50%+:獎勵 + 額外比例獎勵
|
||||||
|
score += overlap * 2000
|
||||||
|
score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎
|
||||||
|
|
||||||
|
# 2. Intersection 控制(稍微放寬)
|
||||||
|
if intersections == 1:
|
||||||
|
score += 20000 # 完美
|
||||||
|
elif intersections == 0:
|
||||||
|
score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好)
|
||||||
|
elif intersections == 2:
|
||||||
|
score -= 10000 # 可接受但不理想
|
||||||
|
else:
|
||||||
|
score -= intersections * 15000
|
||||||
|
|
||||||
|
# 3. Null voxel 懲罰(漸進式)
|
||||||
|
null_ratio = null_vox / cyl_total
|
||||||
|
|
||||||
|
if overlap_ratio < 0.2:
|
||||||
|
# 如果 overlap 還很少,對 null voxel 寬容一點
|
||||||
|
score -= null_vox * 300
|
||||||
|
elif overlap_ratio < 0.4:
|
||||||
|
score -= null_vox * 600
|
||||||
|
else:
|
||||||
|
# overlap 夠高了,開始嚴格要求
|
||||||
|
if null_ratio > 0.5:
|
||||||
|
score -= null_vox * 1500
|
||||||
|
else:
|
||||||
|
score -= null_vox * 800
|
||||||
|
|
||||||
|
# 4. 反向圓柱懲罰
|
||||||
|
score -= null_vox2 * 1000
|
||||||
|
|
||||||
|
# 5. 尺寸合理性(放寬)
|
||||||
|
if diameter is not None and length is not None:
|
||||||
|
if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0)
|
||||||
|
score -= 3000
|
||||||
|
if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60)
|
||||||
|
score -= 3000
|
||||||
|
|
||||||
|
# 6. 尖端 breach 懲罰
|
||||||
|
if cylinder_tip_torch is not None:
|
||||||
|
tip_total = cylinder_tip_torch.sum().item()
|
||||||
|
if tip_total > 0:
|
||||||
|
tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item()
|
||||||
|
tip_breach_ratio = tip_breach / tip_total
|
||||||
|
if tip_breach_ratio > 0:
|
||||||
|
score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多
|
||||||
|
|
||||||
|
return float(-score)
|
||||||
|
|
||||||
def cl_score_torch(
|
def cl_score_torch(
|
||||||
cortical_tensor: torch.Tensor,
|
cortical_tensor: torch.Tensor,
|
||||||
spine_tensor: torch.Tensor,
|
spine_tensor: torch.Tensor,
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,8 @@ def process_dataset(image_dir, label_dir, output_dir, labels_to_process=None):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = process_single_image(image_path, label_path, output_dir_base=output_dir)
|
result = process_single_image(image_path, label_path, output_dir_base=output_dir)
|
||||||
|
# print(result)
|
||||||
|
# exit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[{idx}/{total_files}] Error processing {file_name}: {e}")
|
print(f"[{idx}/{total_files}] Error processing {file_name}: {e}")
|
||||||
file_summary["note"] = f"Error: {e}"
|
file_summary["note"] = f"Error: {e}"
|
||||||
|
|
|
||||||
|
|
@ -22,38 +22,71 @@ def seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_base=None,
|
||||||
|
|
||||||
label_name = label_map[n]
|
label_name = label_map[n]
|
||||||
|
|
||||||
lssif = sitk.LabelShapeStatisticsImageFilter()
|
# lssif = sitk.LabelShapeStatisticsImageFilter()
|
||||||
lssif.Execute(resampled_sitk_lbl)
|
# lssif.Execute(resampled_sitk_lbl)
|
||||||
|
|
||||||
if not lssif.HasLabel(n):
|
# if not lssif.HasLabel(n):
|
||||||
raise RuntimeError(f"Label {n} not found")
|
# raise RuntimeError(f"Label {n} not found")
|
||||||
|
|
||||||
bbox2 = lssif.GetBoundingBox(n)
|
# bbox2 = lssif.GetBoundingBox(n)
|
||||||
|
|
||||||
|
# 1. 提取標籤 n 的二值遮罩 (將標籤 n 設為 1,其餘為 0)
|
||||||
|
binary_mask = sitk.BinaryThreshold(resampled_sitk_lbl, n, n, 1, 0)
|
||||||
|
|
||||||
|
# 2. 獲取所有連通區域
|
||||||
|
# 連通區域濾波器會將 binary_mask 中的不同物體標記為 1, 2, 3...
|
||||||
|
cc_image = sitk.ConnectedComponent(binary_mask)
|
||||||
|
|
||||||
|
# 3. 根據區域大小(像素/體積)重新標記
|
||||||
|
# RelabelComponent 會按大小排序,最大的物體標籤會被設為 1
|
||||||
|
relabeled_cc = sitk.RelabelComponent(cc_image, sortByObjectSize=True)
|
||||||
|
largest_mask = relabeled_cc == 1
|
||||||
|
|
||||||
|
# 4. 計算形狀統計信息
|
||||||
|
shape_stats = sitk.LabelShapeStatisticsImageFilter()
|
||||||
|
shape_stats.Execute(relabeled_cc)
|
||||||
|
|
||||||
|
# 檢查是否有找到任何組件
|
||||||
|
if shape_stats.GetNumberOfLabels() < 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 5. 獲取最大組件(標籤為 1)的邊界框
|
||||||
|
# 格式通常為 [x_start, y_start, z_start, x_size, y_size, z_size]
|
||||||
|
bbox2 = shape_stats.GetBoundingBox(1)
|
||||||
|
|
||||||
roi = sitk.RegionOfInterest(resampled_sitk_img, bbox2[3:], bbox2[:3])
|
roi = sitk.RegionOfInterest(resampled_sitk_img, bbox2[3:], bbox2[:3])
|
||||||
label2 = sitk.RegionOfInterest(resampled_sitk_lbl, bbox2[3:], bbox2[:3])
|
|
||||||
roi_path = os.path.join(output_base, f"{label_name}_roi.nii.gz")
|
roi_path = os.path.join(output_base, f"{label_name}_roi.nii.gz")
|
||||||
sitk.WriteImage(roi, roi_path)
|
sitk.WriteImage(roi, roi_path)
|
||||||
|
|
||||||
binary = sitk.BinaryThreshold(label2, lowerThreshold=n, upperThreshold=n, outsideValue=0, insideValue=1)
|
# label2 = sitk.RegionOfInterest(resampled_sitk_lbl, bbox2[3:], bbox2[:3])
|
||||||
|
# binary = sitk.BinaryThreshold(label2, lowerThreshold=n, upperThreshold=n, outsideValue=0, insideValue=1)
|
||||||
|
binary = sitk.RegionOfInterest(largest_mask, bbox2[3:], bbox2[:3])
|
||||||
binary_path = os.path.join(output_base, f"{label_name}_binary.nii.gz")
|
binary_path = os.path.join(output_base, f"{label_name}_binary.nii.gz")
|
||||||
sitk.WriteImage(binary, binary_path)
|
sitk.WriteImage(binary, binary_path)
|
||||||
|
|
||||||
roi_pixel_type = roi.GetPixelID()
|
# roi_pixel_type = roi.GetPixelID()
|
||||||
binary_cast = sitk.Cast(binary, roi_pixel_type)
|
# binary_cast = sitk.Cast(binary, roi_pixel_type)
|
||||||
roi2 = roi * binary_cast
|
# roi2 = roi * binary_cast
|
||||||
|
roi2 = sitk.Mask(roi, binary)
|
||||||
roi2_path = os.path.join(output_base, f"{label_name}_roi2.nii.gz")
|
roi2_path = os.path.join(output_base, f"{label_name}_roi2.nii.gz")
|
||||||
sitk.WriteImage(roi2, roi2_path)
|
sitk.WriteImage(roi2, roi2_path)
|
||||||
|
|
||||||
lsif = sitk.LabelStatisticsImageFilter()
|
# lsif = sitk.LabelStatisticsImageFilter()
|
||||||
label2_int = sitk.Cast(label2, sitk.sitkUInt16)
|
# label2_int = sitk.Cast(label2, sitk.sitkUInt16)
|
||||||
lsif.Execute(roi2, label2_int)
|
# lsif.Execute(roi2, label2_int)
|
||||||
labels_in_roi = lsif.GetLabels()
|
# labels_in_roi = lsif.GetLabels()
|
||||||
if n in labels_in_roi:
|
# if n in labels_in_roi:
|
||||||
roi_hu = sitk.GetArrayFromImage(roi2)
|
# roi_hu = sitk.GetArrayFromImage(roi2)
|
||||||
threshold = np.percentile(roi_hu, 60)
|
# threshold = np.percentile(roi_hu, 60)
|
||||||
else:
|
# else:
|
||||||
threshold = lsif.GetMedian(labels_in_roi[0])
|
# threshold = lsif.GetMedian(labels_in_roi[0])
|
||||||
|
|
||||||
|
stats = sitk.LabelStatisticsImageFilter()
|
||||||
|
stats.UseHistogramsOn() # Required for median calculation
|
||||||
|
stats.Execute(roi, binary)
|
||||||
|
|
||||||
|
# Get the median for the region where mask == 1
|
||||||
|
threshold = stats.GetMedian(1)
|
||||||
|
|
||||||
cortical = sitk.BinaryThreshold(roi2, lowerThreshold=threshold, upperThreshold=10000, outsideValue=0, insideValue=1)
|
cortical = sitk.BinaryThreshold(roi2, lowerThreshold=threshold, upperThreshold=10000, outsideValue=0, insideValue=1)
|
||||||
cortical_path = os.path.join(output_base, f"{label_name}_cortical.nii.gz")
|
cortical_path = os.path.join(output_base, f"{label_name}_cortical.nii.gz")
|
||||||
|
|
|
||||||
1625
progress.json
Normal file
1625
progress.json
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -7,10 +7,37 @@ import csv
|
||||||
|
|
||||||
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values
|
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values
|
||||||
from core.intersection import center_line_intersections_torch
|
from core.intersection import center_line_intersections_torch
|
||||||
from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask
|
from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask, cl_score_torch_xfr
|
||||||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||||||
from utils.helpers import save_with_unique_name
|
from utils.helpers import save_with_unique_name
|
||||||
|
|
||||||
|
def set_axes_equal_3d(ax):
|
||||||
|
"""
|
||||||
|
Make axes of 3D plot have equal scale so that spheres appear as spheres,
|
||||||
|
cubes as cubes, etc.
|
||||||
|
"""
|
||||||
|
x_limits = ax.get_xlim3d()
|
||||||
|
y_limits = ax.get_ylim3d()
|
||||||
|
z_limits = ax.get_zlim3d()
|
||||||
|
|
||||||
|
x_range = abs(x_limits[1] - x_limits[0])
|
||||||
|
x_middle = np.mean(x_limits)
|
||||||
|
y_range = abs(y_limits[1] - y_limits[0])
|
||||||
|
y_middle = np.mean(y_limits)
|
||||||
|
z_range = abs(z_limits[1] - z_limits[0])
|
||||||
|
z_middle = np.mean(z_limits)
|
||||||
|
|
||||||
|
plot_radius = 0.5*max([x_range, y_range, z_range])
|
||||||
|
|
||||||
|
ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
|
||||||
|
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
|
||||||
|
ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
|
||||||
|
|
||||||
|
try:
|
||||||
|
ax.set_box_aspect([1, 1, 1])
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
def res_plt_2_torch(
|
def res_plt_2_torch(
|
||||||
spine_tensor: torch.Tensor,
|
spine_tensor: torch.Tensor,
|
||||||
cortical_tensor: torch.Tensor,
|
cortical_tensor: torch.Tensor,
|
||||||
|
|
@ -114,7 +141,8 @@ def res_plt_2_torch(
|
||||||
spacing,
|
spacing,
|
||||||
device
|
device
|
||||||
)
|
)
|
||||||
loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
|
# loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
|
||||||
|
loss_r = cl_score_torch_xfr(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
|
||||||
|
|
||||||
azi = azimuth_rotation(image2_path)
|
azi = azimuth_rotation(image2_path)
|
||||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||||
|
|
@ -150,6 +178,7 @@ def res_plt_2_torch(
|
||||||
ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||||||
ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||||||
ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis')
|
ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis')
|
||||||
|
set_axes_equal_3d(ax1)
|
||||||
|
|
||||||
ax2 = fig.add_subplot(222, projection='3d')
|
ax2 = fig.add_subplot(222, projection='3d')
|
||||||
ax2.view_init(elev=90, azim=-90, roll=0)
|
ax2.view_init(elev=90, azim=-90, roll=0)
|
||||||
|
|
@ -161,6 +190,7 @@ def res_plt_2_torch(
|
||||||
ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||||||
ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||||||
ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis')
|
ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis')
|
||||||
|
set_axes_equal_3d(ax2)
|
||||||
ax2.legend()
|
ax2.legend()
|
||||||
|
|
||||||
ax3 = fig.add_subplot(223, projection='3d')
|
ax3 = fig.add_subplot(223, projection='3d')
|
||||||
|
|
@ -173,6 +203,7 @@ def res_plt_2_torch(
|
||||||
ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||||||
ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||||||
ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis')
|
ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis')
|
||||||
|
set_axes_equal_3d(ax3)
|
||||||
|
|
||||||
ax4 = fig.add_subplot(224, projection='3d')
|
ax4 = fig.add_subplot(224, projection='3d')
|
||||||
ax4.view_init(elev=0, azim=0, roll=0)
|
ax4.view_init(elev=0, azim=0, roll=0)
|
||||||
|
|
@ -184,6 +215,7 @@ def res_plt_2_torch(
|
||||||
ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||||||
ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||||||
ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis')
|
ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis')
|
||||||
|
set_axes_equal_3d(ax4)
|
||||||
|
|
||||||
cyl_points_l = torch.sum(cyl_l).item()
|
cyl_points_l = torch.sum(cyl_l).item()
|
||||||
cyl_points_r = torch.sum(cyl_r).item()
|
cyl_points_r = torch.sum(cyl_r).item()
|
||||||
|
|
@ -217,7 +249,7 @@ def res_plt_2_torch(
|
||||||
headers = [
|
headers = [
|
||||||
'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter',
|
'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter',
|
||||||
'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff',
|
'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff',
|
||||||
'Intersections', 'Best_Loss', 'Overlap_Cortical', 'Overlap_Bone',
|
'Intersections', 'Best_Loss', 'cyl_points', 'Overlap_Cortical', 'Overlap_Bone',
|
||||||
'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time'
|
'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time'
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -237,13 +269,15 @@ def res_plt_2_torch(
|
||||||
length_l,
|
length_l,
|
||||||
swarm_size,
|
swarm_size,
|
||||||
max_iter,
|
max_iter,
|
||||||
f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})",
|
# f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})",
|
||||||
|
f"({best_position_l[2]:.2f}, {best_position_l[1]:.2f}, {best_position_l[0]:.2f})",
|
||||||
f"{best_position_l[3]:.2f}",
|
f"{best_position_l[3]:.2f}",
|
||||||
f"{best_position_l[3]-azi:.2f}",
|
f"{best_position_l[3]-azi:.2f}",
|
||||||
f"{best_position_l[4]:.2f}",
|
f"{best_position_l[4]:.2f}",
|
||||||
f"{best_position_l[4]-alt:.2f}",
|
f"{best_position_l[4]-alt:.2f}",
|
||||||
intersections_l,
|
intersections_l,
|
||||||
f"{loss_l:.2f}",
|
f"{loss_l:.2f}",
|
||||||
|
cyl_points_l,
|
||||||
f"{overlap_cortical_l:.2f}",
|
f"{overlap_cortical_l:.2f}",
|
||||||
f"{overlap_vertebral_l:.2f}",
|
f"{overlap_vertebral_l:.2f}",
|
||||||
f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}",
|
f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}",
|
||||||
|
|
@ -260,13 +294,15 @@ def res_plt_2_torch(
|
||||||
length_r,
|
length_r,
|
||||||
swarm_size,
|
swarm_size,
|
||||||
max_iter,
|
max_iter,
|
||||||
f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})",
|
# f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})",
|
||||||
|
f"({best_position_r[2]:.2f}, {best_position_r[1]:.2f}, {best_position_r[0]:.2f})",
|
||||||
f"{best_position_r[3]:.2f}",
|
f"{best_position_r[3]:.2f}",
|
||||||
f"{best_position_r[3]-azi:.2f}",
|
f"{best_position_r[3]-azi:.2f}",
|
||||||
f"{best_position_r[4]:.2f}",
|
f"{best_position_r[4]:.2f}",
|
||||||
f"{best_position_r[4]-alt:.2f}",
|
f"{best_position_r[4]-alt:.2f}",
|
||||||
intersections_r,
|
intersections_r,
|
||||||
f"{loss_r:.2f}",
|
f"{loss_r:.2f}",
|
||||||
|
cyl_points_r,
|
||||||
f"{overlap_cortical_r:.2f}",
|
f"{overlap_cortical_r:.2f}",
|
||||||
f"{overlap_vertebral_r:.2f}",
|
f"{overlap_vertebral_r:.2f}",
|
||||||
f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}",
|
f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}",
|
||||||
|
|
@ -289,14 +325,16 @@ def res_plt_2_torch(
|
||||||
)
|
)
|
||||||
fig.text(
|
fig.text(
|
||||||
0.5, 0.03,
|
0.5, 0.03,
|
||||||
f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), '
|
# f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), '
|
||||||
|
f'Left : Position = ({best_position_l[2]:.2f}, {best_position_l[1]:.2f}, {best_position_l[0]:.2f}), '
|
||||||
f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, '
|
f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, '
|
||||||
f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}',
|
f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}',
|
||||||
ha='center', fontsize=9
|
ha='center', fontsize=9
|
||||||
)
|
)
|
||||||
fig.text(
|
fig.text(
|
||||||
0.5, 0.01,
|
0.5, 0.01,
|
||||||
f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), '
|
# f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), '
|
||||||
|
f'Right : Position = ({best_position_r[2]:.2f}, {best_position_r[1]:.2f}, {best_position_r[0]:.2f}), '
|
||||||
f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, '
|
f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, '
|
||||||
f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}',
|
f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}',
|
||||||
ha='center', fontsize=9
|
ha='center', fontsize=9
|
||||||
|
|
|
||||||
143
xfr_debug.py
Normal file
143
xfr_debug.py
Normal file
|
|
@ -0,0 +1,143 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import SimpleITK as sitk
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# from config.device import get_device
|
||||||
|
from core.cylinder import create_coordinate_grid
|
||||||
|
from core.objective import set_global_context
|
||||||
|
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch, run_pso_torch_xfr
|
||||||
|
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||||||
|
|
||||||
|
standardized_dir = '/mnt/1248/open/cyrou/CBT/Seg/Resample/standardized-xfr/'
|
||||||
|
|
||||||
|
azimuth_rotation_dir = '/mnt/1248/open/cyrou/azimuth_rotation'
|
||||||
|
tilt_contour_dir = '/mnt/1248/open/cyrou/tilt_contour'
|
||||||
|
Output_dir = '/mnt/1248/open/cyrou/Output'
|
||||||
|
|
||||||
|
def get_device(gpu_id=None):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
if gpu_id is None:
|
||||||
|
max_free = -1
|
||||||
|
gpu_id = 0
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
free_mem, _ = torch.cuda.mem_get_info(i)
|
||||||
|
# print(f'GPU {i}: {torch.cuda.get_device_name(i)} {free_mem}')
|
||||||
|
if free_mem > max_free:
|
||||||
|
max_free = free_mem
|
||||||
|
gpu_id = i
|
||||||
|
device = torch.device(f"cuda:{gpu_id}")
|
||||||
|
print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
print("CUDA not available, using CPU")
|
||||||
|
return device
|
||||||
|
|
||||||
|
def debug_orientation(volume_id, level):
|
||||||
|
|
||||||
|
|
||||||
|
volume_dir = os.path.join(standardized_dir, volume_id)
|
||||||
|
|
||||||
|
cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz')
|
||||||
|
binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz')
|
||||||
|
roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz')
|
||||||
|
|
||||||
|
# azi = azimuth_rotation(binary_path)
|
||||||
|
# res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=False, debug=False)
|
||||||
|
azi = azimuth_rotation(binary_path, show_plt=True, save_plt=True, output_path=f'{azimuth_rotation_dir}/{level}_{volume_id}.png')
|
||||||
|
res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=True, debug=False, save_plt=True, output_path=f'{tilt_contour_dir}/{level}_{volume_id}.png')
|
||||||
|
alt = res['superior']['tilt_angle_deg']
|
||||||
|
|
||||||
|
print(binary_path)
|
||||||
|
# print(f'Azimuth: {azi}, Alt: {alt}')
|
||||||
|
print(f'Alt: {alt}')
|
||||||
|
|
||||||
|
def debug_pso(volume_id, level):
|
||||||
|
# ====== PSO ======
|
||||||
|
swarm_size = 100
|
||||||
|
max_iter = 100
|
||||||
|
|
||||||
|
# ====== DEVICE ======
|
||||||
|
device = get_device()
|
||||||
|
|
||||||
|
# ====== OTHER ======
|
||||||
|
spacing = [0.5, 0.5, 0.5]
|
||||||
|
CBT = True
|
||||||
|
|
||||||
|
volume_dir = os.path.join(standardized_dir, volume_id)
|
||||||
|
|
||||||
|
cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz')
|
||||||
|
binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz')
|
||||||
|
roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz')
|
||||||
|
|
||||||
|
cortical_image = sitk.ReadImage(cortical_path)
|
||||||
|
binary_image = sitk.ReadImage(binary_path)
|
||||||
|
roi_image = sitk.ReadImage(roi_path)
|
||||||
|
|
||||||
|
cortical_array = sitk.GetArrayFromImage(cortical_image)
|
||||||
|
binary_array = sitk.GetArrayFromImage(binary_image)
|
||||||
|
roi_array = sitk.GetArrayFromImage(roi_image)
|
||||||
|
image_shape = binary_array.shape
|
||||||
|
|
||||||
|
cortical_tensor = torch.tensor(cortical_array, device=device)
|
||||||
|
binary_tensor = torch.tensor(binary_array, device=device)
|
||||||
|
|
||||||
|
grid = create_coordinate_grid(image_shape, device)
|
||||||
|
|
||||||
|
set_global_context(
|
||||||
|
cortical=cortical_tensor,
|
||||||
|
spine=binary_tensor,
|
||||||
|
shape=image_shape,
|
||||||
|
spacing_=spacing,
|
||||||
|
device_=device,
|
||||||
|
grid_=grid,
|
||||||
|
use_tip_penalty=False
|
||||||
|
)
|
||||||
|
|
||||||
|
best_l, loss_l, best_r, loss_r, total_time = run_pso_torch_xfr(
|
||||||
|
label_str=level,
|
||||||
|
image1_path=cortical_path,
|
||||||
|
image2_path=binary_path,
|
||||||
|
image3_path=roi_path,
|
||||||
|
folder=Output_dir,
|
||||||
|
swarm_size=swarm_size,
|
||||||
|
max_iter=max_iter,
|
||||||
|
spacing=spacing,
|
||||||
|
CBT=CBT,
|
||||||
|
device=device,
|
||||||
|
optimize_size=True,
|
||||||
|
grid=grid,
|
||||||
|
)
|
||||||
|
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# level = 'L1'
|
||||||
|
|
||||||
|
# volume_id = '1.3.6.1.4.1.9328.50.4.0001'
|
||||||
|
# volume_id = '1.3.6.1.4.1.9328.50.4.0003'
|
||||||
|
# # volume_id = '1.3.6.1.4.1.9328.50.4.0121'
|
||||||
|
# process_volume(volume_id, level)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
for volume_id in (
|
||||||
|
'1.3.6.1.4.1.9328.50.4.0001',
|
||||||
|
# '1.3.6.1.4.1.9328.50.4.0002',
|
||||||
|
# '1.3.6.1.4.1.9328.50.4.0003',
|
||||||
|
# '1.3.6.1.4.1.9328.50.4.0004',
|
||||||
|
# '1.3.6.1.4.1.9328.50.4.0005',
|
||||||
|
# '1.3.6.1.4.1.9328.50.4.0006',
|
||||||
|
):
|
||||||
|
# for volume_id in sorted(os.listdir(standardized_dir)):
|
||||||
|
# debug_orientation(volume_id, level)
|
||||||
|
for level in ('L1', 'L2', 'L3', 'L4', 'L5'):
|
||||||
|
# debug_orientation(volume_id, level)
|
||||||
|
debug_pso(volume_id, level)
|
||||||
|
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
24
xfr_preprocess.py
Normal file
24
xfr_preprocess.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from imaging.preprocessing import process_dataset
|
||||||
|
|
||||||
|
data_root = '/mnt/1220/Public/dataset/Spine/CTSpine1K/data/'
|
||||||
|
label_root = '/mnt/1220/Public/dataset/Spine/CTSpine1K/label/'
|
||||||
|
output_dir = '/mnt/1248/open2/cyrou/CBT/Seg/Resample/standardized-xfr/'
|
||||||
|
|
||||||
|
label_map = {
|
||||||
|
'colon': 'conlon',
|
||||||
|
'COVID-19': 'COVID-19',
|
||||||
|
'HNSCC-3DCT-RT_neck': 'HNSCC-3DCT-RT_neck',
|
||||||
|
'liver': 'Liver',
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
for key, value in label_map.items():
|
||||||
|
data_dir = os.path.join(data_root, key)
|
||||||
|
label_dir = os.path.join(label_root, value)
|
||||||
|
process_dataset(data_dir, label_dir, output_dir)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Loading…
Reference in a new issue