CBT_project/core/objective.py
2026-04-10 13:25:27 +08:00

136 lines
No EOL
3.5 KiB
Python

import torch
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch
from core.intersection import center_line_intersections_torch
from core.scoring import cl_score_torch
# Global variables (used in objective_function)
image1_array = None # cortical_nii.gz
image2_array = None # binarynii.gz
image2_shape = None
image3_array = None # roi2.nii.gz
diameter = None
length = None
spacing = [0.5, 0.5, 0.5]
device = None
grid = None
USE_TIP_PENALTY = None
def set_global_context(
cortical,
spine,
shape,
spacing_,
device_,
grid_,
use_tip_penalty=False # 新增
):
global cortical_tensor, spine_tensor, image2_shape, spacing, device, grid, USE_TIP_PENALTY
cortical_tensor = cortical
spine_tensor = spine
image2_shape = shape
spacing = spacing_
device = device_
grid = grid_
USE_TIP_PENALTY = use_tip_penalty
def cylinder_circle_line_intersection_loss_deductions_torch(
diameter: float,
length: float,
params: list[float],
image_shape: tuple[int, int, int],
cortical_tensor: torch.Tensor,
spine_tensor: torch.Tensor,
spacing: list[float],
device: torch.device
) -> float:
"""
Computes the loss for a given set of cylinder params in PyTorch,
returning a Python float for PSO consumption.
"""
position_z, position_y, position_x, azimuth, altitude = params
cyl_fwd = generate_cylinder_n_torch(
diameter,
length,
position_z,
position_y,
position_x,
float(azimuth),
float(altitude),
image_shape,
spacing,
device,
grid
)
cyl_opp = generate_cylinder_o_torch(
diameter,
length,
position_z,
position_y,
position_x,
float(azimuth),
float(altitude),
image_shape,
spacing,
device,
grid
)
# We call the center_line_intersections in Torch mode
intersections, _ = center_line_intersections_torch(
position_z,
position_y,
position_x,
azimuth,
altitude,
length,
spine_tensor,
spacing,
device
)
cyl_tip = None
if USE_TIP_PENALTY:
cyl_tip = generate_cylinder_tip_torch(
diameter, length,
position_z, position_y, position_x,
float(azimuth), float(altitude),
image_shape, spacing, device, grid
)
loss_value = cl_score_torch(
cortical_tensor, spine_tensor,
cyl_fwd, cyl_opp, intersections,
cylinder_tip_torch=cyl_tip
)
return loss_value
def objective_function(params: list[float]) -> float:
"""
Wrapper for the PSO objective function, calling our Torch-based loss function.
Now params includes diameter and length at the end.
params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw]
"""
position_params = params[:5] # [z, y, x, azimuth, altitude]
diameter_raw = params[5]
length_raw = params[6]
# 將連續值轉換為離散值
diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
loss = cylinder_circle_line_intersection_loss_deductions_torch(
diameter_discrete,
length_discrete,
position_params,
image2_shape,
cortical_tensor,
spine_tensor,
spacing,
device
)
return loss