commit f84d8963daf31c33446c102b62da60ebb2922767 Author: Xiao Furen Date: Fri Apr 10 13:25:27 2026 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..64d49ae --- /dev/null +++ b/.gitignore @@ -0,0 +1,216 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[codz] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py.cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +# Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +# poetry.lock +# poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. +# https://pdm-project.org/en/latest/usage/project/#working-with-version-control +# pdm.lock +# pdm.toml +.pdm-python +.pdm-build/ + +# pixi +# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. +# pixi.lock +# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one +# in the .venv directory. It is recommended not to include this directory in version control. +.pixi + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# Redis +*.rdb +*.aof +*.pid + +# RabbitMQ +mnesia/ +rabbitmq/ +rabbitmq-data/ + +# ActiveMQ +activemq-data/ + +# SageMath parsed files +*.sage.py + +# Environments +.env +.envrc +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +# .idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Marimo +marimo/_static/ +marimo/_lsp/ +__marimo__/ + +# Streamlit +.streamlit/secrets.toml \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..734461b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python.analysis.extraPaths": [ + "${workspaceFolder}" + ] +} \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..5dd8fd0 --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +#empty \ No newline at end of file diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/config/constant.py b/config/constant.py new file mode 100644 index 0000000..d257b99 --- /dev/null +++ b/config/constant.py @@ -0,0 +1,18 @@ +""" +LABEL_MAP: 一般脊椎與label對應的順序,若是自己標記的需要額外修改定義。 +ALLOWED_DIAMETERS: 螺絲直徑範圍。 +ALLOWED_LENGTHS: 螺絲長度範圍。 +OVERLAP_THRESH: 圓柱體與骨頭接觸的比例(0-1),若低於這個閾值則會重跑。 +spacing: 影像的spacing。 +""" + +LABEL_MAP = { + 1: "C1", 2: "C2", 3: "C3", 4: "C4", 5: "C5", 6: "C6", 7: "C7", + 8: "T1", 9: "T2", 10: "T3", 11: "T4", 12: "T5", 13: "T6", 14: "T7", + 15: "T8", 16: "T9", 17: "T10", 18: "T11", 19: "T12", + 20: "L1", 21: "L2", 22: "L3", 23: "L4", 24: "L5" +} +ALLOWED_DIAMETERS = [3.5, 4.0, 4.5, 5.0] +ALLOWED_LENGTHS = [35, 40, 45, 50] +OVERLAP_THRESH = 0.50 +DEFAULT_SPACING = [0.5, 0.5, 0.5] \ No newline at end of file diff --git a/config/device.py b/config/device.py new file mode 100644 index 0000000..47807d7 --- /dev/null +++ b/config/device.py @@ -0,0 +1,10 @@ +import torch + +def get_device(gpu_id=0): + if torch.cuda.is_available(): + device = torch.device(f"cuda:{gpu_id}") + print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}") + else: + device = torch.device("cpu") + print("CUDA not available, using CPU") + return device \ No newline at end of file diff --git a/core/__init__.py b/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/cylinder.py b/core/cylinder.py new file mode 100644 index 0000000..47d363d --- /dev/null +++ b/core/cylinder.py @@ -0,0 +1,256 @@ +import torch +import numpy as np +from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS + +def create_coordinate_grid( + shape: tuple[int, int, int], + device: torch.device, + dtype: torch.dtype = torch.float32 +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + 建立固定的 3D voxel coordinate grid + returns: + z_t, y_t, x_t with shape = (Z, Y, X) + """ + z_t = torch.arange(shape[0], device=device, dtype=dtype) + y_t = torch.arange(shape[1], device=device, dtype=dtype) + x_t = torch.arange(shape[2], device=device, dtype=dtype) + + z_t, y_t, x_t = torch.meshgrid(z_t, y_t, x_t, indexing='ij') + return z_t, y_t, x_t + +def snap_to_discrete_values(diameter_raw, length_raw): + """ + 將連續值映射到最接近的允許離散值 + + Parameters: + diameter_raw: PSO 給的連續直徑值 + length_raw: PSO 給的連續長度值 + + Returns: + diameter_discrete, length_discrete: 離散化後的值 + """ + # 找最接近的 diameter + diameter_discrete = min(ALLOWED_DIAMETERS, key=lambda x: abs(x - diameter_raw)) + + # 找最接近的 length + length_discrete = min(ALLOWED_LENGTHS, key=lambda x: abs(x - length_raw)) + + return diameter_discrete, length_discrete + + +def generate_cylinder_n_torch( + diameter: float, + length: float, + position_z: float, + position_y: float, + position_x: float, + azimuth: float, + altitude: float, + shape: tuple[int, int, int], + spacing: list[float], + device: torch.device, + grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None +) -> torch.Tensor: + + """ + Generate a "forward" (positive z) cylinder mask in 3D space using PyTorch tensors. + Returns a binary mask (torch.uint8) on the specified device. + """ + + if grid is None: + z_t, y_t, x_t = create_coordinate_grid(shape, device) + else: + z_t, y_t, x_t = grid + + azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32)) + altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32)) + + # Shift + z_t = z_t - position_z + y_t = y_t - position_y + x_t = x_t - position_x + + # Apply rotation (same formula as your NumPy version, but in torch) + x_rot = ( + x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t) + + y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t) + - z_t * torch.sin(altitude_rad_t) + ) + y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t) + z_rot = ( + x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t) + + y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t) + + z_t * torch.cos(altitude_rad_t) + ) + + # Handle spacing + # You can expand or generalize for more spacing options + if spacing == [1, 1, 1]: + radius = diameter / 2.0 + mask = ( + (x_rot**2 + y_rot**2 <= radius**2) + & (z_rot >= 0) + & (z_rot <= length) + ) + elif spacing == [0.5, 0.5, 0.5]: + radius = (diameter / 2.0) * 2 + mask = ( + (x_rot**2 + y_rot**2 <= radius**2) + & (z_rot >= 0) + & (z_rot <= length * 2) + ) + else: + raise ValueError(f"Unsupported spacing: {spacing}") + + # Convert boolean mask to uint8 + cylinder_mask = mask.to(torch.uint8) + + return cylinder_mask + +def generate_cylinder_o_torch( + diameter: float, + length: float, + position_z: float, + position_y: float, + position_x: float, + azimuth: float, + altitude: float, + shape: tuple[int, int, int], + spacing: list[float], + device: torch.device, + grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None +) -> torch.Tensor: + """ + Generate an "opposite" (negative z) cylinder mask in 3D space using PyTorch tensors. + Returns a binary mask (torch.uint8) on the specified device. + """ + + if grid is None: + z_t, y_t, x_t = create_coordinate_grid(shape, device) + else: + z_t, y_t, x_t = grid + + # Convert angles to torch + azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32)) + altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32)) + + # Shift + z_t = z_t - position_z + y_t = y_t - position_y + x_t = x_t - position_x + + # Apply rotation + x_rot = ( + x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t) + + y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t) + - z_t * torch.sin(altitude_rad_t) + ) + y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t) + z_rot = ( + x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t) + + y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t) + + z_t * torch.cos(altitude_rad_t) + ) + + # Handle spacing + if spacing == [1, 1, 1]: + radius = diameter / 2.0 + mask = ( + (x_rot**2 + y_rot**2 <= radius**2) + & (z_rot <= 0) + & (z_rot <= length) + ) + elif spacing == [0.5, 0.5, 0.5]: + radius = (diameter / 2.0) * 2 + mask = ( + (x_rot**2 + y_rot**2 <= radius**2) + & (z_rot <= 0) + & (z_rot >= -length * 2) + ) + else: + raise ValueError(f"Unsupported spacing: {spacing}") + + cylinder_mask_o = mask.to(torch.uint8) + + return cylinder_mask_o + +def generate_cylinder_numpy(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing): + + azimuth = np.radians(azimuth) + altitude = np.radians(altitude) + + cylinder_mask = np.zeros(shape, dtype=np.uint8) + + z, y, x = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].astype(np.float64) + + z -= float(position_z) + y -= float(position_y) + x -= float(position_x) + + x_rot = x * np.cos(azimuth) * np.cos(altitude) + y * np.sin(azimuth) * np.cos(altitude) - z * np.sin(altitude) + y_rot = -x * np.sin(azimuth) + y * np.cos(azimuth) + z_rot = x * np.cos(azimuth) * np.sin(altitude) + y * np.sin(azimuth) * np.sin(altitude) + z * np.cos(altitude) + + if spacing == [1, 1, 1]: + + radius = diameter / 2.0 + cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length) + + elif spacing == [0.5, 0.5, 0.5]: + radius = diameter / 2.0 * 2 # *2 is for resampling + cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length * 2) + + cylinder_mask[cylinder] = 1 + + return cylinder_mask + +def generate_cylinder_tip_torch( + diameter, length, + position_z, position_y, position_x, + azimuth, altitude, + shape, spacing, device, grid=None, + tip_ratio=0.2 # 取末端 20% 當尖端 +) -> torch.Tensor: + """只生成圓柱末端的 mask""" + + if grid is None: + z_t, y_t, x_t = create_coordinate_grid(shape, device) + else: + z_t, y_t, x_t = grid + + azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32)) + altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32)) + + z_t = z_t - position_z + y_t = y_t - position_y + x_t = x_t - position_x + + x_rot = ( + x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t) + + y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t) + - z_t * torch.sin(altitude_rad_t) + ) + y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t) + z_rot = ( + x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t) + + y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t) + + z_t * torch.cos(altitude_rad_t) + ) + + if spacing == [0.5, 0.5, 0.5]: + radius = (diameter / 2.0) * 2 + total_length = length * 2 + else: + radius = diameter / 2.0 + total_length = length + + tip_start = total_length * (1 - tip_ratio) # 末端 20% 開始的位置 + + mask = ( + (x_rot**2 + y_rot**2 <= radius**2) + & (z_rot >= tip_start) # 只取末端 + & (z_rot <= total_length) + ) + + return mask.to(torch.uint8) \ No newline at end of file diff --git a/core/intersection.py b/core/intersection.py new file mode 100644 index 0000000..00194e1 --- /dev/null +++ b/core/intersection.py @@ -0,0 +1,154 @@ +import torch +import numpy as np + +def bresenham3d(x1: float, y1: float, z1: float, x2: float, y2: float, z2: float) -> list[tuple[int, int, int]]: + """ + 3D Bresenham line in Python (CPU). + We keep this CPU-based because it's discrete and typically not large enough + to warrant GPU acceleration. + """ + x1, y1, z1 = round(x1), round(y1), round(z1) + x2, y2, z2 = round(x2), round(y2), round(z2) + + points = [(x1, y1, z1)] + dx = abs(x2 - x1) + dy = abs(y2 - y1) + dz = abs(z2 - z1) + + xs = 1 if x2 > x1 else -1 + ys = 1 if y2 > y1 else -1 + zs = 1 if z2 > z1 else -1 + + # Driving axis X + if dx >= dy and dx >= dz: + p1 = 2 * dy - dx + p2 = 2 * dz - dx + while x1 != x2: + x1 += xs + if p1 >= 0: + y1 += ys + p1 -= 2 * dx + if p2 >= 0: + z1 += zs + p2 -= 2 * dx + p1 += 2 * dy + p2 += 2 * dz + points.append((x1, y1, z1)) + + # Driving axis Y + elif dy >= dx and dy >= dz: + p1 = 2 * dx - dy + p2 = 2 * dz - dy + while y1 != y2: + y1 += ys + if p1 >= 0: + x1 += xs + p1 -= 2 * dy + if p2 >= 0: + z1 += zs + p2 -= 2 * dy + p1 += 2 * dx + p2 += 2 * dz + points.append((x1, y1, z1)) + + # Driving axis Z + else: + p1 = 2 * dy - dz + p2 = 2 * dx - dz + while z1 != z2: + z1 += zs + if p1 >= 0: + y1 += ys + p1 -= 2 * dz + if p2 >= 0: + x1 += xs + p2 -= 2 * dz + p1 += 2 * dy + p2 += 2 * dx + points.append((x1, y1, z1)) + + return points + + +def center_line_intersections_torch( + position_z: float, + position_y: float, + position_x: float, + azimuth: float, + altitude: float, + length: float, + image_tensor: torch.Tensor, + spacing: list[float], + device:torch.device +) -> tuple[int, torch.Tensor]: + """ + Computes the number of intersections along the 3D center line in the torch-based spine array. + The line generation (Bresenham) is done on CPU, but the intersection counting is done in Torch. + Returns (intersections, line_mask_torch). + """ + azimuth_rad = np.radians(azimuth) + altitude_rad = np.radians(altitude) + + if spacing == [1, 1, 1]: + length2 = length + elif spacing == [0.5, 0.5, 0.5]: + length2 = length * 2 + else: + raise ValueError(f"Unsupported spacing: {spacing}") + + # Direction vectors + direction_z = np.cos(altitude_rad) + direction_y = np.sin(altitude_rad) * np.sin(azimuth_rad) + direction_x = np.sin(altitude_rad) * np.cos(azimuth_rad) + + # Endpoints + end_point = ( + position_z + length2 * direction_z, + position_y + length2 * direction_y, + position_x + length2 * direction_x, + ) + start_opposite = ( + position_z - length2 * direction_z, + position_y - length2 * direction_y, + position_x - length2 * direction_x, + ) + + # Round endpoints + start_point = ( + int(round(position_z)), + int(round(position_y)), + int(round(position_x)), + ) + end_point = tuple(map(int, np.round(end_point))) + start_opposite = tuple(map(int, np.round(start_opposite))) + + # Bresenham line (CPU) + line_points = bresenham3d( + start_opposite[0], start_opposite[1], start_opposite[2], + end_point[0], end_point[1], end_point[2] + ) + + shape = image_tensor.shape # (z, y, x) + line_mask_torch = torch.zeros(shape, device=device, dtype=torch.uint8) + + # Gather values from image_tensor + point_values = [] + for (z_pt, y_pt, x_pt) in line_points: + if 0 <= z_pt < shape[0] and 0 <= y_pt < shape[1] and 0 <= x_pt < shape[2]: + line_mask_torch[z_pt, y_pt, x_pt] = 1 + # Convert to CPU to gather value quickly + point_values.append(image_tensor[z_pt, y_pt, x_pt].item()) + + # Convert to NumPy for the simple difference-based intersection counting + point_values_np = np.array(point_values, dtype=np.int32) + + d1 = np.diff(point_values_np) + d2 = np.diff(d1) + d3 = np.diff(d2) + intersections = ( + np.count_nonzero(np.abs(d1) == 1) + - 2 * np.count_nonzero(np.abs(d2) == 2) + + 2 * np.count_nonzero(np.abs(d3) == 4) + ) + + return intersections, line_mask_torch \ No newline at end of file diff --git a/core/objective.py b/core/objective.py new file mode 100644 index 0000000..f958ad2 --- /dev/null +++ b/core/objective.py @@ -0,0 +1,136 @@ +import torch +from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch + +from core.intersection import center_line_intersections_torch +from core.scoring import cl_score_torch + +# Global variables (used in objective_function) +image1_array = None # cortical_nii.gz +image2_array = None # binarynii.gz +image2_shape = None +image3_array = None # roi2.nii.gz +diameter = None +length = None +spacing = [0.5, 0.5, 0.5] +device = None +grid = None +USE_TIP_PENALTY = None + +def set_global_context( + cortical, + spine, + shape, + spacing_, + device_, + grid_, + use_tip_penalty=False # 新增 +): + global cortical_tensor, spine_tensor, image2_shape, spacing, device, grid, USE_TIP_PENALTY + + cortical_tensor = cortical + spine_tensor = spine + image2_shape = shape + spacing = spacing_ + device = device_ + grid = grid_ + USE_TIP_PENALTY = use_tip_penalty + +def cylinder_circle_line_intersection_loss_deductions_torch( + diameter: float, + length: float, + params: list[float], + image_shape: tuple[int, int, int], + cortical_tensor: torch.Tensor, + spine_tensor: torch.Tensor, + spacing: list[float], + device: torch.device +) -> float: + """ + Computes the loss for a given set of cylinder params in PyTorch, + returning a Python float for PSO consumption. + """ + position_z, position_y, position_x, azimuth, altitude = params + + cyl_fwd = generate_cylinder_n_torch( + diameter, + length, + position_z, + position_y, + position_x, + float(azimuth), + float(altitude), + image_shape, + spacing, + device, + grid + ) + + cyl_opp = generate_cylinder_o_torch( + diameter, + length, + position_z, + position_y, + position_x, + float(azimuth), + float(altitude), + image_shape, + spacing, + device, + grid + ) + + + # We call the center_line_intersections in Torch mode + intersections, _ = center_line_intersections_torch( + position_z, + position_y, + position_x, + azimuth, + altitude, + length, + spine_tensor, + spacing, + device + ) + + cyl_tip = None + if USE_TIP_PENALTY: + cyl_tip = generate_cylinder_tip_torch( + diameter, length, + position_z, position_y, position_x, + float(azimuth), float(altitude), + image_shape, spacing, device, grid + ) + + loss_value = cl_score_torch( + cortical_tensor, spine_tensor, + cyl_fwd, cyl_opp, intersections, + cylinder_tip_torch=cyl_tip + ) + + return loss_value + +def objective_function(params: list[float]) -> float: + """ + Wrapper for the PSO objective function, calling our Torch-based loss function. + Now params includes diameter and length at the end. + params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw] + """ + position_params = params[:5] # [z, y, x, azimuth, altitude] + diameter_raw = params[5] + length_raw = params[6] + + # 將連續值轉換為離散值 + diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw) + + loss = cylinder_circle_line_intersection_loss_deductions_torch( + diameter_discrete, + length_discrete, + position_params, + image2_shape, + cortical_tensor, + spine_tensor, + spacing, + device + ) + return loss \ No newline at end of file diff --git a/core/optimizer.py b/core/optimizer.py new file mode 100644 index 0000000..e36a9ee --- /dev/null +++ b/core/optimizer.py @@ -0,0 +1,691 @@ +import time +from datetime import datetime +import SimpleITK as sitk +import torch +from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour +from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS +from core.objective import objective_function +from pyswarm import pso +import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組 +from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid +from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok +from config.constant import OVERLAP_THRESH +from visualization.res_plot_3d import res_plt_2_torch + +def run_pso_torch( + label_str: str, + image1_path: str, + image2_path: str, + image3_path: str, + folder: str, + swarm_size: int, + max_iter: int, + spacing: list, + CBT: bool, + device: torch.device, + optimize_size: bool = True, + grid=None +): + """ + Main function to run PSO. + 如果 optimize_size=True,diameter 和 length 也會被最佳化 + 如果 optimize_size=False,使用預設值(向後兼容) + """ + start_time = time.time() + + # Use global references + global image1_array, image2_array, image2_shape, image3_array + global diameter, length # 這些現在只用於非最佳化模式 + global spine_tensor, cortical_tensor, spine_roi_tensor + + # Load images + image1 = sitk.ReadImage(image1_path) + image2 = sitk.ReadImage(image2_path) + image3 = sitk.ReadImage(image3_path) + image1_array = sitk.GetArrayFromImage(image1) + image2_array = sitk.GetArrayFromImage(image2) + image3_array = sitk.GetArrayFromImage(image3) + image2_shape = image2_array.shape + image_shape = image2_shape + + # Move arrays to torch + cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8) + spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8) + spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8) + + # ================= [跨檔案注入變數:終極防呆版] ================= + import core.objective + + # 1. 注入 Tensors + core.objective.cortical_tensor = cortical_tensor + core.objective.spine_tensor = spine_tensor + core.objective.spine_roi_tensor = spine_roi_tensor + + # 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列) + core.objective.image1_array = image1_array + core.objective.image2_array = image2_array + core.objective.image3_array = image3_array + + # 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!) + core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行 + core.objective.image_shape = image_shape + core.objective.shape = image_shape + + # 4. 注入環境變數 + core.objective.spacing = spacing + core.objective.device = device + core.objective.grid = grid + + # 5. 注入尺寸參數 (兼容固定尺寸模式) + if not optimize_size: + core.objective.diameter = diameter + core.objective.length = length + # ============================================================== + + azi = azimuth_rotation(image2_path) + res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False) + alt = res['superior']['tilt_angle_deg'] + + # 設定基本的 bounds + if CBT == True: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + azimuth_bounds_l = ((95-azi), (145-azi)) + azimuth_bounds_r = ((50-azi), (85-azi)) + altitude_bounds = ((60-alt), (75-alt)) + else: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + azimuth_bounds_l = (60-azi, 90-azi) + azimuth_bounds_r = (90-azi, 120-azi) + altitude_bounds = (65-alt, 80-alt) + + def eval_overlap_from_position(pos, side: str, optimize_size: bool, + spine_tensor: torch.Tensor, + image_shape, spacing): + """ + 根據 PSO 給的 position 生成 cylinder mask,再算 overlap ratio + side: "L" or "R" 只是方便 debug + """ + if optimize_size: + d, L = snap_to_discrete_values(pos[5], pos[6]) + params_5 = pos[:5] + else: + d, L = diameter, length + params_5 = pos + + cyl_mask = generate_cylinder_n_torch( + d, L, + params_5[0], params_5[1], params_5[2], + params_5[3], params_5[4], + image_shape, spacing, device, grid + ) + + overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor) + return overlap, d, L + + if optimize_size: + # 模式 1:優化 diameter 和 length + print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===") + + # 設定 diameter 和 length 的 bounds(連續範圍) + diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)) + length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS)) + + # bounds 現在有 7 個參數 + lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], + altitude_bounds[0], diameter_bounds[0], length_bounds[0]] + ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], + altitude_bounds[1], diameter_bounds[1], length_bounds[1]] + + lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], + altitude_bounds[0], diameter_bounds[0], length_bounds[0]] + ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], + altitude_bounds[1], diameter_bounds[1], length_bounds[1]] + + else: + # 模式 2:固定 diameter 和 length(向後兼容) + print("=== 固定尺寸模式:最佳化位置和角度 ===") + # 使用預設值(需要在調用時提供) + diameter = 4.5 # 或從參數傳入 + length = 45 # 或從參數傳入 + + lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]] + ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]] + + lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]] + ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]] + + best_loss_l = float('inf') + best_loss_r = float('inf') + best_position_l = None + best_position_r = None + + # Left side optimization + print("\n=== 左側 ===") + position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter) + + overlap_l, diameter_l, length_l = eval_overlap_from_position( + position_l, "L", optimize_size, spine_tensor, image_shape, spacing + ) + print(f"[LEFT] overlap: {overlap_l*100:.1f}%") + + if optimize_size: + print(f"[LEFT] Position: {position_l[:5]}") + print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})") + print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})") + best_position_l = list(position_l[:5]) + [diameter_l, length_l] + else: + print(f"[LEFT] Position: {position_l}") + best_position_l = position_l + + best_loss_l = loss_l + best_overlap_l = overlap_l # 新增 + + # max_retries = 0 + # retries = 0 + + # 左側 retry:loss 要 <=0 且 overlap >= 0.5 才算過關 + # while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries: + # position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter) + # overlap_l, diameter_l, length_l = eval_overlap_from_position( + # position_l, "L", optimize_size, spine_tensor, image_shape, spacing + # ) + + # 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best + # 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案 + # candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l + + # candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) + # best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) + + # if candidate_ok and (not best_ok or loss_l < best_loss_l): + # best_position_l = candidate_pos + # best_loss_l = loss_l + # best_overlap_l = overlap_l + # print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%") + # elif (not best_ok) and (loss_l < best_loss_l): + # best 還不合格時,先用更小 loss 的當暫存(至少越來越好) + # best_position_l = candidate_pos + # best_loss_l = loss_l + # best_overlap_l = overlap_l + # print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%") + # else: + # print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%") + + # retries += 1 + + # Right side optimization + print("\n=== 右側 ===") + position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter) + overlap_r, diameter_r, length_r = eval_overlap_from_position( + position_r, "R", optimize_size, spine_tensor, image_shape, spacing + ) + print(f"[RIGHT] overlap: {overlap_r*100:.1f}%") + + if optimize_size: + diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6]) + print(f"[RIGHT] Position: {position_r[:5]}") + print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})") + print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})") + print(f"[RIGHT] Loss: {loss_r}\n") + + best_position_r = list(position_r[:5]) + [diameter_r, length_r] + else: + print(f"[RIGHT] Position: {position_r}") + print(f"[RIGHT] Loss: {loss_r}\n") + best_position_r = position_r + + best_loss_r = loss_r + best_overlap_r = overlap_r + + # 如果需要 retry(loss > 0) + # max_retries = 10 + # retries = 0 + + # while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries: + # position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter) + # overlap_r, diameter_r, length_r = eval_overlap_from_position( + # position_r, "R", optimize_size, spine_tensor, image_shape, spacing + # ) + + # 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best + # 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案 + # candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r + + # candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) + # best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) + + # if candidate_ok and (not best_ok or loss_r < best_loss_r): + # best_position_r = candidate_pos + # best_loss_r = loss_r + # best_overlap_r = overlap_r + # print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%") + # elif (not best_ok) and (loss_r < best_loss_r): + # best 還不合格時,先用更小 loss 的當暫存(至少越來越好) + # best_position_r = candidate_pos + # best_loss_r = loss_r + # best_overlap_r = overlap_r + # print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%") + # else: + # print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%") + + # retries += 1 + + end_time = time.time() + total_time = end_time - start_time + + # 提取最終的 diameter 和 length + if optimize_size: + final_diameter_l = best_position_l[5] + final_length_l = best_position_l[6] + final_diameter_r = best_position_r[5] + final_length_r = best_position_r[6] + + print(f"\n=== 最終結果 ===") + print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm") + print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm") + else: + final_diameter_l = diameter + final_length_l = length + final_diameter_r = diameter + final_length_r = length + + res_plt_2_torch( + spine_tensor, + cortical_tensor, + image_shape, + image2_path, + 'Output', + label_str, + final_diameter_l, + final_length_l, + final_diameter_r, + final_length_r, + best_position_l, + best_position_r, + swarm_size, + max_iter, + total_time, + spacing, + CBT, + device, + grid) + + return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time + +import time +import numpy as np +import SimpleITK as sitk +import torch +from scipy.optimize import differential_evolution +from scipy.optimize import minimize +from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour +from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS +from core.objective import objective_function +from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid +from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok +from config.constant import OVERLAP_THRESH +from visualization.res_plot_3d import res_plt_2_torch + +def run_de_torch( + label_str: str, + image1_path: str, + image2_path: str, + image3_path: str, + folder: str, + swarm_size: int, + max_iter: int, + spacing: list, + CBT: bool, + device: torch.device, + optimize_size: bool = True, + grid=None +): + """ + 使用 Differential Evolution (DE) 進行最佳化 + """ + start_time = time.time() + + global image1_array, image2_array, image2_shape, image3_array + global diameter, length + global spine_tensor, cortical_tensor, spine_roi_tensor + + image1 = sitk.ReadImage(image1_path) + image2 = sitk.ReadImage(image2_path) + image3 = sitk.ReadImage(image3_path) + image1_array = sitk.GetArrayFromImage(image1) + image2_array = sitk.GetArrayFromImage(image2) + image3_array = sitk.GetArrayFromImage(image3) + image2_shape = image2_array.shape + image_shape = image2_shape + + cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8) + spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8) + spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8) + + # ================= [跨檔案注入變數:終極防呆版] ================= + import core.objective + + # 1. 注入 Tensors + core.objective.cortical_tensor = cortical_tensor + core.objective.spine_tensor = spine_tensor + core.objective.spine_roi_tensor = spine_roi_tensor + + # 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列) + core.objective.image1_array = image1_array + core.objective.image2_array = image2_array + core.objective.image3_array = image3_array + + # 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!) + core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行 + core.objective.image_shape = image_shape + core.objective.shape = image_shape + + # 4. 注入環境變數 + core.objective.spacing = spacing + core.objective.device = device + core.objective.grid = grid + + # 5. 注入尺寸參數 (兼容固定尺寸模式) + if not optimize_size: + core.objective.diameter = diameter + core.objective.length = length + # ============================================================== + + azi = azimuth_rotation(image2_path) + res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False) + alt = res['superior']['tilt_angle_deg'] + + if CBT == True: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + azimuth_bounds_l = ((95-azi), (145-azi)) + azimuth_bounds_r = ((50-azi), (85-azi)) + altitude_bounds = ((60-alt), (75-alt)) + else: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + azimuth_bounds_l = (60-azi, 90-azi) + azimuth_bounds_r = (90-azi, 120-azi) + altitude_bounds = (65-alt, 80-alt) + + def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing): + if optimize_size: + d, L = snap_to_discrete_values(pos[5], pos[6]) + params_5 = pos[:5] + else: + d, L = diameter, length + params_5 = pos + + cyl_mask = generate_cylinder_n_torch( + d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4], + image_shape, spacing, device, grid + ) + overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor) + return overlap, d, L + + if optimize_size: + print("=== DE 最佳化模式:最佳化位置、角度、直徑和長度 ===") + diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)) + length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS)) + + bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, diameter_bounds, length_bounds] + bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, diameter_bounds, length_bounds] + else: + print("=== DE 固定尺寸模式:最佳化位置和角度 ===") + diameter = 4.5 + length = 45 + + bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds] + bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds] + + # DE 的 popsize 實際粒子數 = popsize * len(bounds) + # 為了跟 PSO 公平比較,我們讓它轉換一下 + de_popsize = max(1, swarm_size // len(bounds_l)) + + # --- 左側最佳化 --- + print("\n=== 左側 (DE) ===") + res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter) + position_l, loss_l = res_l.x, res_l.fun + + overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing) + best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l) + best_loss_l, best_overlap_l = loss_l, overlap_l + """ + retries = 0 + while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10: + res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter) + position_l, loss_l = res_l.x, res_l.fun + overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing) + + candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l) + if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l): + best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l + elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l): + best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l + retries += 1 + """ + # --- 右側最佳化 --- + print("\n=== 右側 (DE) ===") + res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter) + position_r, loss_r = res_r.x, res_r.fun + + overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing) + best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r) + best_loss_r, best_overlap_r = loss_r, overlap_r + """ + retries = 0 + while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10: + res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter) + position_r, loss_r = res_r.x, res_r.fun + overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing) + + candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r) + if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r): + best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r + elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r): + best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r + retries += 1 + """ + total_time = time.time() - start_time + + final_diameter_l = best_position_l[5] if optimize_size else diameter + final_length_l = best_position_l[6] if optimize_size else length + final_diameter_r = best_position_r[5] if optimize_size else diameter + final_length_r = best_position_r[6] if optimize_size else length + + res_plt_2_torch( + spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str, + final_diameter_l, final_length_l, final_diameter_r, final_length_r, + best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid + ) + + return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time + +def run_nm_torch( + label_str: str, + image1_path: str, + image2_path: str, + image3_path: str, + folder: str, + swarm_size: int, # NM 不用 swarm_size,但保留參數以維持介面統一 + max_iter: int, + spacing: list, + CBT: bool, + device: torch.device, + optimize_size: bool = True, + grid=None +): + """ + 使用 Nelder-Mead 進行最佳化 + """ + start_time = time.time() + + global image1_array, image2_array, image2_shape, image3_array + global diameter, length + global spine_tensor, cortical_tensor, spine_roi_tensor + + image1 = sitk.ReadImage(image1_path) + image2 = sitk.ReadImage(image2_path) + image3 = sitk.ReadImage(image3_path) + image1_array = sitk.GetArrayFromImage(image1) + image2_array = sitk.GetArrayFromImage(image2) + image3_array = sitk.GetArrayFromImage(image3) + image2_shape = image2_array.shape + image_shape = image2_shape + + cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8) + spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8) + spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8) + # ================= [跨檔案注入變數:終極防呆版] ================= + import core.objective + + # 1. 注入 Tensors + core.objective.cortical_tensor = cortical_tensor + core.objective.spine_tensor = spine_tensor + core.objective.spine_roi_tensor = spine_roi_tensor + + # 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列) + core.objective.image1_array = image1_array + core.objective.image2_array = image2_array + core.objective.image3_array = image3_array + + # 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!) + core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行 + core.objective.image_shape = image_shape + core.objective.shape = image_shape + + # 4. 注入環境變數 + core.objective.spacing = spacing + core.objective.device = device + core.objective.grid = grid + + # 5. 注入尺寸參數 (兼容固定尺寸模式) + if not optimize_size: + core.objective.diameter = diameter + core.objective.length = length + # ============================================================== + + azi = azimuth_rotation(image2_path) + res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False) + alt = res['superior']['tilt_angle_deg'] + + if CBT == True: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + azimuth_bounds_l = ((95-azi), (145-azi)) + azimuth_bounds_r = ((50-azi), (85-azi)) + altitude_bounds = ((60-alt), (75-alt)) + else: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + azimuth_bounds_l = (60-azi, 90-azi) + azimuth_bounds_r = (90-azi, 120-azi) + altitude_bounds = (65-alt, 80-alt) + + def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing): + if optimize_size: + d, L = snap_to_discrete_values(pos[5], pos[6]) + params_5 = pos[:5] + else: + d, L = diameter, length + params_5 = pos + + cyl_mask = generate_cylinder_n_torch( + d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4], + image_shape, spacing, device, grid + ) + overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor) + return overlap, d, L + + if optimize_size: + print("=== NM 最佳化模式 ===") + bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, + (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))] + bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, + (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))] + else: + print("=== NM 固定尺寸模式 ===") + diameter, length = 4.5, 45 + bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds] + bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds] + + def get_random_x0(bounds): + # 產生在 Bounds 內的隨機起始點 + return [np.random.uniform(b[0], b[1]) for b in bounds] + + # --- 左側最佳化 --- + print("\n=== 左側 (Nelder-Mead) ===") + x0_l = get_random_x0(bounds_l) + res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter}) + position_l, loss_l = res_l.x, res_l.fun + + overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing) + best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l) + best_loss_l, best_overlap_l = loss_l, overlap_l + + retries = 0 + while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10: + x0_l = get_random_x0(bounds_l) # 每次 retry 都換一個隨機起始點 + res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter}) + position_l, loss_l = res_l.x, res_l.fun + overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing) + + candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l) + if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l): + best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l + elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l): + best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l + retries += 1 + + # --- 右側最佳化 --- + print("\n=== 右側 (Nelder-Mead) ===") + x0_r = get_random_x0(bounds_r) + res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter}) + position_r, loss_r = res_r.x, res_r.fun + + overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing) + best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r) + best_loss_r, best_overlap_r = loss_r, overlap_r + + retries = 0 + while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10: + x0_r = get_random_x0(bounds_r) + res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter}) + position_r, loss_r = res_r.x, res_r.fun + overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing) + + candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r) + if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r): + best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r + elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r): + best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r + retries += 1 + + total_time = time.time() - start_time + + final_diameter_l = best_position_l[5] if optimize_size else diameter + final_length_l = best_position_l[6] if optimize_size else length + final_diameter_r = best_position_r[5] if optimize_size else diameter + final_length_r = best_position_r[6] if optimize_size else length + + res_plt_2_torch( + spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str, + final_diameter_l, final_length_l, final_diameter_r, final_length_r, + best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid + ) + + return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time \ No newline at end of file diff --git a/core/optimizer_ori.py b/core/optimizer_ori.py new file mode 100644 index 0000000..69e418f --- /dev/null +++ b/core/optimizer_ori.py @@ -0,0 +1,604 @@ +import time +from datetime import datetime +import SimpleITK as sitk +import torch +from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour +from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS +from core.objective import objective_function +from pyswarm import pso +import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組 +from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid +from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok +from config.constant import OVERLAP_THRESH +from visualization.res_plot_3d import res_plt_2_torch + +def run_pso_torch( + label_str: str, + image1_path: str, + image2_path: str, + image3_path: str, + folder: str, + swarm_size: int, + max_iter: int, + spacing: list, + CBT: bool, + device: torch.device, + optimize_size: bool = True, + grid=None +): + """ + Main function to run PSO. + 如果 optimize_size=True,diameter 和 length 也會被最佳化 + 如果 optimize_size=False,使用預設值(向後兼容) + """ + start_time = time.time() + + # Use global references + global image1_array, image2_array, image2_shape, image3_array + global diameter, length # 這些現在只用於非最佳化模式 + global spine_tensor, cortical_tensor, spine_roi_tensor + + # Load images + image1 = sitk.ReadImage(image1_path) + image2 = sitk.ReadImage(image2_path) + image3 = sitk.ReadImage(image3_path) + image1_array = sitk.GetArrayFromImage(image1) + image2_array = sitk.GetArrayFromImage(image2) + image3_array = sitk.GetArrayFromImage(image3) + image2_shape = image2_array.shape + image_shape = image2_shape + + # Move arrays to torch + cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8) + spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8) + spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8) + + azi = azimuth_rotation(image2_path) + res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False) + alt = res['superior']['tilt_angle_deg'] + + # 設定基本的 bounds + if CBT == True: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + azimuth_bounds_l = ((95-azi), (145-azi)) + azimuth_bounds_r = ((50-azi), (85-azi)) + altitude_bounds = ((60-alt), (75-alt)) + else: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + azimuth_bounds_l = (60-azi, 90-azi) + azimuth_bounds_r = (90-azi, 120-azi) + altitude_bounds = (65-alt, 80-alt) + + def eval_overlap_from_position(pos, side: str, optimize_size: bool, + spine_tensor: torch.Tensor, + image_shape, spacing): + """ + 根據 PSO 給的 position 生成 cylinder mask,再算 overlap ratio + side: "L" or "R" 只是方便 debug + """ + if optimize_size: + d, L = snap_to_discrete_values(pos[5], pos[6]) + params_5 = pos[:5] + else: + d, L = diameter, length + params_5 = pos + + cyl_mask = generate_cylinder_n_torch( + d, L, + params_5[0], params_5[1], params_5[2], + params_5[3], params_5[4], + image_shape, spacing, device, grid + ) + + overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor) + return overlap, d, L + + if optimize_size: + # 模式 1:優化 diameter 和 length + print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===") + + # 設定 diameter 和 length 的 bounds(連續範圍) + diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)) + length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS)) + + # bounds 現在有 7 個參數 + lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], + altitude_bounds[0], diameter_bounds[0], length_bounds[0]] + ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], + altitude_bounds[1], diameter_bounds[1], length_bounds[1]] + + lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], + altitude_bounds[0], diameter_bounds[0], length_bounds[0]] + ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], + altitude_bounds[1], diameter_bounds[1], length_bounds[1]] + + else: + # 模式 2:固定 diameter 和 length(向後兼容) + print("=== 固定尺寸模式:最佳化位置和角度 ===") + # 使用預設值(需要在調用時提供) + diameter = 4.5 # 或從參數傳入 + length = 45 # 或從參數傳入 + + lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]] + ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]] + + lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]] + ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]] + + best_loss_l = float('inf') + best_loss_r = float('inf') + best_position_l = None + best_position_r = None + + # Left side optimization + print("\n=== 左側 ===") + position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter) + + overlap_l, diameter_l, length_l = eval_overlap_from_position( + position_l, "L", optimize_size, spine_tensor, image_shape, spacing + ) + print(f"[LEFT] overlap: {overlap_l*100:.1f}%") + + if optimize_size: + print(f"[LEFT] Position: {position_l[:5]}") + print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})") + print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})") + best_position_l = list(position_l[:5]) + [diameter_l, length_l] + else: + print(f"[LEFT] Position: {position_l}") + best_position_l = position_l + + best_loss_l = loss_l + best_overlap_l = overlap_l # 新增 + + max_retries = 10 + retries = 0 + + # 左側 retry:loss 要 <=0 且 overlap >= 0.5 才算過關 + while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries: + position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter) + overlap_l, diameter_l, length_l = eval_overlap_from_position( + position_l, "L", optimize_size, spine_tensor, image_shape, spacing + ) + + # 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best + # 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案 + candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l + + candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) + best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) + + if candidate_ok and (not best_ok or loss_l < best_loss_l): + best_position_l = candidate_pos + best_loss_l = loss_l + best_overlap_l = overlap_l + print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%") + elif (not best_ok) and (loss_l < best_loss_l): + # best 還不合格時,先用更小 loss 的當暫存(至少越來越好) + best_position_l = candidate_pos + best_loss_l = loss_l + best_overlap_l = overlap_l + print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%") + else: + print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%") + + retries += 1 + + # Right side optimization + print("\n=== 右側 ===") + position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter) + overlap_r, diameter_r, length_r = eval_overlap_from_position( + position_r, "R", optimize_size, spine_tensor, image_shape, spacing + ) + print(f"[RIGHT] overlap: {overlap_r*100:.1f}%") + + if optimize_size: + diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6]) + print(f"[RIGHT] Position: {position_r[:5]}") + print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})") + print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})") + print(f"[RIGHT] Loss: {loss_r}\n") + + best_position_r = list(position_r[:5]) + [diameter_r, length_r] + else: + print(f"[RIGHT] Position: {position_r}") + print(f"[RIGHT] Loss: {loss_r}\n") + best_position_r = position_r + + best_loss_r = loss_r + best_overlap_r = overlap_r + + # 如果需要 retry(loss > 0) + max_retries = 10 + retries = 0 + + while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries: + position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter) + overlap_r, diameter_r, length_r = eval_overlap_from_position( + position_r, "R", optimize_size, spine_tensor, image_shape, spacing + ) + + # 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best + # 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案 + candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r + + candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) + best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) + + if candidate_ok and (not best_ok or loss_r < best_loss_r): + best_position_r = candidate_pos + best_loss_r = loss_r + best_overlap_r = overlap_r + print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%") + elif (not best_ok) and (loss_r < best_loss_r): + # best 還不合格時,先用更小 loss 的當暫存(至少越來越好) + best_position_r = candidate_pos + best_loss_r = loss_r + best_overlap_r = overlap_r + print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%") + else: + print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%") + + retries += 1 + end_time = time.time() + total_time = end_time - start_time + + # 提取最終的 diameter 和 length + if optimize_size: + final_diameter_l = best_position_l[5] + final_length_l = best_position_l[6] + final_diameter_r = best_position_r[5] + final_length_r = best_position_r[6] + + print(f"\n=== 最終結果 ===") + print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm") + print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm") + else: + final_diameter_l = diameter + final_length_l = length + final_diameter_r = diameter + final_length_r = length + + res_plt_2_torch( + spine_tensor, + cortical_tensor, + image_shape, + image2_path, + 'Output', + label_str, + final_diameter_l, + final_length_l, + final_diameter_r, + final_length_r, + best_position_l, + best_position_r, + swarm_size, + max_iter, + total_time, + spacing, + CBT, + device, + grid) + + return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time + +import time +import numpy as np +import SimpleITK as sitk +import torch +from scipy.optimize import differential_evolution +from scipy.optimize import minimize +from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour +from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS +from core.objective import objective_function +from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid +from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok +from config.constant import OVERLAP_THRESH +from visualization.res_plot_3d import res_plt_2_torch + +def run_de_torch( + label_str: str, + image1_path: str, + image2_path: str, + image3_path: str, + folder: str, + swarm_size: int, + max_iter: int, + spacing: list, + CBT: bool, + device: torch.device, + optimize_size: bool = True, + grid=None +): + """ + 使用 Differential Evolution (DE) 進行最佳化 + """ + start_time = time.time() + + global image1_array, image2_array, image2_shape, image3_array + global diameter, length + global spine_tensor, cortical_tensor, spine_roi_tensor + + image1 = sitk.ReadImage(image1_path) + image2 = sitk.ReadImage(image2_path) + image3 = sitk.ReadImage(image3_path) + image1_array = sitk.GetArrayFromImage(image1) + image2_array = sitk.GetArrayFromImage(image2) + image3_array = sitk.GetArrayFromImage(image3) + image2_shape = image2_array.shape + image_shape = image2_shape + + cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8) + spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8) + spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8) + + azi = azimuth_rotation(image2_path) + res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False) + alt = res['superior']['tilt_angle_deg'] + + if CBT == True: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + azimuth_bounds_l = ((95-azi), (145-azi)) + azimuth_bounds_r = ((50-azi), (85-azi)) + altitude_bounds = ((60-alt), (75-alt)) + else: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + azimuth_bounds_l = (60-azi, 90-azi) + azimuth_bounds_r = (90-azi, 120-azi) + altitude_bounds = (65-alt, 80-alt) + + def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing): + if optimize_size: + d, L = snap_to_discrete_values(pos[5], pos[6]) + params_5 = pos[:5] + else: + d, L = diameter, length + params_5 = pos + + cyl_mask = generate_cylinder_n_torch( + d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4], + image_shape, spacing, device, grid + ) + overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor) + return overlap, d, L + + if optimize_size: + print("=== DE 最佳化模式:最佳化位置、角度、直徑和長度 ===") + diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)) + length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS)) + + bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, diameter_bounds, length_bounds] + bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, diameter_bounds, length_bounds] + else: + print("=== DE 固定尺寸模式:最佳化位置和角度 ===") + diameter = 4.5 + length = 45 + + bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds] + bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds] + + # DE 的 popsize 實際粒子數 = popsize * len(bounds) + # 為了跟 PSO 公平比較,我們讓它轉換一下 + de_popsize = max(1, swarm_size // len(bounds_l)) + + # --- 左側最佳化 --- + print("\n=== 左側 (DE) ===") + res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter) + position_l, loss_l = res_l.x, res_l.fun + + overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing) + best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l) + best_loss_l, best_overlap_l = loss_l, overlap_l + + retries = 0 + while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10: + res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter) + position_l, loss_l = res_l.x, res_l.fun + overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing) + + candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l) + if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l): + best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l + elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l): + best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l + retries += 1 + + # --- 右側最佳化 --- + print("\n=== 右側 (DE) ===") + res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter) + position_r, loss_r = res_r.x, res_r.fun + + overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing) + best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r) + best_loss_r, best_overlap_r = loss_r, overlap_r + + retries = 0 + while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10: + res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter) + position_r, loss_r = res_r.x, res_r.fun + overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing) + + candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r) + if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r): + best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r + elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r): + best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r + retries += 1 + + total_time = time.time() - start_time + + final_diameter_l = best_position_l[5] if optimize_size else diameter + final_length_l = best_position_l[6] if optimize_size else length + final_diameter_r = best_position_r[5] if optimize_size else diameter + final_length_r = best_position_r[6] if optimize_size else length + + res_plt_2_torch( + spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str, + final_diameter_l, final_length_l, final_diameter_r, final_length_r, + best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid + ) + + return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time + +def run_nm_torch( + label_str: str, + image1_path: str, + image2_path: str, + image3_path: str, + folder: str, + swarm_size: int, # NM 不用 swarm_size,但保留參數以維持介面統一 + max_iter: int, + spacing: list, + CBT: bool, + device: torch.device, + optimize_size: bool = True, + grid=None +): + """ + 使用 Nelder-Mead 進行最佳化 + """ + start_time = time.time() + + global image1_array, image2_array, image2_shape, image3_array + global diameter, length + global spine_tensor, cortical_tensor, spine_roi_tensor + + image1 = sitk.ReadImage(image1_path) + image2 = sitk.ReadImage(image2_path) + image3 = sitk.ReadImage(image3_path) + image1_array = sitk.GetArrayFromImage(image1) + image2_array = sitk.GetArrayFromImage(image2) + image3_array = sitk.GetArrayFromImage(image3) + image2_shape = image2_array.shape + image_shape = image2_shape + + cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8) + spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8) + spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8) + + azi = azimuth_rotation(image2_path) + res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False) + alt = res['superior']['tilt_angle_deg'] + + if CBT == True: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + azimuth_bounds_l = ((95-azi), (145-azi)) + azimuth_bounds_r = ((50-azi), (85-azi)) + altitude_bounds = ((60-alt), (75-alt)) + else: + z_bounds = (0, image_shape[0] - 1) + y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1) + x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1) + x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1) + azimuth_bounds_l = (60-azi, 90-azi) + azimuth_bounds_r = (90-azi, 120-azi) + altitude_bounds = (65-alt, 80-alt) + + def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing): + if optimize_size: + d, L = snap_to_discrete_values(pos[5], pos[6]) + params_5 = pos[:5] + else: + d, L = diameter, length + params_5 = pos + + cyl_mask = generate_cylinder_n_torch( + d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4], + image_shape, spacing, device, grid + ) + overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor) + return overlap, d, L + + if optimize_size: + print("=== NM 最佳化模式 ===") + bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, + (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))] + bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, + (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))] + else: + print("=== NM 固定尺寸模式 ===") + diameter, length = 4.5, 45 + bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds] + bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds] + + def get_random_x0(bounds): + # 產生在 Bounds 內的隨機起始點 + return [np.random.uniform(b[0], b[1]) for b in bounds] + + # --- 左側最佳化 --- + print("\n=== 左側 (Nelder-Mead) ===") + x0_l = get_random_x0(bounds_l) + res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter}) + position_l, loss_l = res_l.x, res_l.fun + + overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing) + best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l) + best_loss_l, best_overlap_l = loss_l, overlap_l + + retries = 0 + while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10: + x0_l = get_random_x0(bounds_l) # 每次 retry 都換一個隨機起始點 + res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter}) + position_l, loss_l = res_l.x, res_l.fun + overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing) + + candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l) + if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l): + best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l + elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l): + best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l + retries += 1 + + # --- 右側最佳化 --- + print("\n=== 右側 (Nelder-Mead) ===") + x0_r = get_random_x0(bounds_r) + res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter}) + position_r, loss_r = res_r.x, res_r.fun + + overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing) + best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r) + best_loss_r, best_overlap_r = loss_r, overlap_r + + retries = 0 + while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10: + x0_r = get_random_x0(bounds_r) + res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter}) + position_r, loss_r = res_r.x, res_r.fun + overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing) + + candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r) + if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r): + best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r + elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r): + best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r + retries += 1 + + total_time = time.time() - start_time + + final_diameter_l = best_position_l[5] if optimize_size else diameter + final_length_l = best_position_l[6] if optimize_size else length + final_diameter_r = best_position_r[5] if optimize_size else diameter + final_length_r = best_position_r[6] if optimize_size else length + + res_plt_2_torch( + spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str, + final_diameter_l, final_length_l, final_diameter_r, final_length_r, + best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid + ) + + return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time \ No newline at end of file diff --git a/core/scoring.py b/core/scoring.py new file mode 100644 index 0000000..c5d7490 --- /dev/null +++ b/core/scoring.py @@ -0,0 +1,151 @@ +import torch +from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch +from config.constant import OVERLAP_THRESH + +def cl_score_torch( + cortical_tensor: torch.Tensor, + spine_tensor: torch.Tensor, + cylinder_torch: torch.Tensor, + cylinder_o_torch: torch.Tensor, + intersections: int, + diameter: float = None, + length: float = None, + cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask +) -> float: + """ + 漸進式評分:優先確保找到骨頭,再改善細節 + """ + cyl_total = cylinder_torch.sum().item() + overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item() + null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item() + null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item() + + if cyl_total == 0: + return float(1e9) # 極差的情況 + + overlap_ratio = overlap / cyl_total + + score = 0 + + # === 階段 1:首要目標是找到骨頭(overlap > 0) === + if overlap == 0: + # 完全沒有 overlap 是最糟糕的情況 + score -= 500000 # 超大懲罰 + # 如果連 spine 都沒穿過,更糟 + if intersections == 0: + score -= 500000 + return float(-score) + + # === 階段 2:有找到骨頭了,開始改善品質 === + + # 1. Overlap 獎勵(非線性,鼓勵快速提升) + if overlap_ratio < 0.1: + # 0-10%:每增加 1% 給大量獎勵(鼓勵探索) + score += overlap * 5000 # 很高的單位獎勵 + elif overlap_ratio < 0.3: + # 10-30%:中等獎勵 + score += overlap * 3000 + elif overlap_ratio < 0.5: + # 30-50%:正常獎勵 + score += overlap * 2000 + else: + # 50%+:獎勵 + 額外比例獎勵 + score += overlap * 2000 + score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎 + + # 2. Intersection 控制(稍微放寬) + if intersections == 1: + score += 20000 # 完美 + elif intersections == 0: + score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好) + elif intersections == 2: + score -= 10000 # 可接受但不理想 + else: + score -= intersections * 15000 + + # 3. Null voxel 懲罰(漸進式) + null_ratio = null_vox / cyl_total + + if overlap_ratio < 0.2: + # 如果 overlap 還很少,對 null voxel 寬容一點 + score -= null_vox * 300 + elif overlap_ratio < 0.4: + score -= null_vox * 600 + else: + # overlap 夠高了,開始嚴格要求 + if null_ratio > 0.5: + score -= null_vox * 1500 + else: + score -= null_vox * 800 + + # 4. 反向圓柱懲罰 + score -= null_vox2 * 1000 + + # 5. 尺寸合理性(放寬) + if diameter is not None and length is not None: + if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0) + score -= 3000 + if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60) + score -= 3000 + + # 6. 尖端 breach 懲罰 + if cylinder_tip_torch is not None: + tip_total = cylinder_tip_torch.sum().item() + if tip_total > 0: + tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item() + tip_breach_ratio = tip_breach / tip_total + if tip_breach_ratio > 0: + score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多 + + return float(-score) + +def get_overlap_ratio( + position_params: list, + diameter: float, + length: float, + cortical_tensor: torch.Tensor, + image_shape: tuple, + spacing: list, + device: torch.device, + grid=None +) -> float: + """ + 計算 Cylinder 與 Cortical Bone 的重疊比例 (%) + """ + # 生成 Cylinder Mask + cyl_mask = generate_cylinder_n_torch( + diameter, length, + position_params[0], position_params[1], position_params[2], + position_params[3], position_params[4], + image_shape, spacing, device, grid + ) + + # 計算體積 (Voxel count) + cyl_vol = torch.sum(cyl_mask).item() + + if cyl_vol == 0: + return 0.0 + + # 計算重疊部分 + # 注意:這裡使用 cortical_tensor (與 cl_score_torch 邏輯一致) + overlap_count = ((cortical_tensor == 1) & (cyl_mask == 1)).sum().item() + + return (overlap_count / cyl_vol) * 100.0 + +def compute_overlap_ratio_from_cylinder_mask(cyl_mask: torch.Tensor, + spine_mask: torch.Tensor, + eps: float = 1e-6) -> float: + """ + 一個常見定義: overlap = intersection / cylinder_volume + 你也可以改成 intersection / spine_volume 或 Dice,依你論文/需求一致即可。 + cyl_mask, spine_mask: uint8/bool tensor, same shape + """ + cyl = cyl_mask.bool() + sp = spine_mask.bool() + inter = (cyl & sp).sum().item() + denom = cyl.sum().item() + return float(inter) / float(denom + eps) + +def is_solution_ok(loss: float, overlap: float, overlap_thresh: float = OVERLAP_THRESH) -> bool: + return (loss <= 0) and (overlap >= overlap_thresh) + diff --git a/experiment.ipynb b/experiment.ipynb new file mode 100644 index 0000000..9ccd701 --- /dev/null +++ b/experiment.ipynb @@ -0,0 +1,561 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "52324a94", + "metadata": {}, + "outputs": [], + "source": [ + "# Core\n", + "from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch\n", + "from core.objective import set_global_context\n", + "from core.cylinder import create_coordinate_grid\n", + "\n", + "# Imaging\n", + "from imaging.preprocessing import process_single_image, process_dataset\n", + "\n", + "# Visualization\n", + "from visualization.res_cyl_to_CT import cyl_and_CT\n", + "from visualization.res_cyl_nifti import save_cyl\n", + "\n", + "# Config\n", + "from config.device import get_device\n", + "from config.constant import *\n", + "\n", + "import SimpleITK as sitk\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e680dfe", + "metadata": {}, + "outputs": [], + "source": [ + "# ====== SEGMENTATION: SINGLE DATA ======\n", + "\n", + "image_path = \"/home/cyrou/CBT/CTSpine1K/data/colon/1.3.6.1.4.1.9328.50.4.0783.nii.gz\"\n", + "label_path = \"/home/cyrou/CBT/CTSpine1K/label/colon/1.3.6.1.4.1.9328.50.4.0783_seg.nii.gz\"\n", + "output_dir = \"/home/cyrou/CBT/Dataset\" # set by user\n", + "\n", + "process_single_image(image_path, label_path, output_dir_base=output_dir)\n", + "# If you need to assign special labels, please use labels_to_process = [23, 24] .etc " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1961cf4c", + "metadata": {}, + "outputs": [], + "source": [ + "# ====== SEGMENTATION: DATASET ======\n", + "\n", + "image_dir = \"/home/cyrou/CBT/CTSpine1K/data/colon/\"\n", + "label_dir = \"/home/cyrou/CBT/CTSpine1K/label/colon/\"\n", + "output_dir = \"/home/cyrou/CBT/Dataset\" # set by user\n", + "\n", + "process_dataset(image_dir, label_dir, output_dir, labels_to_process=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2bed470b", + "metadata": {}, + "outputs": [], + "source": [ + "# ====== CASE 設定 ======\n", + "cortical_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_cortical.nii.gz\"\n", + "binary_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_binary.nii.gz\"\n", + "roi_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_roi2.nii.gz\"\n", + "\n", + "# ====== PSO ======\n", + "swarm_size = 70\n", + "max_iter = 100\n", + "\n", + "# ====== DEVICE ======\n", + "device = get_device(3) \n", + "\n", + "# ====== OTHER ======\n", + "spacing = [0.5, 0.5, 0.5]\n", + "CBT = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc17e584", + "metadata": {}, + "outputs": [], + "source": [ + "import SimpleITK as sitk\n", + "import torch\n", + "\n", + "cortical_image = sitk.ReadImage(cortical_path)\n", + "binary_image = sitk.ReadImage(binary_path)\n", + "roi_image = sitk.ReadImage(roi_path)\n", + "\n", + "cortical_array = sitk.GetArrayFromImage(cortical_image)\n", + "binary_array = sitk.GetArrayFromImage(binary_image)\n", + "roi_array = sitk.GetArrayFromImage(roi_image) \n", + "image_shape = binary_array.shape\n", + "\n", + "cortical_tensor = torch.tensor(cortical_array, device=device)\n", + "binary_tensor = torch.tensor(binary_array, device=device)\n", + "\n", + "grid = create_coordinate_grid(image_shape, device)\n", + "\n", + "set_global_context(\n", + " cortical=cortical_tensor,\n", + " spine=binary_tensor,\n", + " shape=image_shape,\n", + " spacing_=spacing,\n", + " device_=device,\n", + " grid_=grid,\n", + " use_tip_penalty=False \n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c3a3e47", + "metadata": {}, + "outputs": [], + "source": [ + "best_l, loss_l, best_r, loss_r, total_time = run_de_torch(\n", + " label_str=\"L5\",\n", + " image1_path=cortical_path,\n", + " image2_path=binary_path,\n", + " image3_path=roi_path,\n", + " folder=\"Output\",\n", + " swarm_size=swarm_size,\n", + " max_iter=max_iter,\n", + " spacing=spacing,\n", + " CBT=CBT,\n", + " device=device,\n", + " optimize_size=True,\n", + " grid=grid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "94e1ee75", + "metadata": {}, + "outputs": [], + "source": [ + "best_l, loss_l, best_r, loss_r, total_time = run_nm_torch(\n", + " label_str=\"L5\",\n", + " image1_path=cortical_path,\n", + " image2_path=binary_path,\n", + " image3_path=roi_path,\n", + " folder=\"Output\",\n", + " swarm_size=swarm_size,\n", + " max_iter=max_iter,\n", + " spacing=spacing,\n", + " CBT=CBT,\n", + " device=device,\n", + " optimize_size=True,\n", + " grid=grid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be72d637", + "metadata": {}, + "outputs": [], + "source": [ + "best_l, loss_l, best_r, loss_r, total_time = run_pso_torch(\n", + " label_str=\"L5\",\n", + " image1_path=cortical_path,\n", + " image2_path=binary_path,\n", + " image3_path=roi_path,\n", + " folder=\"Output\",\n", + " swarm_size=swarm_size,\n", + " max_iter=max_iter,\n", + " spacing=spacing,\n", + " CBT=CBT,\n", + " device=device,\n", + " optimize_size=True,\n", + " grid=grid)\n", + "\n", + "cylinder_L, cylinder_R = save_cyl(best_l, best_r, spacing, roi_path, output_base='Output', CBT=True)\n", + "\n", + "cyl_and_CT(diameter_l=best_l[5],\n", + " diameter_r=best_r[5],\n", + " length_l=best_l[6],\n", + " length_r=best_r[6],\n", + " roi_image_path=roi_path,\n", + " cylinder_1=cylinder_L,\n", + " cylinder_2=cylinder_R,\n", + " base_folder=\"Output\",\n", + " CBT=CBT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9723e65b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using GPU 1: NVIDIA A40\n", + "📂 總共偵測到 784 個病人資料夾,準備開始處理...\n", + "\n", + "=========================================\n", + "🚀 === Running 1.3.6.1.4.1.9328.50.4.0001_L1 ===\n", + "=========================================\n", + "👉 執行 CBT 軌跡最佳化...\n", + "Superior 終板傾斜角度: 6.64°\n", + "斜率: 0.1164, R²: 0.7599\n", + "=== 最佳化模式:最佳化位置、角度、直徑和長度 ===\n", + "\n", + "=== 左側 ===\n", + "Stopping search: maximum iterations reached --> 100\n", + "[LEFT] overlap: 15.3%\n", + "[LEFT] Position: [ 27.47724944 40.25868267 37.64458752 100.44943459 58.1295552 ]\n", + "[LEFT] Diameter: 5.0 mm (raw: 4.98)\n", + "[LEFT] Length: 50 mm (raw: 49.97)\n", + "Stopping search: maximum iterations reached --> 100\n", + "[LEFT][retry 1] ⚠️ not ok | loss improved=-15892000.0000, overlap=8.4%\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 62\u001b[0m\n\u001b[1;32m 59\u001b[0m image_shape \u001b[38;5;241m=\u001b[39m binary_array\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 61\u001b[0m grid \u001b[38;5;241m=\u001b[39m create_coordinate_grid(image_shape, device)\n\u001b[0;32m---> 62\u001b[0m best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time \u001b[38;5;241m=\u001b[39m \u001b[43mrun_pso_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabel_str\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage1_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage2_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 66\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage3_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE3\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 67\u001b[0m \u001b[43m \u001b[49m\u001b[43mfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 68\u001b[0m \u001b[43m \u001b[49m\u001b[43mswarm_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 69\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_iter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 70\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 71\u001b[0m \u001b[43m \u001b[49m\u001b[43mCBT\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 73\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimize_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrid\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 77\u001b[0m cylinder_L, cylinder_R \u001b[38;5;241m=\u001b[39m save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOutput\u001b[39m\u001b[38;5;124m'\u001b[39m, CBT\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 78\u001b[0m cyl_and_CT(best_pos_l[\u001b[38;5;241m5\u001b[39m], best_pos_r[\u001b[38;5;241m5\u001b[39m], best_pos_l[\u001b[38;5;241m6\u001b[39m], best_pos_r[\u001b[38;5;241m6\u001b[39m], IMAGE3, cylinder_L, cylinder_R, base_folder\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOutput\u001b[39m\u001b[38;5;124m'\u001b[39m, CBT\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", + "File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/optimizer.py:194\u001b[0m, in \u001b[0;36mrun_pso_torch\u001b[0;34m(label_str, image1_path, image2_path, image3_path, folder, swarm_size, max_iter, spacing, CBT, device, optimize_size, grid)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;66;03m# 左側 retry:loss 要 <=0 且 overlap >= 0.5 才算過關\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m (best_loss_l \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m best_overlap_l \u001b[38;5;241m<\u001b[39m OVERLAP_THRESH) \u001b[38;5;129;01mand\u001b[39;00m retries \u001b[38;5;241m<\u001b[39m max_retries:\n\u001b[0;32m--> 194\u001b[0m position_l, loss_l \u001b[38;5;241m=\u001b[39m \u001b[43mpso\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobjective_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlb_l\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mub_l\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mswarmsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mswarm_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmaxiter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_iter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 195\u001b[0m overlap_l, diameter_l, length_l \u001b[38;5;241m=\u001b[39m eval_overlap_from_position(\n\u001b[1;32m 196\u001b[0m position_l, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mL\u001b[39m\u001b[38;5;124m\"\u001b[39m, optimize_size, spine_tensor, image_shape, spacing\n\u001b[1;32m 197\u001b[0m )\n\u001b[1;32m 199\u001b[0m \u001b[38;5;66;03m# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best\u001b[39;00m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;66;03m# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案\u001b[39;00m\n", + "File \u001b[0;32m~/.conda/envs/envpy/lib/python3.13/site-packages/pyswarm/pso.py:145\u001b[0m, in \u001b[0;36mpso\u001b[0;34m(func, lb, ub, ieqcons, f_ieqcons, args, kwargs, swarmsize, omega, phip, phig, maxiter, minstep, minfunc, debug)\u001b[0m\n\u001b[1;32m 143\u001b[0m x[i, mark1] \u001b[38;5;241m=\u001b[39m lb[mark1]\n\u001b[1;32m 144\u001b[0m x[i, mark2] \u001b[38;5;241m=\u001b[39m ub[mark2]\n\u001b[0;32m--> 145\u001b[0m fx \u001b[38;5;241m=\u001b[39m \u001b[43mobj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;66;03m# Compare particle's best position (if constraints are satisfied)\u001b[39;00m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fx\u001b[38;5;241m<\u001b[39mfp[i] \u001b[38;5;129;01mand\u001b[39;00m is_feasible(x[i, :]):\n", + "File \u001b[0;32m~/.conda/envs/envpy/lib/python3.13/site-packages/pyswarm/pso.py:74\u001b[0m, in \u001b[0;36mpso..\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 71\u001b[0m vlow \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mvhigh\n\u001b[1;32m 73\u001b[0m \u001b[38;5;66;03m# Check for constraint function(s) #########################################\u001b[39;00m\n\u001b[0;32m---> 74\u001b[0m obj \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m f_ieqcons \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(ieqcons):\n", + "File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/objective.py:126\u001b[0m, in \u001b[0;36mobjective_function\u001b[0;34m(params)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;66;03m# 將連續值轉換為離散值\u001b[39;00m\n\u001b[1;32m 124\u001b[0m diameter_discrete, length_discrete \u001b[38;5;241m=\u001b[39m snap_to_discrete_values(diameter_raw, length_raw)\n\u001b[0;32m--> 126\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mcylinder_circle_line_intersection_loss_deductions_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 127\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiameter_discrete\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 128\u001b[0m \u001b[43m \u001b[49m\u001b[43mlength_discrete\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 129\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 130\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage2_shape\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[43m \u001b[49m\u001b[43mcortical_tensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 132\u001b[0m \u001b[43m \u001b[49m\u001b[43mspine_tensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 133\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 134\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\n\u001b[1;32m 135\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n", + "File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/objective.py:54\u001b[0m, in \u001b[0;36mcylinder_circle_line_intersection_loss_deductions_torch\u001b[0;34m(diameter, length, params, image_shape, cortical_tensor, spine_tensor, spacing, device)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;124;03mComputes the loss for a given set of cylinder params in PyTorch, \u001b[39;00m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;124;03mreturning a Python float for PSO consumption.\u001b[39;00m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 52\u001b[0m position_z, position_y, position_x, azimuth, altitude \u001b[38;5;241m=\u001b[39m params\n\u001b[0;32m---> 54\u001b[0m cyl_fwd \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_cylinder_n_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 55\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiameter\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[43m \u001b[49m\u001b[43mlength\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 57\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_z\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_y\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_x\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mazimuth\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43maltitude\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage_shape\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\n\u001b[1;32m 66\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 68\u001b[0m cyl_opp \u001b[38;5;241m=\u001b[39m generate_cylinder_o_torch(\n\u001b[1;32m 69\u001b[0m diameter,\n\u001b[1;32m 70\u001b[0m length,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 79\u001b[0m grid\n\u001b[1;32m 80\u001b[0m )\n\u001b[1;32m 83\u001b[0m \u001b[38;5;66;03m# We call the center_line_intersections in Torch mode\u001b[39;00m\n", + "File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/cylinder.py:83\u001b[0m, in \u001b[0;36mgenerate_cylinder_n_torch\u001b[0;34m(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing, device, grid)\u001b[0m\n\u001b[1;32m 75\u001b[0m x_rot \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 76\u001b[0m x_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 77\u001b[0m \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;241m-\u001b[39m z_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[1;32m 79\u001b[0m )\n\u001b[1;32m 80\u001b[0m y_rot \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mx_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(azimuth_rad_t) \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t)\n\u001b[1;32m 81\u001b[0m z_rot \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 82\u001b[0m x_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[0;32m---> 83\u001b[0m \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mazimuth_rad_t\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[1;32m 84\u001b[0m \u001b[38;5;241m+\u001b[39m z_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 85\u001b[0m )\n\u001b[1;32m 87\u001b[0m \u001b[38;5;66;03m# Handle spacing\u001b[39;00m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;66;03m# You can expand or generalize for more spacing options\u001b[39;00m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m spacing \u001b[38;5;241m==\u001b[39m [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m]:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "import os\n", + "import traceback\n", + "from datetime import datetime\n", + "from core.optimizer import run_pso_torch\n", + "from visualization.res_cyl_nifti import save_cyl\n", + "from visualization.res_cyl_to_CT import cyl_and_CT\n", + "\n", + "device = get_device(1)\n", + "\n", + "# 只需要傳入 base_dir, patient_id, level 即可自動組裝路徑\n", + "def build_paths(base_dir: str, patient_id: str, level: str):\n", + " case_dir = os.path.join(base_dir, patient_id)\n", + " img1 = os.path.join(case_dir, f\"{level}_cortical.nii.gz\")\n", + " img2 = os.path.join(case_dir, f\"{level}_binary2.nii.gz\")\n", + " img3 = os.path.join(case_dir, f\"{level}_roi2.nii.gz\")\n", + " return img1, img2, img3\n", + "\n", + "if __name__ == \"__main__\":\n", + " base_dir = \"/home/cyrou/CBT/Seg/Resample/standardized/\" # 你的 case 資料夾根目錄\n", + " spacing = [0.5, 0.5, 0.5]\n", + " levels = [\"L1\", \"L2\", \"L3\", \"L4\", \"L5\"] # 鎖定要跑的脊椎節段\n", + " date = datetime.now().strftime(\"%Y%m%d\")\n", + " \n", + " failed = []\n", + " skipped = []\n", + " succeeded = []\n", + "\n", + " # 1. 動態獲取 base_dir 下的所有項目,並篩選出「資料夾」\n", + " # 使用 sorted 讓它照字母/數字順序跑,強迫症看了比較舒服\n", + " all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])\n", + " \n", + " print(f\"📂 總共偵測到 {len(all_patients)} 個病人資料夾,準備開始處理...\")\n", + "\n", + " # 2. 外層迴圈:遍歷每一個病人資料夾\n", + " for patient_id in all_patients:\n", + " \n", + " # 3. 內層迴圈:遍歷 L1 到 L5\n", + " for level in levels:\n", + " label_str = f\"{patient_id}_{level}\"\n", + " IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)\n", + "\n", + " # 檢查這個 level 的檔案是否齊全\n", + " missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]\n", + " if missing:\n", + " # 缺檔直接跳過,不印出太多落落長的警告,保持終端機乾淨\n", + " print(f\"[SKIP] ⏭️ {label_str} (缺少 {len(missing)} 個檔案)\")\n", + " skipped.append((label_str, missing))\n", + " continue\n", + "\n", + " print(f\"\\n=========================================\")\n", + " print(f\"🚀 === Running {label_str} ===\")\n", + " print(f\"=========================================\")\n", + "\n", + " try:\n", + " # --- 第一階段:CBT = True ---\n", + " print(f\"👉 執行 CBT 軌跡最佳化...\")\n", + " binary_image = sitk.ReadImage(IMAGE2)\n", + " binary_array = sitk.GetArrayFromImage(binary_image)\n", + " image_shape = binary_array.shape\n", + "\n", + " grid = create_coordinate_grid(image_shape, device)\n", + " best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(\n", + " label_str,\n", + " image1_path=IMAGE1,\n", + " image2_path=IMAGE2,\n", + " image3_path=IMAGE3,\n", + " folder=date,\n", + " swarm_size=70,\n", + " max_iter=100,\n", + " spacing=spacing,\n", + " CBT=True,\n", + " device=device,\n", + " optimize_size=True,\n", + " grid=grid\n", + " )\n", + " \n", + " cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)\n", + " cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)\n", + " \n", + " # --- 第二階段:CBT = False (傳統軌跡 TT) ---\n", + " print(f\"👉 執行傳統軌跡 (TT) 最佳化...\")\n", + " best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(\n", + " label_str,\n", + " image1_path=IMAGE2, # 注意:你原本的扣這裡是 IMAGE2,如果有改過記得確認\n", + " image2_path=IMAGE2,\n", + " image3_path=IMAGE3,\n", + " folder=date,\n", + " swarm_size=70,\n", + " max_iter=100,\n", + " spacing=spacing,\n", + " CBT=False,\n", + " device=device,\n", + " optimize_size=True,\n", + " grid=grid\n", + " )\n", + " \n", + " cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)\n", + " cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)\n", + "\n", + " # 記錄成功 (把兩次的耗時加起來)\n", + " succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))\n", + " print(f\"[OK] ✅ {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s\")\n", + "\n", + " except Exception as e:\n", + " print(f\"[FAIL] ❌ {label_str} 發生錯誤: {e}\")\n", + " traceback.print_exc() # 保留剛學到的神技,印出詳細報錯\n", + " failed.append((label_str, str(e)))\n", + "\n", + " # 4. 最終報告\n", + " print(\"\\n\" + \"=\"*40)\n", + " print(\"🏆 ==== 全自動批次處理總結 ====\")\n", + " print(\"=\"*40)\n", + " print(f\"✅ Succeeded (成功): {len(succeeded)} 節段\")\n", + " print(f\"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段\")\n", + " print(f\"❌ Failed (執行錯誤): {len(failed)} 節段\")\n", + " \n", + " if failed:\n", + " print(\"\\n⚠️ 失敗清單:\")\n", + " for fail_label, err_msg in failed:\n", + " print(f\" - {fail_label}: {err_msg}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bcf55fef", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "usage: ipykernel_launcher.py [-h] --gpu GPU [--total_gpus TOTAL_GPUS]\n", + "ipykernel_launcher.py: error: the following arguments are required: --gpu\n" + ] + }, + { + "ename": "SystemExit", + "evalue": "2", + "output_type": "error", + "traceback": [ + "An exception has occurred, use %tb to see the full traceback.\n", + "\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cyrou/.conda/envs/envpy/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3585: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n", + " warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n" + ] + } + ], + "source": [ + "import os\n", + "import traceback\n", + "import argparse\n", + "import torch\n", + "from datetime import datetime\n", + "\n", + "# 請確保你有 import run_pso_torch, save_cyl, cyl_and_CT 等自訂函數\n", + "\n", + "def build_paths(base_dir: str, patient_id: str, level: str):\n", + " case_dir = os.path.join(base_dir, patient_id)\n", + " img1 = os.path.join(case_dir, f\"{level}_cortical.nii.gz\")\n", + " img2 = os.path.join(case_dir, f\"{level}_binary2.nii.gz\")\n", + " img3 = os.path.join(case_dir, f\"{level}_roi2.nii.gz\")\n", + " return img1, img2, img3\n", + "\n", + "if __name__ == \"__main__\":\n", + " # === 1. 設定命令列引數 (GPU 拆分設定) ===\n", + " parser = argparse.ArgumentParser(description=\"Multi-GPU CBT Batch Processing\")\n", + " parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)')\n", + " parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量')\n", + " args = parser.parse_args()\n", + "\n", + " # === 2. 綁定 GPU ===\n", + " if torch.cuda.is_available():\n", + " device = torch.device(f'cuda:{args.gpu}')\n", + " torch.cuda.set_device(device)\n", + " else:\n", + " device = torch.device('cpu')\n", + " print(\"⚠️ 找不到 CUDA,將使用 CPU\")\n", + "\n", + " print(f\"\\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算\")\n", + "\n", + " # === 3. 基本參數 ===\n", + " base_dir = \"Seg/Resample/standardized\" \n", + " spacing = [0.5, 0.5, 0.5]\n", + " levels = [\"L1\", \"L2\", \"L3\", \"L4\", \"L5\"] \n", + " date = datetime.now().strftime(\"%Y%m%d\")\n", + " \n", + " failed = []\n", + " skipped = []\n", + " succeeded = []\n", + "\n", + " # 動態獲取所有病人資料夾\n", + " all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])\n", + " \n", + " # === 4. 餘數分工:只挑選屬於這張 GPU 的病人 ===\n", + " my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu]\n", + " \n", + " print(f\"📂 總資料庫有 {len(all_patients)} 個病人。\")\n", + " print(f\"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...\")\n", + "\n", + " for patient_id in my_patients:\n", + " for level in levels:\n", + " label_str = f\"{patient_id}_{level}\"\n", + " IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)\n", + "\n", + " missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]\n", + " if missing:\n", + " print(f\"[SKIP] ⏭️ {label_str} (缺檔)\")\n", + " skipped.append((label_str, missing))\n", + " continue\n", + "\n", + " print(f\"\\n=========================================\")\n", + " print(f\"🔥 [GPU {args.gpu}] === Running {label_str} ===\")\n", + " print(f\"=========================================\")\n", + "\n", + " try:\n", + " # --- CBT = True ---\n", + " print(f\"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...\")\n", + " binary_image = sitk.ReadImage(IMAGE2)\n", + " binary_array = sitk.GetArrayFromImage(binary_image)\n", + " image_shape = binary_array.shape\n", + "\n", + " grid = create_coordinate_grid(image_shape, device)\n", + " best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(\n", + " label_str,\n", + " image1_path=IMAGE1,\n", + " image2_path=IMAGE2,\n", + " image3_path=IMAGE3,\n", + " folder=date, \n", + " swarm_size=70,\n", + " max_iter=100,\n", + " spacing=spacing,\n", + " CBT=True,\n", + " device=device,\n", + " optimize_size=True,\n", + " grid=grid # 記得把 device 傳進去\n", + " )\n", + " \n", + " cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)\n", + " cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)\n", + " \n", + " # --- CBT = False ---\n", + " print(f\"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...\")\n", + " best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(\n", + " label_str,\n", + " image1_path=IMAGE2, \n", + " image2_path=IMAGE2,\n", + " image3_path=IMAGE3,\n", + " folder=date, \n", + " swarm_size=70,\n", + " max_iter=100,\n", + " spacing=spacing,\n", + " CBT=False,\n", + " device=device,\n", + " optimize_size=True,\n", + " grid=grid # 記得把 device 傳進去\n", + " )\n", + " \n", + " cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)\n", + " cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)\n", + "\n", + " succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))\n", + " print(f\"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s\")\n", + "\n", + " except Exception as e:\n", + " print(f\"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}\")\n", + " traceback.print_exc()\n", + " failed.append((label_str, str(e)))\n", + "\n", + " # --- GPU 專屬的 Summary ---\n", + " print(\"\\n\" + \"=\"*40)\n", + " print(f\"🏆 ==== GPU {args.gpu} 處理總結 ====\")\n", + " print(\"=\"*40)\n", + " print(f\"✅ Succeeded (成功): {len(succeeded)} 節段\")\n", + " print(f\"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段\")\n", + " print(f\"❌ Failed (執行錯誤): {len(failed)} 節段\")\n", + " \n", + " if failed:\n", + " print(f\"\\n⚠️ GPU {args.gpu} 失敗清單:\")\n", + " for fail_label, err_msg in failed:\n", + " print(f\" - {fail_label}: {err_msg}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "envpy", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/imaging/__init__.py b/imaging/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/imaging/affine.py b/imaging/affine.py new file mode 100644 index 0000000..c6291df --- /dev/null +++ b/imaging/affine.py @@ -0,0 +1,42 @@ +import os +import numpy as np +import nibabel as nib + +def standardize_affine(file_path, output_dir): + + img = nib.load(file_path) + data = img.get_fdata() + affine = img.affine.copy() + + # 初始化翻轉軸 + flip_axes = [] + + # 檢查 X 軸方向 + if affine[0, 0] < 0: + flip_axes.append(0) + affine[0, 0] *= -1 + affine[0, 3] *= -1 # 修正平移部分 + + # 檢查 Y 軸方向 + if affine[1, 1] < 0: + flip_axes.append(1) + affine[1, 1] *= -1 + affine[1, 3] *= -1 # 修正平移部分 + + # 檢查 Z 軸方向 + if affine[2, 2] < 0: + flip_axes.append(2) + affine[2, 2] *= -1 + affine[2, 3] *= -1 # 修正平移部分 + + # 翻轉數據(如果需要) + if flip_axes: + data = np.flip(data, axis=tuple(flip_axes)) + + # 保存修正後的影像 + standardized_img = nib.Nifti1Image(data, affine) + output_path = os.path.join(output_dir, os.path.basename(file_path)) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + nib.save(standardized_img, output_path) + diff --git a/imaging/nifti_io.py b/imaging/nifti_io.py new file mode 100644 index 0000000..66d4704 --- /dev/null +++ b/imaging/nifti_io.py @@ -0,0 +1,53 @@ +import SimpleITK as sitk +import nibabel as nib +import numpy as np + +def read_nifti_sitk(path): + return sitk.ReadImage(path) + +def write_nifti_sitk(img, path): + sitk.WriteImage(img, path) + +def sitk_to_numpy(img): + return sitk.GetArrayFromImage(img) + +def numpy_to_sitk(arr, reference_img): + img = sitk.GetImageFromArray(arr) + img.SetSpacing(reference_img.GetSpacing()) + img.SetOrigin(reference_img.GetOrigin()) + img.SetDirection(reference_img.GetDirection()) + return img + +def sitk_to_nibabel(sitk_image): + data = sitk.GetArrayFromImage(sitk_image) # (z, y, x) + data = np.transpose(data, (2, 1, 0)) # 轉回 (x, y, z) + + origin = np.array(sitk_image.GetOrigin()) + spacing = np.array(sitk_image.GetSpacing()) + direction = np.array(sitk_image.GetDirection()).reshape(3, 3) + + affine = np.eye(4) + affine[:3, :3] = direction * spacing + affine[:3, 3] = origin + + return data, affine + +def nibabel_to_sitk(data, affine): + data_transposed = np.transpose(data, (2, 1, 0)) # (x,y,z) → (z,y,x) + sitk_image = sitk.GetImageFromArray(data_transposed) + + spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0)).tolist() + direction = (affine[:3, :3] / spacing).flatten().tolist() + origin = affine[:3, 3].tolist() + + sitk_image.SetSpacing(spacing) + sitk_image.SetDirection(direction) + sitk_image.SetOrigin(origin) + + return sitk_image + +def read_nifti_nib(path): + return nib.load(path) + +def write_nifti_nib(data, affine, path): + nib.save(nib.Nifti1Image(data, affine), path) \ No newline at end of file diff --git a/imaging/orientation.py b/imaging/orientation.py new file mode 100644 index 0000000..6f75892 --- /dev/null +++ b/imaging/orientation.py @@ -0,0 +1,310 @@ +import numpy as np +import SimpleITK as sitk +from scipy.ndimage import center_of_mass +import matplotlib.pyplot as plt + +def azimuth_rotation(image, show_plt=False, save_plt=False, output_path=None): + + img = sitk.ReadImage(image, sitk.sitkUInt8) + arr_zyx = sitk.GetArrayFromImage(img) # (z, y, x) + + max_proj = np.max(arr_zyx, axis=0) # -> (y, x) + + binary_proj = (max_proj > 0).astype(np.uint8) + + ys, xs = np.where(binary_proj > 0) + if len(xs) < 10: + raise ValueError("Not enough foreground points") + + # 2) centroid ←←← 這裡一定會定義 cx, cy + cy = ys.mean() + cx = xs.mean() + centroid = np.array([cy, cx]) + + cy = ys.mean() + cx = xs.mean() + centroid = np.array([cy, cx]) + + y_min = ys.min() + top_row_mask = (ys == y_min) + xs_top_row = xs[top_row_mask] + + # 取這一排的中位數或平均值 + x_center = int(np.median(xs_top_row)) # 或用 np.mean() + top_point = (y_min, x_center) + # print(f"最上排中心點: {top_point}") + + # 計算從 top_point 到 centroid 的向量 + dy = cy - y_min # y 方向的變化 + dx = cx - x_center # x 方向的變化 + + # 計算與 y 軸的夾角 + # 注意:影像座標系中 y 軸向下,所以要特別處理 + angle_rad = np.arctan2(dx, dy) # 弧度 + angle_deg = np.degrees(angle_rad) # 轉成角度 + + if show_plt: + + fig = plt.figure(figsize=(12, 12)) + + # 視覺化時加上角度資訊 + plt.imshow(binary_proj, cmap='gray') + + plt.scatter(x_center, y_min, c='red', s=60, label='Top') + plt.scatter(cx, cy, c='yellow', s=60, label='Centroid') + plt.plot([x_center, cx], [y_min, cy], 'c-', lw=2, + label=f'Angle with y-axis: {angle_deg:.1f}°') + + plt.legend() + plt.title("Top point + centroid-directed line") + plt.axis("off") + + if save_plt: + if output_path is None: + output_path = "azimuth_rotation.png" + fig.savefig(output_path, dpi=200, bbox_inches="tight") + + plt.show() + plt.close(fig) + + return angle_deg + +def split_spine_anterior_posterior(image_path, center_mode='com'): + """ + 從 sagittal view 看,沿著 y 軸(前後方向)將脊椎切成前半部和後半部 + + Parameters: + image_path (str): 影像路徑 + center_mode (str): 'com' 使用質心的 y 座標,'image' 使用圖片中心的 y 座標 + + Returns: + anterior, posterior: 前半部和後半部的 binary mask + """ + + # Sagittal projection: 沿著 x 軸投影 -> (z, y) + img = sitk.ReadImage(image_path, sitk.sitkUInt8) + arr_zyx = sitk.GetArrayFromImage(img) + + # Sagittal projection + max_proj_sagittal = np.max(arr_zyx, axis=2) + binary_proj = (max_proj_sagittal > 0).astype(np.uint8) + + # 決定切割的 y 座標 + if center_mode == 'com': + cz, cy = center_of_mass(binary_proj) + split_y = int(round(cy)) + label = f'Center of Mass (y={split_y})' + + elif center_mode == 'image': + split_y = binary_proj.shape[1] // 2 + label = f'Image Center (y={split_y})' + + # 檢查是否為數字,且範圍在 0 到 1 之間 (不含邊界) + elif isinstance(center_mode, (int, float)) and 0 < center_mode < 1: + ys = np.where(binary_proj > 0)[1] + + if ys.size == 0: # 額外保險:如果投影是空的 + split_y = binary_proj.shape[1] // 2 + else: + y_min, y_max = ys.min(), ys.max() + split_y = int(round(y_min + center_mode * (y_max - y_min))) + + label = f'Custom Ratio {center_mode} (y={split_y})' + + else: + raise ValueError("center_mode 必須是 'com'、'image' 或介於 0 到 1 之間的浮點數 (例如 0.8)") + + # 切割:anterior (y < split_y) 和 posterior (y >= split_y) + anterior = binary_proj.copy() + posterior = binary_proj.copy() + + anterior[:, :split_y] = 0 # 保留spine前半部(image後半部)(y >= split_y) + posterior[:, split_y:] = 0 # 保留spine後半部(image前半部)(y < split_y) + + return anterior, posterior, binary_proj + +def analyze_vertebral_tilt_contour(image_path, edge_type='superior', show_plot=False, debug=False, save_plt=False, output_path=None): + """ + 通過椎體前緣輪廓分析傾斜(可選上或下終板) + + Parameters: + edge_type: 'superior' 上終板, 'inferior' 下終板, 'both' 兩者都分析 + """ + + # 切割出 anterior 部分 + anterior, posterior, binary_proj = split_spine_anterior_posterior(image_path, center_mode='com') + + zs, ys = np.where(anterior > 0) + + if len(zs) == 0: + return None + + from sklearn.linear_model import RANSACRegressor + + results = {} + + # === 根據 edge_type 決定要分析哪些邊 === + edges_to_analyze = [] + if edge_type == 'superior' or edge_type == 'both': + edges_to_analyze.append('superior') + if edge_type == 'inferior' or edge_type == 'both': + edges_to_analyze.append('inferior') + + all_edge_points = {} + all_inliers = {} + all_outliers = {} + all_slopes = {} + all_intercepts = {} + all_angles = {} + + for edge in edges_to_analyze: + # 對每個 y,找對應的邊緣點 + edge_points = [] + unique_ys = np.unique(ys) + + for y in unique_ys: + z_at_y = zs[ys == y] + + if edge == 'superior': + z_edge = z_at_y.max() # 最上面的點(z 最小) + else: # inferior + z_edge = z_at_y.min() # 最下面的點(z 最大) + + edge_points.append([y, z_edge]) + + edge_points = np.array(edge_points) + all_edge_points[edge] = edge_points + + if len(edge_points) < 10: + continue + + # RANSAC 擬合 + X = edge_points[:, 0].reshape(-1, 1) + y_data = edge_points[:, 1] + + ransac = RANSACRegressor( + residual_threshold=5.0, + random_state=42 + ) + ransac.fit(X, y_data) + + inlier_mask = ransac.inlier_mask_ + outlier_mask = ~inlier_mask + + edge_points_inliers = edge_points[inlier_mask] + edge_points_outliers = edge_points[outlier_mask] + + all_inliers[edge] = edge_points_inliers + all_outliers[edge] = edge_points_outliers + + # 獲取擬合結果 + slope = ransac.estimator_.coef_[0] + intercept = ransac.estimator_.intercept_ + + all_slopes[edge] = slope + all_intercepts[edge] = intercept + + # 計算 R² + y_pred = ransac.predict(edge_points_inliers[:, 0].reshape(-1, 1)) + ss_res = np.sum((edge_points_inliers[:, 1] - y_pred) ** 2) + ss_tot = np.sum((edge_points_inliers[:, 1] - np.mean(edge_points_inliers[:, 1])) ** 2) + r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0 + + # 計算傾斜角度 + tilt_angle = np.degrees(np.arctan(slope)) + all_angles[edge] = tilt_angle + + if debug: + print(f"\n=== {edge.upper()} ENDPLATE ===") + print(f"Total points: {len(edge_points)}") + print(f"Inliers: {len(edge_points_inliers)}") + print(f"Outliers: {len(edge_points_outliers)}") + + print(f"{edge.capitalize()} 終板傾斜角度: {tilt_angle:.2f}°") + print(f"斜率: {slope:.4f}, R²: {r_squared:.4f}") + + results[edge] = { + 'tilt_angle_deg': tilt_angle, + 'slope': slope, + 'intercept': intercept, + 'r_squared': r_squared, + 'n_inliers': len(edge_points_inliers), + 'n_outliers': len(edge_points_outliers) + } + + # Visualization + if show_plot: + n_edges = len(edges_to_analyze) + fig, axes = plt.subplots(n_edges, 2, figsize=(16, 6*n_edges)) + + if n_edges == 1: + axes = axes.reshape(1, -1) + + colors = {'superior': 'red', 'inferior': 'cyan'} + + for idx, edge in enumerate(edges_to_analyze): + edge_points = all_edge_points[edge] + edge_points_inliers = all_inliers[edge] + edge_points_outliers = all_outliers[edge] + slope = all_slopes[edge] + intercept = all_intercepts[edge] + tilt_angle = all_angles[edge] + color = colors[edge] + + # 左圖:scatter plot + ax_left = axes[idx, 0] + + if len(edge_points_outliers) > 0: + ax_left.scatter(edge_points_outliers[:, 0], edge_points_outliers[:, 1], + c='lightcoral', s=30, alpha=0.6, marker='x', + label=f'Outliers ({len(edge_points_outliers)})', zorder=3) + + ax_left.scatter(edge_points_inliers[:, 0], edge_points_inliers[:, 1], + c=color, s=20, alpha=0.7, + label=f'Inliers ({len(edge_points_inliers)})', zorder=4) + + # 擬合線 + y_line = np.array([edge_points[:, 0].min(), edge_points[:, 0].max()]) + z_line = slope * y_line + intercept + ax_left.plot(y_line, z_line, 'lime', linewidth=3, + label=f'Angle: {tilt_angle:.1f}°\nR²: {results[edge]["r_squared"]:.3f}', + zorder=5) + + ax_left.set_xlabel('y') + ax_left.set_ylabel('z') + ax_left.set_title(f'{edge.capitalize()} Endplate Analysis\nTilt: {tilt_angle:.1f}°') + ax_left.invert_yaxis() + ax_left.legend() + ax_left.grid(True, alpha=0.3) + + # 右圖:原始影像 + ax_right = axes[idx, 1] + ax_right.imshow(anterior, cmap='gray', aspect='equal') + + if len(edge_points_outliers) > 0: + ax_right.scatter(edge_points_outliers[:, 0], edge_points_outliers[:, 1], + c='red', s=40, alpha=0.7, marker='x', label='Outliers', zorder=4) + + ax_right.scatter(edge_points_inliers[:, 0], edge_points_inliers[:, 1], + c=color, s=25, alpha=0.8, label='Inliers', zorder=3) + + ax_right.plot(y_line, z_line, 'lime', linewidth=3, linestyle='--', + label=f'{edge.capitalize()}: {tilt_angle:.1f}°', zorder=5) + + ax_right.set_title(f'Anterior Half - {edge.capitalize()} Edge') + ax_right.set_xlabel('y') + ax_right.set_ylabel('z') + ax_right.invert_yaxis() + ax_right.legend() + + plt.tight_layout() + + if save_plt: + if output_path is None: + output_path = "analyze_vertebral_tilt_contour.png" + fig.savefig(output_path, dpi=200, bbox_inches="tight") + + plt.show() + plt.close(fig) + + return results diff --git a/imaging/preprocessing.py b/imaging/preprocessing.py new file mode 100644 index 0000000..ababd16 --- /dev/null +++ b/imaging/preprocessing.py @@ -0,0 +1,150 @@ +import os +import SimpleITK as sitk +from imaging.resample import resample_img +from imaging.affine import standardize_affine +from imaging.segmentation import seg_bone +import json +import glob +from config.constant import LABEL_MAP +from imaging.nifti_io import sitk_to_nibabel, nibabel_to_sitk + +PROGRESS_FILE = "progress.json" + +def load_progress(): + if os.path.exists(PROGRESS_FILE): + with open(PROGRESS_FILE, "r") as f: + return json.load(f) + return {} + +def save_progress(progress): + with open(PROGRESS_FILE, "w") as f: + json.dump(progress, f, indent=2) + +def process_single_image(image_path, label_path, output_dir_base=None): + + image = sitk.ReadImage(image_path) + label = sitk.ReadImage(label_path) + file_name = os.path.basename(image_path) + name = file_name.replace(".nii.gz", "") + + # LabelStatisticsImageFilter computes statistics (e.g., mean, minimum, maximum, median) of pixel values in an image, segmented by labels in a corresponding label image. + lsif = sitk.LabelStatisticsImageFilter() + lsif.Execute(image, label) + + # Assume to have some sitk image (itk_image) and label (itk_label) + resampled_sitk_img = resample_img(image, out_spacing=[0.5, 0.5, 0.5], is_label=False) + resampled_sitk_lbl = resample_img(label, out_spacing=[0.5, 0.5, 0.5], is_label=True) + + # 取得現有 label + lssif = sitk.LabelShapeStatisticsImageFilter() + lssif.Execute(label) + existing_labels = lssif.GetLabels() # 這會回傳 list,例如 [1,2,3,20,21] + + print(f"Existing labels in {os.path.basename(label_path)}: {existing_labels}") + + # 建立每個檔案的輸出資料夾 + file_name = os.path.basename(image_path) + name = file_name.replace(".nii.gz", "") + output_dir = os.path.join(output_dir_base, name) + os.makedirs(output_dir, exist_ok=True) + + # 存現有 label 到 txt + txt_path = os.path.join(output_dir, f"{name}_labels.txt") + with open(txt_path, "w") as f: + for lab in existing_labels: + f.write(f"{lab}\t{LABEL_MAP.get(lab, 'Unknown')}\n") + + # 遍歷現有 label 做分割 + for n in existing_labels: + try: + roi_path, binary_path, roi2_path, cortical_path = seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_dir, label_map=LABEL_MAP) + for path in [roi_path, binary_path, roi2_path, cortical_path]: + standardize_affine(path, output_dir) + except RuntimeError as e: + print(f"Label {n} could not be processed, skipping. Error: {e}") + + return { + "processed_labels": [lab for lab in existing_labels], + "missing_labels": [] + } + +def process_dataset(image_dir, label_dir, output_dir, labels_to_process=None): + image_files = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz"))) + total_files = len(image_files) + print(f"Total files: {total_files}") + + progress = load_progress() + all_file_summary = [] + + for idx, image_path in enumerate(image_files, 1): + file_name = os.path.basename(image_path) + name = file_name.replace(".nii.gz", "") + label_path = os.path.join(label_dir, file_name.replace(".nii.gz", "_seg.nii.gz")) + + file_summary = { + "file_name": file_name, + "current_labels": [], + "missing_labels": [] + } + + if progress.get(name, {}).get("finished", False): + print(f"[{idx}/{total_files}] Already finished: {file_name}") + file_summary["current_labels"] = progress[name].get("processed_labels", []) + file_summary["missing_labels"] = progress[name].get("missing_labels", []) + all_file_summary.append(file_summary) + continue + + if not os.path.exists(label_path): + print(f"[{idx}/{total_files}] Warning: label not found for {file_name}") + file_summary["note"] = "Label file not found" + all_file_summary.append(file_summary) + continue + + try: + result = process_single_image(image_path, label_path, output_dir_base=output_dir) + except Exception as e: + print(f"[{idx}/{total_files}] Error processing {file_name}: {e}") + file_summary["note"] = f"Error: {e}" + all_file_summary.append(file_summary) + continue + + file_summary["current_labels"] = result["processed_labels"] + file_summary["missing_labels"] = result["missing_labels"] + all_file_summary.append(file_summary) + + progress[name] = { + "finished": True, + "processed_labels": result["processed_labels"], + "missing_labels": result["missing_labels"] + } + save_progress(progress) + + print(f"[{idx}/{total_files}] Finished: {file_name} | Missing labels: {result['missing_labels'] or 'None'}") + + # --- Summary --- + summary_path = os.path.join(output_dir, "all_files_label_summary.txt") + os.makedirs(os.path.dirname(summary_path), exist_ok=True) + print(f"\nWriting summary to {summary_path}...") + + with open(summary_path, "w") as f: + f.write("--- CTSpine1K Dataset Label Summary ---\n") + f.write(f"Total files processed: {total_files}\n\n") + + for summary in all_file_summary: + f.write("================================================\n") + f.write(f"File: {summary['file_name']}\n") + + processed_labels_str = ", ".join([str(l) for l in summary['current_labels']]) + f.write(f"Labels processed: {processed_labels_str}\n") + + if summary['missing_labels']: + missing_str = ", ".join([f"{l} ({LABEL_MAP.get(l, 'Unknown')})" for l in summary['missing_labels']]) + f.write(f"🚨 Missing Labels: {missing_str}\n") + else: + f.write("✅ Missing Labels: None\n") + + if "note" in summary: + f.write(f"Note: {summary['note']}\n") + f.write("================================================\n\n") + + print("All done! Summary file created.") \ No newline at end of file diff --git a/imaging/resample.py b/imaging/resample.py new file mode 100644 index 0000000..4051cf2 --- /dev/null +++ b/imaging/resample.py @@ -0,0 +1,29 @@ +import numpy as np +import SimpleITK as sitk +from config.constant import LABEL_MAP + +def resample_img(sitk_image, out_spacing=[0.5, 0.5, 0.5], is_label=False): + + # Resample images to 2mm spacing with SimpleITK + original_spacing = sitk_image.GetSpacing() + original_size = sitk_image.GetSize() + + out_size = [ + int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))), + int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))), + int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))] + + resample = sitk.ResampleImageFilter() + resample.SetOutputSpacing(out_spacing) + resample.SetSize(out_size) + resample.SetOutputDirection(sitk_image.GetDirection()) + resample.SetOutputOrigin(sitk_image.GetOrigin()) + resample.SetTransform(sitk.Transform()) + resample.SetDefaultPixelValue(sitk_image.GetPixelIDValue()) + + if is_label: + resample.SetInterpolator(sitk.sitkNearestNeighbor) + else: + resample.SetInterpolator(sitk.sitkBSpline) + + return resample.Execute(sitk_image) \ No newline at end of file diff --git a/imaging/segmentation.py b/imaging/segmentation.py new file mode 100644 index 0000000..56a4744 --- /dev/null +++ b/imaging/segmentation.py @@ -0,0 +1,75 @@ +import os +import SimpleITK as sitk +from config.constant import LABEL_MAP +import numpy as np + +""" +# 沿用原本 LABEL_MAP +seg_bone(n, name, img, lbl) + +# user 自定義 +my_map = {1: "L1", 2: "L2", 3: "L3"} +seg_bone(n, name, img, lbl, label_map=my_map) +""" + +def seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_base=None, label_map=LABEL_MAP): + + if output_base==None: + output_base=='Dataset' + + if n not in label_map: + raise ValueError(f"Label {n} not found in label_map") + + label_name = label_map[n] + + lssif = sitk.LabelShapeStatisticsImageFilter() + lssif.Execute(resampled_sitk_lbl) + + if not lssif.HasLabel(n): + raise RuntimeError(f"Label {n} not found") + + bbox2 = lssif.GetBoundingBox(n) + + roi = sitk.RegionOfInterest(resampled_sitk_img, bbox2[3:], bbox2[:3]) + label2 = sitk.RegionOfInterest(resampled_sitk_lbl, bbox2[3:], bbox2[:3]) + roi_path = os.path.join(output_base, f"{label_name}_roi.nii.gz") + sitk.WriteImage(roi, roi_path) + + binary = sitk.BinaryThreshold(label2, lowerThreshold=n, upperThreshold=n, outsideValue=0, insideValue=1) + binary_path = os.path.join(output_base, f"{label_name}_binary.nii.gz") + sitk.WriteImage(binary, binary_path) + + roi_pixel_type = roi.GetPixelID() + binary_cast = sitk.Cast(binary, roi_pixel_type) + roi2 = roi * binary_cast + roi2_path = os.path.join(output_base, f"{label_name}_roi2.nii.gz") + sitk.WriteImage(roi2, roi2_path) + + lsif = sitk.LabelStatisticsImageFilter() + label2_int = sitk.Cast(label2, sitk.sitkUInt16) + lsif.Execute(roi2, label2_int) + labels_in_roi = lsif.GetLabels() + if n in labels_in_roi: + roi_hu = sitk.GetArrayFromImage(roi2) + threshold = np.percentile(roi_hu, 60) + else: + threshold = lsif.GetMedian(labels_in_roi[0]) + + cortical = sitk.BinaryThreshold(roi2, lowerThreshold=threshold, upperThreshold=10000, outsideValue=0, insideValue=1) + cortical_path = os.path.join(output_base, f"{label_name}_cortical.nii.gz") + sitk.WriteImage(cortical, cortical_path) + + return roi_path, binary_path, roi2_path, cortical_path + +""" +Dataset/ +└── standardized/ + └── subject001/ + ├── L1_roi.nii.gz + ├── L1_binary.nii.gz + ├── L1_roi2.nii.gz + ├── L1_cortical.nii.gz + ├── L2_roi.nii.gz + ├── L2_binary.nii.gz + ... +""" \ No newline at end of file diff --git a/optimization_benchmark_results_DE.csv b/optimization_benchmark_results_DE.csv new file mode 100644 index 0000000..6fe4d5d --- /dev/null +++ b/optimization_benchmark_results_DE.csv @@ -0,0 +1,31 @@ +Patient_ID,Method,Run,Total_Time_sec,Left_Cortical_Score,Left_Bone_Score,Left_C_B_Ratio,Right_Cortical_Score,Right_Bone_Score,Right_C_B_Ratio,Left_Loss,Right_Loss +000121,DE,1,35.753113746643066,79.55674436377531,96.79021780664884,0.821950256612712,87.70637541275083,97.85369570739142,0.8963011031797534,-11164556.744363775,-13010306.375412751 +000121,DE,2,35.85626721382141,79.35336048879837,99.08350305498982,0.8008735868448099,87.6778455284553,97.73882113821138,0.8970626462178322,-11129753.360488798,-13004677.845528455 +000121,DE,3,31.091505527496338,79.55239064089523,99.12258392675484,0.8025657472738935,87.7418533604888,98.30702647657841,0.8925288100479088,-11130152.390640896,-12984341.853360489 +000141,DE,1,140.6565911769867,86.313240728941,96.68663183382185,0.8927112165546328,91.54715065363625,96.59855311587765,0.9477072657995007,-12680113.24072894,-13716747.150653636 +000141,DE,2,192.32001185417175,86.34512598625605,96.27131585645202,0.8968935888962327,91.14729076570846,96.60391757822437,0.9435154707044107,-12705945.125986256,-13703347.290765708 +000141,DE,3,205.5456428527832,85.99695585996956,96.61339421613394,0.8901142181961402,90.93793628633075,96.54778525193552,0.9418956224530038,-12685796.95585997,-13712737.93628633 +000151,DE,1,184.93012046813965,84.59359954687058,96.37496459926366,0.8777549221275346,81.10938163294185,92.27394934201217,0.8790062873792364,-11051193.59954687,-10356109.381632943 +000151,DE,2,255.19188165664673,84.5317135188586,96.27065969769741,0.8780630961115188,80.6579133135677,91.6278413101793,0.8802773497688753,-11085531.713518858,-10327657.913313568 +000151,DE,3,318.5383884906769,84.57409238593021,96.08701794038707,0.8801822993237282,80.9469964664311,92.09893992932862,0.878913443830571,-11070974.092385931,-10367546.996466432 +000161,DE,1,293.82185339927673,78.75095201827875,97.14394516374715,0.8106624852998824,69.2503896840017,93.45330877143262,0.7410159211523882,-9339550.95201828,-7986250.389684001 +000161,DE,2,191.1857042312622,68.1406907098254,90.35300114693513,0.7541607898448519,71.80098553489111,84.10427594976952,0.8537138537138537,-8717140.690709826,-7609600.9855348915 +000161,DE,3,231.82087516784668,78.40345748061523,95.01716028981822,0.8251505016722409,69.55722167208941,93.80393266374311,0.7415171165736691,-9315203.457480615,-8031957.22167209 +000171,DE,1,190.15103244781494,79.46289752650176,94.53003533568905,0.8406100478468899,89.35980654193713,96.29629629629629,0.9279672217816549,-10053062.897526503,-13240559.806541936 +000171,DE,2,202.20453667640686,79.63538722442057,94.1351045788581,0.8459690737126557,93.69841269841271,97.98412698412699,0.9562611372104326,-10089835.38722442,-11529098.412698412 +000171,DE,3,206.49177026748657,79.90092002830856,94.8761500353857,0.8421602267641354,89.3779417376924,95.76389772293601,0.9333156216790648,-10080900.920028308,-13289377.941737693 +000181,DE,1,240.98572325706482,73.77466581795035,97.83577339274348,0.7540663630448926,80.7670534956128,91.80583073874894,0.8797595190380763,-9471774.66581795,-10197567.053495612 +000181,DE,2,225.65672206878662,72.20582032897511,91.57879938141431,0.7884556340190358,80.67940552016985,92.04529370134466,0.8765185299092726,-8721605.820328975,-10186679.405520169 +000181,DE,3,229.29561853408813,73.67884884757417,97.75881828600535,0.7536798228474664,80.4436908294475,91.64900381517592,0.8777366635831021,-9511078.848847574,-10265243.690829448 +000191,DE,1,157.32477927207947,80.80101716465353,95.2638270820089,0.8481815148481815,79.08787541713015,95.67773716828222,0.8266068759342302,-9113401.017164653,-8877287.87541713 +000191,DE,2,183.22506833076477,80.77839555202542,95.20254169976171,0.8484899048890372,79.23358244553982,95.80219430752108,0.8270539419087135,-9101778.395552026,-8926433.58244554 +000191,DE,3,175.80304598808289,80.66000317309218,95.36728541964145,0.8457827316586258,78.95238095238095,95.38095238095238,0.8277583624563155,-9124460.003173092,-8896152.38095238 +000201,DE,1,236.4226109981537,77.4121271530598,91.79386640526536,0.8433257055682686,71.38433515482696,87.14025500910748,0.8191889632107022,-9615012.12715306,-6598584.335154827 +000201,DE,2,253.00249886512756,77.48911669709311,91.81294761971634,0.8439889874579382,76.36016544702512,91.01177219217308,0.8390141583639222,-9598089.116697093,-7789560.165447025 +000201,DE,3,218.03053379058838,77.02247191011236,91.26404494382022,0.8439519852262235,71.35056425191118,87.25882781215873,0.8176887776387151,-9574222.471910112,-6583150.564251911 +000211,DE,1,187.18615746498108,80.2594472645234,93.30231246474902,0.860208553725253,72.7966425028615,85.05659417525118,0.8558612440191388,-10114259.447264524,-9730596.642502861 +000211,DE,2,210.16704440116882,80.55634001694436,94.22479525557752,0.8549378090813726,78.69571367944546,85.98104399490735,0.9152681803224744,-10122956.340016944,-9908895.713679446 +000211,DE,3,254.02969360351562,80.57950530035336,93.80918727915194,0.8589724273014917,78.66120604434403,86.21663606835193,0.9123669123669124,-10126379.505300354,-9916861.206044344 +000221,DE,1,175.14598894119263,74.39877847054332,90.29138567247742,0.8239853438556932,83.13391038696538,94.51374745417516,0.8795959595959596,-10091798.778470544,-11923133.910386965 +000221,DE,2,192.38243532180786,74.53368861819565,90.2042887958381,0.8262765508510338,83.17411794675837,94.95605655330532,0.8759221998658617,-10144933.688618196,-11969374.117946759 +000221,DE,3,193.16686582565308,74.54545454545455,90.38779402415767,0.8247292164861443,80.73070955000705,92.04401184934406,0.8770881226053638,-10129945.454545455,-10333930.709550006 diff --git a/optimization_benchmark_results_NM.csv b/optimization_benchmark_results_NM.csv new file mode 100644 index 0000000..e12e415 --- /dev/null +++ b/optimization_benchmark_results_NM.csv @@ -0,0 +1,30 @@ +Patient_ID,Method,Run,Total_Time_sec,Left_Cortical_Score,Left_Bone_Score,Left_C_B_Ratio,Right_Cortical_Score,Right_Bone_Score,Right_C_B_Ratio,Left_Loss,Right_Loss +000121,NM,1,62.64465260505676,79.69253657327052,93.47879990081826,0.8525198938992041,76.40900366428197,86.93072762170651,0.878964271376957,-4455492.53657327,-7722809.003664281 +000121,NM,2,112.54990696907043,19.996510207642647,24.2017099982551,0.8262436914203318,19.996510207642647,25.423137323329264,0.7865477007549759,-1815500.0,-2052500.0 +000121,NM,3,79.86152529716492,19.99636759898293,45.67744278968398,0.4377733598409543,74.47483154974238,81.78755449861276,0.9105888054276715,-1971500.0,-6530074.831549742 +000141,NM,1,56.64281606674194,60.010930952814725,70.75970122062306,0.8480947476828014,92.42144177449168,99.74121996303143,0.9266123054114158,-4730010.930952814,-4716421.441774491 +000141,NM,2,48.48319458961487,72.2156862745098,88.68627450980392,0.8142825558257794,81.45263827082009,95.61347743165925,0.8518949468085106,-6244615.68627451,-9345852.63827082 +000141,NM,3,63.44100284576416,29.994107248084855,48.00628560204282,0.6247954173486088,29.993650793650794,50.80634920634921,0.5903524118970257,-2409600.0,-3760200.0 +000151,NM,1,41.89803433418274,85.00988979937836,95.0127154563436,0.894721189591078,72.36600197991797,86.67798048366568,0.834883341491271,-9918209.889799379,-8683166.001979917 +000151,NM,2,107.57207727432251,39.95674850527923,58.84747487596997,0.6789883268482492,29.994175888177054,57.94991263832265,0.5175879396984925,-3234000.0,-631800.0 +000151,NM,3,84.16005349159241,29.998184129289996,67.75013619030325,0.4427767354596623,83.55547550432276,97.24423631123919,0.859233191331728,-2620000.0,-8386155.475504322 +000161,NM,1,96.87153625488281,19.997175540177942,20.81626888857506,0.960651289009498,29.998409416255768,60.076348019723234,0.4993380990203866,-2538500.0,-2675400.0 +000161,NM,2,72.83068442344666,29.998588965711864,31.451954282489066,0.9537909376401973,55.50992470910335,74.19575633127995,0.7481549815498154,-3391400.0,-1076509.9247091033 +000161,NM,3,57.63184142112732,57.56774619960344,70.74245428508482,0.813765182186235,29.998409416255768,69.9538730714172,0.42883128694861306,-3624767.746199603,-1873400.0 +000171,NM,1,63.1676971912384,39.99206506645507,58.5399722277326,0.6831582514401897,29.998014691284496,51.737145126067105,0.5798158096699924,-963000.0,-2371400.0 +000171,NM,2,57.605836153030396,29.994765311463965,65.34636189146745,0.4590120160213618,39.98435054773083,80.67292644757433,0.495635305528613,-1853800.0,-2237200.0 +000171,NM,3,77.0465567111969,19.996513249651322,22.7510460251046,0.878927203065134,39.98223575688365,54.19363024996828,0.737766331070007,-2054300.0,-3359000.0 +000181,NM,2,83.35590553283691,29.998727249586356,60.73564973908616,0.4939228834870076,29.994756161510228,36.89914350638001,0.8128848886783515,-3633000.0,-2581000.0 +000181,NM,3,25.232823610305786,39.99717633771001,71.59395736269943,0.5586669295996844,52.23455735838315,77.25590663250783,0.6761238025055268,-3027000.0,-1836834.5573583832 +000191,NM,1,67.51091599464417,81.00364096881431,95.77331011556119,0.845785123966942,29.992557677995535,65.86454973951874,0.455367231638418,-9169003.640968814,-1268800.0 +000191,NM,2,103.78111028671265,19.99607920015683,43.952166241913346,0.4549509366636932,19.9836867862969,30.859162588363244,0.6475770925110131,-1825700.0,-2634200.0 +000191,NM,3,90.03080677986145,29.995223690495145,46.63270179907658,0.6432229429839537,39.45765376301827,86.40204362350167,0.45667500568569475,-3003800.0,-1880400.0 +000201,NM,1,52.31811261177063,79.57905722734388,93.3205774917377,0.852749301025163,29.995029821073558,63.24552683896621,0.4742632612966601,-8079379.057227343,-1912800.0 +000201,NM,2,44.66054558753967,39.9974548231102,57.64825655383049,0.6938189845474613,64.27206116597642,95.63555272379739,0.6720519653564291,-2867000.0,-6088872.061165976 +000201,NM,3,36.97987747192383,85.58890147225368,92.46885617214043,0.9255970606246172,29.99825388510564,73.19713637157325,0.4098282442748092,-5619388.901472254,-1344600.0 +000211,NM,1,72.89356565475464,83.39688592310137,95.9644105497299,0.8690397350993377,9.998585772875124,30.26446047235186,0.33037383177570095,-9519396.8859231,-1229800.0 +000211,NM,2,118.73922371864319,19.99593578540947,30.156472261735416,0.6630727762803235,19.99636098981077,39.082969432314414,0.5116387337057727,-1760900.0,-1947900.0 +000211,NM,3,115.41684770584106,84.35417327832434,97.03268803554427,0.8693376941946035,9.998432847516064,25.952045133991536,0.38526570048309183,-9507554.173278324,-1457100.0 +000221,NM,1,73.53109002113342,9.99825205383674,17.811571403600766,0.5613346418056918,88.07339449541286,98.01636498884206,0.8985580571717684,-1305300.0,-6585273.394495413 +000221,NM,2,111.19093751907349,19.996390543223246,21.746977079949467,0.9195020746887967,29.998725627628392,48.426150121065376,0.6194736842105263,-1724100.0,-3706200.0 +000221,NM,3,74.673180103302,19.997459026807267,29.005208995045102,0.6894437144108629,39.996080736821476,54.399372917891434,0.7352305475504323,-2822900.0,-2181800.0 diff --git a/optimization_benchmark_results_PSO.csv b/optimization_benchmark_results_PSO.csv new file mode 100644 index 0000000..657387a --- /dev/null +++ b/optimization_benchmark_results_PSO.csv @@ -0,0 +1,4 @@ +Patient_ID,Method,Run,Total_Time_sec,Left_Cortical_Score,Left_Bone_Score,Left_C_B_Ratio,Right_Cortical_Score,Right_Bone_Score,Right_C_B_Ratio,Left_Loss,Right_Loss +00010,PSO,1,116.09450578689575,78.9866581956798,90.86721728081322,0.8692536269882887,87.45029425799268,95.67361221568316,0.9140482128013299,-8620586.65819568,-10406250.294257993 +00010,PSO,2,119.36032223701477,79.74112571092371,91.21396352225926,0.8742205977209203,80.71401346715791,93.19019184347606,0.8661213360599862,-7217341.125710924,-11510314.013467157 +00010,PSO,3,116.97214722633362,81.31569409206234,91.3374411018485,0.8902777777777778,88.62724577010292,96.33699633699634,0.91997103023719,-7985515.694092063,-9695027.245770102 diff --git a/results/comparison_results.csv b/results/comparison_results.csv new file mode 100644 index 0000000..78d3d0e --- /dev/null +++ b/results/comparison_results.csv @@ -0,0 +1 @@ +case,side,method,repeat,cortical_ratio,vertebral_ratio,cb_ratio,loss,time_sec diff --git a/results/results_DE.csv b/results/results_DE.csv new file mode 100644 index 0000000..909c5c7 --- /dev/null +++ b/results/results_DE.csv @@ -0,0 +1,10 @@ +case,side,method,repeat,seed,cortical_ratio,vertebral_ratio,cb_ratio,diameter,length,z,y,x,azimuth,altitude,cyl_points,overlap_cortical_voxels,overlap_vertebral_voxels,overlap_warning,cortical_outside_spine,loss,time_sec +1.3.6.1.4.1.9328.50.4.0001,L,DE,1,1,99.99,24.43,4.0922,5.0,50.0,7.5637,69.229,63.3989,123.7477,67.3613,7903,7902,1931,1,2000437,-15843187.3466,46.5 +1.3.6.1.4.1.9328.50.4.0001,L,DE,2,2,100.0,25.26,3.9587,5.0,50.0,4.3,67.2525,59.4038,116.5368,65.9693,7949,7949,2008,1,2000437,-15938000.0,46.95 +1.3.6.1.4.1.9328.50.4.0001,L,DE,3,3,100.0,19.18,5.2127,5.0,50.0,16.8451,72.3062,56.3851,123.8802,60.6996,7892,7892,1514,1,2000437,-15820000.0,48.35 +1.3.6.1.4.1.9328.50.4.0001,R,DE,1,11,99.97,13.0,7.6885,5.0,50.0,26.1806,48.2449,143.6017,90.1226,63.2828,7975,7973,1037,1,2000437,-15984374.9216,52.15 +1.3.6.1.4.1.9328.50.4.0001,R,DE,2,12,99.98,6.65,15.041,5.0,50.0,28.1923,62.7821,145.4942,90.0013,63.4555,8064,8062,536,1,2000437,-16162375.1984,54.69 +1.3.6.1.4.1.9328.50.4.0001,R,DE,3,13,99.9,66.77,1.4961,5.0,50.0,29.0259,33.4412,117.5387,90.0831,63.5098,8030,8022,5362,1,2000437,-16107500.3736,53.13 +1.3.6.1.4.1.9328.50.4.0002,L,DE,1,101,100.0,16.05,6.2311,5.0,50.0,74.8105,64.7153,67.9007,108.4366,83.9283,7901,7901,1268,1,2955801,-15842000.0,67.79 +1.3.6.1.4.1.9328.50.4.0002,L,DE,2,102,99.99,64.19,1.5578,5.0,50.0,47.0961,43.6089,70.2426,99.1139,78.8503,7877,7876,5056,1,2955801,-15821187.3048,65.29 +1.3.6.1.4.1.9328.50.4.0002,L,DE,3,103,100.0,22.02,4.5416,5.0,50.0,68.4024,44.1539,55.9956,97.1283,82.8774,7907,7907,1741,1,2955801,-15854000.0,66.76 diff --git a/results/results_GA.csv b/results/results_GA.csv new file mode 100644 index 0000000..78d3d0e --- /dev/null +++ b/results/results_GA.csv @@ -0,0 +1 @@ +case,side,method,repeat,cortical_ratio,vertebral_ratio,cb_ratio,loss,time_sec diff --git a/results/results_PSO.csv b/results/results_PSO.csv new file mode 100644 index 0000000..e0ff9a9 --- /dev/null +++ b/results/results_PSO.csv @@ -0,0 +1,61 @@ +case,side,method,repeat,seed,cortical_ratio,vertebral_ratio,cb_ratio,diameter,length,z,y,x,azimuth,altitude,cyl_points,overlap_cortical_voxels,overlap_vertebral_voxels,overlap_warning,cortical_outside_spine,loss,time_sec +1.3.6.1.4.1.9328.50.4.0001,L,PSO,1,1,99.99,8.69,11.5064,5.0,50.0,7.4176,36.793,49.1919,116.5719,65.9124,8044,8043,699,1,2000437,-16125187.5684,48.67 +1.3.6.1.4.1.9328.50.4.0001,L,PSO,2,2,100.0,9.16,10.9133,5.0,50.0,5.6307,53.3182,65.5052,130.5898,61.4896,7934,7934,727,1,2000437,-15908000.0,44.6 +1.3.6.1.4.1.9328.50.4.0001,L,PSO,3,3,100.0,21.26,4.7037,5.0,50.0,4.6025,75.6423,62.4257,126.8619,68.2206,7968,7968,1694,1,2000437,-15976000.0,46.46 +1.3.6.1.4.1.9328.50.4.0001,R,PSO,1,11,99.99,10.11,9.8857,5.0,50.0,20.5102,49.6687,140.5223,90.0478,62.6068,7959,7958,805,1,2000437,-15955187.4356,54.07 +1.3.6.1.4.1.9328.50.4.0001,R,PSO,2,12,99.95,24.83,4.0259,5.0,50.0,10.9234,62.9727,100.2,71.5645,64.5968,7939,7935,1971,1,2000437,-15906749.6158,50.59 +1.3.6.1.4.1.9328.50.4.0001,R,PSO,3,13,99.99,13.2,7.5755,5.0,50.0,10.0558,57.709,104.1599,70.0213,59.1209,7925,7924,1046,1,2000437,-15887187.3817,48.66 +1.3.6.1.4.1.9328.50.4.0002,L,PSO,1,101,100.0,56.5,1.7698,5.0,50.0,51.3303,40.1518,74.2,105.9376,82.1742,7918,7918,4474,1,2955801,-15906000.0,66.71 +1.3.6.1.4.1.9328.50.4.0002,L,PSO,2,102,100.0,14.09,7.0958,5.0,50.0,76.8487,50.3333,72.7445,113.1861,82.4419,7926,7926,1117,1,2955801,-15892000.0,67.51 +1.3.6.1.4.1.9328.50.4.0002,L,PSO,3,103,100.0,15.19,6.5849,5.0,50.0,75.4254,61.8117,65.6238,114.4484,80.5996,7915,7915,1202,1,2955801,-15870000.0,69.06 +1.3.6.1.4.1.9328.50.4.0002,R,PSO,1,111,99.91,15.41,6.4817,5.0,50.0,5.6976,82.3753,113.0593,56.3006,74.4506,7973,7966,1229,1,2955801,-15966312.2037,60.3 +1.3.6.1.4.1.9328.50.4.0002,R,PSO,2,112,100.0,4.12,24.2831,5.0,50.0,6.1936,37.2415,118.2451,81.1364,88.5141,7892,7892,325,1,2955801,-15854000.0,66.59 +1.3.6.1.4.1.9328.50.4.0002,R,PSO,3,113,100.0,8.74,11.4435,5.0,50.0,76.495,47.1313,116.2606,84.6961,79.5085,7896,7896,690,1,2955801,-15832000.0,68.45 +1.3.6.1.4.1.9328.50.4.0003,L,PSO,1,201,20.0,33.57,0.5956,5.0,45.0,56.3852,63.9562,17.4898,109.292,97.2601,4316,863,1449,0,0,-1543100.0,58.56 +1.3.6.1.4.1.9328.50.4.0003,L,PSO,2,202,19.99,37.37,0.535,4.5,40.0,58.1093,70.7978,13.3408,105.8432,96.161,3251,650,1215,0,0,-1154700.0,59.49 +1.3.6.1.4.1.9328.50.4.0003,L,PSO,3,203,29.99,46.94,0.6389,5.0,35.0,56.1536,67.1454,23.1849,122.2648,98.3765,3534,1060,1659,0,0,-1685600.0,54.63 +1.3.6.1.4.1.9328.50.4.0003,R,PSO,1,211,30.0,36.28,0.8268,4.5,45.0,52.26,61.821,174.191,67.415,101.7994,3947,1184,1432,0,0,-1861200.0,59.29 +1.3.6.1.4.1.9328.50.4.0003,R,PSO,2,212,30.0,37.93,0.7909,5.0,45.0,49.2119,73.0297,177.9171,69.6459,104.6915,4577,1373,1736,0,0,-2185600.0,154.05 +1.3.6.1.4.1.9328.50.4.0003,R,PSO,3,213,72.6,96.98,0.7486,4.0,45.0,47.5138,37.4122,118.8,68.5949,97.5772,4529,3288,4392,0,0,-3715798.8077,116.51 +1.3.6.1.4.1.9328.50.4.0004,L,PSO,1,301,10.0,34.7,0.2881,5.0,45.0,52.9404,69.3304,33.8691,122.0144,82.5102,5141,514,1784,0,0,-1170900.0,55.93 +1.3.6.1.4.1.9328.50.4.0004,L,PSO,2,302,72.22,82.03,0.8804,4.0,45.0,27.8612,42.2801,69.5263,106.5231,72.4112,4525,3268,3712,0,0,-4742620.9945,56.87 +1.3.6.1.4.1.9328.50.4.0004,L,PSO,3,303,20.0,23.0,0.8695,4.5,50.0,66.1798,62.605,71.8798,143.1133,72.1335,4431,886,1019,0,0,-1558500.0,56.84 +1.3.6.1.4.1.9328.50.4.0004,R,PSO,1,311,10.0,29.66,0.337,5.0,45.0,39.5416,39.5878,142.8218,52.787,81.3579,5222,522,1549,0,0,-1190000.0,51.0 +1.3.6.1.4.1.9328.50.4.0004,R,PSO,2,312,73.86,85.83,0.8605,5.0,40.0,27.5925,50.8509,128.0866,74.6366,72.0293,6281,4639,5391,0,0,-6538257.666,57.0 +1.3.6.1.4.1.9328.50.4.0004,R,PSO,3,313,78.66,92.39,0.8515,4.5,45.0,23.4699,44.1544,115.4623,66.7588,68.4046,5793,4557,5352,0,0,-7034863.9047,56.72 +1.3.6.1.4.1.9328.50.4.0005,L,PSO,1,401,19.99,41.52,0.4815,4.5,45.0,86.7698,41.0198,59.0922,66.1304,73.5371,2861,572,1188,0,0,-939300.0,61.61 +1.3.6.1.4.1.9328.50.4.0005,L,PSO,2,402,68.99,76.62,0.9004,4.0,35.0,67.478,35.5161,62.2094,79.0893,78.6353,3525,2432,2701,0,0,-3528592.9078,65.13 +1.3.6.1.4.1.9328.50.4.0005,L,PSO,3,403,67.91,78.01,0.8704,4.5,35.0,68.7886,35.1374,60.2172,76.6768,79.7908,4462,3030,3481,0,0,-4257306.7683,65.38 +1.3.6.1.4.1.9328.50.4.0005,R,PSO,1,411,19.99,29.69,0.6734,4.5,45.0,73.758,66.3197,195.9131,47.802,81.4935,2651,530,787,0,0,-942700.0,60.99 +1.3.6.1.4.1.9328.50.4.0005,R,PSO,2,412,30.0,44.41,0.6754,5.0,45.0,64.2517,57.7756,165.137,30.7208,76.554,5487,1646,2437,0,0,-2534400.0,67.84 +1.3.6.1.4.1.9328.50.4.0005,R,PSO,3,413,30.0,48.03,0.6245,5.0,45.0,62.0927,58.3166,163.1757,28.6541,79.3388,5517,1655,2650,0,0,-2550800.0,68.7 +1.3.6.1.4.1.9328.50.4.0006,L,PSO,1,501,19.99,36.76,0.5438,4.5,45.0,48.161,83.507,23.63,126.7736,80.398,2606,521,958,0,0,-926500.0,65.1 +1.3.6.1.4.1.9328.50.4.0006,L,PSO,2,502,19.99,21.15,0.9453,5.0,40.0,86.9474,89.9983,80.2344,143.267,75.3227,3891,778,823,0,0,-1355100.0,66.54 +1.3.6.1.4.1.9328.50.4.0006,L,PSO,3,503,87.62,91.7,0.9555,4.5,40.0,37.0386,49.0451,78.8935,100.2066,81.9978,5095,4464,4672,0,0,-7843815.3091,70.14 +1.3.6.1.4.1.9328.50.4.0006,R,PSO,1,511,84.55,93.9,0.9004,5.0,45.0,36.9954,47.6081,125.6671,67.2978,79.8972,7080,5986,6648,0,0,-9821348.0226,70.52 +1.3.6.1.4.1.9328.50.4.0006,R,PSO,2,512,85.56,97.02,0.8819,4.5,45.0,35.0025,48.3838,132.0857,71.1536,75.9332,5736,4908,5565,0,0,-8354164.8536,70.63 +1.3.6.1.4.1.9328.50.4.0006,R,PSO,3,513,83.52,95.45,0.8751,5.0,45.0,35.3546,48.4405,132.2396,71.1812,76.3927,7077,5911,6755,0,0,-9646724.0921,70.72 +1.3.6.1.4.1.9328.50.4.0007,L,PSO,1,601,70.97,95.87,0.7403,4.5,40.0,36.9433,55.6157,64.2063,115.1423,79.1659,5085,3609,4875,0,0,-3608173.4513,59.55 +1.3.6.1.4.1.9328.50.4.0007,L,PSO,2,602,70.98,96.63,0.7346,4.5,45.0,35.6859,52.2457,63.4741,112.467,78.4136,5720,4060,5527,0,0,-4897979.021,60.47 +1.3.6.1.4.1.9328.50.4.0007,L,PSO,3,603,20.0,21.66,0.9231,4.5,45.0,31.1183,86.2203,55.6987,146.9233,81.005,4321,864,936,0,0,-1392900.0,58.01 +1.3.6.1.4.1.9328.50.4.0007,R,PSO,1,611,72.02,81.92,0.8791,5.0,40.0,32.9074,50.676,127.462,82.3338,79.7994,6290,4530,5153,0,0,-5524019.0779,60.6 +1.3.6.1.4.1.9328.50.4.0007,R,PSO,2,612,79.91,85.73,0.9322,4.5,40.0,31.1365,41.2923,111.6,79.8617,72.6211,5107,4081,4378,0,0,-6314109.9276,58.61 +1.3.6.1.4.1.9328.50.4.0007,R,PSO,3,613,72.34,78.22,0.9248,5.0,35.0,33.3902,49.9694,124.8573,82.6153,80.3251,5509,3985,4309,0,0,-5853136.1772,59.22 +1.3.6.1.4.1.9328.50.4.0008,L,PSO,1,701,72.5,79.06,0.9171,4.5,40.0,96.018,63.7441,83.1531,103.4561,102.5322,5095,3694,4028,0,0,-4761702.4534,90.81 +1.3.6.1.4.1.9328.50.4.0008,L,PSO,2,702,10.0,34.26,0.2918,4.5,40.0,91.3609,88.0287,26.9624,126.4576,100.0414,3001,300,1028,0,0,-670700.0,82.41 +1.3.6.1.4.1.9328.50.4.0008,L,PSO,3,703,76.63,80.29,0.9545,4.5,40.0,92.4717,44.9791,85.0,101.9128,91.5303,5088,3899,4085,0,0,-6386431.2893,88.53 +1.3.6.1.4.1.9328.50.4.0008,R,PSO,1,711,19.96,41.5,0.481,4.5,45.0,100.5249,98.418,190.6688,59.3931,102.7534,3046,608,1264,0,0,-1082600.0,85.95 +1.3.6.1.4.1.9328.50.4.0008,R,PSO,2,712,10.0,23.19,0.4312,4.5,50.0,100.0967,77.2015,178.2021,59.5183,102.775,4671,467,1083,0,0,-1063800.0,90.26 +1.3.6.1.4.1.9328.50.4.0008,R,PSO,3,713,57.92,61.87,0.9362,4.5,40.0,100.8996,81.4359,134.1989,79.9791,102.6145,5093,2950,3151,0,0,-3357522.6389,93.01 +1.3.6.1.4.1.9328.50.4.0009,L,PSO,1,801,72.55,91.83,0.7901,5.0,45.0,32.8651,47.155,58.961,108.0286,72.0106,7035,5104,6460,0,0,-4581751.5281,57.55 +1.3.6.1.4.1.9328.50.4.0009,L,PSO,2,802,92.17,97.94,0.9412,4.5,40.0,24.7077,39.9498,69.0,117.5817,64.9292,5137,4735,5031,0,0,-8810574.4209,54.11 +1.3.6.1.4.1.9328.50.4.0009,L,PSO,3,803,85.67,94.4,0.9075,4.5,45.0,24.2003,37.5681,68.921,116.1667,66.0548,5737,4915,5416,0,0,-7685071.954,55.68 +1.3.6.1.4.1.9328.50.4.0009,R,PSO,1,811,95.17,97.9,0.9721,5.0,40.0,25.8881,43.4766,111.5397,78.9595,62.6884,6296,5992,6164,0,0,-11591971.5375,55.69 +1.3.6.1.4.1.9328.50.4.0009,R,PSO,2,812,89.54,94.23,0.9503,5.0,45.0,25.4625,39.9189,110.1685,76.9905,65.9668,7086,6345,6677,0,0,-11054742.7604,56.89 +1.3.6.1.4.1.9328.50.4.0009,R,PSO,3,813,90.49,93.69,0.9658,5.0,40.0,29.9746,43.9718,113.0538,79.0801,69.1544,6309,5709,5911,0,0,-10428489.7765,112.27 +1.3.6.1.4.1.9328.50.4.0010,L,PSO,1,901,85.08,94.76,0.8978,4.0,35.0,63.2001,36.7243,79.0574,102.9994,85.2987,3533,3006,3348,0,0,-5263483.4984,222.6 +1.3.6.1.4.1.9328.50.4.0010,L,PSO,2,902,10.0,20.85,0.4795,4.0,45.0,88.7729,54.613,28.9483,117.0101,84.3316,3281,328,684,0,0,-744100.0,220.5 +1.3.6.1.4.1.9328.50.4.0010,L,PSO,3,903,19.99,43.84,0.456,4.5,40.0,90.5944,80.2496,21.8845,130.5236,83.6175,2231,446,978,0,0,-792500.0,198.56 +1.3.6.1.4.1.9328.50.4.0010,R,PSO,1,911,85.57,95.5,0.896,4.5,45.0,62.0497,36.0756,127.2002,81.6334,82.4432,5738,4910,5480,0,0,-8410169.885,238.93 +1.3.6.1.4.1.9328.50.4.0010,R,PSO,2,912,19.99,36.71,0.5447,5.0,50.0,56.5369,70.4833,183.4086,53.8323,82.6558,3781,756,1388,0,0,-1350500.0,249.84 +1.3.6.1.4.1.9328.50.4.0010,R,PSO,3,913,19.99,34.85,0.5738,5.0,40.0,54.9824,66.8207,180.2205,49.5794,81.3042,3851,770,1342,0,0,-1370700.0,201.47 diff --git a/results/results_RandomSearch.csv b/results/results_RandomSearch.csv new file mode 100644 index 0000000..17871cb --- /dev/null +++ b/results/results_RandomSearch.csv @@ -0,0 +1,3 @@ +case,side,method,repeat,seed,cortical_ratio,vertebral_ratio,cb_ratio,diameter,length,z,y,x,azimuth,altitude,loss,time_sec +1.3.6.1.4.1.9328.50.4.0001,L,RandomSearch,1,1,99.97,26.92,3.7139,5.0,50,42.4153,32.9548,61.7995,109.8793,65.1445,-15748374.545,45.15 +1.3.6.1.4.1.9328.50.4.0001,L,RandomSearch,2,2,99.95,22.52,4.4387,5.0,50,17.1245,72.8784,61.9283,128.65,63.801,-15740749.0835,44.92 diff --git a/run_all_cases.py b/run_all_cases.py new file mode 100644 index 0000000..a652bec --- /dev/null +++ b/run_all_cases.py @@ -0,0 +1,146 @@ +import os +import traceback +import argparse +import torch +import SimpleITK as sitk +from datetime import datetime +# Core +from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch +from core.objective import set_global_context +from core.cylinder import create_coordinate_grid + +# Imaging +from imaging.preprocessing import process_single_image, process_dataset + +# Visualization +from visualization.res_cyl_to_CT import cyl_and_CT +from visualization.res_cyl_nifti import save_cyl + +# Config +from config.device import get_device +from config.constant import * + +def build_paths(base_dir: str, patient_id: str, level: str): + case_dir = os.path.join(base_dir, patient_id) + img1 = os.path.join(case_dir, f"{level}_cortical.nii.gz") + img2 = os.path.join(case_dir, f"{level}_binary2.nii.gz") + img3 = os.path.join(case_dir, f"{level}_roi2.nii.gz") + return img1, img2, img3 + +if __name__ == "__main__": + # === 1. 設定命令列引數 (GPU 拆分設定) === + parser = argparse.ArgumentParser(description="Multi-GPU CBT Batch Processing") + parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)') + parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量') + args = parser.parse_args() + + # === 2. 綁定 GPU === + if torch.cuda.is_available(): + device = torch.device(f'cuda:{args.gpu}') + torch.cuda.set_device(device) + else: + device = torch.device('cpu') + print("⚠️ 找不到 CUDA,將使用 CPU") + + print(f"\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算") + + # === 3. 基本參數 === + base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/" + spacing = [0.5, 0.5, 0.5] + levels = ["L5"] + date = datetime.now().strftime("%Y%m%d") + + failed = [] + skipped = [] + succeeded = [] + + # 動態獲取所有病人資料夾 + all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]) + + # === 4. 餘數分工:只挑選屬於這張 GPU 的病人 === + my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu] + + print(f"📂 總資料庫有 {len(all_patients)} 個病人。") + print(f"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...") + + for patient_id in my_patients: + for level in levels: + label_str = f"{patient_id}_{level}" + IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level) + + missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)] + if missing: + print(f"[SKIP] ⏭️ {label_str} (缺檔)") + skipped.append((label_str, missing)) + continue + + print(f"\n=========================================") + print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===") + print(f"=========================================") + + try: + # --- CBT = True --- + print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...") + binary_image = sitk.ReadImage(IMAGE2) + binary_array = sitk.GetArrayFromImage(binary_image) + image_shape = binary_array.shape + + grid = create_coordinate_grid(image_shape, device) + best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch( + label_str, + image1_path=IMAGE1, + image2_path=IMAGE2, + image3_path=IMAGE3, + folder=date, + swarm_size=70, + max_iter=100, + spacing=spacing, + CBT=True, + device=device, + optimize_size=True, + grid=grid # 記得把 device 傳進去 + ) + + cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True) + cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True) + + # --- CBT = False --- + # print(f"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...") + # best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch( + # label_str, + # image1_path=IMAGE2, + # image2_path=IMAGE2, + # image3_path=IMAGE3, + # folder=date, + # swarm_size=70, + # max_iter=100, + # spacing=spacing, + # CBT=False, + # device=device, + # optimize_size=True, + # grid=grid # 記得把 device 傳進去 + # ) + + # cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False) + # cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False) + + # succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt)) + # print(f"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s") + + except Exception as e: + print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}") + traceback.print_exc() + failed.append((label_str, str(e))) + + # --- GPU 專屬的 Summary --- + print("\n" + "="*40) + print(f"🏆 ==== GPU {args.gpu} 處理總結 ====") + print("="*40) + print(f"✅ Succeeded (成功): {len(succeeded)} 節段") + print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段") + print(f"❌ Failed (執行錯誤): {len(failed)} 節段") + + if failed: + print(f"\n⚠️ GPU {args.gpu} 失敗清單:") + for fail_label, err_msg in failed: + print(f" - {fail_label}: {err_msg}") \ No newline at end of file diff --git a/run_benchmark.py b/run_benchmark.py new file mode 100644 index 0000000..7a0eb25 --- /dev/null +++ b/run_benchmark.py @@ -0,0 +1,157 @@ +import os +import time +import pandas as pd +import torch +import argparse +import SimpleITK as sitk +import traceback +import core.objective +from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch +from core.cylinder import generate_cylinder_n_torch, create_coordinate_grid + +def main(): + parser = argparse.ArgumentParser(description="CBT Optimization Benchmark") + parser.add_argument('--method', type=str, required=True, choices=['PSO', 'DE', 'NM']) + parser.add_argument('--gpu', type=int, default=0) + args = parser.parse_args() + + if torch.cuda.is_available(): + device = torch.device(f'cuda:{args.gpu}') + torch.cuda.set_device(device) + else: + device = torch.device('cpu') + + print(f"\n🚀 啟動任務: 演算法 [{args.method}] | 運行於 [{device}]") + + base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}" + patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221] + + runs_per_method = 3 + + CBT = True + optimize_size = True + spacing = [0.5, 0.5, 0.5] + + if args.method == 'PSO': + method_func, swarm, iters = run_pso_torch, 70, 100 + elif args.method == 'DE': + method_func, swarm, iters = run_de_torch, 70, 100 + elif args.method == 'NM': + method_func, swarm, iters = run_nm_torch, 70, 7000 + + results_data = [] + + for p_id in patients: + folder_num = str(p_id) + current_dir = base_dir.format(folder_num) + + cortical_path = os.path.join(current_dir, "L5_cortical.nii.gz") + binary_path = os.path.join(current_dir, "L5_binary2.nii.gz") + roi_path = os.path.join(current_dir, "L5_roi2.nii.gz") + + if not all(os.path.exists(p) for p in [cortical_path, binary_path, roi_path]): + print(f"⚠️ 找不到病人 000{p_id} 的影像檔案,跳過此病人。") + continue + + print(f"\n=========================================") + print(f"📍 開始處理病人 ID: 000{p_id} (L5) - {args.method}") + print(f"=========================================") + # 在每個 case 進入 run_idx 之前先做 + img_b = sitk.GetArrayFromImage(sitk.ReadImage(binary_path)) + image_shape = img_b.shape + grid = create_coordinate_grid(image_shape, device) + + for run_idx in range(1, runs_per_method + 1): + label_str = f"Patient_000{p_id}_{args.method}_Run{run_idx}" + print(f"\n---> 執行 {args.method} (第 {run_idx}/{runs_per_method} 次)") + + try: + # 1. 呼叫演算法 (只會拿到 5 個回傳值) + best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = method_func( + label_str=label_str, + image1_path=cortical_path, + image2_path=binary_path, + image3_path=roi_path, + folder="Output", + swarm_size=swarm, + max_iter=iters, + spacing=spacing, + CBT=CBT, + device=device, + optimize_size=optimize_size, + grid=grid + ) + + # ========== 2. 在外部獨立計算分數 (絕對不會有 NameError) ========== + # 讀取影像為 Numpy 並轉為 GPU Tensor + img_c = sitk.GetArrayFromImage(sitk.ReadImage(cortical_path)) + img_b = sitk.GetArrayFromImage(sitk.ReadImage(binary_path)) + image_shape = img_b.shape + + cortical_eval = torch.from_numpy(img_c).to(device=device, dtype=torch.uint8) + spine_eval = torch.from_numpy(img_b).to(device=device, dtype=torch.uint8) + + # 解析 Diameter 與 Length + d_l, l_l = (best_pos_l[5], best_pos_l[6]) if optimize_size else (4.5, 45) + d_r, l_r = (best_pos_r[5], best_pos_r[6]) if optimize_size else (4.5, 45) + + # 重新生成左側與右側的 Cylinder Mask + cyl_l = generate_cylinder_n_torch( + d_l, l_l, best_pos_l[0], best_pos_l[1], best_pos_l[2], best_pos_l[3], best_pos_l[4], + image_shape, spacing, device, grid=grid + ) + cyl_r = generate_cylinder_n_torch( + d_r, l_r, best_pos_r[0], best_pos_r[1], best_pos_r[2], best_pos_r[3], best_pos_r[4], + image_shape, spacing, device, grid=grid + ) + + # 計算分數 + cyl_points_l = torch.sum(cyl_l).item() + cyl_points_r = torch.sum(cyl_r).item() + + overlap_c_l = ((cortical_eval == 1) & (cyl_l == 1)).sum().item() + overlap_c_r = ((cortical_eval == 1) & (cyl_r == 1)).sum().item() + overlap_b_l = ((spine_eval == 1) & (cyl_l == 1)).sum().item() + overlap_b_r = ((spine_eval == 1) & (cyl_r == 1)).sum().item() + + score_c_l = (overlap_c_l / cyl_points_l * 100) if cyl_points_l > 0 else 0 + score_c_r = (overlap_c_r / cyl_points_r * 100) if cyl_points_r > 0 else 0 + score_b_l = (overlap_b_l / cyl_points_l * 100) if cyl_points_l > 0 else 0 + score_b_r = (overlap_b_r / cyl_points_r * 100) if cyl_points_r > 0 else 0 + + cb_ratio_l = (score_c_l / score_b_l) if score_b_l > 0 else 0 + cb_ratio_r = (score_c_r / score_b_r) if score_b_r > 0 else 0 + # =============================================================== + + # 3. 儲存結果 + run_result = { + "Patient_ID": f"000{p_id}", + "Method": args.method, + "Run": run_idx, + "Total_Time_sec": total_time, + "Left_Cortical_Score": score_c_l, + "Left_Bone_Score": score_b_l, + "Left_C_B_Ratio": cb_ratio_l, + "Right_Cortical_Score": score_c_r, + "Right_Bone_Score": score_b_r, + "Right_C_B_Ratio": cb_ratio_r, + "Left_Loss": best_loss_l, + "Right_Loss": best_loss_r + } + results_data.append(run_result) + + print(f"✅ 完成 | 耗時: {total_time:.2f}s | " + f"左C/B: {cb_ratio_l:.3f} | 右C/B: {cb_ratio_r:.3f}") + + except Exception as e: + print(f"❌ {args.method} Run {run_idx} 發生錯誤: {e}") + traceback.print_exc() # 印出詳細報錯足跡 + + if results_data: + df = pd.DataFrame(results_data) + csv_filename = f"optimization_benchmark_results_{args.method}.csv" + df.to_csv(csv_filename, index=False) + print(f"\n🎉 {args.method} 測試完成!結果已儲存至: {csv_filename}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/run_de.py b/run_de.py new file mode 100644 index 0000000..462a1c4 --- /dev/null +++ b/run_de.py @@ -0,0 +1,109 @@ +import os +import traceback +import argparse +import torch +import SimpleITK as sitk +from datetime import datetime +# Core +from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch +from core.objective import set_global_context +from core.cylinder import create_coordinate_grid + +# Imaging +from imaging.preprocessing import process_single_image, process_dataset + +# Visualization +from visualization.res_cyl_to_CT import cyl_and_CT +from visualization.res_cyl_nifti import save_cyl + +# Config +from config.device import get_device +from config.constant import * + +if __name__ == "__main__": + # === 1. 設定命令列引數 (GPU 拆分設定) === + parser = argparse.ArgumentParser(description="DE CBT Processing") + parser.add_argument('--gpu', type=int, default=0) + args = parser.parse_args() + + if torch.cuda.is_available(): + device = torch.device(f'cuda:{args.gpu}') + torch.cuda.set_device(device) + else: + device = torch.device('cpu') + + base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}" + patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221] + + CBT = True + optimize_size = True + spacing = [0.5, 0.5, 0.5] + levels = ["L5"] + date = datetime.now().strftime("%Y%m%d") + + failed = [] + skipped = [] + succeeded = [] + + for p_id in patients: + folder_num = str(p_id) + current_dir = base_dir.format(folder_num) + + IMAGE1 = os.path.join(current_dir, "L5_cortical.nii.gz") + IMAGE2 = os.path.join(current_dir, "L5_binary2.nii.gz") + IMAGE3 = os.path.join(current_dir, "L5_roi2.nii.gz") + + missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)] + label_str = levels[0] + if missing: + print(f"[SKIP] ⏭️ {label_str} (缺檔)") + skipped.append((label_str, missing)) + continue + + print(f"\n=========================================") + print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===") + print(f"=========================================") + + try: + # --- CBT = True --- + print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...") + binary_image = sitk.ReadImage(IMAGE2) + binary_array = sitk.GetArrayFromImage(binary_image) + image_shape = binary_array.shape + + grid = create_coordinate_grid(image_shape, device) + best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_nm_torch( + 'NM', + image1_path=IMAGE1, + image2_path=IMAGE2, + image3_path=IMAGE3, + folder=date, + swarm_size=70, + max_iter=100, + spacing=spacing, + CBT=True, + device=device, + optimize_size=True, + grid=grid # 記得把 device 傳進去 + ) + + cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True) + cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True) + + except Exception as e: + print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}") + traceback.print_exc() + failed.append((label_str, str(e))) + + # --- GPU 專屬的 Summary --- + print("\n" + "="*40) + print(f"🏆 ==== GPU {args.gpu} 處理總結 ====") + print("="*40) + print(f"✅ Succeeded (成功): {len(succeeded)} 節段") + print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段") + print(f"❌ Failed (執行錯誤): {len(failed)} 節段") + + if failed: + print(f"\n⚠️ GPU {args.gpu} 失敗清單:") + for fail_label, err_msg in failed: + print(f" - {fail_label}: {err_msg}") \ No newline at end of file diff --git a/run_pso.py b/run_pso.py new file mode 100644 index 0000000..bb5d51c --- /dev/null +++ b/run_pso.py @@ -0,0 +1,109 @@ +import os +import traceback +import argparse +import torch +import SimpleITK as sitk +from datetime import datetime +# Core +from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch +from core.objective import set_global_context +from core.cylinder import create_coordinate_grid + +# Imaging +from imaging.preprocessing import process_single_image, process_dataset + +# Visualization +from visualization.res_cyl_to_CT import cyl_and_CT +from visualization.res_cyl_nifti import save_cyl + +# Config +from config.device import get_device +from config.constant import * + +if __name__ == "__main__": + # === 1. 設定命令列引數 (GPU 拆分設定) === + parser = argparse.ArgumentParser(description="PSO CBT Processing") + parser.add_argument('--gpu', type=int, default=0) + args = parser.parse_args() + + if torch.cuda.is_available(): + device = torch.device(f'cuda:{args.gpu}') + torch.cuda.set_device(device) + else: + device = torch.device('cpu') + + base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}" + patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221] + + CBT = True + optimize_size = True + spacing = [0.5, 0.5, 0.5] + levels = ["L5"] + date = datetime.now().strftime("%Y%m%d") + + failed = [] + skipped = [] + succeeded = [] + + for p_id in patients: + folder_num = str(p_id) + current_dir = base_dir.format(folder_num) + + IMAGE1 = os.path.join(current_dir, "L5_cortical.nii.gz") + IMAGE2 = os.path.join(current_dir, "L5_binary2.nii.gz") + IMAGE3 = os.path.join(current_dir, "L5_roi2.nii.gz") + + missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)] + if missing: + print(f"[SKIP] ⏭️ {label_str} (缺檔)") + skipped.append((label_str, missing)) + continue + label_str = levels + + print(f"\n=========================================") + print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===") + print(f"=========================================") + + try: + # --- CBT = True --- + print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...") + binary_image = sitk.ReadImage(IMAGE2) + binary_array = sitk.GetArrayFromImage(binary_image) + image_shape = binary_array.shape + + grid = create_coordinate_grid(image_shape, device) + best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch( + label_str, + image1_path=IMAGE1, + image2_path=IMAGE2, + image3_path=IMAGE3, + folder=date, + swarm_size=70, + max_iter=100, + spacing=spacing, + CBT=True, + device=device, + optimize_size=True, + grid=grid # 記得把 device 傳進去 + ) + + cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True) + cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True) + + except Exception as e: + print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}") + traceback.print_exc() + failed.append((label_str, str(e))) + + # --- GPU 專屬的 Summary --- + print("\n" + "="*40) + print(f"🏆 ==== GPU {args.gpu} 處理總結 ====") + print("="*40) + print(f"✅ Succeeded (成功): {len(succeeded)} 節段") + print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段") + print(f"❌ Failed (執行錯誤): {len(failed)} 節段") + + if failed: + print(f"\n⚠️ GPU {args.gpu} 失敗清單:") + for fail_label, err_msg in failed: + print(f" - {fail_label}: {err_msg}") \ No newline at end of file diff --git a/structure.txt b/structure.txt new file mode 100644 index 0000000..3c8bf26 --- /dev/null +++ b/structure.txt @@ -0,0 +1,49 @@ +CBT_project/ +│ +├── notebooks/ +│ └── experiment.ipynb +│ +├── core/ +│ ├── cylinder.py +│ │ ├── create_coordinate_grid +│ │ ├── snap_to_discrete_values +│ │ ├── generate_cylinder_n_torch +│ │ ├── generate_cylinder_o_torch +│ │ └── generate_cylinder_numpy +│ │ +│ ├── intersection.py +│ │ ├── bresenham3d +│ │ └── center_line_intersections_torch +│ │ +│ ├── scoring.py +│ │ ├── cl_score_torch +│ │ ├── get_overlap_ratio +│ │ └── compute_overlap_ratio_from_cylinder_mask +│ │ +│ ├── objective.py +│ │ ├── cylinder_circle_line_intersection_loss_deductions_torch +│ │ └── objective_function +│ │ +│ └── optimizer.py +│ └── run_pso_torch +│ +├── imaging/ +│ ├── nifti_io.py +│ ├── preprocessing.py + +│ └── orientation.py +│ +├── visualization/ +│ ├── plot_cylinder.py +│ └── plot_overlay.py +│ +├── utils/ +│ └── helpers.py +│ +├── config/ +│ ├── constants.py +│ └── device.py +│ +├── run_optimization.py +│ +└── __init__.py \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/helpers.py b/utils/helpers.py new file mode 100644 index 0000000..bd7383d --- /dev/null +++ b/utils/helpers.py @@ -0,0 +1,52 @@ +import os +import numpy as np +import matplotlib.pyplot as plt + +def get_unique_filepath(path: str) -> str: + """ + 如果檔案已存在,自動加 _1, _2... 避免覆蓋 + """ + current_path = path + count = 1 + + base_name, file_ext = os.path.splitext(path) + + while os.path.exists(current_path): + current_path = f"{base_name}_{count}{file_ext}" + count += 1 + + return current_path + +def save_with_unique_name(folder, label_str, way, diameter_l, length_l, diameter_r, length_r, swarm_size, max_iter): + + base_name = f'{label_str}_{way}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_{swarm_size}_{max_iter}.png' + file_path = os.path.join(folder, base_name) + + count = 1 + while os.path.exists(file_path): + file_name, file_ext = os.path.splitext(base_name) + file_path = os.path.join(folder, f"{file_name}_{count}{file_ext}") + count += 1 + + return file_path + +def pad_rgb_to_shape(rgb, target_hw, pad_value=0.0): + """Pad RGB image (H,W,3) to target (H,W), centered.""" + h, w, c = rgb.shape + H, W = target_hw + assert c == 3 + + if h > H or w > W: + raise ValueError(f"target {target_hw} smaller than rgb {(h,w)}") + + pad_h1 = (H - h) // 2 + pad_h2 = H - h - pad_h1 + pad_w1 = (W - w) // 2 + pad_w2 = W - w - pad_w1 + + return np.pad( + rgb, + ((pad_h1, pad_h2), (pad_w1, pad_w2), (0, 0)), + mode="constant", + constant_values=pad_value + ) diff --git a/visualization/__init__.py b/visualization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/visualization/res_cyl_nifti.py b/visualization/res_cyl_nifti.py new file mode 100644 index 0000000..428f00e --- /dev/null +++ b/visualization/res_cyl_nifti.py @@ -0,0 +1,83 @@ +import SimpleITK as sitk +import os +from datetime import datetime +from core.cylinder import generate_cylinder_numpy + +def save_cyl(best_position_l, best_position_r, spacing, IMAGE3, output_base, CBT=True): + + image3 = sitk.ReadImage(IMAGE3) #Reading image2 or image3 can get the same shape and origin actually + image3_array = sitk.GetArrayFromImage(image3) + image3_shape = image3_array.shape + origin = image3.GetOrigin() + diameter_l = best_position_l[5] + diameter_r = best_position_r[5] + length_l = best_position_l[6] + length_r = best_position_r[6] + + cyl_l = generate_cylinder_numpy(diameter_l, + length_l, + best_position_l[0], + best_position_l[1], + best_position_l[2], + best_position_l[3], + best_position_l[4], + image3_shape, + spacing) + + cyl_r = generate_cylinder_numpy(diameter_r, + length_r, + best_position_r[0], + best_position_r[1], + best_position_r[2], + best_position_r[3], + best_position_r[4], + image3_shape, + spacing) + + cyl_L = sitk.GetImageFromArray(cyl_l) + cyl_R = sitk.GetImageFromArray(cyl_r) + + cyl_L.SetOrigin(origin) + cyl_R.SetOrigin(origin) + + image_path = IMAGE3 + date_str = datetime.now().strftime("%Y%m%d") + patient_id = os.path.basename(os.path.dirname(image_path)) + file_name = os.path.basename(image_path) + level = file_name.split('_')[0] + output_folder = os.path.join(output_base, date_str, patient_id) + + os.makedirs(output_folder, exist_ok=True) + + def get_unique_filename(path): + + if not os.path.exists(path): + return path + + # 特別處理 .nii.gz 檔案 + if path.endswith(".nii.gz"): + base = path[:-7] # 去除 .nii.gz + ext = ".nii.gz" + else: + base, ext = os.path.splitext(path) + + i = 1 + new_path = f"{base}_{i}{ext}" + while os.path.exists(new_path): + i += 1 + new_path = f"{base}_{i}{ext}" + return new_path + + if CBT: + output_path_l = get_unique_filename(f'{output_folder}/{level}_{diameter_l}_{length_l}_CBT_L.nii.gz') + output_path_r = get_unique_filename(f'{output_folder}/{level}_{diameter_r}_{length_r}_CBT_R.nii.gz') + sitk.WriteImage(cyl_L, output_path_l) + sitk.WriteImage(cyl_R, output_path_r) + + else: + output_path_l = get_unique_filename(f'{output_folder}/{level}_{diameter_l}_{length_l}_TPS_L.nii.gz') + output_path_r = get_unique_filename(f'{output_folder}/{level}_{diameter_r}_{length_r}_TPS_R.nii.gz') + sitk.WriteImage(cyl_L, output_path_l) + sitk.WriteImage(cyl_R, output_path_r) + + return output_path_l, output_path_r \ No newline at end of file diff --git a/visualization/res_cyl_to_CT.py b/visualization/res_cyl_to_CT.py new file mode 100644 index 0000000..6ea03db --- /dev/null +++ b/visualization/res_cyl_to_CT.py @@ -0,0 +1,127 @@ +from utils.helpers import get_unique_filepath +import nibabel as nib +from datetime import datetime +import os +import numpy as np +import matplotlib.pyplot as plt +from utils.helpers import pad_rgb_to_shape + +def cyl_and_CT(diameter_l, diameter_r, length_l, length_r, roi_image_path, cylinder_1, cylinder_2, base_folder='Output', CBT=True): + + bone_nifti = nib.load(roi_image_path) + bone_array = bone_nifti.get_fdata() + + cylinder_nifti = nib.load(cylinder_1) + cylinder_nifti2 = nib.load(cylinder_2) + cylinder_array = cylinder_nifti.get_fdata() + cylinder_array2 = cylinder_nifti2.get_fdata() + + assert bone_array.shape == cylinder_array.shape, "影像 shape 不匹配,請先進行重採樣!" + + date = datetime.now().strftime("%Y%m%d") + patient_id = os.path.basename(os.path.dirname(roi_image_path)) + file_name = os.path.basename(roi_image_path) + level = file_name.split('_')[0] + + output_folder = os.path.join(base_folder, date, patient_id) + + if CBT: + output_file = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_whole.png' + output_axial = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_axial.png' + output_coronal = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_coronal.png' + output_saggital = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_saggital.png' + else: + output_file = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_whole.png' + output_axial = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_axial.png' + output_coronal = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_coronal.png' + output_saggital = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_saggital.png' + + mip_bone_z = np.max(bone_array, axis=0) + mip_cylinder_z = np.max(cylinder_array, axis=0) + mip_cylinder2_z = np.max(cylinder_array2, axis=0) + + mip_bone_x = np.max(bone_array, axis=1) + mip_cylinder_x = np.max(cylinder_array, axis=1) + mip_cylinder2_x = np.max(cylinder_array2, axis=1) + + mip_bone_y = np.max(bone_array, axis=2) + mip_cylinder_y = np.max(cylinder_array, axis=2) + mip_cylinder2_y = np.max(cylinder_array2, axis=2) + + bone_norm_z = mip_bone_z / np.max(mip_bone_z) + bone_norm_x = mip_bone_x / np.max(mip_bone_x) + bone_norm_y = mip_bone_y / np.max(mip_bone_y) + + cylinder_mask_z = mip_cylinder_z > 0 + cylinder_mask_x = mip_cylinder_x > 0 + cylinder_mask_y = mip_cylinder_y > 0 + cylinder_mask2_z = mip_cylinder2_z > 0 + cylinder_mask2_x = mip_cylinder2_x > 0 + cylinder_mask2_y = mip_cylinder2_y > 0 + + def create_rgb_image(bone_norm, cylinder_mask, cylinder_mask2): + bone_norm = (bone_norm - np.min(bone_norm)) / (np.max(bone_norm) - np.min(bone_norm)) + bone_norm = np.clip(bone_norm, 0, 1) + + rgb_image = np.zeros((bone_norm.shape[0], bone_norm.shape[1], 3)) + rgb_image[..., 0] = bone_norm + rgb_image[..., 1] = bone_norm + rgb_image[..., 2] = bone_norm + rgb_image[cylinder_mask, :] = 1 + rgb_image[cylinder_mask2, :] = 1 + + return rgb_image + + def crop_center(img, cropx, cropy): + y, x = img.shape[0], img.shape[1] + startx = x//2 - cropx//2 + starty = y//2 - cropy//2 + return img[starty:starty+cropy, startx:startx+cropx, :] + + rgb_image_z = np.rot90(create_rgb_image(bone_norm_z, cylinder_mask_z, cylinder_mask2_z)) + rgb_image_x = np.rot90(create_rgb_image(bone_norm_x, cylinder_mask_x, cylinder_mask2_x)) + rgb_image_y = np.rot90(create_rgb_image(bone_norm_y, cylinder_mask_y, cylinder_mask2_y)) + + H = max(rgb_image_y.shape[0], rgb_image_x.shape[0], rgb_image_z.shape[0]) + W = max(rgb_image_y.shape[1], rgb_image_x.shape[1], rgb_image_z.shape[1]) + + rgb_image_y = pad_rgb_to_shape(rgb_image_y, (H, W), pad_value=0.0) + rgb_image_x = pad_rgb_to_shape(rgb_image_x, (H, W), pad_value=0.0) + rgb_image_z = pad_rgb_to_shape(rgb_image_z, (H, W), pad_value=0.0) + + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + axs[0].imshow(rgb_image_y) + axs[0].set_title("Axial", pad=10) + axs[0].axis("off") + axs[0].set_aspect("equal") + + axs[1].imshow(rgb_image_x) + axs[1].set_title("Coronal", pad=10) + axs[1].axis("off") + axs[1].set_aspect("equal") + + axs[2].imshow(rgb_image_z) + axs[2].set_title("Saggital", pad=10) + axs[2].axis("off") + axs[2].set_aspect("equal") + + plt.subplots_adjust(left=0.05, right=0.75, top=0.9, bottom=0.1, wspace=0.01) + plt.savefig(get_unique_filepath(output_file), bbox_inches='tight') + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + ax.imshow(rgb_image_z) + ax.set_title("Saggital", pad=10) + ax.axis("off") + plt.savefig(get_unique_filepath(output_saggital), bbox_inches='tight') + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + ax.imshow(rgb_image_x) + ax.set_title("Coronal", pad=10) + ax.axis("off") + plt.savefig(get_unique_filepath(output_coronal), bbox_inches='tight') + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + ax.imshow(rgb_image_y) + ax.set_title("Axial", pad=10) + ax.axis("off") + plt.savefig(get_unique_filepath(output_axial), bbox_inches='tight') \ No newline at end of file diff --git a/visualization/res_plot_3d.py b/visualization/res_plot_3d.py new file mode 100644 index 0000000..73e8e8c --- /dev/null +++ b/visualization/res_plot_3d.py @@ -0,0 +1,366 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +import os +from datetime import datetime +import csv + +from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values +from core.intersection import center_line_intersections_torch +from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask +from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour +from utils.helpers import save_with_unique_name + +def res_plt_2_torch( + spine_tensor: torch.Tensor, + cortical_tensor: torch.Tensor, + image_shape: tuple[int, int, int], + image2_path: str, + base_folder: str, + label_str: str, + diameter_l: float, + length_l: float, + diameter_r: float, + length_r: float, + best_position_l: list[float], + best_position_r: list[float], + swarm_size: int, + max_iter: int, + total_time: float, + spacing: list[float], + CBT: bool, + device: torch.device, + grid=None +) -> None: + """ + Same plotting function as before, but it uses torch-based generation + and then moves data to CPU for matplotlib 3D scatter. + """ + cyl_l = generate_cylinder_n_torch( + diameter_l, + length_l, + best_position_l[0], + best_position_l[1], + best_position_l[2], + best_position_l[3], + best_position_l[4], + image_shape, + spacing, + device, + grid + ) + + cyl_lo = generate_cylinder_o_torch( + diameter_l, + length_l, + best_position_l[0], + best_position_l[1], + best_position_l[2], + best_position_l[3], + best_position_l[4], + image_shape, + spacing, + device, + grid + ) + cyl_r = generate_cylinder_n_torch( + diameter_r, + length_r, + best_position_r[0], + best_position_r[1], + best_position_r[2], + best_position_r[3], + best_position_r[4], + image_shape, + spacing, + device, + grid + ) + cyl_ro = generate_cylinder_o_torch( + diameter_r, + length_r, + best_position_r[0], + best_position_r[1], + best_position_r[2], + best_position_r[3], + best_position_r[4], + image_shape, + spacing, + device, + grid + ) + + intersections_l, line_mask_l = center_line_intersections_torch( + best_position_l[0], + best_position_l[1], + best_position_l[2], + best_position_l[3], + best_position_l[4], + int(length_l), + spine_tensor, + spacing, + device + ) + loss_l = cl_score_torch(cortical_tensor, spine_tensor, cyl_l, cyl_lo, intersections_l) + + intersections_r, line_mask_r = center_line_intersections_torch( + best_position_r[0], + best_position_r[1], + best_position_r[2], + best_position_r[3], + best_position_r[4], + int(length_r), + spine_tensor, + spacing, + device + ) + loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r) + + azi = azimuth_rotation(image2_path) + res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False) + alt = res['superior']['tilt_angle_deg'] + + # Move data to CPU for plotting + line_mask_l_cpu = line_mask_l.cpu().numpy() + line_mask_r_cpu = line_mask_r.cpu().numpy() + cyl_l_cpu = cyl_l.cpu().numpy() + cyl_lo_cpu = cyl_lo.cpu().numpy() + cyl_r_cpu = cyl_r.cpu().numpy() + cyl_ro_cpu = cyl_ro.cpu().numpy() + spine_cpu = spine_tensor.cpu().numpy() + + z_lin1, y_lin1, x_lin1 = np.where(line_mask_l_cpu == 1) + z_lin2, y_lin2, x_lin2 = np.where(line_mask_r_cpu == 1) + + z_cyl_l1, y_cyl_l1, x_cyl_l1 = np.where(cyl_l_cpu == 1) + z_cyl_l2, y_cyl_l2, x_cyl_l2 = np.where(cyl_lo_cpu == 1) + z_cyl_r1, y_cyl_r1, x_cyl_r1 = np.where(cyl_r_cpu == 1) + z_cyl_r2, y_cyl_r2, x_cyl_r2 = np.where(cyl_ro_cpu == 1) + + z_img, y_img, x_img = np.where(spine_cpu == 1) + + fig = plt.figure(figsize=(12, 12)) + + ax1 = fig.add_subplot(221, projection='3d') + ax1.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1) + ax1.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1) + ax1.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)') + ax1.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o') + ax1.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)') + ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o') + ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine') + ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis') + + ax2 = fig.add_subplot(222, projection='3d') + ax2.view_init(elev=90, azim=-90, roll=0) + ax2.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1) + ax2.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1) + ax2.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)') + ax2.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o') + ax2.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)') + ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o') + ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine') + ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis') + ax2.legend() + + ax3 = fig.add_subplot(223, projection='3d') + ax3.view_init(elev=0, azim=90, roll=0) + ax3.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1) + ax3.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1) + ax3.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)') + ax3.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o') + ax3.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)') + ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o') + ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine') + ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis') + + ax4 = fig.add_subplot(224, projection='3d') + ax4.view_init(elev=0, azim=0, roll=0) + ax4.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1) + ax4.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1) + ax4.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)') + ax4.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o') + ax4.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)') + ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o') + ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine') + ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis') + + cyl_points_l = torch.sum(cyl_l).item() + cyl_points_r = torch.sum(cyl_r).item() + + overlap_l = ((cortical_tensor == 1) & (cyl_l == 1)).sum().item() + overlap_r = ((cortical_tensor == 1) & (cyl_r == 1)).sum().item() + overlap_b_l = ((spine_tensor == 1) & (cyl_l == 1)).sum().item() + overlap_b_r = ((spine_tensor == 1) & (cyl_r == 1)).sum().item() + + overlap_cortical_l = (overlap_l / cyl_points_l) * 100 + overlap_cortical_r = (overlap_r / cyl_points_r) * 100 + overlap_vertebral_l = (overlap_b_l / cyl_points_l) * 100 + overlap_vertebral_r = (overlap_b_r / cyl_points_r) * 100 + cb_ratio_l = overlap_cortical_l/overlap_vertebral_l + cb_ratio_r = overlap_cortical_r/overlap_vertebral_r + user_altitude_l = 90 - best_position_l[4] - alt + user_altitude_r = 90 - best_position_r[4] - alt + user_azimuth_l = 90 - best_position_l[3] - azi + user_azimuth_r = 90 - best_position_r[3] - azi + + date_str = datetime.now().strftime("%Y%m%d") + patient_id = os.path.basename(os.path.dirname(image2_path)) + output_folder = os.path.join(base_folder, date_str, patient_id) + os.makedirs(output_folder, exist_ok=True) + csv_path = os.path.join(output_folder, 'output.csv') + + # 檢查檔案是否存在 (決定是否寫入標題) + file_exists = os.path.isfile(csv_path) + + # 欄位標題 (Header) + headers = [ + 'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter', + 'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff', + 'Intersections', 'Best_Loss', 'Overlap_Cortical', 'Overlap_Bone', + 'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time' + ] + + try: + with open(csv_path, 'a', newline='') as csvfile: + writer = csv.writer(csvfile) + + # 如果是新檔案,寫入 Header + if not file_exists: + writer.writerow(headers) + + # 寫入 Left 數據 + writer.writerow([ + label_str, + 'L', + diameter_l, + length_l, + swarm_size, + max_iter, + f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})", + f"{best_position_l[3]:.2f}", + f"{best_position_l[3]-azi:.2f}", + f"{best_position_l[4]:.2f}", + f"{best_position_l[4]-alt:.2f}", + intersections_l, + f"{loss_l:.2f}", + f"{overlap_cortical_l:.2f}", + f"{overlap_vertebral_l:.2f}", + f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}", + f"{user_azimuth_l:.2f}", + f"{user_altitude_l:.2f}", + f"{total_time:.2f}" + ]) + + # 寫入 Right 數據 + writer.writerow([ + label_str, + 'R', + diameter_r, + length_r, + swarm_size, + max_iter, + f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})", + f"{best_position_r[3]:.2f}", + f"{best_position_r[3]-azi:.2f}", + f"{best_position_r[4]:.2f}", + f"{best_position_r[4]-alt:.2f}", + intersections_r, + f"{loss_r:.2f}", + f"{overlap_cortical_r:.2f}", + f"{overlap_vertebral_r:.2f}", + f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}", + f"{user_azimuth_r:.2f}", + f"{user_altitude_r:.2f}", + f"{total_time:.2f}" + ]) + print(f"[CSV Saved] {csv_path}") + + except Exception as e: + print(f"[Error] Failed to write CSV: {e}") + + fig.text(0.5, 0.98, f'{label_str} Best Position', ha='center', fontsize=15) + fig.text( + 0.5, 0.44, + f'L: Diameter = {diameter_l} mm, {length_l} mm, ' + f'R: Diameter = {diameter_r} mm, {length_r} mm, ' + f'Swarm size = {swarm_size}, Iteration = {max_iter}, Total time = {total_time:.2f} s', + ha='center', fontsize=12 + ) + fig.text( + 0.5, 0.03, + f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), ' + f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, ' + f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}', + ha='center', fontsize=9 + ) + fig.text( + 0.5, 0.01, + f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), ' + f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, ' + f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}', + ha='center', fontsize=9 + ) + + fig.tight_layout() + + date_str = datetime.now().strftime("%Y%m%d") + file_name = os.path.basename(image2_path) + level = file_name.split('_')[0] + output_folder = os.path.join(base_folder, date_str, patient_id) + os.makedirs(output_folder, exist_ok=True) + + if CBT == True: + way = 'CBT' + + else: + way = 'TPS' + + path = save_with_unique_name(output_folder, label_str, way, + diameter_l, length_l, diameter_r, length_r, + swarm_size, max_iter) + + fig.savefig(path, dpi=200, bbox_inches="tight") + print("[Saved figure]", path) + plt.close(fig) + + +def eval_overlap_from_position( + pos, + optimize_size: bool, + spine_tensor: torch.Tensor, + image_shape, + spacing, + device: torch.device, + grid=None, + fixed_diameter: float | None = None, + fixed_length: float | None = None, +): + """ + 根據 position 生成 cylinder mask,再算 overlap ratio + """ + + if optimize_size: + d, L = snap_to_discrete_values(pos[5], pos[6]) + params_5 = pos[:5] + else: + if fixed_diameter is None or fixed_length is None: + raise ValueError("fixed_diameter and fixed_length must be provided when optimize_size=False") + d, L = fixed_diameter, fixed_length + params_5 = pos + + z, y, x, az, alt = params_5 + + cyl_mask = generate_cylinder_n_torch( + d, L, + z, y, x, + az, alt, + image_shape, spacing, + device=device, + grid=grid + ) + + overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor) + return overlap, d, L + +