179 lines
No EOL
4.9 KiB
Python
179 lines
No EOL
4.9 KiB
Python
import random
|
|
|
|
from scipy.ndimage import map_coordinates
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch
|
|
from core.intersection import center_line_intersections_torch
|
|
from core.scoring import cl_score_torch, cl_score_torch_xfr
|
|
|
|
# Global variables (used in objective_function)
|
|
image1_array = None # cortical_nii.gz
|
|
image2_array = None # binarynii.gz
|
|
image2_shape = None
|
|
image3_array = None # roi2.nii.gz
|
|
diameter = None
|
|
length = None
|
|
spacing = [0.5, 0.5, 0.5]
|
|
device = None
|
|
grid = None
|
|
USE_TIP_PENALTY = None
|
|
|
|
def set_global_context(
|
|
cortical,
|
|
spine,
|
|
shape,
|
|
spacing_,
|
|
device_,
|
|
grid_,
|
|
use_tip_penalty=False # 新增
|
|
):
|
|
global cortical_tensor, spine_tensor, image2_shape, spacing, device, grid, USE_TIP_PENALTY
|
|
|
|
cortical_tensor = cortical
|
|
spine_tensor = spine
|
|
image2_shape = shape
|
|
spacing = spacing_
|
|
device = device_
|
|
grid = grid_
|
|
USE_TIP_PENALTY = use_tip_penalty
|
|
|
|
def cylinder_circle_line_intersection_loss_deductions_torch(
|
|
diameter: float,
|
|
length: float,
|
|
params: list[float],
|
|
image_shape: tuple[int, int, int],
|
|
cortical_tensor: torch.Tensor,
|
|
spine_tensor: torch.Tensor,
|
|
spacing: list[float],
|
|
device: torch.device
|
|
) -> float:
|
|
"""
|
|
Computes the loss for a given set of cylinder params in PyTorch,
|
|
returning a Python float for PSO consumption.
|
|
"""
|
|
position_z, position_y, position_x, azimuth, altitude = params
|
|
|
|
cyl_fwd = generate_cylinder_n_torch(
|
|
diameter,
|
|
length,
|
|
position_z,
|
|
position_y,
|
|
position_x,
|
|
float(azimuth),
|
|
float(altitude),
|
|
image_shape,
|
|
spacing,
|
|
device,
|
|
grid
|
|
)
|
|
|
|
cyl_opp = generate_cylinder_o_torch(
|
|
diameter,
|
|
length,
|
|
position_z,
|
|
position_y,
|
|
position_x,
|
|
float(azimuth),
|
|
float(altitude),
|
|
image_shape,
|
|
spacing,
|
|
device,
|
|
grid
|
|
)
|
|
|
|
|
|
# We call the center_line_intersections in Torch mode
|
|
intersections, _ = center_line_intersections_torch(
|
|
position_z,
|
|
position_y,
|
|
position_x,
|
|
azimuth,
|
|
altitude,
|
|
length,
|
|
spine_tensor,
|
|
spacing,
|
|
device
|
|
)
|
|
|
|
cyl_tip = None
|
|
if USE_TIP_PENALTY:
|
|
cyl_tip = generate_cylinder_tip_torch(
|
|
diameter, length,
|
|
position_z, position_y, position_x,
|
|
float(azimuth), float(altitude),
|
|
image_shape, spacing, device, grid
|
|
)
|
|
|
|
# loss_value = cl_score_torch(
|
|
loss_value = cl_score_torch_xfr(
|
|
cortical_tensor, spine_tensor,
|
|
cyl_fwd, cyl_opp, intersections,
|
|
cylinder_tip_torch=cyl_tip
|
|
)
|
|
|
|
return loss_value
|
|
|
|
def objective_function_xfr(params: list[float], y_indices) -> float:
|
|
"""
|
|
Wrapper for the PSO objective function, calling our Torch-based loss function.
|
|
Now params includes diameter and length at the end.
|
|
params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw]
|
|
"""
|
|
|
|
# position_params = params[:5] # [z, y, x, azimuth, altitude]
|
|
# diameter_raw = params[5]
|
|
# length_raw = params[6]
|
|
|
|
z, x, azimuth, altitude, diameter_raw, length_raw = params
|
|
y = y_indices[round(z), round(x)] #+ random.uniform(-0.5, 0.5)
|
|
|
|
# coords = np.array([[z], [x]])
|
|
# result = map_coordinates(y_indices, coords, order=1)
|
|
# y= result[0]
|
|
|
|
position_params = [z, y, x, azimuth, altitude]
|
|
|
|
# 將連續值轉換為離散值
|
|
# diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
|
|
diameter_discrete, length_discrete = diameter_raw, length_raw
|
|
|
|
loss = cylinder_circle_line_intersection_loss_deductions_torch(
|
|
diameter_discrete,
|
|
length_discrete,
|
|
position_params,
|
|
image2_shape,
|
|
cortical_tensor,
|
|
spine_tensor,
|
|
spacing,
|
|
device
|
|
)
|
|
return loss
|
|
|
|
def objective_function(params: list[float]) -> float:
|
|
"""
|
|
Wrapper for the PSO objective function, calling our Torch-based loss function.
|
|
Now params includes diameter and length at the end.
|
|
params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw]
|
|
"""
|
|
position_params = params[:5] # [z, y, x, azimuth, altitude]
|
|
diameter_raw = params[5]
|
|
length_raw = params[6]
|
|
|
|
# 將連續值轉換為離散值
|
|
diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
|
|
# diameter_discrete, length_discrete = diameter_raw, length_raw
|
|
|
|
loss = cylinder_circle_line_intersection_loss_deductions_torch(
|
|
diameter_discrete,
|
|
length_discrete,
|
|
position_params,
|
|
image2_shape,
|
|
cortical_tensor,
|
|
spine_tensor,
|
|
spacing,
|
|
device
|
|
)
|
|
return loss |