first commit
This commit is contained in:
commit
f84d8963da
40 changed files with 4801 additions and 0 deletions
216
.gitignore
vendored
Normal file
216
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[codz]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
# Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
# poetry.lock
|
||||
# poetry.toml
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
||||
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
||||
# pdm.lock
|
||||
# pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# pixi
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
||||
# pixi.lock
|
||||
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
||||
# in the .venv directory. It is recommended not to include this directory in version control.
|
||||
.pixi
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# Redis
|
||||
*.rdb
|
||||
*.aof
|
||||
*.pid
|
||||
|
||||
# RabbitMQ
|
||||
mnesia/
|
||||
rabbitmq/
|
||||
rabbitmq-data/
|
||||
|
||||
# ActiveMQ
|
||||
activemq-data/
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.envrc
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
# .idea/
|
||||
|
||||
# Abstra
|
||||
# Abstra is an AI-powered process automation framework.
|
||||
# Ignore directories containing user credentials, local state, and settings.
|
||||
# Learn more at https://abstra.io/docs
|
||||
.abstra/
|
||||
|
||||
# Visual Studio Code
|
||||
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
||||
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
||||
# you could uncomment the following to ignore the entire vscode folder
|
||||
# .vscode/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
# Marimo
|
||||
marimo/_static/
|
||||
marimo/_lsp/
|
||||
__marimo__/
|
||||
|
||||
# Streamlit
|
||||
.streamlit/secrets.toml
|
||||
5
.vscode/settings.json
vendored
Normal file
5
.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"${workspaceFolder}"
|
||||
]
|
||||
}
|
||||
1
__init__.py
Normal file
1
__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
#empty
|
||||
0
config/__init__.py
Normal file
0
config/__init__.py
Normal file
18
config/constant.py
Normal file
18
config/constant.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""
|
||||
LABEL_MAP: 一般脊椎與label對應的順序,若是自己標記的需要額外修改定義。
|
||||
ALLOWED_DIAMETERS: 螺絲直徑範圍。
|
||||
ALLOWED_LENGTHS: 螺絲長度範圍。
|
||||
OVERLAP_THRESH: 圓柱體與骨頭接觸的比例(0-1),若低於這個閾值則會重跑。
|
||||
spacing: 影像的spacing。
|
||||
"""
|
||||
|
||||
LABEL_MAP = {
|
||||
1: "C1", 2: "C2", 3: "C3", 4: "C4", 5: "C5", 6: "C6", 7: "C7",
|
||||
8: "T1", 9: "T2", 10: "T3", 11: "T4", 12: "T5", 13: "T6", 14: "T7",
|
||||
15: "T8", 16: "T9", 17: "T10", 18: "T11", 19: "T12",
|
||||
20: "L1", 21: "L2", 22: "L3", 23: "L4", 24: "L5"
|
||||
}
|
||||
ALLOWED_DIAMETERS = [3.5, 4.0, 4.5, 5.0]
|
||||
ALLOWED_LENGTHS = [35, 40, 45, 50]
|
||||
OVERLAP_THRESH = 0.50
|
||||
DEFAULT_SPACING = [0.5, 0.5, 0.5]
|
||||
10
config/device.py
Normal file
10
config/device.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
import torch
|
||||
|
||||
def get_device(gpu_id=0):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print("CUDA not available, using CPU")
|
||||
return device
|
||||
0
core/__init__.py
Normal file
0
core/__init__.py
Normal file
256
core/cylinder.py
Normal file
256
core/cylinder.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||||
|
||||
def create_coordinate_grid(
|
||||
shape: tuple[int, int, int],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float32
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
建立固定的 3D voxel coordinate grid
|
||||
returns:
|
||||
z_t, y_t, x_t with shape = (Z, Y, X)
|
||||
"""
|
||||
z_t = torch.arange(shape[0], device=device, dtype=dtype)
|
||||
y_t = torch.arange(shape[1], device=device, dtype=dtype)
|
||||
x_t = torch.arange(shape[2], device=device, dtype=dtype)
|
||||
|
||||
z_t, y_t, x_t = torch.meshgrid(z_t, y_t, x_t, indexing='ij')
|
||||
return z_t, y_t, x_t
|
||||
|
||||
def snap_to_discrete_values(diameter_raw, length_raw):
|
||||
"""
|
||||
將連續值映射到最接近的允許離散值
|
||||
|
||||
Parameters:
|
||||
diameter_raw: PSO 給的連續直徑值
|
||||
length_raw: PSO 給的連續長度值
|
||||
|
||||
Returns:
|
||||
diameter_discrete, length_discrete: 離散化後的值
|
||||
"""
|
||||
# 找最接近的 diameter
|
||||
diameter_discrete = min(ALLOWED_DIAMETERS, key=lambda x: abs(x - diameter_raw))
|
||||
|
||||
# 找最接近的 length
|
||||
length_discrete = min(ALLOWED_LENGTHS, key=lambda x: abs(x - length_raw))
|
||||
|
||||
return diameter_discrete, length_discrete
|
||||
|
||||
|
||||
def generate_cylinder_n_torch(
|
||||
diameter: float,
|
||||
length: float,
|
||||
position_z: float,
|
||||
position_y: float,
|
||||
position_x: float,
|
||||
azimuth: float,
|
||||
altitude: float,
|
||||
shape: tuple[int, int, int],
|
||||
spacing: list[float],
|
||||
device: torch.device,
|
||||
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
"""
|
||||
Generate a "forward" (positive z) cylinder mask in 3D space using PyTorch tensors.
|
||||
Returns a binary mask (torch.uint8) on the specified device.
|
||||
"""
|
||||
|
||||
if grid is None:
|
||||
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
||||
else:
|
||||
z_t, y_t, x_t = grid
|
||||
|
||||
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
||||
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
||||
|
||||
# Shift
|
||||
z_t = z_t - position_z
|
||||
y_t = y_t - position_y
|
||||
x_t = x_t - position_x
|
||||
|
||||
# Apply rotation (same formula as your NumPy version, but in torch)
|
||||
x_rot = (
|
||||
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||||
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||||
- z_t * torch.sin(altitude_rad_t)
|
||||
)
|
||||
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
||||
z_rot = (
|
||||
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||||
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||||
+ z_t * torch.cos(altitude_rad_t)
|
||||
)
|
||||
|
||||
# Handle spacing
|
||||
# You can expand or generalize for more spacing options
|
||||
if spacing == [1, 1, 1]:
|
||||
radius = diameter / 2.0
|
||||
mask = (
|
||||
(x_rot**2 + y_rot**2 <= radius**2)
|
||||
& (z_rot >= 0)
|
||||
& (z_rot <= length)
|
||||
)
|
||||
elif spacing == [0.5, 0.5, 0.5]:
|
||||
radius = (diameter / 2.0) * 2
|
||||
mask = (
|
||||
(x_rot**2 + y_rot**2 <= radius**2)
|
||||
& (z_rot >= 0)
|
||||
& (z_rot <= length * 2)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported spacing: {spacing}")
|
||||
|
||||
# Convert boolean mask to uint8
|
||||
cylinder_mask = mask.to(torch.uint8)
|
||||
|
||||
return cylinder_mask
|
||||
|
||||
def generate_cylinder_o_torch(
|
||||
diameter: float,
|
||||
length: float,
|
||||
position_z: float,
|
||||
position_y: float,
|
||||
position_x: float,
|
||||
azimuth: float,
|
||||
altitude: float,
|
||||
shape: tuple[int, int, int],
|
||||
spacing: list[float],
|
||||
device: torch.device,
|
||||
grid: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Generate an "opposite" (negative z) cylinder mask in 3D space using PyTorch tensors.
|
||||
Returns a binary mask (torch.uint8) on the specified device.
|
||||
"""
|
||||
|
||||
if grid is None:
|
||||
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
||||
else:
|
||||
z_t, y_t, x_t = grid
|
||||
|
||||
# Convert angles to torch
|
||||
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
||||
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
||||
|
||||
# Shift
|
||||
z_t = z_t - position_z
|
||||
y_t = y_t - position_y
|
||||
x_t = x_t - position_x
|
||||
|
||||
# Apply rotation
|
||||
x_rot = (
|
||||
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||||
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||||
- z_t * torch.sin(altitude_rad_t)
|
||||
)
|
||||
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
||||
z_rot = (
|
||||
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||||
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||||
+ z_t * torch.cos(altitude_rad_t)
|
||||
)
|
||||
|
||||
# Handle spacing
|
||||
if spacing == [1, 1, 1]:
|
||||
radius = diameter / 2.0
|
||||
mask = (
|
||||
(x_rot**2 + y_rot**2 <= radius**2)
|
||||
& (z_rot <= 0)
|
||||
& (z_rot <= length)
|
||||
)
|
||||
elif spacing == [0.5, 0.5, 0.5]:
|
||||
radius = (diameter / 2.0) * 2
|
||||
mask = (
|
||||
(x_rot**2 + y_rot**2 <= radius**2)
|
||||
& (z_rot <= 0)
|
||||
& (z_rot >= -length * 2)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported spacing: {spacing}")
|
||||
|
||||
cylinder_mask_o = mask.to(torch.uint8)
|
||||
|
||||
return cylinder_mask_o
|
||||
|
||||
def generate_cylinder_numpy(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing):
|
||||
|
||||
azimuth = np.radians(azimuth)
|
||||
altitude = np.radians(altitude)
|
||||
|
||||
cylinder_mask = np.zeros(shape, dtype=np.uint8)
|
||||
|
||||
z, y, x = np.mgrid[0:shape[0], 0:shape[1], 0:shape[2]].astype(np.float64)
|
||||
|
||||
z -= float(position_z)
|
||||
y -= float(position_y)
|
||||
x -= float(position_x)
|
||||
|
||||
x_rot = x * np.cos(azimuth) * np.cos(altitude) + y * np.sin(azimuth) * np.cos(altitude) - z * np.sin(altitude)
|
||||
y_rot = -x * np.sin(azimuth) + y * np.cos(azimuth)
|
||||
z_rot = x * np.cos(azimuth) * np.sin(altitude) + y * np.sin(azimuth) * np.sin(altitude) + z * np.cos(altitude)
|
||||
|
||||
if spacing == [1, 1, 1]:
|
||||
|
||||
radius = diameter / 2.0
|
||||
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length)
|
||||
|
||||
elif spacing == [0.5, 0.5, 0.5]:
|
||||
radius = diameter / 2.0 * 2 # *2 is for resampling
|
||||
cylinder = (x_rot**2 + y_rot**2 <= radius**2) & (z_rot >= 0) & (z_rot <= length * 2)
|
||||
|
||||
cylinder_mask[cylinder] = 1
|
||||
|
||||
return cylinder_mask
|
||||
|
||||
def generate_cylinder_tip_torch(
|
||||
diameter, length,
|
||||
position_z, position_y, position_x,
|
||||
azimuth, altitude,
|
||||
shape, spacing, device, grid=None,
|
||||
tip_ratio=0.2 # 取末端 20% 當尖端
|
||||
) -> torch.Tensor:
|
||||
"""只生成圓柱末端的 mask"""
|
||||
|
||||
if grid is None:
|
||||
z_t, y_t, x_t = create_coordinate_grid(shape, device)
|
||||
else:
|
||||
z_t, y_t, x_t = grid
|
||||
|
||||
azimuth_rad_t = torch.deg2rad(torch.tensor(azimuth, device=device, dtype=torch.float32))
|
||||
altitude_rad_t = torch.deg2rad(torch.tensor(altitude, device=device, dtype=torch.float32))
|
||||
|
||||
z_t = z_t - position_z
|
||||
y_t = y_t - position_y
|
||||
x_t = x_t - position_x
|
||||
|
||||
x_rot = (
|
||||
x_t * torch.cos(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||||
+ y_t * torch.sin(azimuth_rad_t) * torch.cos(altitude_rad_t)
|
||||
- z_t * torch.sin(altitude_rad_t)
|
||||
)
|
||||
y_rot = -x_t * torch.sin(azimuth_rad_t) + y_t * torch.cos(azimuth_rad_t)
|
||||
z_rot = (
|
||||
x_t * torch.cos(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||||
+ y_t * torch.sin(azimuth_rad_t) * torch.sin(altitude_rad_t)
|
||||
+ z_t * torch.cos(altitude_rad_t)
|
||||
)
|
||||
|
||||
if spacing == [0.5, 0.5, 0.5]:
|
||||
radius = (diameter / 2.0) * 2
|
||||
total_length = length * 2
|
||||
else:
|
||||
radius = diameter / 2.0
|
||||
total_length = length
|
||||
|
||||
tip_start = total_length * (1 - tip_ratio) # 末端 20% 開始的位置
|
||||
|
||||
mask = (
|
||||
(x_rot**2 + y_rot**2 <= radius**2)
|
||||
& (z_rot >= tip_start) # 只取末端
|
||||
& (z_rot <= total_length)
|
||||
)
|
||||
|
||||
return mask.to(torch.uint8)
|
||||
154
core/intersection.py
Normal file
154
core/intersection.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
def bresenham3d(x1: float, y1: float, z1: float, x2: float, y2: float, z2: float) -> list[tuple[int, int, int]]:
|
||||
"""
|
||||
3D Bresenham line in Python (CPU).
|
||||
We keep this CPU-based because it's discrete and typically not large enough
|
||||
to warrant GPU acceleration.
|
||||
"""
|
||||
x1, y1, z1 = round(x1), round(y1), round(z1)
|
||||
x2, y2, z2 = round(x2), round(y2), round(z2)
|
||||
|
||||
points = [(x1, y1, z1)]
|
||||
dx = abs(x2 - x1)
|
||||
dy = abs(y2 - y1)
|
||||
dz = abs(z2 - z1)
|
||||
|
||||
xs = 1 if x2 > x1 else -1
|
||||
ys = 1 if y2 > y1 else -1
|
||||
zs = 1 if z2 > z1 else -1
|
||||
|
||||
# Driving axis X
|
||||
if dx >= dy and dx >= dz:
|
||||
p1 = 2 * dy - dx
|
||||
p2 = 2 * dz - dx
|
||||
while x1 != x2:
|
||||
x1 += xs
|
||||
if p1 >= 0:
|
||||
y1 += ys
|
||||
p1 -= 2 * dx
|
||||
if p2 >= 0:
|
||||
z1 += zs
|
||||
p2 -= 2 * dx
|
||||
p1 += 2 * dy
|
||||
p2 += 2 * dz
|
||||
points.append((x1, y1, z1))
|
||||
|
||||
# Driving axis Y
|
||||
elif dy >= dx and dy >= dz:
|
||||
p1 = 2 * dx - dy
|
||||
p2 = 2 * dz - dy
|
||||
while y1 != y2:
|
||||
y1 += ys
|
||||
if p1 >= 0:
|
||||
x1 += xs
|
||||
p1 -= 2 * dy
|
||||
if p2 >= 0:
|
||||
z1 += zs
|
||||
p2 -= 2 * dy
|
||||
p1 += 2 * dx
|
||||
p2 += 2 * dz
|
||||
points.append((x1, y1, z1))
|
||||
|
||||
# Driving axis Z
|
||||
else:
|
||||
p1 = 2 * dy - dz
|
||||
p2 = 2 * dx - dz
|
||||
while z1 != z2:
|
||||
z1 += zs
|
||||
if p1 >= 0:
|
||||
y1 += ys
|
||||
p1 -= 2 * dz
|
||||
if p2 >= 0:
|
||||
x1 += xs
|
||||
p2 -= 2 * dz
|
||||
p1 += 2 * dy
|
||||
p2 += 2 * dx
|
||||
points.append((x1, y1, z1))
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def center_line_intersections_torch(
|
||||
position_z: float,
|
||||
position_y: float,
|
||||
position_x: float,
|
||||
azimuth: float,
|
||||
altitude: float,
|
||||
length: float,
|
||||
image_tensor: torch.Tensor,
|
||||
spacing: list[float],
|
||||
device:torch.device
|
||||
) -> tuple[int, torch.Tensor]:
|
||||
"""
|
||||
Computes the number of intersections along the 3D center line in the torch-based spine array.
|
||||
The line generation (Bresenham) is done on CPU, but the intersection counting is done in Torch.
|
||||
Returns (intersections, line_mask_torch).
|
||||
"""
|
||||
azimuth_rad = np.radians(azimuth)
|
||||
altitude_rad = np.radians(altitude)
|
||||
|
||||
if spacing == [1, 1, 1]:
|
||||
length2 = length
|
||||
elif spacing == [0.5, 0.5, 0.5]:
|
||||
length2 = length * 2
|
||||
else:
|
||||
raise ValueError(f"Unsupported spacing: {spacing}")
|
||||
|
||||
# Direction vectors
|
||||
direction_z = np.cos(altitude_rad)
|
||||
direction_y = np.sin(altitude_rad) * np.sin(azimuth_rad)
|
||||
direction_x = np.sin(altitude_rad) * np.cos(azimuth_rad)
|
||||
|
||||
# Endpoints
|
||||
end_point = (
|
||||
position_z + length2 * direction_z,
|
||||
position_y + length2 * direction_y,
|
||||
position_x + length2 * direction_x,
|
||||
)
|
||||
start_opposite = (
|
||||
position_z - length2 * direction_z,
|
||||
position_y - length2 * direction_y,
|
||||
position_x - length2 * direction_x,
|
||||
)
|
||||
|
||||
# Round endpoints
|
||||
start_point = (
|
||||
int(round(position_z)),
|
||||
int(round(position_y)),
|
||||
int(round(position_x)),
|
||||
)
|
||||
end_point = tuple(map(int, np.round(end_point)))
|
||||
start_opposite = tuple(map(int, np.round(start_opposite)))
|
||||
|
||||
# Bresenham line (CPU)
|
||||
line_points = bresenham3d(
|
||||
start_opposite[0], start_opposite[1], start_opposite[2],
|
||||
end_point[0], end_point[1], end_point[2]
|
||||
)
|
||||
|
||||
shape = image_tensor.shape # (z, y, x)
|
||||
line_mask_torch = torch.zeros(shape, device=device, dtype=torch.uint8)
|
||||
|
||||
# Gather values from image_tensor
|
||||
point_values = []
|
||||
for (z_pt, y_pt, x_pt) in line_points:
|
||||
if 0 <= z_pt < shape[0] and 0 <= y_pt < shape[1] and 0 <= x_pt < shape[2]:
|
||||
line_mask_torch[z_pt, y_pt, x_pt] = 1
|
||||
# Convert to CPU to gather value quickly
|
||||
point_values.append(image_tensor[z_pt, y_pt, x_pt].item())
|
||||
|
||||
# Convert to NumPy for the simple difference-based intersection counting
|
||||
point_values_np = np.array(point_values, dtype=np.int32)
|
||||
|
||||
d1 = np.diff(point_values_np)
|
||||
d2 = np.diff(d1)
|
||||
d3 = np.diff(d2)
|
||||
intersections = (
|
||||
np.count_nonzero(np.abs(d1) == 1)
|
||||
- 2 * np.count_nonzero(np.abs(d2) == 2)
|
||||
+ 2 * np.count_nonzero(np.abs(d3) == 4)
|
||||
)
|
||||
|
||||
return intersections, line_mask_torch
|
||||
136
core/objective.py
Normal file
136
core/objective.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import torch
|
||||
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values, generate_cylinder_tip_torch
|
||||
|
||||
from core.intersection import center_line_intersections_torch
|
||||
from core.scoring import cl_score_torch
|
||||
|
||||
# Global variables (used in objective_function)
|
||||
image1_array = None # cortical_nii.gz
|
||||
image2_array = None # binarynii.gz
|
||||
image2_shape = None
|
||||
image3_array = None # roi2.nii.gz
|
||||
diameter = None
|
||||
length = None
|
||||
spacing = [0.5, 0.5, 0.5]
|
||||
device = None
|
||||
grid = None
|
||||
USE_TIP_PENALTY = None
|
||||
|
||||
def set_global_context(
|
||||
cortical,
|
||||
spine,
|
||||
shape,
|
||||
spacing_,
|
||||
device_,
|
||||
grid_,
|
||||
use_tip_penalty=False # 新增
|
||||
):
|
||||
global cortical_tensor, spine_tensor, image2_shape, spacing, device, grid, USE_TIP_PENALTY
|
||||
|
||||
cortical_tensor = cortical
|
||||
spine_tensor = spine
|
||||
image2_shape = shape
|
||||
spacing = spacing_
|
||||
device = device_
|
||||
grid = grid_
|
||||
USE_TIP_PENALTY = use_tip_penalty
|
||||
|
||||
def cylinder_circle_line_intersection_loss_deductions_torch(
|
||||
diameter: float,
|
||||
length: float,
|
||||
params: list[float],
|
||||
image_shape: tuple[int, int, int],
|
||||
cortical_tensor: torch.Tensor,
|
||||
spine_tensor: torch.Tensor,
|
||||
spacing: list[float],
|
||||
device: torch.device
|
||||
) -> float:
|
||||
"""
|
||||
Computes the loss for a given set of cylinder params in PyTorch,
|
||||
returning a Python float for PSO consumption.
|
||||
"""
|
||||
position_z, position_y, position_x, azimuth, altitude = params
|
||||
|
||||
cyl_fwd = generate_cylinder_n_torch(
|
||||
diameter,
|
||||
length,
|
||||
position_z,
|
||||
position_y,
|
||||
position_x,
|
||||
float(azimuth),
|
||||
float(altitude),
|
||||
image_shape,
|
||||
spacing,
|
||||
device,
|
||||
grid
|
||||
)
|
||||
|
||||
cyl_opp = generate_cylinder_o_torch(
|
||||
diameter,
|
||||
length,
|
||||
position_z,
|
||||
position_y,
|
||||
position_x,
|
||||
float(azimuth),
|
||||
float(altitude),
|
||||
image_shape,
|
||||
spacing,
|
||||
device,
|
||||
grid
|
||||
)
|
||||
|
||||
|
||||
# We call the center_line_intersections in Torch mode
|
||||
intersections, _ = center_line_intersections_torch(
|
||||
position_z,
|
||||
position_y,
|
||||
position_x,
|
||||
azimuth,
|
||||
altitude,
|
||||
length,
|
||||
spine_tensor,
|
||||
spacing,
|
||||
device
|
||||
)
|
||||
|
||||
cyl_tip = None
|
||||
if USE_TIP_PENALTY:
|
||||
cyl_tip = generate_cylinder_tip_torch(
|
||||
diameter, length,
|
||||
position_z, position_y, position_x,
|
||||
float(azimuth), float(altitude),
|
||||
image_shape, spacing, device, grid
|
||||
)
|
||||
|
||||
loss_value = cl_score_torch(
|
||||
cortical_tensor, spine_tensor,
|
||||
cyl_fwd, cyl_opp, intersections,
|
||||
cylinder_tip_torch=cyl_tip
|
||||
)
|
||||
|
||||
return loss_value
|
||||
|
||||
def objective_function(params: list[float]) -> float:
|
||||
"""
|
||||
Wrapper for the PSO objective function, calling our Torch-based loss function.
|
||||
Now params includes diameter and length at the end.
|
||||
params = [position_z, position_y, position_x, azimuth, altitude, diameter_raw, length_raw]
|
||||
"""
|
||||
position_params = params[:5] # [z, y, x, azimuth, altitude]
|
||||
diameter_raw = params[5]
|
||||
length_raw = params[6]
|
||||
|
||||
# 將連續值轉換為離散值
|
||||
diameter_discrete, length_discrete = snap_to_discrete_values(diameter_raw, length_raw)
|
||||
|
||||
loss = cylinder_circle_line_intersection_loss_deductions_torch(
|
||||
diameter_discrete,
|
||||
length_discrete,
|
||||
position_params,
|
||||
image2_shape,
|
||||
cortical_tensor,
|
||||
spine_tensor,
|
||||
spacing,
|
||||
device
|
||||
)
|
||||
return loss
|
||||
691
core/optimizer.py
Normal file
691
core/optimizer.py
Normal file
|
|
@ -0,0 +1,691 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
import SimpleITK as sitk
|
||||
import torch
|
||||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||||
from core.objective import objective_function
|
||||
from pyswarm import pso
|
||||
import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組
|
||||
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
|
||||
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
|
||||
from config.constant import OVERLAP_THRESH
|
||||
from visualization.res_plot_3d import res_plt_2_torch
|
||||
|
||||
def run_pso_torch(
|
||||
label_str: str,
|
||||
image1_path: str,
|
||||
image2_path: str,
|
||||
image3_path: str,
|
||||
folder: str,
|
||||
swarm_size: int,
|
||||
max_iter: int,
|
||||
spacing: list,
|
||||
CBT: bool,
|
||||
device: torch.device,
|
||||
optimize_size: bool = True,
|
||||
grid=None
|
||||
):
|
||||
"""
|
||||
Main function to run PSO.
|
||||
如果 optimize_size=True,diameter 和 length 也會被最佳化
|
||||
如果 optimize_size=False,使用預設值(向後兼容)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use global references
|
||||
global image1_array, image2_array, image2_shape, image3_array
|
||||
global diameter, length # 這些現在只用於非最佳化模式
|
||||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||||
|
||||
# Load images
|
||||
image1 = sitk.ReadImage(image1_path)
|
||||
image2 = sitk.ReadImage(image2_path)
|
||||
image3 = sitk.ReadImage(image3_path)
|
||||
image1_array = sitk.GetArrayFromImage(image1)
|
||||
image2_array = sitk.GetArrayFromImage(image2)
|
||||
image3_array = sitk.GetArrayFromImage(image3)
|
||||
image2_shape = image2_array.shape
|
||||
image_shape = image2_shape
|
||||
|
||||
# Move arrays to torch
|
||||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||||
|
||||
# ================= [跨檔案注入變數:終極防呆版] =================
|
||||
import core.objective
|
||||
|
||||
# 1. 注入 Tensors
|
||||
core.objective.cortical_tensor = cortical_tensor
|
||||
core.objective.spine_tensor = spine_tensor
|
||||
core.objective.spine_roi_tensor = spine_roi_tensor
|
||||
|
||||
# 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列)
|
||||
core.objective.image1_array = image1_array
|
||||
core.objective.image2_array = image2_array
|
||||
core.objective.image3_array = image3_array
|
||||
|
||||
# 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!)
|
||||
core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行
|
||||
core.objective.image_shape = image_shape
|
||||
core.objective.shape = image_shape
|
||||
|
||||
# 4. 注入環境變數
|
||||
core.objective.spacing = spacing
|
||||
core.objective.device = device
|
||||
core.objective.grid = grid
|
||||
|
||||
# 5. 注入尺寸參數 (兼容固定尺寸模式)
|
||||
if not optimize_size:
|
||||
core.objective.diameter = diameter
|
||||
core.objective.length = length
|
||||
# ==============================================================
|
||||
|
||||
azi = azimuth_rotation(image2_path)
|
||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||
alt = res['superior']['tilt_angle_deg']
|
||||
|
||||
# 設定基本的 bounds
|
||||
if CBT == True:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||||
altitude_bounds = ((60-alt), (75-alt))
|
||||
else:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
azimuth_bounds_l = (60-azi, 90-azi)
|
||||
azimuth_bounds_r = (90-azi, 120-azi)
|
||||
altitude_bounds = (65-alt, 80-alt)
|
||||
|
||||
def eval_overlap_from_position(pos, side: str, optimize_size: bool,
|
||||
spine_tensor: torch.Tensor,
|
||||
image_shape, spacing):
|
||||
"""
|
||||
根據 PSO 給的 position 生成 cylinder mask,再算 overlap ratio
|
||||
side: "L" or "R" 只是方便 debug
|
||||
"""
|
||||
if optimize_size:
|
||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||
params_5 = pos[:5]
|
||||
else:
|
||||
d, L = diameter, length
|
||||
params_5 = pos
|
||||
|
||||
cyl_mask = generate_cylinder_n_torch(
|
||||
d, L,
|
||||
params_5[0], params_5[1], params_5[2],
|
||||
params_5[3], params_5[4],
|
||||
image_shape, spacing, device, grid
|
||||
)
|
||||
|
||||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||||
return overlap, d, L
|
||||
|
||||
if optimize_size:
|
||||
# 模式 1:優化 diameter 和 length
|
||||
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
||||
|
||||
# 設定 diameter 和 length 的 bounds(連續範圍)
|
||||
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
|
||||
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
|
||||
|
||||
# bounds 現在有 7 個參數
|
||||
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
|
||||
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||||
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
|
||||
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||||
|
||||
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
|
||||
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||||
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
|
||||
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||||
|
||||
else:
|
||||
# 模式 2:固定 diameter 和 length(向後兼容)
|
||||
print("=== 固定尺寸模式:最佳化位置和角度 ===")
|
||||
# 使用預設值(需要在調用時提供)
|
||||
diameter = 4.5 # 或從參數傳入
|
||||
length = 45 # 或從參數傳入
|
||||
|
||||
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]]
|
||||
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]]
|
||||
|
||||
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
|
||||
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
|
||||
|
||||
best_loss_l = float('inf')
|
||||
best_loss_r = float('inf')
|
||||
best_position_l = None
|
||||
best_position_r = None
|
||||
|
||||
# Left side optimization
|
||||
print("\n=== 左側 ===")
|
||||
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
|
||||
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||||
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||||
)
|
||||
print(f"[LEFT] overlap: {overlap_l*100:.1f}%")
|
||||
|
||||
if optimize_size:
|
||||
print(f"[LEFT] Position: {position_l[:5]}")
|
||||
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
|
||||
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
|
||||
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
|
||||
else:
|
||||
print(f"[LEFT] Position: {position_l}")
|
||||
best_position_l = position_l
|
||||
|
||||
best_loss_l = loss_l
|
||||
best_overlap_l = overlap_l # 新增
|
||||
|
||||
# max_retries = 0
|
||||
# retries = 0
|
||||
|
||||
# 左側 retry:loss 要 <=0 且 overlap >= 0.5 才算過關
|
||||
# while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries:
|
||||
# position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
|
||||
# overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||||
# position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||||
# )
|
||||
|
||||
# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best
|
||||
# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
|
||||
# candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l
|
||||
|
||||
# candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH)
|
||||
# best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)
|
||||
|
||||
# if candidate_ok and (not best_ok or loss_l < best_loss_l):
|
||||
# best_position_l = candidate_pos
|
||||
# best_loss_l = loss_l
|
||||
# best_overlap_l = overlap_l
|
||||
# print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||
# elif (not best_ok) and (loss_l < best_loss_l):
|
||||
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
|
||||
# best_position_l = candidate_pos
|
||||
# best_loss_l = loss_l
|
||||
# best_overlap_l = overlap_l
|
||||
# print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||
# else:
|
||||
# print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||
|
||||
# retries += 1
|
||||
|
||||
# Right side optimization
|
||||
print("\n=== 右側 ===")
|
||||
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||||
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||||
)
|
||||
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
|
||||
|
||||
if optimize_size:
|
||||
diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
|
||||
print(f"[RIGHT] Position: {position_r[:5]}")
|
||||
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
|
||||
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
|
||||
print(f"[RIGHT] Loss: {loss_r}\n")
|
||||
|
||||
best_position_r = list(position_r[:5]) + [diameter_r, length_r]
|
||||
else:
|
||||
print(f"[RIGHT] Position: {position_r}")
|
||||
print(f"[RIGHT] Loss: {loss_r}\n")
|
||||
best_position_r = position_r
|
||||
|
||||
best_loss_r = loss_r
|
||||
best_overlap_r = overlap_r
|
||||
|
||||
# 如果需要 retry(loss > 0)
|
||||
# max_retries = 10
|
||||
# retries = 0
|
||||
|
||||
# while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries:
|
||||
# position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
|
||||
# overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||||
# position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||||
# )
|
||||
|
||||
# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best
|
||||
# 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
|
||||
# candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r
|
||||
|
||||
# candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH)
|
||||
# best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)
|
||||
|
||||
# if candidate_ok and (not best_ok or loss_r < best_loss_r):
|
||||
# best_position_r = candidate_pos
|
||||
# best_loss_r = loss_r
|
||||
# best_overlap_r = overlap_r
|
||||
# print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||
# elif (not best_ok) and (loss_r < best_loss_r):
|
||||
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
|
||||
# best_position_r = candidate_pos
|
||||
# best_loss_r = loss_r
|
||||
# best_overlap_r = overlap_r
|
||||
# print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||
# else:
|
||||
# print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||
|
||||
# retries += 1
|
||||
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
# 提取最終的 diameter 和 length
|
||||
if optimize_size:
|
||||
final_diameter_l = best_position_l[5]
|
||||
final_length_l = best_position_l[6]
|
||||
final_diameter_r = best_position_r[5]
|
||||
final_length_r = best_position_r[6]
|
||||
|
||||
print(f"\n=== 最終結果 ===")
|
||||
print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm")
|
||||
print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm")
|
||||
else:
|
||||
final_diameter_l = diameter
|
||||
final_length_l = length
|
||||
final_diameter_r = diameter
|
||||
final_length_r = length
|
||||
|
||||
res_plt_2_torch(
|
||||
spine_tensor,
|
||||
cortical_tensor,
|
||||
image_shape,
|
||||
image2_path,
|
||||
'Output',
|
||||
label_str,
|
||||
final_diameter_l,
|
||||
final_length_l,
|
||||
final_diameter_r,
|
||||
final_length_r,
|
||||
best_position_l,
|
||||
best_position_r,
|
||||
swarm_size,
|
||||
max_iter,
|
||||
total_time,
|
||||
spacing,
|
||||
CBT,
|
||||
device,
|
||||
grid)
|
||||
|
||||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
import torch
|
||||
from scipy.optimize import differential_evolution
|
||||
from scipy.optimize import minimize
|
||||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||||
from core.objective import objective_function
|
||||
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
|
||||
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
|
||||
from config.constant import OVERLAP_THRESH
|
||||
from visualization.res_plot_3d import res_plt_2_torch
|
||||
|
||||
def run_de_torch(
|
||||
label_str: str,
|
||||
image1_path: str,
|
||||
image2_path: str,
|
||||
image3_path: str,
|
||||
folder: str,
|
||||
swarm_size: int,
|
||||
max_iter: int,
|
||||
spacing: list,
|
||||
CBT: bool,
|
||||
device: torch.device,
|
||||
optimize_size: bool = True,
|
||||
grid=None
|
||||
):
|
||||
"""
|
||||
使用 Differential Evolution (DE) 進行最佳化
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
global image1_array, image2_array, image2_shape, image3_array
|
||||
global diameter, length
|
||||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||||
|
||||
image1 = sitk.ReadImage(image1_path)
|
||||
image2 = sitk.ReadImage(image2_path)
|
||||
image3 = sitk.ReadImage(image3_path)
|
||||
image1_array = sitk.GetArrayFromImage(image1)
|
||||
image2_array = sitk.GetArrayFromImage(image2)
|
||||
image3_array = sitk.GetArrayFromImage(image3)
|
||||
image2_shape = image2_array.shape
|
||||
image_shape = image2_shape
|
||||
|
||||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||||
|
||||
# ================= [跨檔案注入變數:終極防呆版] =================
|
||||
import core.objective
|
||||
|
||||
# 1. 注入 Tensors
|
||||
core.objective.cortical_tensor = cortical_tensor
|
||||
core.objective.spine_tensor = spine_tensor
|
||||
core.objective.spine_roi_tensor = spine_roi_tensor
|
||||
|
||||
# 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列)
|
||||
core.objective.image1_array = image1_array
|
||||
core.objective.image2_array = image2_array
|
||||
core.objective.image3_array = image3_array
|
||||
|
||||
# 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!)
|
||||
core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行
|
||||
core.objective.image_shape = image_shape
|
||||
core.objective.shape = image_shape
|
||||
|
||||
# 4. 注入環境變數
|
||||
core.objective.spacing = spacing
|
||||
core.objective.device = device
|
||||
core.objective.grid = grid
|
||||
|
||||
# 5. 注入尺寸參數 (兼容固定尺寸模式)
|
||||
if not optimize_size:
|
||||
core.objective.diameter = diameter
|
||||
core.objective.length = length
|
||||
# ==============================================================
|
||||
|
||||
azi = azimuth_rotation(image2_path)
|
||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||
alt = res['superior']['tilt_angle_deg']
|
||||
|
||||
if CBT == True:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||||
altitude_bounds = ((60-alt), (75-alt))
|
||||
else:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
azimuth_bounds_l = (60-azi, 90-azi)
|
||||
azimuth_bounds_r = (90-azi, 120-azi)
|
||||
altitude_bounds = (65-alt, 80-alt)
|
||||
|
||||
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
||||
if optimize_size:
|
||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||
params_5 = pos[:5]
|
||||
else:
|
||||
d, L = diameter, length
|
||||
params_5 = pos
|
||||
|
||||
cyl_mask = generate_cylinder_n_torch(
|
||||
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
|
||||
image_shape, spacing, device, grid
|
||||
)
|
||||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||||
return overlap, d, L
|
||||
|
||||
if optimize_size:
|
||||
print("=== DE 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
||||
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
|
||||
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
|
||||
|
||||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, diameter_bounds, length_bounds]
|
||||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, diameter_bounds, length_bounds]
|
||||
else:
|
||||
print("=== DE 固定尺寸模式:最佳化位置和角度 ===")
|
||||
diameter = 4.5
|
||||
length = 45
|
||||
|
||||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
|
||||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
|
||||
|
||||
# DE 的 popsize 實際粒子數 = popsize * len(bounds)
|
||||
# 為了跟 PSO 公平比較,我們讓它轉換一下
|
||||
de_popsize = max(1, swarm_size // len(bounds_l))
|
||||
|
||||
# --- 左側最佳化 ---
|
||||
print("\n=== 左側 (DE) ===")
|
||||
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
|
||||
position_l, loss_l = res_l.x, res_l.fun
|
||||
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||||
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
|
||||
best_loss_l, best_overlap_l = loss_l, overlap_l
|
||||
"""
|
||||
retries = 0
|
||||
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
|
||||
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
|
||||
position_l, loss_l = res_l.x, res_l.fun
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||||
|
||||
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
|
||||
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
|
||||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||||
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
|
||||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||||
retries += 1
|
||||
"""
|
||||
# --- 右側最佳化 ---
|
||||
print("\n=== 右側 (DE) ===")
|
||||
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
|
||||
position_r, loss_r = res_r.x, res_r.fun
|
||||
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||||
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
|
||||
best_loss_r, best_overlap_r = loss_r, overlap_r
|
||||
"""
|
||||
retries = 0
|
||||
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
|
||||
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
|
||||
position_r, loss_r = res_r.x, res_r.fun
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||||
|
||||
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
|
||||
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
|
||||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||||
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
|
||||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||||
retries += 1
|
||||
"""
|
||||
total_time = time.time() - start_time
|
||||
|
||||
final_diameter_l = best_position_l[5] if optimize_size else diameter
|
||||
final_length_l = best_position_l[6] if optimize_size else length
|
||||
final_diameter_r = best_position_r[5] if optimize_size else diameter
|
||||
final_length_r = best_position_r[6] if optimize_size else length
|
||||
|
||||
res_plt_2_torch(
|
||||
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
|
||||
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
|
||||
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
|
||||
)
|
||||
|
||||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||||
|
||||
def run_nm_torch(
|
||||
label_str: str,
|
||||
image1_path: str,
|
||||
image2_path: str,
|
||||
image3_path: str,
|
||||
folder: str,
|
||||
swarm_size: int, # NM 不用 swarm_size,但保留參數以維持介面統一
|
||||
max_iter: int,
|
||||
spacing: list,
|
||||
CBT: bool,
|
||||
device: torch.device,
|
||||
optimize_size: bool = True,
|
||||
grid=None
|
||||
):
|
||||
"""
|
||||
使用 Nelder-Mead 進行最佳化
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
global image1_array, image2_array, image2_shape, image3_array
|
||||
global diameter, length
|
||||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||||
|
||||
image1 = sitk.ReadImage(image1_path)
|
||||
image2 = sitk.ReadImage(image2_path)
|
||||
image3 = sitk.ReadImage(image3_path)
|
||||
image1_array = sitk.GetArrayFromImage(image1)
|
||||
image2_array = sitk.GetArrayFromImage(image2)
|
||||
image3_array = sitk.GetArrayFromImage(image3)
|
||||
image2_shape = image2_array.shape
|
||||
image_shape = image2_shape
|
||||
|
||||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||||
# ================= [跨檔案注入變數:終極防呆版] =================
|
||||
import core.objective
|
||||
|
||||
# 1. 注入 Tensors
|
||||
core.objective.cortical_tensor = cortical_tensor
|
||||
core.objective.spine_tensor = spine_tensor
|
||||
core.objective.spine_roi_tensor = spine_roi_tensor
|
||||
|
||||
# 2. 注入 Arrays (以防 objective 裡面偷偷用到 Numpy 陣列)
|
||||
core.objective.image1_array = image1_array
|
||||
core.objective.image2_array = image2_array
|
||||
core.objective.image3_array = image3_array
|
||||
|
||||
# 3. 注入 Shapes (這就是導致這次 NoneType 報錯的真兇!)
|
||||
core.objective.image2_shape = image2_shape # <--- 解除警報的最關鍵一行
|
||||
core.objective.image_shape = image_shape
|
||||
core.objective.shape = image_shape
|
||||
|
||||
# 4. 注入環境變數
|
||||
core.objective.spacing = spacing
|
||||
core.objective.device = device
|
||||
core.objective.grid = grid
|
||||
|
||||
# 5. 注入尺寸參數 (兼容固定尺寸模式)
|
||||
if not optimize_size:
|
||||
core.objective.diameter = diameter
|
||||
core.objective.length = length
|
||||
# ==============================================================
|
||||
|
||||
azi = azimuth_rotation(image2_path)
|
||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||
alt = res['superior']['tilt_angle_deg']
|
||||
|
||||
if CBT == True:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||||
altitude_bounds = ((60-alt), (75-alt))
|
||||
else:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
azimuth_bounds_l = (60-azi, 90-azi)
|
||||
azimuth_bounds_r = (90-azi, 120-azi)
|
||||
altitude_bounds = (65-alt, 80-alt)
|
||||
|
||||
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
||||
if optimize_size:
|
||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||
params_5 = pos[:5]
|
||||
else:
|
||||
d, L = diameter, length
|
||||
params_5 = pos
|
||||
|
||||
cyl_mask = generate_cylinder_n_torch(
|
||||
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
|
||||
image_shape, spacing, device, grid
|
||||
)
|
||||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||||
return overlap, d, L
|
||||
|
||||
if optimize_size:
|
||||
print("=== NM 最佳化模式 ===")
|
||||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds,
|
||||
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
|
||||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds,
|
||||
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
|
||||
else:
|
||||
print("=== NM 固定尺寸模式 ===")
|
||||
diameter, length = 4.5, 45
|
||||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
|
||||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
|
||||
|
||||
def get_random_x0(bounds):
|
||||
# 產生在 Bounds 內的隨機起始點
|
||||
return [np.random.uniform(b[0], b[1]) for b in bounds]
|
||||
|
||||
# --- 左側最佳化 ---
|
||||
print("\n=== 左側 (Nelder-Mead) ===")
|
||||
x0_l = get_random_x0(bounds_l)
|
||||
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
|
||||
position_l, loss_l = res_l.x, res_l.fun
|
||||
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||||
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
|
||||
best_loss_l, best_overlap_l = loss_l, overlap_l
|
||||
|
||||
retries = 0
|
||||
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
|
||||
x0_l = get_random_x0(bounds_l) # 每次 retry 都換一個隨機起始點
|
||||
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
|
||||
position_l, loss_l = res_l.x, res_l.fun
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||||
|
||||
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
|
||||
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
|
||||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||||
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
|
||||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||||
retries += 1
|
||||
|
||||
# --- 右側最佳化 ---
|
||||
print("\n=== 右側 (Nelder-Mead) ===")
|
||||
x0_r = get_random_x0(bounds_r)
|
||||
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
|
||||
position_r, loss_r = res_r.x, res_r.fun
|
||||
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||||
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
|
||||
best_loss_r, best_overlap_r = loss_r, overlap_r
|
||||
|
||||
retries = 0
|
||||
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
|
||||
x0_r = get_random_x0(bounds_r)
|
||||
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
|
||||
position_r, loss_r = res_r.x, res_r.fun
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||||
|
||||
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
|
||||
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
|
||||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||||
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
|
||||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||||
retries += 1
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
final_diameter_l = best_position_l[5] if optimize_size else diameter
|
||||
final_length_l = best_position_l[6] if optimize_size else length
|
||||
final_diameter_r = best_position_r[5] if optimize_size else diameter
|
||||
final_length_r = best_position_r[6] if optimize_size else length
|
||||
|
||||
res_plt_2_torch(
|
||||
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
|
||||
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
|
||||
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
|
||||
)
|
||||
|
||||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||||
604
core/optimizer_ori.py
Normal file
604
core/optimizer_ori.py
Normal file
|
|
@ -0,0 +1,604 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
import SimpleITK as sitk
|
||||
import torch
|
||||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||||
from core.objective import objective_function
|
||||
from pyswarm import pso
|
||||
import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組
|
||||
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
|
||||
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
|
||||
from config.constant import OVERLAP_THRESH
|
||||
from visualization.res_plot_3d import res_plt_2_torch
|
||||
|
||||
def run_pso_torch(
|
||||
label_str: str,
|
||||
image1_path: str,
|
||||
image2_path: str,
|
||||
image3_path: str,
|
||||
folder: str,
|
||||
swarm_size: int,
|
||||
max_iter: int,
|
||||
spacing: list,
|
||||
CBT: bool,
|
||||
device: torch.device,
|
||||
optimize_size: bool = True,
|
||||
grid=None
|
||||
):
|
||||
"""
|
||||
Main function to run PSO.
|
||||
如果 optimize_size=True,diameter 和 length 也會被最佳化
|
||||
如果 optimize_size=False,使用預設值(向後兼容)
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Use global references
|
||||
global image1_array, image2_array, image2_shape, image3_array
|
||||
global diameter, length # 這些現在只用於非最佳化模式
|
||||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||||
|
||||
# Load images
|
||||
image1 = sitk.ReadImage(image1_path)
|
||||
image2 = sitk.ReadImage(image2_path)
|
||||
image3 = sitk.ReadImage(image3_path)
|
||||
image1_array = sitk.GetArrayFromImage(image1)
|
||||
image2_array = sitk.GetArrayFromImage(image2)
|
||||
image3_array = sitk.GetArrayFromImage(image3)
|
||||
image2_shape = image2_array.shape
|
||||
image_shape = image2_shape
|
||||
|
||||
# Move arrays to torch
|
||||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||||
|
||||
azi = azimuth_rotation(image2_path)
|
||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||
alt = res['superior']['tilt_angle_deg']
|
||||
|
||||
# 設定基本的 bounds
|
||||
if CBT == True:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||||
altitude_bounds = ((60-alt), (75-alt))
|
||||
else:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
azimuth_bounds_l = (60-azi, 90-azi)
|
||||
azimuth_bounds_r = (90-azi, 120-azi)
|
||||
altitude_bounds = (65-alt, 80-alt)
|
||||
|
||||
def eval_overlap_from_position(pos, side: str, optimize_size: bool,
|
||||
spine_tensor: torch.Tensor,
|
||||
image_shape, spacing):
|
||||
"""
|
||||
根據 PSO 給的 position 生成 cylinder mask,再算 overlap ratio
|
||||
side: "L" or "R" 只是方便 debug
|
||||
"""
|
||||
if optimize_size:
|
||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||
params_5 = pos[:5]
|
||||
else:
|
||||
d, L = diameter, length
|
||||
params_5 = pos
|
||||
|
||||
cyl_mask = generate_cylinder_n_torch(
|
||||
d, L,
|
||||
params_5[0], params_5[1], params_5[2],
|
||||
params_5[3], params_5[4],
|
||||
image_shape, spacing, device, grid
|
||||
)
|
||||
|
||||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||||
return overlap, d, L
|
||||
|
||||
if optimize_size:
|
||||
# 模式 1:優化 diameter 和 length
|
||||
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
||||
|
||||
# 設定 diameter 和 length 的 bounds(連續範圍)
|
||||
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
|
||||
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
|
||||
|
||||
# bounds 現在有 7 個參數
|
||||
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
|
||||
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||||
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
|
||||
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||||
|
||||
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
|
||||
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
|
||||
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
|
||||
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
|
||||
|
||||
else:
|
||||
# 模式 2:固定 diameter 和 length(向後兼容)
|
||||
print("=== 固定尺寸模式:最佳化位置和角度 ===")
|
||||
# 使用預設值(需要在調用時提供)
|
||||
diameter = 4.5 # 或從參數傳入
|
||||
length = 45 # 或從參數傳入
|
||||
|
||||
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]]
|
||||
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]]
|
||||
|
||||
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
|
||||
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
|
||||
|
||||
best_loss_l = float('inf')
|
||||
best_loss_r = float('inf')
|
||||
best_position_l = None
|
||||
best_position_r = None
|
||||
|
||||
# Left side optimization
|
||||
print("\n=== 左側 ===")
|
||||
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
|
||||
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||||
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||||
)
|
||||
print(f"[LEFT] overlap: {overlap_l*100:.1f}%")
|
||||
|
||||
if optimize_size:
|
||||
print(f"[LEFT] Position: {position_l[:5]}")
|
||||
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
|
||||
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
|
||||
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
|
||||
else:
|
||||
print(f"[LEFT] Position: {position_l}")
|
||||
best_position_l = position_l
|
||||
|
||||
best_loss_l = loss_l
|
||||
best_overlap_l = overlap_l # 新增
|
||||
|
||||
max_retries = 10
|
||||
retries = 0
|
||||
|
||||
# 左側 retry:loss 要 <=0 且 overlap >= 0.5 才算過關
|
||||
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries:
|
||||
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(
|
||||
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
|
||||
)
|
||||
|
||||
# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best
|
||||
# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
|
||||
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l
|
||||
|
||||
candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH)
|
||||
best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)
|
||||
|
||||
if candidate_ok and (not best_ok or loss_l < best_loss_l):
|
||||
best_position_l = candidate_pos
|
||||
best_loss_l = loss_l
|
||||
best_overlap_l = overlap_l
|
||||
print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||
elif (not best_ok) and (loss_l < best_loss_l):
|
||||
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
|
||||
best_position_l = candidate_pos
|
||||
best_loss_l = loss_l
|
||||
best_overlap_l = overlap_l
|
||||
print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||
else:
|
||||
print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
|
||||
|
||||
retries += 1
|
||||
|
||||
# Right side optimization
|
||||
print("\n=== 右側 ===")
|
||||
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||||
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||||
)
|
||||
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
|
||||
|
||||
if optimize_size:
|
||||
diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
|
||||
print(f"[RIGHT] Position: {position_r[:5]}")
|
||||
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
|
||||
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
|
||||
print(f"[RIGHT] Loss: {loss_r}\n")
|
||||
|
||||
best_position_r = list(position_r[:5]) + [diameter_r, length_r]
|
||||
else:
|
||||
print(f"[RIGHT] Position: {position_r}")
|
||||
print(f"[RIGHT] Loss: {loss_r}\n")
|
||||
best_position_r = position_r
|
||||
|
||||
best_loss_r = loss_r
|
||||
best_overlap_r = overlap_r
|
||||
|
||||
# 如果需要 retry(loss > 0)
|
||||
max_retries = 10
|
||||
retries = 0
|
||||
|
||||
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries:
|
||||
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(
|
||||
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
|
||||
)
|
||||
|
||||
# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best
|
||||
# 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
|
||||
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r
|
||||
|
||||
candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH)
|
||||
best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)
|
||||
|
||||
if candidate_ok and (not best_ok or loss_r < best_loss_r):
|
||||
best_position_r = candidate_pos
|
||||
best_loss_r = loss_r
|
||||
best_overlap_r = overlap_r
|
||||
print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||
elif (not best_ok) and (loss_r < best_loss_r):
|
||||
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
|
||||
best_position_r = candidate_pos
|
||||
best_loss_r = loss_r
|
||||
best_overlap_r = overlap_r
|
||||
print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||
else:
|
||||
print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
|
||||
|
||||
retries += 1
|
||||
end_time = time.time()
|
||||
total_time = end_time - start_time
|
||||
|
||||
# 提取最終的 diameter 和 length
|
||||
if optimize_size:
|
||||
final_diameter_l = best_position_l[5]
|
||||
final_length_l = best_position_l[6]
|
||||
final_diameter_r = best_position_r[5]
|
||||
final_length_r = best_position_r[6]
|
||||
|
||||
print(f"\n=== 最終結果 ===")
|
||||
print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm")
|
||||
print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm")
|
||||
else:
|
||||
final_diameter_l = diameter
|
||||
final_length_l = length
|
||||
final_diameter_r = diameter
|
||||
final_length_r = length
|
||||
|
||||
res_plt_2_torch(
|
||||
spine_tensor,
|
||||
cortical_tensor,
|
||||
image_shape,
|
||||
image2_path,
|
||||
'Output',
|
||||
label_str,
|
||||
final_diameter_l,
|
||||
final_length_l,
|
||||
final_diameter_r,
|
||||
final_length_r,
|
||||
best_position_l,
|
||||
best_position_r,
|
||||
swarm_size,
|
||||
max_iter,
|
||||
total_time,
|
||||
spacing,
|
||||
CBT,
|
||||
device,
|
||||
grid)
|
||||
|
||||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
import torch
|
||||
from scipy.optimize import differential_evolution
|
||||
from scipy.optimize import minimize
|
||||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||||
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
|
||||
from core.objective import objective_function
|
||||
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
|
||||
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
|
||||
from config.constant import OVERLAP_THRESH
|
||||
from visualization.res_plot_3d import res_plt_2_torch
|
||||
|
||||
def run_de_torch(
|
||||
label_str: str,
|
||||
image1_path: str,
|
||||
image2_path: str,
|
||||
image3_path: str,
|
||||
folder: str,
|
||||
swarm_size: int,
|
||||
max_iter: int,
|
||||
spacing: list,
|
||||
CBT: bool,
|
||||
device: torch.device,
|
||||
optimize_size: bool = True,
|
||||
grid=None
|
||||
):
|
||||
"""
|
||||
使用 Differential Evolution (DE) 進行最佳化
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
global image1_array, image2_array, image2_shape, image3_array
|
||||
global diameter, length
|
||||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||||
|
||||
image1 = sitk.ReadImage(image1_path)
|
||||
image2 = sitk.ReadImage(image2_path)
|
||||
image3 = sitk.ReadImage(image3_path)
|
||||
image1_array = sitk.GetArrayFromImage(image1)
|
||||
image2_array = sitk.GetArrayFromImage(image2)
|
||||
image3_array = sitk.GetArrayFromImage(image3)
|
||||
image2_shape = image2_array.shape
|
||||
image_shape = image2_shape
|
||||
|
||||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||||
|
||||
azi = azimuth_rotation(image2_path)
|
||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||
alt = res['superior']['tilt_angle_deg']
|
||||
|
||||
if CBT == True:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||||
altitude_bounds = ((60-alt), (75-alt))
|
||||
else:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
azimuth_bounds_l = (60-azi, 90-azi)
|
||||
azimuth_bounds_r = (90-azi, 120-azi)
|
||||
altitude_bounds = (65-alt, 80-alt)
|
||||
|
||||
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
||||
if optimize_size:
|
||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||
params_5 = pos[:5]
|
||||
else:
|
||||
d, L = diameter, length
|
||||
params_5 = pos
|
||||
|
||||
cyl_mask = generate_cylinder_n_torch(
|
||||
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
|
||||
image_shape, spacing, device, grid
|
||||
)
|
||||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||||
return overlap, d, L
|
||||
|
||||
if optimize_size:
|
||||
print("=== DE 最佳化模式:最佳化位置、角度、直徑和長度 ===")
|
||||
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
|
||||
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
|
||||
|
||||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, diameter_bounds, length_bounds]
|
||||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, diameter_bounds, length_bounds]
|
||||
else:
|
||||
print("=== DE 固定尺寸模式:最佳化位置和角度 ===")
|
||||
diameter = 4.5
|
||||
length = 45
|
||||
|
||||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
|
||||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
|
||||
|
||||
# DE 的 popsize 實際粒子數 = popsize * len(bounds)
|
||||
# 為了跟 PSO 公平比較,我們讓它轉換一下
|
||||
de_popsize = max(1, swarm_size // len(bounds_l))
|
||||
|
||||
# --- 左側最佳化 ---
|
||||
print("\n=== 左側 (DE) ===")
|
||||
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
|
||||
position_l, loss_l = res_l.x, res_l.fun
|
||||
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||||
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
|
||||
best_loss_l, best_overlap_l = loss_l, overlap_l
|
||||
|
||||
retries = 0
|
||||
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
|
||||
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
|
||||
position_l, loss_l = res_l.x, res_l.fun
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||||
|
||||
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
|
||||
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
|
||||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||||
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
|
||||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||||
retries += 1
|
||||
|
||||
# --- 右側最佳化 ---
|
||||
print("\n=== 右側 (DE) ===")
|
||||
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
|
||||
position_r, loss_r = res_r.x, res_r.fun
|
||||
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||||
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
|
||||
best_loss_r, best_overlap_r = loss_r, overlap_r
|
||||
|
||||
retries = 0
|
||||
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
|
||||
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
|
||||
position_r, loss_r = res_r.x, res_r.fun
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||||
|
||||
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
|
||||
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
|
||||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||||
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
|
||||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||||
retries += 1
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
final_diameter_l = best_position_l[5] if optimize_size else diameter
|
||||
final_length_l = best_position_l[6] if optimize_size else length
|
||||
final_diameter_r = best_position_r[5] if optimize_size else diameter
|
||||
final_length_r = best_position_r[6] if optimize_size else length
|
||||
|
||||
res_plt_2_torch(
|
||||
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
|
||||
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
|
||||
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
|
||||
)
|
||||
|
||||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||||
|
||||
def run_nm_torch(
|
||||
label_str: str,
|
||||
image1_path: str,
|
||||
image2_path: str,
|
||||
image3_path: str,
|
||||
folder: str,
|
||||
swarm_size: int, # NM 不用 swarm_size,但保留參數以維持介面統一
|
||||
max_iter: int,
|
||||
spacing: list,
|
||||
CBT: bool,
|
||||
device: torch.device,
|
||||
optimize_size: bool = True,
|
||||
grid=None
|
||||
):
|
||||
"""
|
||||
使用 Nelder-Mead 進行最佳化
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
global image1_array, image2_array, image2_shape, image3_array
|
||||
global diameter, length
|
||||
global spine_tensor, cortical_tensor, spine_roi_tensor
|
||||
|
||||
image1 = sitk.ReadImage(image1_path)
|
||||
image2 = sitk.ReadImage(image2_path)
|
||||
image3 = sitk.ReadImage(image3_path)
|
||||
image1_array = sitk.GetArrayFromImage(image1)
|
||||
image2_array = sitk.GetArrayFromImage(image2)
|
||||
image3_array = sitk.GetArrayFromImage(image3)
|
||||
image2_shape = image2_array.shape
|
||||
image_shape = image2_shape
|
||||
|
||||
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
|
||||
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
|
||||
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
|
||||
|
||||
azi = azimuth_rotation(image2_path)
|
||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||
alt = res['superior']['tilt_angle_deg']
|
||||
|
||||
if CBT == True:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
azimuth_bounds_l = ((95-azi), (145-azi))
|
||||
azimuth_bounds_r = ((50-azi), (85-azi))
|
||||
altitude_bounds = ((60-alt), (75-alt))
|
||||
else:
|
||||
z_bounds = (0, image_shape[0] - 1)
|
||||
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
|
||||
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
|
||||
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
|
||||
azimuth_bounds_l = (60-azi, 90-azi)
|
||||
azimuth_bounds_r = (90-azi, 120-azi)
|
||||
altitude_bounds = (65-alt, 80-alt)
|
||||
|
||||
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
|
||||
if optimize_size:
|
||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||
params_5 = pos[:5]
|
||||
else:
|
||||
d, L = diameter, length
|
||||
params_5 = pos
|
||||
|
||||
cyl_mask = generate_cylinder_n_torch(
|
||||
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
|
||||
image_shape, spacing, device, grid
|
||||
)
|
||||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||||
return overlap, d, L
|
||||
|
||||
if optimize_size:
|
||||
print("=== NM 最佳化模式 ===")
|
||||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds,
|
||||
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
|
||||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds,
|
||||
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
|
||||
else:
|
||||
print("=== NM 固定尺寸模式 ===")
|
||||
diameter, length = 4.5, 45
|
||||
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
|
||||
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
|
||||
|
||||
def get_random_x0(bounds):
|
||||
# 產生在 Bounds 內的隨機起始點
|
||||
return [np.random.uniform(b[0], b[1]) for b in bounds]
|
||||
|
||||
# --- 左側最佳化 ---
|
||||
print("\n=== 左側 (Nelder-Mead) ===")
|
||||
x0_l = get_random_x0(bounds_l)
|
||||
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
|
||||
position_l, loss_l = res_l.x, res_l.fun
|
||||
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||||
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
|
||||
best_loss_l, best_overlap_l = loss_l, overlap_l
|
||||
|
||||
retries = 0
|
||||
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
|
||||
x0_l = get_random_x0(bounds_l) # 每次 retry 都換一個隨機起始點
|
||||
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
|
||||
position_l, loss_l = res_l.x, res_l.fun
|
||||
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
|
||||
|
||||
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
|
||||
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
|
||||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||||
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
|
||||
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
|
||||
retries += 1
|
||||
|
||||
# --- 右側最佳化 ---
|
||||
print("\n=== 右側 (Nelder-Mead) ===")
|
||||
x0_r = get_random_x0(bounds_r)
|
||||
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
|
||||
position_r, loss_r = res_r.x, res_r.fun
|
||||
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||||
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
|
||||
best_loss_r, best_overlap_r = loss_r, overlap_r
|
||||
|
||||
retries = 0
|
||||
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
|
||||
x0_r = get_random_x0(bounds_r)
|
||||
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
|
||||
position_r, loss_r = res_r.x, res_r.fun
|
||||
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
|
||||
|
||||
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
|
||||
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
|
||||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||||
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
|
||||
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
|
||||
retries += 1
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
final_diameter_l = best_position_l[5] if optimize_size else diameter
|
||||
final_length_l = best_position_l[6] if optimize_size else length
|
||||
final_diameter_r = best_position_r[5] if optimize_size else diameter
|
||||
final_length_r = best_position_r[6] if optimize_size else length
|
||||
|
||||
res_plt_2_torch(
|
||||
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
|
||||
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
|
||||
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
|
||||
)
|
||||
|
||||
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
|
||||
151
core/scoring.py
Normal file
151
core/scoring.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
import torch
|
||||
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch
|
||||
from config.constant import OVERLAP_THRESH
|
||||
|
||||
def cl_score_torch(
|
||||
cortical_tensor: torch.Tensor,
|
||||
spine_tensor: torch.Tensor,
|
||||
cylinder_torch: torch.Tensor,
|
||||
cylinder_o_torch: torch.Tensor,
|
||||
intersections: int,
|
||||
diameter: float = None,
|
||||
length: float = None,
|
||||
cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask
|
||||
) -> float:
|
||||
"""
|
||||
漸進式評分:優先確保找到骨頭,再改善細節
|
||||
"""
|
||||
cyl_total = cylinder_torch.sum().item()
|
||||
overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item()
|
||||
null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item()
|
||||
null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item()
|
||||
|
||||
if cyl_total == 0:
|
||||
return float(1e9) # 極差的情況
|
||||
|
||||
overlap_ratio = overlap / cyl_total
|
||||
|
||||
score = 0
|
||||
|
||||
# === 階段 1:首要目標是找到骨頭(overlap > 0) ===
|
||||
if overlap == 0:
|
||||
# 完全沒有 overlap 是最糟糕的情況
|
||||
score -= 500000 # 超大懲罰
|
||||
# 如果連 spine 都沒穿過,更糟
|
||||
if intersections == 0:
|
||||
score -= 500000
|
||||
return float(-score)
|
||||
|
||||
# === 階段 2:有找到骨頭了,開始改善品質 ===
|
||||
|
||||
# 1. Overlap 獎勵(非線性,鼓勵快速提升)
|
||||
if overlap_ratio < 0.1:
|
||||
# 0-10%:每增加 1% 給大量獎勵(鼓勵探索)
|
||||
score += overlap * 5000 # 很高的單位獎勵
|
||||
elif overlap_ratio < 0.3:
|
||||
# 10-30%:中等獎勵
|
||||
score += overlap * 3000
|
||||
elif overlap_ratio < 0.5:
|
||||
# 30-50%:正常獎勵
|
||||
score += overlap * 2000
|
||||
else:
|
||||
# 50%+:獎勵 + 額外比例獎勵
|
||||
score += overlap * 2000
|
||||
score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎
|
||||
|
||||
# 2. Intersection 控制(稍微放寬)
|
||||
if intersections == 1:
|
||||
score += 20000 # 完美
|
||||
elif intersections == 0:
|
||||
score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好)
|
||||
elif intersections == 2:
|
||||
score -= 10000 # 可接受但不理想
|
||||
else:
|
||||
score -= intersections * 15000
|
||||
|
||||
# 3. Null voxel 懲罰(漸進式)
|
||||
null_ratio = null_vox / cyl_total
|
||||
|
||||
if overlap_ratio < 0.2:
|
||||
# 如果 overlap 還很少,對 null voxel 寬容一點
|
||||
score -= null_vox * 300
|
||||
elif overlap_ratio < 0.4:
|
||||
score -= null_vox * 600
|
||||
else:
|
||||
# overlap 夠高了,開始嚴格要求
|
||||
if null_ratio > 0.5:
|
||||
score -= null_vox * 1500
|
||||
else:
|
||||
score -= null_vox * 800
|
||||
|
||||
# 4. 反向圓柱懲罰
|
||||
score -= null_vox2 * 1000
|
||||
|
||||
# 5. 尺寸合理性(放寬)
|
||||
if diameter is not None and length is not None:
|
||||
if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0)
|
||||
score -= 3000
|
||||
if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60)
|
||||
score -= 3000
|
||||
|
||||
# 6. 尖端 breach 懲罰
|
||||
if cylinder_tip_torch is not None:
|
||||
tip_total = cylinder_tip_torch.sum().item()
|
||||
if tip_total > 0:
|
||||
tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item()
|
||||
tip_breach_ratio = tip_breach / tip_total
|
||||
if tip_breach_ratio > 0:
|
||||
score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多
|
||||
|
||||
return float(-score)
|
||||
|
||||
def get_overlap_ratio(
|
||||
position_params: list,
|
||||
diameter: float,
|
||||
length: float,
|
||||
cortical_tensor: torch.Tensor,
|
||||
image_shape: tuple,
|
||||
spacing: list,
|
||||
device: torch.device,
|
||||
grid=None
|
||||
) -> float:
|
||||
"""
|
||||
計算 Cylinder 與 Cortical Bone 的重疊比例 (%)
|
||||
"""
|
||||
# 生成 Cylinder Mask
|
||||
cyl_mask = generate_cylinder_n_torch(
|
||||
diameter, length,
|
||||
position_params[0], position_params[1], position_params[2],
|
||||
position_params[3], position_params[4],
|
||||
image_shape, spacing, device, grid
|
||||
)
|
||||
|
||||
# 計算體積 (Voxel count)
|
||||
cyl_vol = torch.sum(cyl_mask).item()
|
||||
|
||||
if cyl_vol == 0:
|
||||
return 0.0
|
||||
|
||||
# 計算重疊部分
|
||||
# 注意:這裡使用 cortical_tensor (與 cl_score_torch 邏輯一致)
|
||||
overlap_count = ((cortical_tensor == 1) & (cyl_mask == 1)).sum().item()
|
||||
|
||||
return (overlap_count / cyl_vol) * 100.0
|
||||
|
||||
def compute_overlap_ratio_from_cylinder_mask(cyl_mask: torch.Tensor,
|
||||
spine_mask: torch.Tensor,
|
||||
eps: float = 1e-6) -> float:
|
||||
"""
|
||||
一個常見定義: overlap = intersection / cylinder_volume
|
||||
你也可以改成 intersection / spine_volume 或 Dice,依你論文/需求一致即可。
|
||||
cyl_mask, spine_mask: uint8/bool tensor, same shape
|
||||
"""
|
||||
cyl = cyl_mask.bool()
|
||||
sp = spine_mask.bool()
|
||||
inter = (cyl & sp).sum().item()
|
||||
denom = cyl.sum().item()
|
||||
return float(inter) / float(denom + eps)
|
||||
|
||||
def is_solution_ok(loss: float, overlap: float, overlap_thresh: float = OVERLAP_THRESH) -> bool:
|
||||
return (loss <= 0) and (overlap >= overlap_thresh)
|
||||
|
||||
561
experiment.ipynb
Normal file
561
experiment.ipynb
Normal file
|
|
@ -0,0 +1,561 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "52324a94",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Core\n",
|
||||
"from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch\n",
|
||||
"from core.objective import set_global_context\n",
|
||||
"from core.cylinder import create_coordinate_grid\n",
|
||||
"\n",
|
||||
"# Imaging\n",
|
||||
"from imaging.preprocessing import process_single_image, process_dataset\n",
|
||||
"\n",
|
||||
"# Visualization\n",
|
||||
"from visualization.res_cyl_to_CT import cyl_and_CT\n",
|
||||
"from visualization.res_cyl_nifti import save_cyl\n",
|
||||
"\n",
|
||||
"# Config\n",
|
||||
"from config.device import get_device\n",
|
||||
"from config.constant import *\n",
|
||||
"\n",
|
||||
"import SimpleITK as sitk\n",
|
||||
"import os"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4e680dfe",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ====== SEGMENTATION: SINGLE DATA ======\n",
|
||||
"\n",
|
||||
"image_path = \"/home/cyrou/CBT/CTSpine1K/data/colon/1.3.6.1.4.1.9328.50.4.0783.nii.gz\"\n",
|
||||
"label_path = \"/home/cyrou/CBT/CTSpine1K/label/colon/1.3.6.1.4.1.9328.50.4.0783_seg.nii.gz\"\n",
|
||||
"output_dir = \"/home/cyrou/CBT/Dataset\" # set by user\n",
|
||||
"\n",
|
||||
"process_single_image(image_path, label_path, output_dir_base=output_dir)\n",
|
||||
"# If you need to assign special labels, please use labels_to_process = [23, 24] .etc "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1961cf4c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ====== SEGMENTATION: DATASET ======\n",
|
||||
"\n",
|
||||
"image_dir = \"/home/cyrou/CBT/CTSpine1K/data/colon/\"\n",
|
||||
"label_dir = \"/home/cyrou/CBT/CTSpine1K/label/colon/\"\n",
|
||||
"output_dir = \"/home/cyrou/CBT/Dataset\" # set by user\n",
|
||||
"\n",
|
||||
"process_dataset(image_dir, label_dir, output_dir, labels_to_process=None)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2bed470b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# ====== CASE 設定 ======\n",
|
||||
"cortical_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_cortical.nii.gz\"\n",
|
||||
"binary_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_binary.nii.gz\"\n",
|
||||
"roi_path = \"/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0121/L5_roi2.nii.gz\"\n",
|
||||
"\n",
|
||||
"# ====== PSO ======\n",
|
||||
"swarm_size = 70\n",
|
||||
"max_iter = 100\n",
|
||||
"\n",
|
||||
"# ====== DEVICE ======\n",
|
||||
"device = get_device(3) \n",
|
||||
"\n",
|
||||
"# ====== OTHER ======\n",
|
||||
"spacing = [0.5, 0.5, 0.5]\n",
|
||||
"CBT = True"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dc17e584",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import SimpleITK as sitk\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"cortical_image = sitk.ReadImage(cortical_path)\n",
|
||||
"binary_image = sitk.ReadImage(binary_path)\n",
|
||||
"roi_image = sitk.ReadImage(roi_path)\n",
|
||||
"\n",
|
||||
"cortical_array = sitk.GetArrayFromImage(cortical_image)\n",
|
||||
"binary_array = sitk.GetArrayFromImage(binary_image)\n",
|
||||
"roi_array = sitk.GetArrayFromImage(roi_image) \n",
|
||||
"image_shape = binary_array.shape\n",
|
||||
"\n",
|
||||
"cortical_tensor = torch.tensor(cortical_array, device=device)\n",
|
||||
"binary_tensor = torch.tensor(binary_array, device=device)\n",
|
||||
"\n",
|
||||
"grid = create_coordinate_grid(image_shape, device)\n",
|
||||
"\n",
|
||||
"set_global_context(\n",
|
||||
" cortical=cortical_tensor,\n",
|
||||
" spine=binary_tensor,\n",
|
||||
" shape=image_shape,\n",
|
||||
" spacing_=spacing,\n",
|
||||
" device_=device,\n",
|
||||
" grid_=grid,\n",
|
||||
" use_tip_penalty=False \n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2c3a3e47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"best_l, loss_l, best_r, loss_r, total_time = run_de_torch(\n",
|
||||
" label_str=\"L5\",\n",
|
||||
" image1_path=cortical_path,\n",
|
||||
" image2_path=binary_path,\n",
|
||||
" image3_path=roi_path,\n",
|
||||
" folder=\"Output\",\n",
|
||||
" swarm_size=swarm_size,\n",
|
||||
" max_iter=max_iter,\n",
|
||||
" spacing=spacing,\n",
|
||||
" CBT=CBT,\n",
|
||||
" device=device,\n",
|
||||
" optimize_size=True,\n",
|
||||
" grid=grid)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "94e1ee75",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"best_l, loss_l, best_r, loss_r, total_time = run_nm_torch(\n",
|
||||
" label_str=\"L5\",\n",
|
||||
" image1_path=cortical_path,\n",
|
||||
" image2_path=binary_path,\n",
|
||||
" image3_path=roi_path,\n",
|
||||
" folder=\"Output\",\n",
|
||||
" swarm_size=swarm_size,\n",
|
||||
" max_iter=max_iter,\n",
|
||||
" spacing=spacing,\n",
|
||||
" CBT=CBT,\n",
|
||||
" device=device,\n",
|
||||
" optimize_size=True,\n",
|
||||
" grid=grid)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "be72d637",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"best_l, loss_l, best_r, loss_r, total_time = run_pso_torch(\n",
|
||||
" label_str=\"L5\",\n",
|
||||
" image1_path=cortical_path,\n",
|
||||
" image2_path=binary_path,\n",
|
||||
" image3_path=roi_path,\n",
|
||||
" folder=\"Output\",\n",
|
||||
" swarm_size=swarm_size,\n",
|
||||
" max_iter=max_iter,\n",
|
||||
" spacing=spacing,\n",
|
||||
" CBT=CBT,\n",
|
||||
" device=device,\n",
|
||||
" optimize_size=True,\n",
|
||||
" grid=grid)\n",
|
||||
"\n",
|
||||
"cylinder_L, cylinder_R = save_cyl(best_l, best_r, spacing, roi_path, output_base='Output', CBT=True)\n",
|
||||
"\n",
|
||||
"cyl_and_CT(diameter_l=best_l[5],\n",
|
||||
" diameter_r=best_r[5],\n",
|
||||
" length_l=best_l[6],\n",
|
||||
" length_r=best_r[6],\n",
|
||||
" roi_image_path=roi_path,\n",
|
||||
" cylinder_1=cylinder_L,\n",
|
||||
" cylinder_2=cylinder_R,\n",
|
||||
" base_folder=\"Output\",\n",
|
||||
" CBT=CBT)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9723e65b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using GPU 1: NVIDIA A40\n",
|
||||
"📂 總共偵測到 784 個病人資料夾,準備開始處理...\n",
|
||||
"\n",
|
||||
"=========================================\n",
|
||||
"🚀 === Running 1.3.6.1.4.1.9328.50.4.0001_L1 ===\n",
|
||||
"=========================================\n",
|
||||
"👉 執行 CBT 軌跡最佳化...\n",
|
||||
"Superior 終板傾斜角度: 6.64°\n",
|
||||
"斜率: 0.1164, R²: 0.7599\n",
|
||||
"=== 最佳化模式:最佳化位置、角度、直徑和長度 ===\n",
|
||||
"\n",
|
||||
"=== 左側 ===\n",
|
||||
"Stopping search: maximum iterations reached --> 100\n",
|
||||
"[LEFT] overlap: 15.3%\n",
|
||||
"[LEFT] Position: [ 27.47724944 40.25868267 37.64458752 100.44943459 58.1295552 ]\n",
|
||||
"[LEFT] Diameter: 5.0 mm (raw: 4.98)\n",
|
||||
"[LEFT] Length: 50 mm (raw: 49.97)\n",
|
||||
"Stopping search: maximum iterations reached --> 100\n",
|
||||
"[LEFT][retry 1] ⚠️ not ok | loss improved=-15892000.0000, overlap=8.4%\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "KeyboardInterrupt",
|
||||
"evalue": "",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[2], line 62\u001b[0m\n\u001b[1;32m 59\u001b[0m image_shape \u001b[38;5;241m=\u001b[39m binary_array\u001b[38;5;241m.\u001b[39mshape\n\u001b[1;32m 61\u001b[0m grid \u001b[38;5;241m=\u001b[39m create_coordinate_grid(image_shape, device)\n\u001b[0;32m---> 62\u001b[0m best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time \u001b[38;5;241m=\u001b[39m \u001b[43mrun_pso_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabel_str\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage1_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE1\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage2_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE2\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 66\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage3_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mIMAGE3\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 67\u001b[0m \u001b[43m \u001b[49m\u001b[43mfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdate\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 68\u001b[0m \u001b[43m \u001b[49m\u001b[43mswarm_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m70\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 69\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_iter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 70\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 71\u001b[0m \u001b[43m \u001b[49m\u001b[43mCBT\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 73\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimize_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 74\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgrid\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 77\u001b[0m cylinder_L, cylinder_R \u001b[38;5;241m=\u001b[39m save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOutput\u001b[39m\u001b[38;5;124m'\u001b[39m, CBT\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 78\u001b[0m cyl_and_CT(best_pos_l[\u001b[38;5;241m5\u001b[39m], best_pos_r[\u001b[38;5;241m5\u001b[39m], best_pos_l[\u001b[38;5;241m6\u001b[39m], best_pos_r[\u001b[38;5;241m6\u001b[39m], IMAGE3, cylinder_L, cylinder_R, base_folder\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOutput\u001b[39m\u001b[38;5;124m'\u001b[39m, CBT\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
|
||||
"File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/optimizer.py:194\u001b[0m, in \u001b[0;36mrun_pso_torch\u001b[0;34m(label_str, image1_path, image2_path, image3_path, folder, swarm_size, max_iter, spacing, CBT, device, optimize_size, grid)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[38;5;66;03m# 左側 retry:loss 要 <=0 且 overlap >= 0.5 才算過關\u001b[39;00m\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m (best_loss_l \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m best_overlap_l \u001b[38;5;241m<\u001b[39m OVERLAP_THRESH) \u001b[38;5;129;01mand\u001b[39;00m retries \u001b[38;5;241m<\u001b[39m max_retries:\n\u001b[0;32m--> 194\u001b[0m position_l, loss_l \u001b[38;5;241m=\u001b[39m \u001b[43mpso\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobjective_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlb_l\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mub_l\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mswarmsize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mswarm_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmaxiter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_iter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 195\u001b[0m overlap_l, diameter_l, length_l \u001b[38;5;241m=\u001b[39m eval_overlap_from_position(\n\u001b[1;32m 196\u001b[0m position_l, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mL\u001b[39m\u001b[38;5;124m\"\u001b[39m, optimize_size, spine_tensor, image_shape, spacing\n\u001b[1;32m 197\u001b[0m )\n\u001b[1;32m 199\u001b[0m \u001b[38;5;66;03m# 只要找到更好的 loss(或你想用 loss+overlap 綜合排序也行)就更新 best\u001b[39;00m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;66;03m# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/envpy/lib/python3.13/site-packages/pyswarm/pso.py:145\u001b[0m, in \u001b[0;36mpso\u001b[0;34m(func, lb, ub, ieqcons, f_ieqcons, args, kwargs, swarmsize, omega, phip, phig, maxiter, minstep, minfunc, debug)\u001b[0m\n\u001b[1;32m 143\u001b[0m x[i, mark1] \u001b[38;5;241m=\u001b[39m lb[mark1]\n\u001b[1;32m 144\u001b[0m x[i, mark2] \u001b[38;5;241m=\u001b[39m ub[mark2]\n\u001b[0;32m--> 145\u001b[0m fx \u001b[38;5;241m=\u001b[39m \u001b[43mobj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;66;03m# Compare particle's best position (if constraints are satisfied)\u001b[39;00m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fx\u001b[38;5;241m<\u001b[39mfp[i] \u001b[38;5;129;01mand\u001b[39;00m is_feasible(x[i, :]):\n",
|
||||
"File \u001b[0;32m~/.conda/envs/envpy/lib/python3.13/site-packages/pyswarm/pso.py:74\u001b[0m, in \u001b[0;36mpso.<locals>.<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 71\u001b[0m vlow \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mvhigh\n\u001b[1;32m 73\u001b[0m \u001b[38;5;66;03m# Check for constraint function(s) #########################################\u001b[39;00m\n\u001b[0;32m---> 74\u001b[0m obj \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m f_ieqcons \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(ieqcons):\n",
|
||||
"File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/objective.py:126\u001b[0m, in \u001b[0;36mobjective_function\u001b[0;34m(params)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[38;5;66;03m# 將連續值轉換為離散值\u001b[39;00m\n\u001b[1;32m 124\u001b[0m diameter_discrete, length_discrete \u001b[38;5;241m=\u001b[39m snap_to_discrete_values(diameter_raw, length_raw)\n\u001b[0;32m--> 126\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mcylinder_circle_line_intersection_loss_deductions_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 127\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiameter_discrete\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 128\u001b[0m \u001b[43m \u001b[49m\u001b[43mlength_discrete\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 129\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 130\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage2_shape\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[43m \u001b[49m\u001b[43mcortical_tensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 132\u001b[0m \u001b[43m \u001b[49m\u001b[43mspine_tensor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 133\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 134\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\n\u001b[1;32m 135\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
|
||||
"File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/objective.py:54\u001b[0m, in \u001b[0;36mcylinder_circle_line_intersection_loss_deductions_torch\u001b[0;34m(diameter, length, params, image_shape, cortical_tensor, spine_tensor, spacing, device)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;124;03mComputes the loss for a given set of cylinder params in PyTorch, \u001b[39;00m\n\u001b[1;32m 50\u001b[0m \u001b[38;5;124;03mreturning a Python float for PSO consumption.\u001b[39;00m\n\u001b[1;32m 51\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 52\u001b[0m position_z, position_y, position_x, azimuth, altitude \u001b[38;5;241m=\u001b[39m params\n\u001b[0;32m---> 54\u001b[0m cyl_fwd \u001b[38;5;241m=\u001b[39m \u001b[43mgenerate_cylinder_n_torch\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 55\u001b[0m \u001b[43m \u001b[49m\u001b[43mdiameter\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[43m \u001b[49m\u001b[43mlength\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 57\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_z\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 58\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_y\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_x\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mazimuth\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mfloat\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43maltitude\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 62\u001b[0m \u001b[43m \u001b[49m\u001b[43mimage_shape\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 63\u001b[0m \u001b[43m \u001b[49m\u001b[43mspacing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 64\u001b[0m \u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrid\u001b[49m\n\u001b[1;32m 66\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 68\u001b[0m cyl_opp \u001b[38;5;241m=\u001b[39m generate_cylinder_o_torch(\n\u001b[1;32m 69\u001b[0m diameter,\n\u001b[1;32m 70\u001b[0m length,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 79\u001b[0m grid\n\u001b[1;32m 80\u001b[0m )\n\u001b[1;32m 83\u001b[0m \u001b[38;5;66;03m# We call the center_line_intersections in Torch mode\u001b[39;00m\n",
|
||||
"File \u001b[0;32m/mnt/pool0/home/cyrou/CBT/CBT_project/core/cylinder.py:83\u001b[0m, in \u001b[0;36mgenerate_cylinder_n_torch\u001b[0;34m(diameter, length, position_z, position_y, position_x, azimuth, altitude, shape, spacing, device, grid)\u001b[0m\n\u001b[1;32m 75\u001b[0m x_rot \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 76\u001b[0m x_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 77\u001b[0m \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 78\u001b[0m \u001b[38;5;241m-\u001b[39m z_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[1;32m 79\u001b[0m )\n\u001b[1;32m 80\u001b[0m y_rot \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39mx_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(azimuth_rad_t) \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t)\n\u001b[1;32m 81\u001b[0m z_rot \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 82\u001b[0m x_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(azimuth_rad_t) \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[0;32m---> 83\u001b[0m \u001b[38;5;241m+\u001b[39m y_t \u001b[38;5;241m*\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mazimuth_rad_t\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39msin(altitude_rad_t)\n\u001b[1;32m 84\u001b[0m \u001b[38;5;241m+\u001b[39m z_t \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mcos(altitude_rad_t)\n\u001b[1;32m 85\u001b[0m )\n\u001b[1;32m 87\u001b[0m \u001b[38;5;66;03m# Handle spacing\u001b[39;00m\n\u001b[1;32m 88\u001b[0m \u001b[38;5;66;03m# You can expand or generalize for more spacing options\u001b[39;00m\n\u001b[1;32m 89\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m spacing \u001b[38;5;241m==\u001b[39m [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1\u001b[39m]:\n",
|
||||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import traceback\n",
|
||||
"from datetime import datetime\n",
|
||||
"from core.optimizer import run_pso_torch\n",
|
||||
"from visualization.res_cyl_nifti import save_cyl\n",
|
||||
"from visualization.res_cyl_to_CT import cyl_and_CT\n",
|
||||
"\n",
|
||||
"device = get_device(1)\n",
|
||||
"\n",
|
||||
"# 只需要傳入 base_dir, patient_id, level 即可自動組裝路徑\n",
|
||||
"def build_paths(base_dir: str, patient_id: str, level: str):\n",
|
||||
" case_dir = os.path.join(base_dir, patient_id)\n",
|
||||
" img1 = os.path.join(case_dir, f\"{level}_cortical.nii.gz\")\n",
|
||||
" img2 = os.path.join(case_dir, f\"{level}_binary2.nii.gz\")\n",
|
||||
" img3 = os.path.join(case_dir, f\"{level}_roi2.nii.gz\")\n",
|
||||
" return img1, img2, img3\n",
|
||||
"\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" base_dir = \"/home/cyrou/CBT/Seg/Resample/standardized/\" # 你的 case 資料夾根目錄\n",
|
||||
" spacing = [0.5, 0.5, 0.5]\n",
|
||||
" levels = [\"L1\", \"L2\", \"L3\", \"L4\", \"L5\"] # 鎖定要跑的脊椎節段\n",
|
||||
" date = datetime.now().strftime(\"%Y%m%d\")\n",
|
||||
" \n",
|
||||
" failed = []\n",
|
||||
" skipped = []\n",
|
||||
" succeeded = []\n",
|
||||
"\n",
|
||||
" # 1. 動態獲取 base_dir 下的所有項目,並篩選出「資料夾」\n",
|
||||
" # 使用 sorted 讓它照字母/數字順序跑,強迫症看了比較舒服\n",
|
||||
" all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])\n",
|
||||
" \n",
|
||||
" print(f\"📂 總共偵測到 {len(all_patients)} 個病人資料夾,準備開始處理...\")\n",
|
||||
"\n",
|
||||
" # 2. 外層迴圈:遍歷每一個病人資料夾\n",
|
||||
" for patient_id in all_patients:\n",
|
||||
" \n",
|
||||
" # 3. 內層迴圈:遍歷 L1 到 L5\n",
|
||||
" for level in levels:\n",
|
||||
" label_str = f\"{patient_id}_{level}\"\n",
|
||||
" IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)\n",
|
||||
"\n",
|
||||
" # 檢查這個 level 的檔案是否齊全\n",
|
||||
" missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]\n",
|
||||
" if missing:\n",
|
||||
" # 缺檔直接跳過,不印出太多落落長的警告,保持終端機乾淨\n",
|
||||
" print(f\"[SKIP] ⏭️ {label_str} (缺少 {len(missing)} 個檔案)\")\n",
|
||||
" skipped.append((label_str, missing))\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" print(f\"\\n=========================================\")\n",
|
||||
" print(f\"🚀 === Running {label_str} ===\")\n",
|
||||
" print(f\"=========================================\")\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" # --- 第一階段:CBT = True ---\n",
|
||||
" print(f\"👉 執行 CBT 軌跡最佳化...\")\n",
|
||||
" binary_image = sitk.ReadImage(IMAGE2)\n",
|
||||
" binary_array = sitk.GetArrayFromImage(binary_image)\n",
|
||||
" image_shape = binary_array.shape\n",
|
||||
"\n",
|
||||
" grid = create_coordinate_grid(image_shape, device)\n",
|
||||
" best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(\n",
|
||||
" label_str,\n",
|
||||
" image1_path=IMAGE1,\n",
|
||||
" image2_path=IMAGE2,\n",
|
||||
" image3_path=IMAGE3,\n",
|
||||
" folder=date,\n",
|
||||
" swarm_size=70,\n",
|
||||
" max_iter=100,\n",
|
||||
" spacing=spacing,\n",
|
||||
" CBT=True,\n",
|
||||
" device=device,\n",
|
||||
" optimize_size=True,\n",
|
||||
" grid=grid\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)\n",
|
||||
" cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)\n",
|
||||
" \n",
|
||||
" # --- 第二階段:CBT = False (傳統軌跡 TT) ---\n",
|
||||
" print(f\"👉 執行傳統軌跡 (TT) 最佳化...\")\n",
|
||||
" best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(\n",
|
||||
" label_str,\n",
|
||||
" image1_path=IMAGE2, # 注意:你原本的扣這裡是 IMAGE2,如果有改過記得確認\n",
|
||||
" image2_path=IMAGE2,\n",
|
||||
" image3_path=IMAGE3,\n",
|
||||
" folder=date,\n",
|
||||
" swarm_size=70,\n",
|
||||
" max_iter=100,\n",
|
||||
" spacing=spacing,\n",
|
||||
" CBT=False,\n",
|
||||
" device=device,\n",
|
||||
" optimize_size=True,\n",
|
||||
" grid=grid\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)\n",
|
||||
" cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)\n",
|
||||
"\n",
|
||||
" # 記錄成功 (把兩次的耗時加起來)\n",
|
||||
" succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))\n",
|
||||
" print(f\"[OK] ✅ {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s\")\n",
|
||||
"\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"[FAIL] ❌ {label_str} 發生錯誤: {e}\")\n",
|
||||
" traceback.print_exc() # 保留剛學到的神技,印出詳細報錯\n",
|
||||
" failed.append((label_str, str(e)))\n",
|
||||
"\n",
|
||||
" # 4. 最終報告\n",
|
||||
" print(\"\\n\" + \"=\"*40)\n",
|
||||
" print(\"🏆 ==== 全自動批次處理總結 ====\")\n",
|
||||
" print(\"=\"*40)\n",
|
||||
" print(f\"✅ Succeeded (成功): {len(succeeded)} 節段\")\n",
|
||||
" print(f\"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段\")\n",
|
||||
" print(f\"❌ Failed (執行錯誤): {len(failed)} 節段\")\n",
|
||||
" \n",
|
||||
" if failed:\n",
|
||||
" print(\"\\n⚠️ 失敗清單:\")\n",
|
||||
" for fail_label, err_msg in failed:\n",
|
||||
" print(f\" - {fail_label}: {err_msg}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "bcf55fef",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"usage: ipykernel_launcher.py [-h] --gpu GPU [--total_gpus TOTAL_GPUS]\n",
|
||||
"ipykernel_launcher.py: error: the following arguments are required: --gpu\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "SystemExit",
|
||||
"evalue": "2",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"An exception has occurred, use %tb to see the full traceback.\n",
|
||||
"\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/cyrou/.conda/envs/envpy/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3585: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.\n",
|
||||
" warn(\"To exit: use 'exit', 'quit', or Ctrl-D.\", stacklevel=1)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import traceback\n",
|
||||
"import argparse\n",
|
||||
"import torch\n",
|
||||
"from datetime import datetime\n",
|
||||
"\n",
|
||||
"# 請確保你有 import run_pso_torch, save_cyl, cyl_and_CT 等自訂函數\n",
|
||||
"\n",
|
||||
"def build_paths(base_dir: str, patient_id: str, level: str):\n",
|
||||
" case_dir = os.path.join(base_dir, patient_id)\n",
|
||||
" img1 = os.path.join(case_dir, f\"{level}_cortical.nii.gz\")\n",
|
||||
" img2 = os.path.join(case_dir, f\"{level}_binary2.nii.gz\")\n",
|
||||
" img3 = os.path.join(case_dir, f\"{level}_roi2.nii.gz\")\n",
|
||||
" return img1, img2, img3\n",
|
||||
"\n",
|
||||
"if __name__ == \"__main__\":\n",
|
||||
" # === 1. 設定命令列引數 (GPU 拆分設定) ===\n",
|
||||
" parser = argparse.ArgumentParser(description=\"Multi-GPU CBT Batch Processing\")\n",
|
||||
" parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)')\n",
|
||||
" parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量')\n",
|
||||
" args = parser.parse_args()\n",
|
||||
"\n",
|
||||
" # === 2. 綁定 GPU ===\n",
|
||||
" if torch.cuda.is_available():\n",
|
||||
" device = torch.device(f'cuda:{args.gpu}')\n",
|
||||
" torch.cuda.set_device(device)\n",
|
||||
" else:\n",
|
||||
" device = torch.device('cpu')\n",
|
||||
" print(\"⚠️ 找不到 CUDA,將使用 CPU\")\n",
|
||||
"\n",
|
||||
" print(f\"\\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算\")\n",
|
||||
"\n",
|
||||
" # === 3. 基本參數 ===\n",
|
||||
" base_dir = \"Seg/Resample/standardized\" \n",
|
||||
" spacing = [0.5, 0.5, 0.5]\n",
|
||||
" levels = [\"L1\", \"L2\", \"L3\", \"L4\", \"L5\"] \n",
|
||||
" date = datetime.now().strftime(\"%Y%m%d\")\n",
|
||||
" \n",
|
||||
" failed = []\n",
|
||||
" skipped = []\n",
|
||||
" succeeded = []\n",
|
||||
"\n",
|
||||
" # 動態獲取所有病人資料夾\n",
|
||||
" all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])\n",
|
||||
" \n",
|
||||
" # === 4. 餘數分工:只挑選屬於這張 GPU 的病人 ===\n",
|
||||
" my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu]\n",
|
||||
" \n",
|
||||
" print(f\"📂 總資料庫有 {len(all_patients)} 個病人。\")\n",
|
||||
" print(f\"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...\")\n",
|
||||
"\n",
|
||||
" for patient_id in my_patients:\n",
|
||||
" for level in levels:\n",
|
||||
" label_str = f\"{patient_id}_{level}\"\n",
|
||||
" IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)\n",
|
||||
"\n",
|
||||
" missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]\n",
|
||||
" if missing:\n",
|
||||
" print(f\"[SKIP] ⏭️ {label_str} (缺檔)\")\n",
|
||||
" skipped.append((label_str, missing))\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" print(f\"\\n=========================================\")\n",
|
||||
" print(f\"🔥 [GPU {args.gpu}] === Running {label_str} ===\")\n",
|
||||
" print(f\"=========================================\")\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" # --- CBT = True ---\n",
|
||||
" print(f\"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...\")\n",
|
||||
" binary_image = sitk.ReadImage(IMAGE2)\n",
|
||||
" binary_array = sitk.GetArrayFromImage(binary_image)\n",
|
||||
" image_shape = binary_array.shape\n",
|
||||
"\n",
|
||||
" grid = create_coordinate_grid(image_shape, device)\n",
|
||||
" best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(\n",
|
||||
" label_str,\n",
|
||||
" image1_path=IMAGE1,\n",
|
||||
" image2_path=IMAGE2,\n",
|
||||
" image3_path=IMAGE3,\n",
|
||||
" folder=date, \n",
|
||||
" swarm_size=70,\n",
|
||||
" max_iter=100,\n",
|
||||
" spacing=spacing,\n",
|
||||
" CBT=True,\n",
|
||||
" device=device,\n",
|
||||
" optimize_size=True,\n",
|
||||
" grid=grid # 記得把 device 傳進去\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)\n",
|
||||
" cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)\n",
|
||||
" \n",
|
||||
" # --- CBT = False ---\n",
|
||||
" print(f\"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...\")\n",
|
||||
" best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(\n",
|
||||
" label_str,\n",
|
||||
" image1_path=IMAGE2, \n",
|
||||
" image2_path=IMAGE2,\n",
|
||||
" image3_path=IMAGE3,\n",
|
||||
" folder=date, \n",
|
||||
" swarm_size=70,\n",
|
||||
" max_iter=100,\n",
|
||||
" spacing=spacing,\n",
|
||||
" CBT=False,\n",
|
||||
" device=device,\n",
|
||||
" optimize_size=True,\n",
|
||||
" grid=grid # 記得把 device 傳進去\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)\n",
|
||||
" cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)\n",
|
||||
"\n",
|
||||
" succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))\n",
|
||||
" print(f\"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s\")\n",
|
||||
"\n",
|
||||
" except Exception as e:\n",
|
||||
" print(f\"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}\")\n",
|
||||
" traceback.print_exc()\n",
|
||||
" failed.append((label_str, str(e)))\n",
|
||||
"\n",
|
||||
" # --- GPU 專屬的 Summary ---\n",
|
||||
" print(\"\\n\" + \"=\"*40)\n",
|
||||
" print(f\"🏆 ==== GPU {args.gpu} 處理總結 ====\")\n",
|
||||
" print(\"=\"*40)\n",
|
||||
" print(f\"✅ Succeeded (成功): {len(succeeded)} 節段\")\n",
|
||||
" print(f\"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段\")\n",
|
||||
" print(f\"❌ Failed (執行錯誤): {len(failed)} 節段\")\n",
|
||||
" \n",
|
||||
" if failed:\n",
|
||||
" print(f\"\\n⚠️ GPU {args.gpu} 失敗清單:\")\n",
|
||||
" for fail_label, err_msg in failed:\n",
|
||||
" print(f\" - {fail_label}: {err_msg}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "envpy",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.0"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
0
imaging/__init__.py
Normal file
0
imaging/__init__.py
Normal file
42
imaging/affine.py
Normal file
42
imaging/affine.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import nibabel as nib
|
||||
|
||||
def standardize_affine(file_path, output_dir):
|
||||
|
||||
img = nib.load(file_path)
|
||||
data = img.get_fdata()
|
||||
affine = img.affine.copy()
|
||||
|
||||
# 初始化翻轉軸
|
||||
flip_axes = []
|
||||
|
||||
# 檢查 X 軸方向
|
||||
if affine[0, 0] < 0:
|
||||
flip_axes.append(0)
|
||||
affine[0, 0] *= -1
|
||||
affine[0, 3] *= -1 # 修正平移部分
|
||||
|
||||
# 檢查 Y 軸方向
|
||||
if affine[1, 1] < 0:
|
||||
flip_axes.append(1)
|
||||
affine[1, 1] *= -1
|
||||
affine[1, 3] *= -1 # 修正平移部分
|
||||
|
||||
# 檢查 Z 軸方向
|
||||
if affine[2, 2] < 0:
|
||||
flip_axes.append(2)
|
||||
affine[2, 2] *= -1
|
||||
affine[2, 3] *= -1 # 修正平移部分
|
||||
|
||||
# 翻轉數據(如果需要)
|
||||
if flip_axes:
|
||||
data = np.flip(data, axis=tuple(flip_axes))
|
||||
|
||||
# 保存修正後的影像
|
||||
standardized_img = nib.Nifti1Image(data, affine)
|
||||
output_path = os.path.join(output_dir, os.path.basename(file_path))
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
nib.save(standardized_img, output_path)
|
||||
|
||||
53
imaging/nifti_io.py
Normal file
53
imaging/nifti_io.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
import SimpleITK as sitk
|
||||
import nibabel as nib
|
||||
import numpy as np
|
||||
|
||||
def read_nifti_sitk(path):
|
||||
return sitk.ReadImage(path)
|
||||
|
||||
def write_nifti_sitk(img, path):
|
||||
sitk.WriteImage(img, path)
|
||||
|
||||
def sitk_to_numpy(img):
|
||||
return sitk.GetArrayFromImage(img)
|
||||
|
||||
def numpy_to_sitk(arr, reference_img):
|
||||
img = sitk.GetImageFromArray(arr)
|
||||
img.SetSpacing(reference_img.GetSpacing())
|
||||
img.SetOrigin(reference_img.GetOrigin())
|
||||
img.SetDirection(reference_img.GetDirection())
|
||||
return img
|
||||
|
||||
def sitk_to_nibabel(sitk_image):
|
||||
data = sitk.GetArrayFromImage(sitk_image) # (z, y, x)
|
||||
data = np.transpose(data, (2, 1, 0)) # 轉回 (x, y, z)
|
||||
|
||||
origin = np.array(sitk_image.GetOrigin())
|
||||
spacing = np.array(sitk_image.GetSpacing())
|
||||
direction = np.array(sitk_image.GetDirection()).reshape(3, 3)
|
||||
|
||||
affine = np.eye(4)
|
||||
affine[:3, :3] = direction * spacing
|
||||
affine[:3, 3] = origin
|
||||
|
||||
return data, affine
|
||||
|
||||
def nibabel_to_sitk(data, affine):
|
||||
data_transposed = np.transpose(data, (2, 1, 0)) # (x,y,z) → (z,y,x)
|
||||
sitk_image = sitk.GetImageFromArray(data_transposed)
|
||||
|
||||
spacing = np.sqrt((affine[:3, :3] ** 2).sum(axis=0)).tolist()
|
||||
direction = (affine[:3, :3] / spacing).flatten().tolist()
|
||||
origin = affine[:3, 3].tolist()
|
||||
|
||||
sitk_image.SetSpacing(spacing)
|
||||
sitk_image.SetDirection(direction)
|
||||
sitk_image.SetOrigin(origin)
|
||||
|
||||
return sitk_image
|
||||
|
||||
def read_nifti_nib(path):
|
||||
return nib.load(path)
|
||||
|
||||
def write_nifti_nib(data, affine, path):
|
||||
nib.save(nib.Nifti1Image(data, affine), path)
|
||||
310
imaging/orientation.py
Normal file
310
imaging/orientation.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
from scipy.ndimage import center_of_mass
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def azimuth_rotation(image, show_plt=False, save_plt=False, output_path=None):
|
||||
|
||||
img = sitk.ReadImage(image, sitk.sitkUInt8)
|
||||
arr_zyx = sitk.GetArrayFromImage(img) # (z, y, x)
|
||||
|
||||
max_proj = np.max(arr_zyx, axis=0) # -> (y, x)
|
||||
|
||||
binary_proj = (max_proj > 0).astype(np.uint8)
|
||||
|
||||
ys, xs = np.where(binary_proj > 0)
|
||||
if len(xs) < 10:
|
||||
raise ValueError("Not enough foreground points")
|
||||
|
||||
# 2) centroid ←←← 這裡一定會定義 cx, cy
|
||||
cy = ys.mean()
|
||||
cx = xs.mean()
|
||||
centroid = np.array([cy, cx])
|
||||
|
||||
cy = ys.mean()
|
||||
cx = xs.mean()
|
||||
centroid = np.array([cy, cx])
|
||||
|
||||
y_min = ys.min()
|
||||
top_row_mask = (ys == y_min)
|
||||
xs_top_row = xs[top_row_mask]
|
||||
|
||||
# 取這一排的中位數或平均值
|
||||
x_center = int(np.median(xs_top_row)) # 或用 np.mean()
|
||||
top_point = (y_min, x_center)
|
||||
# print(f"最上排中心點: {top_point}")
|
||||
|
||||
# 計算從 top_point 到 centroid 的向量
|
||||
dy = cy - y_min # y 方向的變化
|
||||
dx = cx - x_center # x 方向的變化
|
||||
|
||||
# 計算與 y 軸的夾角
|
||||
# 注意:影像座標系中 y 軸向下,所以要特別處理
|
||||
angle_rad = np.arctan2(dx, dy) # 弧度
|
||||
angle_deg = np.degrees(angle_rad) # 轉成角度
|
||||
|
||||
if show_plt:
|
||||
|
||||
fig = plt.figure(figsize=(12, 12))
|
||||
|
||||
# 視覺化時加上角度資訊
|
||||
plt.imshow(binary_proj, cmap='gray')
|
||||
|
||||
plt.scatter(x_center, y_min, c='red', s=60, label='Top')
|
||||
plt.scatter(cx, cy, c='yellow', s=60, label='Centroid')
|
||||
plt.plot([x_center, cx], [y_min, cy], 'c-', lw=2,
|
||||
label=f'Angle with y-axis: {angle_deg:.1f}°')
|
||||
|
||||
plt.legend()
|
||||
plt.title("Top point + centroid-directed line")
|
||||
plt.axis("off")
|
||||
|
||||
if save_plt:
|
||||
if output_path is None:
|
||||
output_path = "azimuth_rotation.png"
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
|
||||
plt.show()
|
||||
plt.close(fig)
|
||||
|
||||
return angle_deg
|
||||
|
||||
def split_spine_anterior_posterior(image_path, center_mode='com'):
|
||||
"""
|
||||
從 sagittal view 看,沿著 y 軸(前後方向)將脊椎切成前半部和後半部
|
||||
|
||||
Parameters:
|
||||
image_path (str): 影像路徑
|
||||
center_mode (str): 'com' 使用質心的 y 座標,'image' 使用圖片中心的 y 座標
|
||||
|
||||
Returns:
|
||||
anterior, posterior: 前半部和後半部的 binary mask
|
||||
"""
|
||||
|
||||
# Sagittal projection: 沿著 x 軸投影 -> (z, y)
|
||||
img = sitk.ReadImage(image_path, sitk.sitkUInt8)
|
||||
arr_zyx = sitk.GetArrayFromImage(img)
|
||||
|
||||
# Sagittal projection
|
||||
max_proj_sagittal = np.max(arr_zyx, axis=2)
|
||||
binary_proj = (max_proj_sagittal > 0).astype(np.uint8)
|
||||
|
||||
# 決定切割的 y 座標
|
||||
if center_mode == 'com':
|
||||
cz, cy = center_of_mass(binary_proj)
|
||||
split_y = int(round(cy))
|
||||
label = f'Center of Mass (y={split_y})'
|
||||
|
||||
elif center_mode == 'image':
|
||||
split_y = binary_proj.shape[1] // 2
|
||||
label = f'Image Center (y={split_y})'
|
||||
|
||||
# 檢查是否為數字,且範圍在 0 到 1 之間 (不含邊界)
|
||||
elif isinstance(center_mode, (int, float)) and 0 < center_mode < 1:
|
||||
ys = np.where(binary_proj > 0)[1]
|
||||
|
||||
if ys.size == 0: # 額外保險:如果投影是空的
|
||||
split_y = binary_proj.shape[1] // 2
|
||||
else:
|
||||
y_min, y_max = ys.min(), ys.max()
|
||||
split_y = int(round(y_min + center_mode * (y_max - y_min)))
|
||||
|
||||
label = f'Custom Ratio {center_mode} (y={split_y})'
|
||||
|
||||
else:
|
||||
raise ValueError("center_mode 必須是 'com'、'image' 或介於 0 到 1 之間的浮點數 (例如 0.8)")
|
||||
|
||||
# 切割:anterior (y < split_y) 和 posterior (y >= split_y)
|
||||
anterior = binary_proj.copy()
|
||||
posterior = binary_proj.copy()
|
||||
|
||||
anterior[:, :split_y] = 0 # 保留spine前半部(image後半部)(y >= split_y)
|
||||
posterior[:, split_y:] = 0 # 保留spine後半部(image前半部)(y < split_y)
|
||||
|
||||
return anterior, posterior, binary_proj
|
||||
|
||||
def analyze_vertebral_tilt_contour(image_path, edge_type='superior', show_plot=False, debug=False, save_plt=False, output_path=None):
|
||||
"""
|
||||
通過椎體前緣輪廓分析傾斜(可選上或下終板)
|
||||
|
||||
Parameters:
|
||||
edge_type: 'superior' 上終板, 'inferior' 下終板, 'both' 兩者都分析
|
||||
"""
|
||||
|
||||
# 切割出 anterior 部分
|
||||
anterior, posterior, binary_proj = split_spine_anterior_posterior(image_path, center_mode='com')
|
||||
|
||||
zs, ys = np.where(anterior > 0)
|
||||
|
||||
if len(zs) == 0:
|
||||
return None
|
||||
|
||||
from sklearn.linear_model import RANSACRegressor
|
||||
|
||||
results = {}
|
||||
|
||||
# === 根據 edge_type 決定要分析哪些邊 ===
|
||||
edges_to_analyze = []
|
||||
if edge_type == 'superior' or edge_type == 'both':
|
||||
edges_to_analyze.append('superior')
|
||||
if edge_type == 'inferior' or edge_type == 'both':
|
||||
edges_to_analyze.append('inferior')
|
||||
|
||||
all_edge_points = {}
|
||||
all_inliers = {}
|
||||
all_outliers = {}
|
||||
all_slopes = {}
|
||||
all_intercepts = {}
|
||||
all_angles = {}
|
||||
|
||||
for edge in edges_to_analyze:
|
||||
# 對每個 y,找對應的邊緣點
|
||||
edge_points = []
|
||||
unique_ys = np.unique(ys)
|
||||
|
||||
for y in unique_ys:
|
||||
z_at_y = zs[ys == y]
|
||||
|
||||
if edge == 'superior':
|
||||
z_edge = z_at_y.max() # 最上面的點(z 最小)
|
||||
else: # inferior
|
||||
z_edge = z_at_y.min() # 最下面的點(z 最大)
|
||||
|
||||
edge_points.append([y, z_edge])
|
||||
|
||||
edge_points = np.array(edge_points)
|
||||
all_edge_points[edge] = edge_points
|
||||
|
||||
if len(edge_points) < 10:
|
||||
continue
|
||||
|
||||
# RANSAC 擬合
|
||||
X = edge_points[:, 0].reshape(-1, 1)
|
||||
y_data = edge_points[:, 1]
|
||||
|
||||
ransac = RANSACRegressor(
|
||||
residual_threshold=5.0,
|
||||
random_state=42
|
||||
)
|
||||
ransac.fit(X, y_data)
|
||||
|
||||
inlier_mask = ransac.inlier_mask_
|
||||
outlier_mask = ~inlier_mask
|
||||
|
||||
edge_points_inliers = edge_points[inlier_mask]
|
||||
edge_points_outliers = edge_points[outlier_mask]
|
||||
|
||||
all_inliers[edge] = edge_points_inliers
|
||||
all_outliers[edge] = edge_points_outliers
|
||||
|
||||
# 獲取擬合結果
|
||||
slope = ransac.estimator_.coef_[0]
|
||||
intercept = ransac.estimator_.intercept_
|
||||
|
||||
all_slopes[edge] = slope
|
||||
all_intercepts[edge] = intercept
|
||||
|
||||
# 計算 R²
|
||||
y_pred = ransac.predict(edge_points_inliers[:, 0].reshape(-1, 1))
|
||||
ss_res = np.sum((edge_points_inliers[:, 1] - y_pred) ** 2)
|
||||
ss_tot = np.sum((edge_points_inliers[:, 1] - np.mean(edge_points_inliers[:, 1])) ** 2)
|
||||
r_squared = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0
|
||||
|
||||
# 計算傾斜角度
|
||||
tilt_angle = np.degrees(np.arctan(slope))
|
||||
all_angles[edge] = tilt_angle
|
||||
|
||||
if debug:
|
||||
print(f"\n=== {edge.upper()} ENDPLATE ===")
|
||||
print(f"Total points: {len(edge_points)}")
|
||||
print(f"Inliers: {len(edge_points_inliers)}")
|
||||
print(f"Outliers: {len(edge_points_outliers)}")
|
||||
|
||||
print(f"{edge.capitalize()} 終板傾斜角度: {tilt_angle:.2f}°")
|
||||
print(f"斜率: {slope:.4f}, R²: {r_squared:.4f}")
|
||||
|
||||
results[edge] = {
|
||||
'tilt_angle_deg': tilt_angle,
|
||||
'slope': slope,
|
||||
'intercept': intercept,
|
||||
'r_squared': r_squared,
|
||||
'n_inliers': len(edge_points_inliers),
|
||||
'n_outliers': len(edge_points_outliers)
|
||||
}
|
||||
|
||||
# Visualization
|
||||
if show_plot:
|
||||
n_edges = len(edges_to_analyze)
|
||||
fig, axes = plt.subplots(n_edges, 2, figsize=(16, 6*n_edges))
|
||||
|
||||
if n_edges == 1:
|
||||
axes = axes.reshape(1, -1)
|
||||
|
||||
colors = {'superior': 'red', 'inferior': 'cyan'}
|
||||
|
||||
for idx, edge in enumerate(edges_to_analyze):
|
||||
edge_points = all_edge_points[edge]
|
||||
edge_points_inliers = all_inliers[edge]
|
||||
edge_points_outliers = all_outliers[edge]
|
||||
slope = all_slopes[edge]
|
||||
intercept = all_intercepts[edge]
|
||||
tilt_angle = all_angles[edge]
|
||||
color = colors[edge]
|
||||
|
||||
# 左圖:scatter plot
|
||||
ax_left = axes[idx, 0]
|
||||
|
||||
if len(edge_points_outliers) > 0:
|
||||
ax_left.scatter(edge_points_outliers[:, 0], edge_points_outliers[:, 1],
|
||||
c='lightcoral', s=30, alpha=0.6, marker='x',
|
||||
label=f'Outliers ({len(edge_points_outliers)})', zorder=3)
|
||||
|
||||
ax_left.scatter(edge_points_inliers[:, 0], edge_points_inliers[:, 1],
|
||||
c=color, s=20, alpha=0.7,
|
||||
label=f'Inliers ({len(edge_points_inliers)})', zorder=4)
|
||||
|
||||
# 擬合線
|
||||
y_line = np.array([edge_points[:, 0].min(), edge_points[:, 0].max()])
|
||||
z_line = slope * y_line + intercept
|
||||
ax_left.plot(y_line, z_line, 'lime', linewidth=3,
|
||||
label=f'Angle: {tilt_angle:.1f}°\nR²: {results[edge]["r_squared"]:.3f}',
|
||||
zorder=5)
|
||||
|
||||
ax_left.set_xlabel('y')
|
||||
ax_left.set_ylabel('z')
|
||||
ax_left.set_title(f'{edge.capitalize()} Endplate Analysis\nTilt: {tilt_angle:.1f}°')
|
||||
ax_left.invert_yaxis()
|
||||
ax_left.legend()
|
||||
ax_left.grid(True, alpha=0.3)
|
||||
|
||||
# 右圖:原始影像
|
||||
ax_right = axes[idx, 1]
|
||||
ax_right.imshow(anterior, cmap='gray', aspect='equal')
|
||||
|
||||
if len(edge_points_outliers) > 0:
|
||||
ax_right.scatter(edge_points_outliers[:, 0], edge_points_outliers[:, 1],
|
||||
c='red', s=40, alpha=0.7, marker='x', label='Outliers', zorder=4)
|
||||
|
||||
ax_right.scatter(edge_points_inliers[:, 0], edge_points_inliers[:, 1],
|
||||
c=color, s=25, alpha=0.8, label='Inliers', zorder=3)
|
||||
|
||||
ax_right.plot(y_line, z_line, 'lime', linewidth=3, linestyle='--',
|
||||
label=f'{edge.capitalize()}: {tilt_angle:.1f}°', zorder=5)
|
||||
|
||||
ax_right.set_title(f'Anterior Half - {edge.capitalize()} Edge')
|
||||
ax_right.set_xlabel('y')
|
||||
ax_right.set_ylabel('z')
|
||||
ax_right.invert_yaxis()
|
||||
ax_right.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_plt:
|
||||
if output_path is None:
|
||||
output_path = "analyze_vertebral_tilt_contour.png"
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
|
||||
plt.show()
|
||||
plt.close(fig)
|
||||
|
||||
return results
|
||||
150
imaging/preprocessing.py
Normal file
150
imaging/preprocessing.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import os
|
||||
import SimpleITK as sitk
|
||||
from imaging.resample import resample_img
|
||||
from imaging.affine import standardize_affine
|
||||
from imaging.segmentation import seg_bone
|
||||
import json
|
||||
import glob
|
||||
from config.constant import LABEL_MAP
|
||||
from imaging.nifti_io import sitk_to_nibabel, nibabel_to_sitk
|
||||
|
||||
PROGRESS_FILE = "progress.json"
|
||||
|
||||
def load_progress():
|
||||
if os.path.exists(PROGRESS_FILE):
|
||||
with open(PROGRESS_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
return {}
|
||||
|
||||
def save_progress(progress):
|
||||
with open(PROGRESS_FILE, "w") as f:
|
||||
json.dump(progress, f, indent=2)
|
||||
|
||||
def process_single_image(image_path, label_path, output_dir_base=None):
|
||||
|
||||
image = sitk.ReadImage(image_path)
|
||||
label = sitk.ReadImage(label_path)
|
||||
file_name = os.path.basename(image_path)
|
||||
name = file_name.replace(".nii.gz", "")
|
||||
|
||||
# LabelStatisticsImageFilter computes statistics (e.g., mean, minimum, maximum, median) of pixel values in an image, segmented by labels in a corresponding label image.
|
||||
lsif = sitk.LabelStatisticsImageFilter()
|
||||
lsif.Execute(image, label)
|
||||
|
||||
# Assume to have some sitk image (itk_image) and label (itk_label)
|
||||
resampled_sitk_img = resample_img(image, out_spacing=[0.5, 0.5, 0.5], is_label=False)
|
||||
resampled_sitk_lbl = resample_img(label, out_spacing=[0.5, 0.5, 0.5], is_label=True)
|
||||
|
||||
# 取得現有 label
|
||||
lssif = sitk.LabelShapeStatisticsImageFilter()
|
||||
lssif.Execute(label)
|
||||
existing_labels = lssif.GetLabels() # 這會回傳 list,例如 [1,2,3,20,21]
|
||||
|
||||
print(f"Existing labels in {os.path.basename(label_path)}: {existing_labels}")
|
||||
|
||||
# 建立每個檔案的輸出資料夾
|
||||
file_name = os.path.basename(image_path)
|
||||
name = file_name.replace(".nii.gz", "")
|
||||
output_dir = os.path.join(output_dir_base, name)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 存現有 label 到 txt
|
||||
txt_path = os.path.join(output_dir, f"{name}_labels.txt")
|
||||
with open(txt_path, "w") as f:
|
||||
for lab in existing_labels:
|
||||
f.write(f"{lab}\t{LABEL_MAP.get(lab, 'Unknown')}\n")
|
||||
|
||||
# 遍歷現有 label 做分割
|
||||
for n in existing_labels:
|
||||
try:
|
||||
roi_path, binary_path, roi2_path, cortical_path = seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_dir, label_map=LABEL_MAP)
|
||||
for path in [roi_path, binary_path, roi2_path, cortical_path]:
|
||||
standardize_affine(path, output_dir)
|
||||
except RuntimeError as e:
|
||||
print(f"Label {n} could not be processed, skipping. Error: {e}")
|
||||
|
||||
return {
|
||||
"processed_labels": [lab for lab in existing_labels],
|
||||
"missing_labels": []
|
||||
}
|
||||
|
||||
def process_dataset(image_dir, label_dir, output_dir, labels_to_process=None):
|
||||
image_files = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz")))
|
||||
total_files = len(image_files)
|
||||
print(f"Total files: {total_files}")
|
||||
|
||||
progress = load_progress()
|
||||
all_file_summary = []
|
||||
|
||||
for idx, image_path in enumerate(image_files, 1):
|
||||
file_name = os.path.basename(image_path)
|
||||
name = file_name.replace(".nii.gz", "")
|
||||
label_path = os.path.join(label_dir, file_name.replace(".nii.gz", "_seg.nii.gz"))
|
||||
|
||||
file_summary = {
|
||||
"file_name": file_name,
|
||||
"current_labels": [],
|
||||
"missing_labels": []
|
||||
}
|
||||
|
||||
if progress.get(name, {}).get("finished", False):
|
||||
print(f"[{idx}/{total_files}] Already finished: {file_name}")
|
||||
file_summary["current_labels"] = progress[name].get("processed_labels", [])
|
||||
file_summary["missing_labels"] = progress[name].get("missing_labels", [])
|
||||
all_file_summary.append(file_summary)
|
||||
continue
|
||||
|
||||
if not os.path.exists(label_path):
|
||||
print(f"[{idx}/{total_files}] Warning: label not found for {file_name}")
|
||||
file_summary["note"] = "Label file not found"
|
||||
all_file_summary.append(file_summary)
|
||||
continue
|
||||
|
||||
try:
|
||||
result = process_single_image(image_path, label_path, output_dir_base=output_dir)
|
||||
except Exception as e:
|
||||
print(f"[{idx}/{total_files}] Error processing {file_name}: {e}")
|
||||
file_summary["note"] = f"Error: {e}"
|
||||
all_file_summary.append(file_summary)
|
||||
continue
|
||||
|
||||
file_summary["current_labels"] = result["processed_labels"]
|
||||
file_summary["missing_labels"] = result["missing_labels"]
|
||||
all_file_summary.append(file_summary)
|
||||
|
||||
progress[name] = {
|
||||
"finished": True,
|
||||
"processed_labels": result["processed_labels"],
|
||||
"missing_labels": result["missing_labels"]
|
||||
}
|
||||
save_progress(progress)
|
||||
|
||||
print(f"[{idx}/{total_files}] Finished: {file_name} | Missing labels: {result['missing_labels'] or 'None'}")
|
||||
|
||||
# --- Summary ---
|
||||
summary_path = os.path.join(output_dir, "all_files_label_summary.txt")
|
||||
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
|
||||
print(f"\nWriting summary to {summary_path}...")
|
||||
|
||||
with open(summary_path, "w") as f:
|
||||
f.write("--- CTSpine1K Dataset Label Summary ---\n")
|
||||
f.write(f"Total files processed: {total_files}\n\n")
|
||||
|
||||
for summary in all_file_summary:
|
||||
f.write("================================================\n")
|
||||
f.write(f"File: {summary['file_name']}\n")
|
||||
|
||||
processed_labels_str = ", ".join([str(l) for l in summary['current_labels']])
|
||||
f.write(f"Labels processed: {processed_labels_str}\n")
|
||||
|
||||
if summary['missing_labels']:
|
||||
missing_str = ", ".join([f"{l} ({LABEL_MAP.get(l, 'Unknown')})" for l in summary['missing_labels']])
|
||||
f.write(f"🚨 Missing Labels: {missing_str}\n")
|
||||
else:
|
||||
f.write("✅ Missing Labels: None\n")
|
||||
|
||||
if "note" in summary:
|
||||
f.write(f"Note: {summary['note']}\n")
|
||||
f.write("================================================\n\n")
|
||||
|
||||
print("All done! Summary file created.")
|
||||
29
imaging/resample.py
Normal file
29
imaging/resample.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
from config.constant import LABEL_MAP
|
||||
|
||||
def resample_img(sitk_image, out_spacing=[0.5, 0.5, 0.5], is_label=False):
|
||||
|
||||
# Resample images to 2mm spacing with SimpleITK
|
||||
original_spacing = sitk_image.GetSpacing()
|
||||
original_size = sitk_image.GetSize()
|
||||
|
||||
out_size = [
|
||||
int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
|
||||
int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
|
||||
int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]
|
||||
|
||||
resample = sitk.ResampleImageFilter()
|
||||
resample.SetOutputSpacing(out_spacing)
|
||||
resample.SetSize(out_size)
|
||||
resample.SetOutputDirection(sitk_image.GetDirection())
|
||||
resample.SetOutputOrigin(sitk_image.GetOrigin())
|
||||
resample.SetTransform(sitk.Transform())
|
||||
resample.SetDefaultPixelValue(sitk_image.GetPixelIDValue())
|
||||
|
||||
if is_label:
|
||||
resample.SetInterpolator(sitk.sitkNearestNeighbor)
|
||||
else:
|
||||
resample.SetInterpolator(sitk.sitkBSpline)
|
||||
|
||||
return resample.Execute(sitk_image)
|
||||
75
imaging/segmentation.py
Normal file
75
imaging/segmentation.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
import SimpleITK as sitk
|
||||
from config.constant import LABEL_MAP
|
||||
import numpy as np
|
||||
|
||||
"""
|
||||
# 沿用原本 LABEL_MAP
|
||||
seg_bone(n, name, img, lbl)
|
||||
|
||||
# user 自定義
|
||||
my_map = {1: "L1", 2: "L2", 3: "L3"}
|
||||
seg_bone(n, name, img, lbl, label_map=my_map)
|
||||
"""
|
||||
|
||||
def seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_base=None, label_map=LABEL_MAP):
|
||||
|
||||
if output_base==None:
|
||||
output_base=='Dataset'
|
||||
|
||||
if n not in label_map:
|
||||
raise ValueError(f"Label {n} not found in label_map")
|
||||
|
||||
label_name = label_map[n]
|
||||
|
||||
lssif = sitk.LabelShapeStatisticsImageFilter()
|
||||
lssif.Execute(resampled_sitk_lbl)
|
||||
|
||||
if not lssif.HasLabel(n):
|
||||
raise RuntimeError(f"Label {n} not found")
|
||||
|
||||
bbox2 = lssif.GetBoundingBox(n)
|
||||
|
||||
roi = sitk.RegionOfInterest(resampled_sitk_img, bbox2[3:], bbox2[:3])
|
||||
label2 = sitk.RegionOfInterest(resampled_sitk_lbl, bbox2[3:], bbox2[:3])
|
||||
roi_path = os.path.join(output_base, f"{label_name}_roi.nii.gz")
|
||||
sitk.WriteImage(roi, roi_path)
|
||||
|
||||
binary = sitk.BinaryThreshold(label2, lowerThreshold=n, upperThreshold=n, outsideValue=0, insideValue=1)
|
||||
binary_path = os.path.join(output_base, f"{label_name}_binary.nii.gz")
|
||||
sitk.WriteImage(binary, binary_path)
|
||||
|
||||
roi_pixel_type = roi.GetPixelID()
|
||||
binary_cast = sitk.Cast(binary, roi_pixel_type)
|
||||
roi2 = roi * binary_cast
|
||||
roi2_path = os.path.join(output_base, f"{label_name}_roi2.nii.gz")
|
||||
sitk.WriteImage(roi2, roi2_path)
|
||||
|
||||
lsif = sitk.LabelStatisticsImageFilter()
|
||||
label2_int = sitk.Cast(label2, sitk.sitkUInt16)
|
||||
lsif.Execute(roi2, label2_int)
|
||||
labels_in_roi = lsif.GetLabels()
|
||||
if n in labels_in_roi:
|
||||
roi_hu = sitk.GetArrayFromImage(roi2)
|
||||
threshold = np.percentile(roi_hu, 60)
|
||||
else:
|
||||
threshold = lsif.GetMedian(labels_in_roi[0])
|
||||
|
||||
cortical = sitk.BinaryThreshold(roi2, lowerThreshold=threshold, upperThreshold=10000, outsideValue=0, insideValue=1)
|
||||
cortical_path = os.path.join(output_base, f"{label_name}_cortical.nii.gz")
|
||||
sitk.WriteImage(cortical, cortical_path)
|
||||
|
||||
return roi_path, binary_path, roi2_path, cortical_path
|
||||
|
||||
"""
|
||||
Dataset/
|
||||
└── standardized/
|
||||
└── subject001/
|
||||
├── L1_roi.nii.gz
|
||||
├── L1_binary.nii.gz
|
||||
├── L1_roi2.nii.gz
|
||||
├── L1_cortical.nii.gz
|
||||
├── L2_roi.nii.gz
|
||||
├── L2_binary.nii.gz
|
||||
...
|
||||
"""
|
||||
31
optimization_benchmark_results_DE.csv
Normal file
31
optimization_benchmark_results_DE.csv
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
Patient_ID,Method,Run,Total_Time_sec,Left_Cortical_Score,Left_Bone_Score,Left_C_B_Ratio,Right_Cortical_Score,Right_Bone_Score,Right_C_B_Ratio,Left_Loss,Right_Loss
|
||||
000121,DE,1,35.753113746643066,79.55674436377531,96.79021780664884,0.821950256612712,87.70637541275083,97.85369570739142,0.8963011031797534,-11164556.744363775,-13010306.375412751
|
||||
000121,DE,2,35.85626721382141,79.35336048879837,99.08350305498982,0.8008735868448099,87.6778455284553,97.73882113821138,0.8970626462178322,-11129753.360488798,-13004677.845528455
|
||||
000121,DE,3,31.091505527496338,79.55239064089523,99.12258392675484,0.8025657472738935,87.7418533604888,98.30702647657841,0.8925288100479088,-11130152.390640896,-12984341.853360489
|
||||
000141,DE,1,140.6565911769867,86.313240728941,96.68663183382185,0.8927112165546328,91.54715065363625,96.59855311587765,0.9477072657995007,-12680113.24072894,-13716747.150653636
|
||||
000141,DE,2,192.32001185417175,86.34512598625605,96.27131585645202,0.8968935888962327,91.14729076570846,96.60391757822437,0.9435154707044107,-12705945.125986256,-13703347.290765708
|
||||
000141,DE,3,205.5456428527832,85.99695585996956,96.61339421613394,0.8901142181961402,90.93793628633075,96.54778525193552,0.9418956224530038,-12685796.95585997,-13712737.93628633
|
||||
000151,DE,1,184.93012046813965,84.59359954687058,96.37496459926366,0.8777549221275346,81.10938163294185,92.27394934201217,0.8790062873792364,-11051193.59954687,-10356109.381632943
|
||||
000151,DE,2,255.19188165664673,84.5317135188586,96.27065969769741,0.8780630961115188,80.6579133135677,91.6278413101793,0.8802773497688753,-11085531.713518858,-10327657.913313568
|
||||
000151,DE,3,318.5383884906769,84.57409238593021,96.08701794038707,0.8801822993237282,80.9469964664311,92.09893992932862,0.878913443830571,-11070974.092385931,-10367546.996466432
|
||||
000161,DE,1,293.82185339927673,78.75095201827875,97.14394516374715,0.8106624852998824,69.2503896840017,93.45330877143262,0.7410159211523882,-9339550.95201828,-7986250.389684001
|
||||
000161,DE,2,191.1857042312622,68.1406907098254,90.35300114693513,0.7541607898448519,71.80098553489111,84.10427594976952,0.8537138537138537,-8717140.690709826,-7609600.9855348915
|
||||
000161,DE,3,231.82087516784668,78.40345748061523,95.01716028981822,0.8251505016722409,69.55722167208941,93.80393266374311,0.7415171165736691,-9315203.457480615,-8031957.22167209
|
||||
000171,DE,1,190.15103244781494,79.46289752650176,94.53003533568905,0.8406100478468899,89.35980654193713,96.29629629629629,0.9279672217816549,-10053062.897526503,-13240559.806541936
|
||||
000171,DE,2,202.20453667640686,79.63538722442057,94.1351045788581,0.8459690737126557,93.69841269841271,97.98412698412699,0.9562611372104326,-10089835.38722442,-11529098.412698412
|
||||
000171,DE,3,206.49177026748657,79.90092002830856,94.8761500353857,0.8421602267641354,89.3779417376924,95.76389772293601,0.9333156216790648,-10080900.920028308,-13289377.941737693
|
||||
000181,DE,1,240.98572325706482,73.77466581795035,97.83577339274348,0.7540663630448926,80.7670534956128,91.80583073874894,0.8797595190380763,-9471774.66581795,-10197567.053495612
|
||||
000181,DE,2,225.65672206878662,72.20582032897511,91.57879938141431,0.7884556340190358,80.67940552016985,92.04529370134466,0.8765185299092726,-8721605.820328975,-10186679.405520169
|
||||
000181,DE,3,229.29561853408813,73.67884884757417,97.75881828600535,0.7536798228474664,80.4436908294475,91.64900381517592,0.8777366635831021,-9511078.848847574,-10265243.690829448
|
||||
000191,DE,1,157.32477927207947,80.80101716465353,95.2638270820089,0.8481815148481815,79.08787541713015,95.67773716828222,0.8266068759342302,-9113401.017164653,-8877287.87541713
|
||||
000191,DE,2,183.22506833076477,80.77839555202542,95.20254169976171,0.8484899048890372,79.23358244553982,95.80219430752108,0.8270539419087135,-9101778.395552026,-8926433.58244554
|
||||
000191,DE,3,175.80304598808289,80.66000317309218,95.36728541964145,0.8457827316586258,78.95238095238095,95.38095238095238,0.8277583624563155,-9124460.003173092,-8896152.38095238
|
||||
000201,DE,1,236.4226109981537,77.4121271530598,91.79386640526536,0.8433257055682686,71.38433515482696,87.14025500910748,0.8191889632107022,-9615012.12715306,-6598584.335154827
|
||||
000201,DE,2,253.00249886512756,77.48911669709311,91.81294761971634,0.8439889874579382,76.36016544702512,91.01177219217308,0.8390141583639222,-9598089.116697093,-7789560.165447025
|
||||
000201,DE,3,218.03053379058838,77.02247191011236,91.26404494382022,0.8439519852262235,71.35056425191118,87.25882781215873,0.8176887776387151,-9574222.471910112,-6583150.564251911
|
||||
000211,DE,1,187.18615746498108,80.2594472645234,93.30231246474902,0.860208553725253,72.7966425028615,85.05659417525118,0.8558612440191388,-10114259.447264524,-9730596.642502861
|
||||
000211,DE,2,210.16704440116882,80.55634001694436,94.22479525557752,0.8549378090813726,78.69571367944546,85.98104399490735,0.9152681803224744,-10122956.340016944,-9908895.713679446
|
||||
000211,DE,3,254.02969360351562,80.57950530035336,93.80918727915194,0.8589724273014917,78.66120604434403,86.21663606835193,0.9123669123669124,-10126379.505300354,-9916861.206044344
|
||||
000221,DE,1,175.14598894119263,74.39877847054332,90.29138567247742,0.8239853438556932,83.13391038696538,94.51374745417516,0.8795959595959596,-10091798.778470544,-11923133.910386965
|
||||
000221,DE,2,192.38243532180786,74.53368861819565,90.2042887958381,0.8262765508510338,83.17411794675837,94.95605655330532,0.8759221998658617,-10144933.688618196,-11969374.117946759
|
||||
000221,DE,3,193.16686582565308,74.54545454545455,90.38779402415767,0.8247292164861443,80.73070955000705,92.04401184934406,0.8770881226053638,-10129945.454545455,-10333930.709550006
|
||||
|
30
optimization_benchmark_results_NM.csv
Normal file
30
optimization_benchmark_results_NM.csv
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
Patient_ID,Method,Run,Total_Time_sec,Left_Cortical_Score,Left_Bone_Score,Left_C_B_Ratio,Right_Cortical_Score,Right_Bone_Score,Right_C_B_Ratio,Left_Loss,Right_Loss
|
||||
000121,NM,1,62.64465260505676,79.69253657327052,93.47879990081826,0.8525198938992041,76.40900366428197,86.93072762170651,0.878964271376957,-4455492.53657327,-7722809.003664281
|
||||
000121,NM,2,112.54990696907043,19.996510207642647,24.2017099982551,0.8262436914203318,19.996510207642647,25.423137323329264,0.7865477007549759,-1815500.0,-2052500.0
|
||||
000121,NM,3,79.86152529716492,19.99636759898293,45.67744278968398,0.4377733598409543,74.47483154974238,81.78755449861276,0.9105888054276715,-1971500.0,-6530074.831549742
|
||||
000141,NM,1,56.64281606674194,60.010930952814725,70.75970122062306,0.8480947476828014,92.42144177449168,99.74121996303143,0.9266123054114158,-4730010.930952814,-4716421.441774491
|
||||
000141,NM,2,48.48319458961487,72.2156862745098,88.68627450980392,0.8142825558257794,81.45263827082009,95.61347743165925,0.8518949468085106,-6244615.68627451,-9345852.63827082
|
||||
000141,NM,3,63.44100284576416,29.994107248084855,48.00628560204282,0.6247954173486088,29.993650793650794,50.80634920634921,0.5903524118970257,-2409600.0,-3760200.0
|
||||
000151,NM,1,41.89803433418274,85.00988979937836,95.0127154563436,0.894721189591078,72.36600197991797,86.67798048366568,0.834883341491271,-9918209.889799379,-8683166.001979917
|
||||
000151,NM,2,107.57207727432251,39.95674850527923,58.84747487596997,0.6789883268482492,29.994175888177054,57.94991263832265,0.5175879396984925,-3234000.0,-631800.0
|
||||
000151,NM,3,84.16005349159241,29.998184129289996,67.75013619030325,0.4427767354596623,83.55547550432276,97.24423631123919,0.859233191331728,-2620000.0,-8386155.475504322
|
||||
000161,NM,1,96.87153625488281,19.997175540177942,20.81626888857506,0.960651289009498,29.998409416255768,60.076348019723234,0.4993380990203866,-2538500.0,-2675400.0
|
||||
000161,NM,2,72.83068442344666,29.998588965711864,31.451954282489066,0.9537909376401973,55.50992470910335,74.19575633127995,0.7481549815498154,-3391400.0,-1076509.9247091033
|
||||
000161,NM,3,57.63184142112732,57.56774619960344,70.74245428508482,0.813765182186235,29.998409416255768,69.9538730714172,0.42883128694861306,-3624767.746199603,-1873400.0
|
||||
000171,NM,1,63.1676971912384,39.99206506645507,58.5399722277326,0.6831582514401897,29.998014691284496,51.737145126067105,0.5798158096699924,-963000.0,-2371400.0
|
||||
000171,NM,2,57.605836153030396,29.994765311463965,65.34636189146745,0.4590120160213618,39.98435054773083,80.67292644757433,0.495635305528613,-1853800.0,-2237200.0
|
||||
000171,NM,3,77.0465567111969,19.996513249651322,22.7510460251046,0.878927203065134,39.98223575688365,54.19363024996828,0.737766331070007,-2054300.0,-3359000.0
|
||||
000181,NM,2,83.35590553283691,29.998727249586356,60.73564973908616,0.4939228834870076,29.994756161510228,36.89914350638001,0.8128848886783515,-3633000.0,-2581000.0
|
||||
000181,NM,3,25.232823610305786,39.99717633771001,71.59395736269943,0.5586669295996844,52.23455735838315,77.25590663250783,0.6761238025055268,-3027000.0,-1836834.5573583832
|
||||
000191,NM,1,67.51091599464417,81.00364096881431,95.77331011556119,0.845785123966942,29.992557677995535,65.86454973951874,0.455367231638418,-9169003.640968814,-1268800.0
|
||||
000191,NM,2,103.78111028671265,19.99607920015683,43.952166241913346,0.4549509366636932,19.9836867862969,30.859162588363244,0.6475770925110131,-1825700.0,-2634200.0
|
||||
000191,NM,3,90.03080677986145,29.995223690495145,46.63270179907658,0.6432229429839537,39.45765376301827,86.40204362350167,0.45667500568569475,-3003800.0,-1880400.0
|
||||
000201,NM,1,52.31811261177063,79.57905722734388,93.3205774917377,0.852749301025163,29.995029821073558,63.24552683896621,0.4742632612966601,-8079379.057227343,-1912800.0
|
||||
000201,NM,2,44.66054558753967,39.9974548231102,57.64825655383049,0.6938189845474613,64.27206116597642,95.63555272379739,0.6720519653564291,-2867000.0,-6088872.061165976
|
||||
000201,NM,3,36.97987747192383,85.58890147225368,92.46885617214043,0.9255970606246172,29.99825388510564,73.19713637157325,0.4098282442748092,-5619388.901472254,-1344600.0
|
||||
000211,NM,1,72.89356565475464,83.39688592310137,95.9644105497299,0.8690397350993377,9.998585772875124,30.26446047235186,0.33037383177570095,-9519396.8859231,-1229800.0
|
||||
000211,NM,2,118.73922371864319,19.99593578540947,30.156472261735416,0.6630727762803235,19.99636098981077,39.082969432314414,0.5116387337057727,-1760900.0,-1947900.0
|
||||
000211,NM,3,115.41684770584106,84.35417327832434,97.03268803554427,0.8693376941946035,9.998432847516064,25.952045133991536,0.38526570048309183,-9507554.173278324,-1457100.0
|
||||
000221,NM,1,73.53109002113342,9.99825205383674,17.811571403600766,0.5613346418056918,88.07339449541286,98.01636498884206,0.8985580571717684,-1305300.0,-6585273.394495413
|
||||
000221,NM,2,111.19093751907349,19.996390543223246,21.746977079949467,0.9195020746887967,29.998725627628392,48.426150121065376,0.6194736842105263,-1724100.0,-3706200.0
|
||||
000221,NM,3,74.673180103302,19.997459026807267,29.005208995045102,0.6894437144108629,39.996080736821476,54.399372917891434,0.7352305475504323,-2822900.0,-2181800.0
|
||||
|
4
optimization_benchmark_results_PSO.csv
Normal file
4
optimization_benchmark_results_PSO.csv
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
Patient_ID,Method,Run,Total_Time_sec,Left_Cortical_Score,Left_Bone_Score,Left_C_B_Ratio,Right_Cortical_Score,Right_Bone_Score,Right_C_B_Ratio,Left_Loss,Right_Loss
|
||||
00010,PSO,1,116.09450578689575,78.9866581956798,90.86721728081322,0.8692536269882887,87.45029425799268,95.67361221568316,0.9140482128013299,-8620586.65819568,-10406250.294257993
|
||||
00010,PSO,2,119.36032223701477,79.74112571092371,91.21396352225926,0.8742205977209203,80.71401346715791,93.19019184347606,0.8661213360599862,-7217341.125710924,-11510314.013467157
|
||||
00010,PSO,3,116.97214722633362,81.31569409206234,91.3374411018485,0.8902777777777778,88.62724577010292,96.33699633699634,0.91997103023719,-7985515.694092063,-9695027.245770102
|
||||
|
1
results/comparison_results.csv
Normal file
1
results/comparison_results.csv
Normal file
|
|
@ -0,0 +1 @@
|
|||
case,side,method,repeat,cortical_ratio,vertebral_ratio,cb_ratio,loss,time_sec
|
||||
|
10
results/results_DE.csv
Normal file
10
results/results_DE.csv
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
case,side,method,repeat,seed,cortical_ratio,vertebral_ratio,cb_ratio,diameter,length,z,y,x,azimuth,altitude,cyl_points,overlap_cortical_voxels,overlap_vertebral_voxels,overlap_warning,cortical_outside_spine,loss,time_sec
|
||||
1.3.6.1.4.1.9328.50.4.0001,L,DE,1,1,99.99,24.43,4.0922,5.0,50.0,7.5637,69.229,63.3989,123.7477,67.3613,7903,7902,1931,1,2000437,-15843187.3466,46.5
|
||||
1.3.6.1.4.1.9328.50.4.0001,L,DE,2,2,100.0,25.26,3.9587,5.0,50.0,4.3,67.2525,59.4038,116.5368,65.9693,7949,7949,2008,1,2000437,-15938000.0,46.95
|
||||
1.3.6.1.4.1.9328.50.4.0001,L,DE,3,3,100.0,19.18,5.2127,5.0,50.0,16.8451,72.3062,56.3851,123.8802,60.6996,7892,7892,1514,1,2000437,-15820000.0,48.35
|
||||
1.3.6.1.4.1.9328.50.4.0001,R,DE,1,11,99.97,13.0,7.6885,5.0,50.0,26.1806,48.2449,143.6017,90.1226,63.2828,7975,7973,1037,1,2000437,-15984374.9216,52.15
|
||||
1.3.6.1.4.1.9328.50.4.0001,R,DE,2,12,99.98,6.65,15.041,5.0,50.0,28.1923,62.7821,145.4942,90.0013,63.4555,8064,8062,536,1,2000437,-16162375.1984,54.69
|
||||
1.3.6.1.4.1.9328.50.4.0001,R,DE,3,13,99.9,66.77,1.4961,5.0,50.0,29.0259,33.4412,117.5387,90.0831,63.5098,8030,8022,5362,1,2000437,-16107500.3736,53.13
|
||||
1.3.6.1.4.1.9328.50.4.0002,L,DE,1,101,100.0,16.05,6.2311,5.0,50.0,74.8105,64.7153,67.9007,108.4366,83.9283,7901,7901,1268,1,2955801,-15842000.0,67.79
|
||||
1.3.6.1.4.1.9328.50.4.0002,L,DE,2,102,99.99,64.19,1.5578,5.0,50.0,47.0961,43.6089,70.2426,99.1139,78.8503,7877,7876,5056,1,2955801,-15821187.3048,65.29
|
||||
1.3.6.1.4.1.9328.50.4.0002,L,DE,3,103,100.0,22.02,4.5416,5.0,50.0,68.4024,44.1539,55.9956,97.1283,82.8774,7907,7907,1741,1,2955801,-15854000.0,66.76
|
||||
|
1
results/results_GA.csv
Normal file
1
results/results_GA.csv
Normal file
|
|
@ -0,0 +1 @@
|
|||
case,side,method,repeat,cortical_ratio,vertebral_ratio,cb_ratio,loss,time_sec
|
||||
|
61
results/results_PSO.csv
Normal file
61
results/results_PSO.csv
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
case,side,method,repeat,seed,cortical_ratio,vertebral_ratio,cb_ratio,diameter,length,z,y,x,azimuth,altitude,cyl_points,overlap_cortical_voxels,overlap_vertebral_voxels,overlap_warning,cortical_outside_spine,loss,time_sec
|
||||
1.3.6.1.4.1.9328.50.4.0001,L,PSO,1,1,99.99,8.69,11.5064,5.0,50.0,7.4176,36.793,49.1919,116.5719,65.9124,8044,8043,699,1,2000437,-16125187.5684,48.67
|
||||
1.3.6.1.4.1.9328.50.4.0001,L,PSO,2,2,100.0,9.16,10.9133,5.0,50.0,5.6307,53.3182,65.5052,130.5898,61.4896,7934,7934,727,1,2000437,-15908000.0,44.6
|
||||
1.3.6.1.4.1.9328.50.4.0001,L,PSO,3,3,100.0,21.26,4.7037,5.0,50.0,4.6025,75.6423,62.4257,126.8619,68.2206,7968,7968,1694,1,2000437,-15976000.0,46.46
|
||||
1.3.6.1.4.1.9328.50.4.0001,R,PSO,1,11,99.99,10.11,9.8857,5.0,50.0,20.5102,49.6687,140.5223,90.0478,62.6068,7959,7958,805,1,2000437,-15955187.4356,54.07
|
||||
1.3.6.1.4.1.9328.50.4.0001,R,PSO,2,12,99.95,24.83,4.0259,5.0,50.0,10.9234,62.9727,100.2,71.5645,64.5968,7939,7935,1971,1,2000437,-15906749.6158,50.59
|
||||
1.3.6.1.4.1.9328.50.4.0001,R,PSO,3,13,99.99,13.2,7.5755,5.0,50.0,10.0558,57.709,104.1599,70.0213,59.1209,7925,7924,1046,1,2000437,-15887187.3817,48.66
|
||||
1.3.6.1.4.1.9328.50.4.0002,L,PSO,1,101,100.0,56.5,1.7698,5.0,50.0,51.3303,40.1518,74.2,105.9376,82.1742,7918,7918,4474,1,2955801,-15906000.0,66.71
|
||||
1.3.6.1.4.1.9328.50.4.0002,L,PSO,2,102,100.0,14.09,7.0958,5.0,50.0,76.8487,50.3333,72.7445,113.1861,82.4419,7926,7926,1117,1,2955801,-15892000.0,67.51
|
||||
1.3.6.1.4.1.9328.50.4.0002,L,PSO,3,103,100.0,15.19,6.5849,5.0,50.0,75.4254,61.8117,65.6238,114.4484,80.5996,7915,7915,1202,1,2955801,-15870000.0,69.06
|
||||
1.3.6.1.4.1.9328.50.4.0002,R,PSO,1,111,99.91,15.41,6.4817,5.0,50.0,5.6976,82.3753,113.0593,56.3006,74.4506,7973,7966,1229,1,2955801,-15966312.2037,60.3
|
||||
1.3.6.1.4.1.9328.50.4.0002,R,PSO,2,112,100.0,4.12,24.2831,5.0,50.0,6.1936,37.2415,118.2451,81.1364,88.5141,7892,7892,325,1,2955801,-15854000.0,66.59
|
||||
1.3.6.1.4.1.9328.50.4.0002,R,PSO,3,113,100.0,8.74,11.4435,5.0,50.0,76.495,47.1313,116.2606,84.6961,79.5085,7896,7896,690,1,2955801,-15832000.0,68.45
|
||||
1.3.6.1.4.1.9328.50.4.0003,L,PSO,1,201,20.0,33.57,0.5956,5.0,45.0,56.3852,63.9562,17.4898,109.292,97.2601,4316,863,1449,0,0,-1543100.0,58.56
|
||||
1.3.6.1.4.1.9328.50.4.0003,L,PSO,2,202,19.99,37.37,0.535,4.5,40.0,58.1093,70.7978,13.3408,105.8432,96.161,3251,650,1215,0,0,-1154700.0,59.49
|
||||
1.3.6.1.4.1.9328.50.4.0003,L,PSO,3,203,29.99,46.94,0.6389,5.0,35.0,56.1536,67.1454,23.1849,122.2648,98.3765,3534,1060,1659,0,0,-1685600.0,54.63
|
||||
1.3.6.1.4.1.9328.50.4.0003,R,PSO,1,211,30.0,36.28,0.8268,4.5,45.0,52.26,61.821,174.191,67.415,101.7994,3947,1184,1432,0,0,-1861200.0,59.29
|
||||
1.3.6.1.4.1.9328.50.4.0003,R,PSO,2,212,30.0,37.93,0.7909,5.0,45.0,49.2119,73.0297,177.9171,69.6459,104.6915,4577,1373,1736,0,0,-2185600.0,154.05
|
||||
1.3.6.1.4.1.9328.50.4.0003,R,PSO,3,213,72.6,96.98,0.7486,4.0,45.0,47.5138,37.4122,118.8,68.5949,97.5772,4529,3288,4392,0,0,-3715798.8077,116.51
|
||||
1.3.6.1.4.1.9328.50.4.0004,L,PSO,1,301,10.0,34.7,0.2881,5.0,45.0,52.9404,69.3304,33.8691,122.0144,82.5102,5141,514,1784,0,0,-1170900.0,55.93
|
||||
1.3.6.1.4.1.9328.50.4.0004,L,PSO,2,302,72.22,82.03,0.8804,4.0,45.0,27.8612,42.2801,69.5263,106.5231,72.4112,4525,3268,3712,0,0,-4742620.9945,56.87
|
||||
1.3.6.1.4.1.9328.50.4.0004,L,PSO,3,303,20.0,23.0,0.8695,4.5,50.0,66.1798,62.605,71.8798,143.1133,72.1335,4431,886,1019,0,0,-1558500.0,56.84
|
||||
1.3.6.1.4.1.9328.50.4.0004,R,PSO,1,311,10.0,29.66,0.337,5.0,45.0,39.5416,39.5878,142.8218,52.787,81.3579,5222,522,1549,0,0,-1190000.0,51.0
|
||||
1.3.6.1.4.1.9328.50.4.0004,R,PSO,2,312,73.86,85.83,0.8605,5.0,40.0,27.5925,50.8509,128.0866,74.6366,72.0293,6281,4639,5391,0,0,-6538257.666,57.0
|
||||
1.3.6.1.4.1.9328.50.4.0004,R,PSO,3,313,78.66,92.39,0.8515,4.5,45.0,23.4699,44.1544,115.4623,66.7588,68.4046,5793,4557,5352,0,0,-7034863.9047,56.72
|
||||
1.3.6.1.4.1.9328.50.4.0005,L,PSO,1,401,19.99,41.52,0.4815,4.5,45.0,86.7698,41.0198,59.0922,66.1304,73.5371,2861,572,1188,0,0,-939300.0,61.61
|
||||
1.3.6.1.4.1.9328.50.4.0005,L,PSO,2,402,68.99,76.62,0.9004,4.0,35.0,67.478,35.5161,62.2094,79.0893,78.6353,3525,2432,2701,0,0,-3528592.9078,65.13
|
||||
1.3.6.1.4.1.9328.50.4.0005,L,PSO,3,403,67.91,78.01,0.8704,4.5,35.0,68.7886,35.1374,60.2172,76.6768,79.7908,4462,3030,3481,0,0,-4257306.7683,65.38
|
||||
1.3.6.1.4.1.9328.50.4.0005,R,PSO,1,411,19.99,29.69,0.6734,4.5,45.0,73.758,66.3197,195.9131,47.802,81.4935,2651,530,787,0,0,-942700.0,60.99
|
||||
1.3.6.1.4.1.9328.50.4.0005,R,PSO,2,412,30.0,44.41,0.6754,5.0,45.0,64.2517,57.7756,165.137,30.7208,76.554,5487,1646,2437,0,0,-2534400.0,67.84
|
||||
1.3.6.1.4.1.9328.50.4.0005,R,PSO,3,413,30.0,48.03,0.6245,5.0,45.0,62.0927,58.3166,163.1757,28.6541,79.3388,5517,1655,2650,0,0,-2550800.0,68.7
|
||||
1.3.6.1.4.1.9328.50.4.0006,L,PSO,1,501,19.99,36.76,0.5438,4.5,45.0,48.161,83.507,23.63,126.7736,80.398,2606,521,958,0,0,-926500.0,65.1
|
||||
1.3.6.1.4.1.9328.50.4.0006,L,PSO,2,502,19.99,21.15,0.9453,5.0,40.0,86.9474,89.9983,80.2344,143.267,75.3227,3891,778,823,0,0,-1355100.0,66.54
|
||||
1.3.6.1.4.1.9328.50.4.0006,L,PSO,3,503,87.62,91.7,0.9555,4.5,40.0,37.0386,49.0451,78.8935,100.2066,81.9978,5095,4464,4672,0,0,-7843815.3091,70.14
|
||||
1.3.6.1.4.1.9328.50.4.0006,R,PSO,1,511,84.55,93.9,0.9004,5.0,45.0,36.9954,47.6081,125.6671,67.2978,79.8972,7080,5986,6648,0,0,-9821348.0226,70.52
|
||||
1.3.6.1.4.1.9328.50.4.0006,R,PSO,2,512,85.56,97.02,0.8819,4.5,45.0,35.0025,48.3838,132.0857,71.1536,75.9332,5736,4908,5565,0,0,-8354164.8536,70.63
|
||||
1.3.6.1.4.1.9328.50.4.0006,R,PSO,3,513,83.52,95.45,0.8751,5.0,45.0,35.3546,48.4405,132.2396,71.1812,76.3927,7077,5911,6755,0,0,-9646724.0921,70.72
|
||||
1.3.6.1.4.1.9328.50.4.0007,L,PSO,1,601,70.97,95.87,0.7403,4.5,40.0,36.9433,55.6157,64.2063,115.1423,79.1659,5085,3609,4875,0,0,-3608173.4513,59.55
|
||||
1.3.6.1.4.1.9328.50.4.0007,L,PSO,2,602,70.98,96.63,0.7346,4.5,45.0,35.6859,52.2457,63.4741,112.467,78.4136,5720,4060,5527,0,0,-4897979.021,60.47
|
||||
1.3.6.1.4.1.9328.50.4.0007,L,PSO,3,603,20.0,21.66,0.9231,4.5,45.0,31.1183,86.2203,55.6987,146.9233,81.005,4321,864,936,0,0,-1392900.0,58.01
|
||||
1.3.6.1.4.1.9328.50.4.0007,R,PSO,1,611,72.02,81.92,0.8791,5.0,40.0,32.9074,50.676,127.462,82.3338,79.7994,6290,4530,5153,0,0,-5524019.0779,60.6
|
||||
1.3.6.1.4.1.9328.50.4.0007,R,PSO,2,612,79.91,85.73,0.9322,4.5,40.0,31.1365,41.2923,111.6,79.8617,72.6211,5107,4081,4378,0,0,-6314109.9276,58.61
|
||||
1.3.6.1.4.1.9328.50.4.0007,R,PSO,3,613,72.34,78.22,0.9248,5.0,35.0,33.3902,49.9694,124.8573,82.6153,80.3251,5509,3985,4309,0,0,-5853136.1772,59.22
|
||||
1.3.6.1.4.1.9328.50.4.0008,L,PSO,1,701,72.5,79.06,0.9171,4.5,40.0,96.018,63.7441,83.1531,103.4561,102.5322,5095,3694,4028,0,0,-4761702.4534,90.81
|
||||
1.3.6.1.4.1.9328.50.4.0008,L,PSO,2,702,10.0,34.26,0.2918,4.5,40.0,91.3609,88.0287,26.9624,126.4576,100.0414,3001,300,1028,0,0,-670700.0,82.41
|
||||
1.3.6.1.4.1.9328.50.4.0008,L,PSO,3,703,76.63,80.29,0.9545,4.5,40.0,92.4717,44.9791,85.0,101.9128,91.5303,5088,3899,4085,0,0,-6386431.2893,88.53
|
||||
1.3.6.1.4.1.9328.50.4.0008,R,PSO,1,711,19.96,41.5,0.481,4.5,45.0,100.5249,98.418,190.6688,59.3931,102.7534,3046,608,1264,0,0,-1082600.0,85.95
|
||||
1.3.6.1.4.1.9328.50.4.0008,R,PSO,2,712,10.0,23.19,0.4312,4.5,50.0,100.0967,77.2015,178.2021,59.5183,102.775,4671,467,1083,0,0,-1063800.0,90.26
|
||||
1.3.6.1.4.1.9328.50.4.0008,R,PSO,3,713,57.92,61.87,0.9362,4.5,40.0,100.8996,81.4359,134.1989,79.9791,102.6145,5093,2950,3151,0,0,-3357522.6389,93.01
|
||||
1.3.6.1.4.1.9328.50.4.0009,L,PSO,1,801,72.55,91.83,0.7901,5.0,45.0,32.8651,47.155,58.961,108.0286,72.0106,7035,5104,6460,0,0,-4581751.5281,57.55
|
||||
1.3.6.1.4.1.9328.50.4.0009,L,PSO,2,802,92.17,97.94,0.9412,4.5,40.0,24.7077,39.9498,69.0,117.5817,64.9292,5137,4735,5031,0,0,-8810574.4209,54.11
|
||||
1.3.6.1.4.1.9328.50.4.0009,L,PSO,3,803,85.67,94.4,0.9075,4.5,45.0,24.2003,37.5681,68.921,116.1667,66.0548,5737,4915,5416,0,0,-7685071.954,55.68
|
||||
1.3.6.1.4.1.9328.50.4.0009,R,PSO,1,811,95.17,97.9,0.9721,5.0,40.0,25.8881,43.4766,111.5397,78.9595,62.6884,6296,5992,6164,0,0,-11591971.5375,55.69
|
||||
1.3.6.1.4.1.9328.50.4.0009,R,PSO,2,812,89.54,94.23,0.9503,5.0,45.0,25.4625,39.9189,110.1685,76.9905,65.9668,7086,6345,6677,0,0,-11054742.7604,56.89
|
||||
1.3.6.1.4.1.9328.50.4.0009,R,PSO,3,813,90.49,93.69,0.9658,5.0,40.0,29.9746,43.9718,113.0538,79.0801,69.1544,6309,5709,5911,0,0,-10428489.7765,112.27
|
||||
1.3.6.1.4.1.9328.50.4.0010,L,PSO,1,901,85.08,94.76,0.8978,4.0,35.0,63.2001,36.7243,79.0574,102.9994,85.2987,3533,3006,3348,0,0,-5263483.4984,222.6
|
||||
1.3.6.1.4.1.9328.50.4.0010,L,PSO,2,902,10.0,20.85,0.4795,4.0,45.0,88.7729,54.613,28.9483,117.0101,84.3316,3281,328,684,0,0,-744100.0,220.5
|
||||
1.3.6.1.4.1.9328.50.4.0010,L,PSO,3,903,19.99,43.84,0.456,4.5,40.0,90.5944,80.2496,21.8845,130.5236,83.6175,2231,446,978,0,0,-792500.0,198.56
|
||||
1.3.6.1.4.1.9328.50.4.0010,R,PSO,1,911,85.57,95.5,0.896,4.5,45.0,62.0497,36.0756,127.2002,81.6334,82.4432,5738,4910,5480,0,0,-8410169.885,238.93
|
||||
1.3.6.1.4.1.9328.50.4.0010,R,PSO,2,912,19.99,36.71,0.5447,5.0,50.0,56.5369,70.4833,183.4086,53.8323,82.6558,3781,756,1388,0,0,-1350500.0,249.84
|
||||
1.3.6.1.4.1.9328.50.4.0010,R,PSO,3,913,19.99,34.85,0.5738,5.0,40.0,54.9824,66.8207,180.2205,49.5794,81.3042,3851,770,1342,0,0,-1370700.0,201.47
|
||||
|
3
results/results_RandomSearch.csv
Normal file
3
results/results_RandomSearch.csv
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
case,side,method,repeat,seed,cortical_ratio,vertebral_ratio,cb_ratio,diameter,length,z,y,x,azimuth,altitude,loss,time_sec
|
||||
1.3.6.1.4.1.9328.50.4.0001,L,RandomSearch,1,1,99.97,26.92,3.7139,5.0,50,42.4153,32.9548,61.7995,109.8793,65.1445,-15748374.545,45.15
|
||||
1.3.6.1.4.1.9328.50.4.0001,L,RandomSearch,2,2,99.95,22.52,4.4387,5.0,50,17.1245,72.8784,61.9283,128.65,63.801,-15740749.0835,44.92
|
||||
|
146
run_all_cases.py
Normal file
146
run_all_cases.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
import os
|
||||
import traceback
|
||||
import argparse
|
||||
import torch
|
||||
import SimpleITK as sitk
|
||||
from datetime import datetime
|
||||
# Core
|
||||
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
|
||||
from core.objective import set_global_context
|
||||
from core.cylinder import create_coordinate_grid
|
||||
|
||||
# Imaging
|
||||
from imaging.preprocessing import process_single_image, process_dataset
|
||||
|
||||
# Visualization
|
||||
from visualization.res_cyl_to_CT import cyl_and_CT
|
||||
from visualization.res_cyl_nifti import save_cyl
|
||||
|
||||
# Config
|
||||
from config.device import get_device
|
||||
from config.constant import *
|
||||
|
||||
def build_paths(base_dir: str, patient_id: str, level: str):
|
||||
case_dir = os.path.join(base_dir, patient_id)
|
||||
img1 = os.path.join(case_dir, f"{level}_cortical.nii.gz")
|
||||
img2 = os.path.join(case_dir, f"{level}_binary2.nii.gz")
|
||||
img3 = os.path.join(case_dir, f"{level}_roi2.nii.gz")
|
||||
return img1, img2, img3
|
||||
|
||||
if __name__ == "__main__":
|
||||
# === 1. 設定命令列引數 (GPU 拆分設定) ===
|
||||
parser = argparse.ArgumentParser(description="Multi-GPU CBT Batch Processing")
|
||||
parser.add_argument('--gpu', type=int, required=True, help='指定這支程式要用的 GPU ID (0, 1, 2, 3)')
|
||||
parser.add_argument('--total_gpus', type=int, default=4, help='總共開啟的 GPU 數量')
|
||||
args = parser.parse_args()
|
||||
|
||||
# === 2. 綁定 GPU ===
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f'cuda:{args.gpu}')
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
print("⚠️ 找不到 CUDA,將使用 CPU")
|
||||
|
||||
print(f"\n🚀 啟動批次任務 | 分配至 GPU: [{args.gpu}] | 總共 {args.total_gpus} 個節點協同運算")
|
||||
|
||||
# === 3. 基本參數 ===
|
||||
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/"
|
||||
spacing = [0.5, 0.5, 0.5]
|
||||
levels = ["L5"]
|
||||
date = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
failed = []
|
||||
skipped = []
|
||||
succeeded = []
|
||||
|
||||
# 動態獲取所有病人資料夾
|
||||
all_patients = sorted([d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))])
|
||||
|
||||
# === 4. 餘數分工:只挑選屬於這張 GPU 的病人 ===
|
||||
my_patients = [p for i, p in enumerate(all_patients) if i % args.total_gpus == args.gpu]
|
||||
|
||||
print(f"📂 總資料庫有 {len(all_patients)} 個病人。")
|
||||
print(f"🎯 本 GPU (ID: {args.gpu}) 被分配到 {len(my_patients)} 個病人,準備開始處理...")
|
||||
|
||||
for patient_id in my_patients:
|
||||
for level in levels:
|
||||
label_str = f"{patient_id}_{level}"
|
||||
IMAGE1, IMAGE2, IMAGE3 = build_paths(base_dir, patient_id, level)
|
||||
|
||||
missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]
|
||||
if missing:
|
||||
print(f"[SKIP] ⏭️ {label_str} (缺檔)")
|
||||
skipped.append((label_str, missing))
|
||||
continue
|
||||
|
||||
print(f"\n=========================================")
|
||||
print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===")
|
||||
print(f"=========================================")
|
||||
|
||||
try:
|
||||
# --- CBT = True ---
|
||||
print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...")
|
||||
binary_image = sitk.ReadImage(IMAGE2)
|
||||
binary_array = sitk.GetArrayFromImage(binary_image)
|
||||
image_shape = binary_array.shape
|
||||
|
||||
grid = create_coordinate_grid(image_shape, device)
|
||||
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(
|
||||
label_str,
|
||||
image1_path=IMAGE1,
|
||||
image2_path=IMAGE2,
|
||||
image3_path=IMAGE3,
|
||||
folder=date,
|
||||
swarm_size=70,
|
||||
max_iter=100,
|
||||
spacing=spacing,
|
||||
CBT=True,
|
||||
device=device,
|
||||
optimize_size=True,
|
||||
grid=grid # 記得把 device 傳進去
|
||||
)
|
||||
|
||||
cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)
|
||||
cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)
|
||||
|
||||
# --- CBT = False ---
|
||||
# print(f"👉 [GPU {args.gpu}] 執行傳統軌跡 (TT) 最佳化...")
|
||||
# best_pos_l_tt, best_loss_l_tt, best_pos_r_tt, best_loss_r_tt, total_time_tt = run_pso_torch(
|
||||
# label_str,
|
||||
# image1_path=IMAGE2,
|
||||
# image2_path=IMAGE2,
|
||||
# image3_path=IMAGE3,
|
||||
# folder=date,
|
||||
# swarm_size=70,
|
||||
# max_iter=100,
|
||||
# spacing=spacing,
|
||||
# CBT=False,
|
||||
# device=device,
|
||||
# optimize_size=True,
|
||||
# grid=grid # 記得把 device 傳進去
|
||||
# )
|
||||
|
||||
# cylinder_L_tt, cylinder_R_tt = save_cyl(best_pos_l_tt, best_pos_r_tt, spacing, IMAGE3, output_base='Output', CBT=False)
|
||||
# cyl_and_CT(best_pos_l_tt[5], best_pos_r_tt[5], best_pos_l_tt[6], best_pos_r_tt[6], IMAGE3, cylinder_L_tt, cylinder_R_tt, base_folder='Output', CBT=False)
|
||||
|
||||
# succeeded.append((label_str, best_loss_l, best_loss_r, total_time + total_time_tt))
|
||||
# print(f"[OK] ✅ [GPU {args.gpu}] {label_str} 完成!總耗時={(total_time + total_time_tt):.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}")
|
||||
traceback.print_exc()
|
||||
failed.append((label_str, str(e)))
|
||||
|
||||
# --- GPU 專屬的 Summary ---
|
||||
print("\n" + "="*40)
|
||||
print(f"🏆 ==== GPU {args.gpu} 處理總結 ====")
|
||||
print("="*40)
|
||||
print(f"✅ Succeeded (成功): {len(succeeded)} 節段")
|
||||
print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段")
|
||||
print(f"❌ Failed (執行錯誤): {len(failed)} 節段")
|
||||
|
||||
if failed:
|
||||
print(f"\n⚠️ GPU {args.gpu} 失敗清單:")
|
||||
for fail_label, err_msg in failed:
|
||||
print(f" - {fail_label}: {err_msg}")
|
||||
157
run_benchmark.py
Normal file
157
run_benchmark.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
import os
|
||||
import time
|
||||
import pandas as pd
|
||||
import torch
|
||||
import argparse
|
||||
import SimpleITK as sitk
|
||||
import traceback
|
||||
import core.objective
|
||||
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
|
||||
from core.cylinder import generate_cylinder_n_torch, create_coordinate_grid
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="CBT Optimization Benchmark")
|
||||
parser.add_argument('--method', type=str, required=True, choices=['PSO', 'DE', 'NM'])
|
||||
parser.add_argument('--gpu', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f'cuda:{args.gpu}')
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
print(f"\n🚀 啟動任務: 演算法 [{args.method}] | 運行於 [{device}]")
|
||||
|
||||
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}"
|
||||
patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221]
|
||||
|
||||
runs_per_method = 3
|
||||
|
||||
CBT = True
|
||||
optimize_size = True
|
||||
spacing = [0.5, 0.5, 0.5]
|
||||
|
||||
if args.method == 'PSO':
|
||||
method_func, swarm, iters = run_pso_torch, 70, 100
|
||||
elif args.method == 'DE':
|
||||
method_func, swarm, iters = run_de_torch, 70, 100
|
||||
elif args.method == 'NM':
|
||||
method_func, swarm, iters = run_nm_torch, 70, 7000
|
||||
|
||||
results_data = []
|
||||
|
||||
for p_id in patients:
|
||||
folder_num = str(p_id)
|
||||
current_dir = base_dir.format(folder_num)
|
||||
|
||||
cortical_path = os.path.join(current_dir, "L5_cortical.nii.gz")
|
||||
binary_path = os.path.join(current_dir, "L5_binary2.nii.gz")
|
||||
roi_path = os.path.join(current_dir, "L5_roi2.nii.gz")
|
||||
|
||||
if not all(os.path.exists(p) for p in [cortical_path, binary_path, roi_path]):
|
||||
print(f"⚠️ 找不到病人 000{p_id} 的影像檔案,跳過此病人。")
|
||||
continue
|
||||
|
||||
print(f"\n=========================================")
|
||||
print(f"📍 開始處理病人 ID: 000{p_id} (L5) - {args.method}")
|
||||
print(f"=========================================")
|
||||
# 在每個 case 進入 run_idx 之前先做
|
||||
img_b = sitk.GetArrayFromImage(sitk.ReadImage(binary_path))
|
||||
image_shape = img_b.shape
|
||||
grid = create_coordinate_grid(image_shape, device)
|
||||
|
||||
for run_idx in range(1, runs_per_method + 1):
|
||||
label_str = f"Patient_000{p_id}_{args.method}_Run{run_idx}"
|
||||
print(f"\n---> 執行 {args.method} (第 {run_idx}/{runs_per_method} 次)")
|
||||
|
||||
try:
|
||||
# 1. 呼叫演算法 (只會拿到 5 個回傳值)
|
||||
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = method_func(
|
||||
label_str=label_str,
|
||||
image1_path=cortical_path,
|
||||
image2_path=binary_path,
|
||||
image3_path=roi_path,
|
||||
folder="Output",
|
||||
swarm_size=swarm,
|
||||
max_iter=iters,
|
||||
spacing=spacing,
|
||||
CBT=CBT,
|
||||
device=device,
|
||||
optimize_size=optimize_size,
|
||||
grid=grid
|
||||
)
|
||||
|
||||
# ========== 2. 在外部獨立計算分數 (絕對不會有 NameError) ==========
|
||||
# 讀取影像為 Numpy 並轉為 GPU Tensor
|
||||
img_c = sitk.GetArrayFromImage(sitk.ReadImage(cortical_path))
|
||||
img_b = sitk.GetArrayFromImage(sitk.ReadImage(binary_path))
|
||||
image_shape = img_b.shape
|
||||
|
||||
cortical_eval = torch.from_numpy(img_c).to(device=device, dtype=torch.uint8)
|
||||
spine_eval = torch.from_numpy(img_b).to(device=device, dtype=torch.uint8)
|
||||
|
||||
# 解析 Diameter 與 Length
|
||||
d_l, l_l = (best_pos_l[5], best_pos_l[6]) if optimize_size else (4.5, 45)
|
||||
d_r, l_r = (best_pos_r[5], best_pos_r[6]) if optimize_size else (4.5, 45)
|
||||
|
||||
# 重新生成左側與右側的 Cylinder Mask
|
||||
cyl_l = generate_cylinder_n_torch(
|
||||
d_l, l_l, best_pos_l[0], best_pos_l[1], best_pos_l[2], best_pos_l[3], best_pos_l[4],
|
||||
image_shape, spacing, device, grid=grid
|
||||
)
|
||||
cyl_r = generate_cylinder_n_torch(
|
||||
d_r, l_r, best_pos_r[0], best_pos_r[1], best_pos_r[2], best_pos_r[3], best_pos_r[4],
|
||||
image_shape, spacing, device, grid=grid
|
||||
)
|
||||
|
||||
# 計算分數
|
||||
cyl_points_l = torch.sum(cyl_l).item()
|
||||
cyl_points_r = torch.sum(cyl_r).item()
|
||||
|
||||
overlap_c_l = ((cortical_eval == 1) & (cyl_l == 1)).sum().item()
|
||||
overlap_c_r = ((cortical_eval == 1) & (cyl_r == 1)).sum().item()
|
||||
overlap_b_l = ((spine_eval == 1) & (cyl_l == 1)).sum().item()
|
||||
overlap_b_r = ((spine_eval == 1) & (cyl_r == 1)).sum().item()
|
||||
|
||||
score_c_l = (overlap_c_l / cyl_points_l * 100) if cyl_points_l > 0 else 0
|
||||
score_c_r = (overlap_c_r / cyl_points_r * 100) if cyl_points_r > 0 else 0
|
||||
score_b_l = (overlap_b_l / cyl_points_l * 100) if cyl_points_l > 0 else 0
|
||||
score_b_r = (overlap_b_r / cyl_points_r * 100) if cyl_points_r > 0 else 0
|
||||
|
||||
cb_ratio_l = (score_c_l / score_b_l) if score_b_l > 0 else 0
|
||||
cb_ratio_r = (score_c_r / score_b_r) if score_b_r > 0 else 0
|
||||
# ===============================================================
|
||||
|
||||
# 3. 儲存結果
|
||||
run_result = {
|
||||
"Patient_ID": f"000{p_id}",
|
||||
"Method": args.method,
|
||||
"Run": run_idx,
|
||||
"Total_Time_sec": total_time,
|
||||
"Left_Cortical_Score": score_c_l,
|
||||
"Left_Bone_Score": score_b_l,
|
||||
"Left_C_B_Ratio": cb_ratio_l,
|
||||
"Right_Cortical_Score": score_c_r,
|
||||
"Right_Bone_Score": score_b_r,
|
||||
"Right_C_B_Ratio": cb_ratio_r,
|
||||
"Left_Loss": best_loss_l,
|
||||
"Right_Loss": best_loss_r
|
||||
}
|
||||
results_data.append(run_result)
|
||||
|
||||
print(f"✅ 完成 | 耗時: {total_time:.2f}s | "
|
||||
f"左C/B: {cb_ratio_l:.3f} | 右C/B: {cb_ratio_r:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {args.method} Run {run_idx} 發生錯誤: {e}")
|
||||
traceback.print_exc() # 印出詳細報錯足跡
|
||||
|
||||
if results_data:
|
||||
df = pd.DataFrame(results_data)
|
||||
csv_filename = f"optimization_benchmark_results_{args.method}.csv"
|
||||
df.to_csv(csv_filename, index=False)
|
||||
print(f"\n🎉 {args.method} 測試完成!結果已儲存至: {csv_filename}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
109
run_de.py
Normal file
109
run_de.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
import os
|
||||
import traceback
|
||||
import argparse
|
||||
import torch
|
||||
import SimpleITK as sitk
|
||||
from datetime import datetime
|
||||
# Core
|
||||
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
|
||||
from core.objective import set_global_context
|
||||
from core.cylinder import create_coordinate_grid
|
||||
|
||||
# Imaging
|
||||
from imaging.preprocessing import process_single_image, process_dataset
|
||||
|
||||
# Visualization
|
||||
from visualization.res_cyl_to_CT import cyl_and_CT
|
||||
from visualization.res_cyl_nifti import save_cyl
|
||||
|
||||
# Config
|
||||
from config.device import get_device
|
||||
from config.constant import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
# === 1. 設定命令列引數 (GPU 拆分設定) ===
|
||||
parser = argparse.ArgumentParser(description="DE CBT Processing")
|
||||
parser.add_argument('--gpu', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f'cuda:{args.gpu}')
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}"
|
||||
patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221]
|
||||
|
||||
CBT = True
|
||||
optimize_size = True
|
||||
spacing = [0.5, 0.5, 0.5]
|
||||
levels = ["L5"]
|
||||
date = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
failed = []
|
||||
skipped = []
|
||||
succeeded = []
|
||||
|
||||
for p_id in patients:
|
||||
folder_num = str(p_id)
|
||||
current_dir = base_dir.format(folder_num)
|
||||
|
||||
IMAGE1 = os.path.join(current_dir, "L5_cortical.nii.gz")
|
||||
IMAGE2 = os.path.join(current_dir, "L5_binary2.nii.gz")
|
||||
IMAGE3 = os.path.join(current_dir, "L5_roi2.nii.gz")
|
||||
|
||||
missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]
|
||||
label_str = levels[0]
|
||||
if missing:
|
||||
print(f"[SKIP] ⏭️ {label_str} (缺檔)")
|
||||
skipped.append((label_str, missing))
|
||||
continue
|
||||
|
||||
print(f"\n=========================================")
|
||||
print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===")
|
||||
print(f"=========================================")
|
||||
|
||||
try:
|
||||
# --- CBT = True ---
|
||||
print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...")
|
||||
binary_image = sitk.ReadImage(IMAGE2)
|
||||
binary_array = sitk.GetArrayFromImage(binary_image)
|
||||
image_shape = binary_array.shape
|
||||
|
||||
grid = create_coordinate_grid(image_shape, device)
|
||||
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_nm_torch(
|
||||
'NM',
|
||||
image1_path=IMAGE1,
|
||||
image2_path=IMAGE2,
|
||||
image3_path=IMAGE3,
|
||||
folder=date,
|
||||
swarm_size=70,
|
||||
max_iter=100,
|
||||
spacing=spacing,
|
||||
CBT=True,
|
||||
device=device,
|
||||
optimize_size=True,
|
||||
grid=grid # 記得把 device 傳進去
|
||||
)
|
||||
|
||||
cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)
|
||||
cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}")
|
||||
traceback.print_exc()
|
||||
failed.append((label_str, str(e)))
|
||||
|
||||
# --- GPU 專屬的 Summary ---
|
||||
print("\n" + "="*40)
|
||||
print(f"🏆 ==== GPU {args.gpu} 處理總結 ====")
|
||||
print("="*40)
|
||||
print(f"✅ Succeeded (成功): {len(succeeded)} 節段")
|
||||
print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段")
|
||||
print(f"❌ Failed (執行錯誤): {len(failed)} 節段")
|
||||
|
||||
if failed:
|
||||
print(f"\n⚠️ GPU {args.gpu} 失敗清單:")
|
||||
for fail_label, err_msg in failed:
|
||||
print(f" - {fail_label}: {err_msg}")
|
||||
109
run_pso.py
Normal file
109
run_pso.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
import os
|
||||
import traceback
|
||||
import argparse
|
||||
import torch
|
||||
import SimpleITK as sitk
|
||||
from datetime import datetime
|
||||
# Core
|
||||
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch
|
||||
from core.objective import set_global_context
|
||||
from core.cylinder import create_coordinate_grid
|
||||
|
||||
# Imaging
|
||||
from imaging.preprocessing import process_single_image, process_dataset
|
||||
|
||||
# Visualization
|
||||
from visualization.res_cyl_to_CT import cyl_and_CT
|
||||
from visualization.res_cyl_nifti import save_cyl
|
||||
|
||||
# Config
|
||||
from config.device import get_device
|
||||
from config.constant import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
# === 1. 設定命令列引數 (GPU 拆分設定) ===
|
||||
parser = argparse.ArgumentParser(description="PSO CBT Processing")
|
||||
parser.add_argument('--gpu', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f'cuda:{args.gpu}')
|
||||
torch.cuda.set_device(device)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
base_dir = "/home/cyrou/CBT/Seg/Resample/standardized/1.3.6.1.4.1.9328.50.4.0{}"
|
||||
patients = [121, 141, 151, 161, 171, 181, 191, 201, 211, 221]
|
||||
|
||||
CBT = True
|
||||
optimize_size = True
|
||||
spacing = [0.5, 0.5, 0.5]
|
||||
levels = ["L5"]
|
||||
date = datetime.now().strftime("%Y%m%d")
|
||||
|
||||
failed = []
|
||||
skipped = []
|
||||
succeeded = []
|
||||
|
||||
for p_id in patients:
|
||||
folder_num = str(p_id)
|
||||
current_dir = base_dir.format(folder_num)
|
||||
|
||||
IMAGE1 = os.path.join(current_dir, "L5_cortical.nii.gz")
|
||||
IMAGE2 = os.path.join(current_dir, "L5_binary2.nii.gz")
|
||||
IMAGE3 = os.path.join(current_dir, "L5_roi2.nii.gz")
|
||||
|
||||
missing = [p for p in [IMAGE1, IMAGE2, IMAGE3] if not os.path.exists(p)]
|
||||
if missing:
|
||||
print(f"[SKIP] ⏭️ {label_str} (缺檔)")
|
||||
skipped.append((label_str, missing))
|
||||
continue
|
||||
label_str = levels
|
||||
|
||||
print(f"\n=========================================")
|
||||
print(f"🔥 [GPU {args.gpu}] === Running {label_str} ===")
|
||||
print(f"=========================================")
|
||||
|
||||
try:
|
||||
# --- CBT = True ---
|
||||
print(f"👉 [GPU {args.gpu}] 執行 CBT 軌跡最佳化...")
|
||||
binary_image = sitk.ReadImage(IMAGE2)
|
||||
binary_array = sitk.GetArrayFromImage(binary_image)
|
||||
image_shape = binary_array.shape
|
||||
|
||||
grid = create_coordinate_grid(image_shape, device)
|
||||
best_pos_l, best_loss_l, best_pos_r, best_loss_r, total_time = run_pso_torch(
|
||||
label_str,
|
||||
image1_path=IMAGE1,
|
||||
image2_path=IMAGE2,
|
||||
image3_path=IMAGE3,
|
||||
folder=date,
|
||||
swarm_size=70,
|
||||
max_iter=100,
|
||||
spacing=spacing,
|
||||
CBT=True,
|
||||
device=device,
|
||||
optimize_size=True,
|
||||
grid=grid # 記得把 device 傳進去
|
||||
)
|
||||
|
||||
cylinder_L, cylinder_R = save_cyl(best_pos_l, best_pos_r, spacing, IMAGE3, output_base='Output', CBT=True)
|
||||
cyl_and_CT(best_pos_l[5], best_pos_r[5], best_pos_l[6], best_pos_r[6], IMAGE3, cylinder_L, cylinder_R, base_folder='Output', CBT=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[FAIL] ❌ [GPU {args.gpu}] {label_str} 發生錯誤: {e}")
|
||||
traceback.print_exc()
|
||||
failed.append((label_str, str(e)))
|
||||
|
||||
# --- GPU 專屬的 Summary ---
|
||||
print("\n" + "="*40)
|
||||
print(f"🏆 ==== GPU {args.gpu} 處理總結 ====")
|
||||
print("="*40)
|
||||
print(f"✅ Succeeded (成功): {len(succeeded)} 節段")
|
||||
print(f"⏭️ Skipped (缺檔跳過): {len(skipped)} 節段")
|
||||
print(f"❌ Failed (執行錯誤): {len(failed)} 節段")
|
||||
|
||||
if failed:
|
||||
print(f"\n⚠️ GPU {args.gpu} 失敗清單:")
|
||||
for fail_label, err_msg in failed:
|
||||
print(f" - {fail_label}: {err_msg}")
|
||||
49
structure.txt
Normal file
49
structure.txt
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
CBT_project/
|
||||
│
|
||||
├── notebooks/
|
||||
│ └── experiment.ipynb
|
||||
│
|
||||
├── core/
|
||||
│ ├── cylinder.py
|
||||
│ │ ├── create_coordinate_grid
|
||||
│ │ ├── snap_to_discrete_values
|
||||
│ │ ├── generate_cylinder_n_torch
|
||||
│ │ ├── generate_cylinder_o_torch
|
||||
│ │ └── generate_cylinder_numpy
|
||||
│ │
|
||||
│ ├── intersection.py
|
||||
│ │ ├── bresenham3d
|
||||
│ │ └── center_line_intersections_torch
|
||||
│ │
|
||||
│ ├── scoring.py
|
||||
│ │ ├── cl_score_torch
|
||||
│ │ ├── get_overlap_ratio
|
||||
│ │ └── compute_overlap_ratio_from_cylinder_mask
|
||||
│ │
|
||||
│ ├── objective.py
|
||||
│ │ ├── cylinder_circle_line_intersection_loss_deductions_torch
|
||||
│ │ └── objective_function
|
||||
│ │
|
||||
│ └── optimizer.py
|
||||
│ └── run_pso_torch
|
||||
│
|
||||
├── imaging/
|
||||
│ ├── nifti_io.py
|
||||
│ ├── preprocessing.py
|
||||
|
||||
│ └── orientation.py
|
||||
│
|
||||
├── visualization/
|
||||
│ ├── plot_cylinder.py
|
||||
│ └── plot_overlay.py
|
||||
│
|
||||
├── utils/
|
||||
│ └── helpers.py
|
||||
│
|
||||
├── config/
|
||||
│ ├── constants.py
|
||||
│ └── device.py
|
||||
│
|
||||
├── run_optimization.py
|
||||
│
|
||||
└── __init__.py
|
||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
52
utils/helpers.py
Normal file
52
utils/helpers.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def get_unique_filepath(path: str) -> str:
|
||||
"""
|
||||
如果檔案已存在,自動加 _1, _2... 避免覆蓋
|
||||
"""
|
||||
current_path = path
|
||||
count = 1
|
||||
|
||||
base_name, file_ext = os.path.splitext(path)
|
||||
|
||||
while os.path.exists(current_path):
|
||||
current_path = f"{base_name}_{count}{file_ext}"
|
||||
count += 1
|
||||
|
||||
return current_path
|
||||
|
||||
def save_with_unique_name(folder, label_str, way, diameter_l, length_l, diameter_r, length_r, swarm_size, max_iter):
|
||||
|
||||
base_name = f'{label_str}_{way}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_{swarm_size}_{max_iter}.png'
|
||||
file_path = os.path.join(folder, base_name)
|
||||
|
||||
count = 1
|
||||
while os.path.exists(file_path):
|
||||
file_name, file_ext = os.path.splitext(base_name)
|
||||
file_path = os.path.join(folder, f"{file_name}_{count}{file_ext}")
|
||||
count += 1
|
||||
|
||||
return file_path
|
||||
|
||||
def pad_rgb_to_shape(rgb, target_hw, pad_value=0.0):
|
||||
"""Pad RGB image (H,W,3) to target (H,W), centered."""
|
||||
h, w, c = rgb.shape
|
||||
H, W = target_hw
|
||||
assert c == 3
|
||||
|
||||
if h > H or w > W:
|
||||
raise ValueError(f"target {target_hw} smaller than rgb {(h,w)}")
|
||||
|
||||
pad_h1 = (H - h) // 2
|
||||
pad_h2 = H - h - pad_h1
|
||||
pad_w1 = (W - w) // 2
|
||||
pad_w2 = W - w - pad_w1
|
||||
|
||||
return np.pad(
|
||||
rgb,
|
||||
((pad_h1, pad_h2), (pad_w1, pad_w2), (0, 0)),
|
||||
mode="constant",
|
||||
constant_values=pad_value
|
||||
)
|
||||
0
visualization/__init__.py
Normal file
0
visualization/__init__.py
Normal file
83
visualization/res_cyl_nifti.py
Normal file
83
visualization/res_cyl_nifti.py
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
import SimpleITK as sitk
|
||||
import os
|
||||
from datetime import datetime
|
||||
from core.cylinder import generate_cylinder_numpy
|
||||
|
||||
def save_cyl(best_position_l, best_position_r, spacing, IMAGE3, output_base, CBT=True):
|
||||
|
||||
image3 = sitk.ReadImage(IMAGE3) #Reading image2 or image3 can get the same shape and origin actually
|
||||
image3_array = sitk.GetArrayFromImage(image3)
|
||||
image3_shape = image3_array.shape
|
||||
origin = image3.GetOrigin()
|
||||
diameter_l = best_position_l[5]
|
||||
diameter_r = best_position_r[5]
|
||||
length_l = best_position_l[6]
|
||||
length_r = best_position_r[6]
|
||||
|
||||
cyl_l = generate_cylinder_numpy(diameter_l,
|
||||
length_l,
|
||||
best_position_l[0],
|
||||
best_position_l[1],
|
||||
best_position_l[2],
|
||||
best_position_l[3],
|
||||
best_position_l[4],
|
||||
image3_shape,
|
||||
spacing)
|
||||
|
||||
cyl_r = generate_cylinder_numpy(diameter_r,
|
||||
length_r,
|
||||
best_position_r[0],
|
||||
best_position_r[1],
|
||||
best_position_r[2],
|
||||
best_position_r[3],
|
||||
best_position_r[4],
|
||||
image3_shape,
|
||||
spacing)
|
||||
|
||||
cyl_L = sitk.GetImageFromArray(cyl_l)
|
||||
cyl_R = sitk.GetImageFromArray(cyl_r)
|
||||
|
||||
cyl_L.SetOrigin(origin)
|
||||
cyl_R.SetOrigin(origin)
|
||||
|
||||
image_path = IMAGE3
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
patient_id = os.path.basename(os.path.dirname(image_path))
|
||||
file_name = os.path.basename(image_path)
|
||||
level = file_name.split('_')[0]
|
||||
output_folder = os.path.join(output_base, date_str, patient_id)
|
||||
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
def get_unique_filename(path):
|
||||
|
||||
if not os.path.exists(path):
|
||||
return path
|
||||
|
||||
# 特別處理 .nii.gz 檔案
|
||||
if path.endswith(".nii.gz"):
|
||||
base = path[:-7] # 去除 .nii.gz
|
||||
ext = ".nii.gz"
|
||||
else:
|
||||
base, ext = os.path.splitext(path)
|
||||
|
||||
i = 1
|
||||
new_path = f"{base}_{i}{ext}"
|
||||
while os.path.exists(new_path):
|
||||
i += 1
|
||||
new_path = f"{base}_{i}{ext}"
|
||||
return new_path
|
||||
|
||||
if CBT:
|
||||
output_path_l = get_unique_filename(f'{output_folder}/{level}_{diameter_l}_{length_l}_CBT_L.nii.gz')
|
||||
output_path_r = get_unique_filename(f'{output_folder}/{level}_{diameter_r}_{length_r}_CBT_R.nii.gz')
|
||||
sitk.WriteImage(cyl_L, output_path_l)
|
||||
sitk.WriteImage(cyl_R, output_path_r)
|
||||
|
||||
else:
|
||||
output_path_l = get_unique_filename(f'{output_folder}/{level}_{diameter_l}_{length_l}_TPS_L.nii.gz')
|
||||
output_path_r = get_unique_filename(f'{output_folder}/{level}_{diameter_r}_{length_r}_TPS_R.nii.gz')
|
||||
sitk.WriteImage(cyl_L, output_path_l)
|
||||
sitk.WriteImage(cyl_R, output_path_r)
|
||||
|
||||
return output_path_l, output_path_r
|
||||
127
visualization/res_cyl_to_CT.py
Normal file
127
visualization/res_cyl_to_CT.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
from utils.helpers import get_unique_filepath
|
||||
import nibabel as nib
|
||||
from datetime import datetime
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from utils.helpers import pad_rgb_to_shape
|
||||
|
||||
def cyl_and_CT(diameter_l, diameter_r, length_l, length_r, roi_image_path, cylinder_1, cylinder_2, base_folder='Output', CBT=True):
|
||||
|
||||
bone_nifti = nib.load(roi_image_path)
|
||||
bone_array = bone_nifti.get_fdata()
|
||||
|
||||
cylinder_nifti = nib.load(cylinder_1)
|
||||
cylinder_nifti2 = nib.load(cylinder_2)
|
||||
cylinder_array = cylinder_nifti.get_fdata()
|
||||
cylinder_array2 = cylinder_nifti2.get_fdata()
|
||||
|
||||
assert bone_array.shape == cylinder_array.shape, "影像 shape 不匹配,請先進行重採樣!"
|
||||
|
||||
date = datetime.now().strftime("%Y%m%d")
|
||||
patient_id = os.path.basename(os.path.dirname(roi_image_path))
|
||||
file_name = os.path.basename(roi_image_path)
|
||||
level = file_name.split('_')[0]
|
||||
|
||||
output_folder = os.path.join(base_folder, date, patient_id)
|
||||
|
||||
if CBT:
|
||||
output_file = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_whole.png'
|
||||
output_axial = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_axial.png'
|
||||
output_coronal = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_coronal.png'
|
||||
output_saggital = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_saggital.png'
|
||||
else:
|
||||
output_file = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_whole.png'
|
||||
output_axial = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_axial.png'
|
||||
output_coronal = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_coronal.png'
|
||||
output_saggital = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_saggital.png'
|
||||
|
||||
mip_bone_z = np.max(bone_array, axis=0)
|
||||
mip_cylinder_z = np.max(cylinder_array, axis=0)
|
||||
mip_cylinder2_z = np.max(cylinder_array2, axis=0)
|
||||
|
||||
mip_bone_x = np.max(bone_array, axis=1)
|
||||
mip_cylinder_x = np.max(cylinder_array, axis=1)
|
||||
mip_cylinder2_x = np.max(cylinder_array2, axis=1)
|
||||
|
||||
mip_bone_y = np.max(bone_array, axis=2)
|
||||
mip_cylinder_y = np.max(cylinder_array, axis=2)
|
||||
mip_cylinder2_y = np.max(cylinder_array2, axis=2)
|
||||
|
||||
bone_norm_z = mip_bone_z / np.max(mip_bone_z)
|
||||
bone_norm_x = mip_bone_x / np.max(mip_bone_x)
|
||||
bone_norm_y = mip_bone_y / np.max(mip_bone_y)
|
||||
|
||||
cylinder_mask_z = mip_cylinder_z > 0
|
||||
cylinder_mask_x = mip_cylinder_x > 0
|
||||
cylinder_mask_y = mip_cylinder_y > 0
|
||||
cylinder_mask2_z = mip_cylinder2_z > 0
|
||||
cylinder_mask2_x = mip_cylinder2_x > 0
|
||||
cylinder_mask2_y = mip_cylinder2_y > 0
|
||||
|
||||
def create_rgb_image(bone_norm, cylinder_mask, cylinder_mask2):
|
||||
bone_norm = (bone_norm - np.min(bone_norm)) / (np.max(bone_norm) - np.min(bone_norm))
|
||||
bone_norm = np.clip(bone_norm, 0, 1)
|
||||
|
||||
rgb_image = np.zeros((bone_norm.shape[0], bone_norm.shape[1], 3))
|
||||
rgb_image[..., 0] = bone_norm
|
||||
rgb_image[..., 1] = bone_norm
|
||||
rgb_image[..., 2] = bone_norm
|
||||
rgb_image[cylinder_mask, :] = 1
|
||||
rgb_image[cylinder_mask2, :] = 1
|
||||
|
||||
return rgb_image
|
||||
|
||||
def crop_center(img, cropx, cropy):
|
||||
y, x = img.shape[0], img.shape[1]
|
||||
startx = x//2 - cropx//2
|
||||
starty = y//2 - cropy//2
|
||||
return img[starty:starty+cropy, startx:startx+cropx, :]
|
||||
|
||||
rgb_image_z = np.rot90(create_rgb_image(bone_norm_z, cylinder_mask_z, cylinder_mask2_z))
|
||||
rgb_image_x = np.rot90(create_rgb_image(bone_norm_x, cylinder_mask_x, cylinder_mask2_x))
|
||||
rgb_image_y = np.rot90(create_rgb_image(bone_norm_y, cylinder_mask_y, cylinder_mask2_y))
|
||||
|
||||
H = max(rgb_image_y.shape[0], rgb_image_x.shape[0], rgb_image_z.shape[0])
|
||||
W = max(rgb_image_y.shape[1], rgb_image_x.shape[1], rgb_image_z.shape[1])
|
||||
|
||||
rgb_image_y = pad_rgb_to_shape(rgb_image_y, (H, W), pad_value=0.0)
|
||||
rgb_image_x = pad_rgb_to_shape(rgb_image_x, (H, W), pad_value=0.0)
|
||||
rgb_image_z = pad_rgb_to_shape(rgb_image_z, (H, W), pad_value=0.0)
|
||||
|
||||
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
|
||||
axs[0].imshow(rgb_image_y)
|
||||
axs[0].set_title("Axial", pad=10)
|
||||
axs[0].axis("off")
|
||||
axs[0].set_aspect("equal")
|
||||
|
||||
axs[1].imshow(rgb_image_x)
|
||||
axs[1].set_title("Coronal", pad=10)
|
||||
axs[1].axis("off")
|
||||
axs[1].set_aspect("equal")
|
||||
|
||||
axs[2].imshow(rgb_image_z)
|
||||
axs[2].set_title("Saggital", pad=10)
|
||||
axs[2].axis("off")
|
||||
axs[2].set_aspect("equal")
|
||||
|
||||
plt.subplots_adjust(left=0.05, right=0.75, top=0.9, bottom=0.1, wspace=0.01)
|
||||
plt.savefig(get_unique_filepath(output_file), bbox_inches='tight')
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
||||
ax.imshow(rgb_image_z)
|
||||
ax.set_title("Saggital", pad=10)
|
||||
ax.axis("off")
|
||||
plt.savefig(get_unique_filepath(output_saggital), bbox_inches='tight')
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
||||
ax.imshow(rgb_image_x)
|
||||
ax.set_title("Coronal", pad=10)
|
||||
ax.axis("off")
|
||||
plt.savefig(get_unique_filepath(output_coronal), bbox_inches='tight')
|
||||
|
||||
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
||||
ax.imshow(rgb_image_y)
|
||||
ax.set_title("Axial", pad=10)
|
||||
ax.axis("off")
|
||||
plt.savefig(get_unique_filepath(output_axial), bbox_inches='tight')
|
||||
366
visualization/res_plot_3d.py
Normal file
366
visualization/res_plot_3d.py
Normal file
|
|
@ -0,0 +1,366 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from datetime import datetime
|
||||
import csv
|
||||
|
||||
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_o_torch, snap_to_discrete_values
|
||||
from core.intersection import center_line_intersections_torch
|
||||
from core.scoring import cl_score_torch, compute_overlap_ratio_from_cylinder_mask
|
||||
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||||
from utils.helpers import save_with_unique_name
|
||||
|
||||
def res_plt_2_torch(
|
||||
spine_tensor: torch.Tensor,
|
||||
cortical_tensor: torch.Tensor,
|
||||
image_shape: tuple[int, int, int],
|
||||
image2_path: str,
|
||||
base_folder: str,
|
||||
label_str: str,
|
||||
diameter_l: float,
|
||||
length_l: float,
|
||||
diameter_r: float,
|
||||
length_r: float,
|
||||
best_position_l: list[float],
|
||||
best_position_r: list[float],
|
||||
swarm_size: int,
|
||||
max_iter: int,
|
||||
total_time: float,
|
||||
spacing: list[float],
|
||||
CBT: bool,
|
||||
device: torch.device,
|
||||
grid=None
|
||||
) -> None:
|
||||
"""
|
||||
Same plotting function as before, but it uses torch-based generation
|
||||
and then moves data to CPU for matplotlib 3D scatter.
|
||||
"""
|
||||
cyl_l = generate_cylinder_n_torch(
|
||||
diameter_l,
|
||||
length_l,
|
||||
best_position_l[0],
|
||||
best_position_l[1],
|
||||
best_position_l[2],
|
||||
best_position_l[3],
|
||||
best_position_l[4],
|
||||
image_shape,
|
||||
spacing,
|
||||
device,
|
||||
grid
|
||||
)
|
||||
|
||||
cyl_lo = generate_cylinder_o_torch(
|
||||
diameter_l,
|
||||
length_l,
|
||||
best_position_l[0],
|
||||
best_position_l[1],
|
||||
best_position_l[2],
|
||||
best_position_l[3],
|
||||
best_position_l[4],
|
||||
image_shape,
|
||||
spacing,
|
||||
device,
|
||||
grid
|
||||
)
|
||||
cyl_r = generate_cylinder_n_torch(
|
||||
diameter_r,
|
||||
length_r,
|
||||
best_position_r[0],
|
||||
best_position_r[1],
|
||||
best_position_r[2],
|
||||
best_position_r[3],
|
||||
best_position_r[4],
|
||||
image_shape,
|
||||
spacing,
|
||||
device,
|
||||
grid
|
||||
)
|
||||
cyl_ro = generate_cylinder_o_torch(
|
||||
diameter_r,
|
||||
length_r,
|
||||
best_position_r[0],
|
||||
best_position_r[1],
|
||||
best_position_r[2],
|
||||
best_position_r[3],
|
||||
best_position_r[4],
|
||||
image_shape,
|
||||
spacing,
|
||||
device,
|
||||
grid
|
||||
)
|
||||
|
||||
intersections_l, line_mask_l = center_line_intersections_torch(
|
||||
best_position_l[0],
|
||||
best_position_l[1],
|
||||
best_position_l[2],
|
||||
best_position_l[3],
|
||||
best_position_l[4],
|
||||
int(length_l),
|
||||
spine_tensor,
|
||||
spacing,
|
||||
device
|
||||
)
|
||||
loss_l = cl_score_torch(cortical_tensor, spine_tensor, cyl_l, cyl_lo, intersections_l)
|
||||
|
||||
intersections_r, line_mask_r = center_line_intersections_torch(
|
||||
best_position_r[0],
|
||||
best_position_r[1],
|
||||
best_position_r[2],
|
||||
best_position_r[3],
|
||||
best_position_r[4],
|
||||
int(length_r),
|
||||
spine_tensor,
|
||||
spacing,
|
||||
device
|
||||
)
|
||||
loss_r = cl_score_torch(cortical_tensor, spine_tensor, cyl_r, cyl_ro, intersections_r)
|
||||
|
||||
azi = azimuth_rotation(image2_path)
|
||||
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
|
||||
alt = res['superior']['tilt_angle_deg']
|
||||
|
||||
# Move data to CPU for plotting
|
||||
line_mask_l_cpu = line_mask_l.cpu().numpy()
|
||||
line_mask_r_cpu = line_mask_r.cpu().numpy()
|
||||
cyl_l_cpu = cyl_l.cpu().numpy()
|
||||
cyl_lo_cpu = cyl_lo.cpu().numpy()
|
||||
cyl_r_cpu = cyl_r.cpu().numpy()
|
||||
cyl_ro_cpu = cyl_ro.cpu().numpy()
|
||||
spine_cpu = spine_tensor.cpu().numpy()
|
||||
|
||||
z_lin1, y_lin1, x_lin1 = np.where(line_mask_l_cpu == 1)
|
||||
z_lin2, y_lin2, x_lin2 = np.where(line_mask_r_cpu == 1)
|
||||
|
||||
z_cyl_l1, y_cyl_l1, x_cyl_l1 = np.where(cyl_l_cpu == 1)
|
||||
z_cyl_l2, y_cyl_l2, x_cyl_l2 = np.where(cyl_lo_cpu == 1)
|
||||
z_cyl_r1, y_cyl_r1, x_cyl_r1 = np.where(cyl_r_cpu == 1)
|
||||
z_cyl_r2, y_cyl_r2, x_cyl_r2 = np.where(cyl_ro_cpu == 1)
|
||||
|
||||
z_img, y_img, x_img = np.where(spine_cpu == 1)
|
||||
|
||||
fig = plt.figure(figsize=(12, 12))
|
||||
|
||||
ax1 = fig.add_subplot(221, projection='3d')
|
||||
ax1.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
|
||||
ax1.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
|
||||
ax1.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
|
||||
ax1.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
|
||||
ax1.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
|
||||
ax1.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||||
ax1.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||||
ax1.set_xlabel('X-axis'); ax1.set_ylabel('Y-axis'); ax1.set_zlabel('Z-axis')
|
||||
|
||||
ax2 = fig.add_subplot(222, projection='3d')
|
||||
ax2.view_init(elev=90, azim=-90, roll=0)
|
||||
ax2.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
|
||||
ax2.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
|
||||
ax2.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
|
||||
ax2.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
|
||||
ax2.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
|
||||
ax2.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||||
ax2.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||||
ax2.set_xlabel('X-axis'); ax2.set_ylabel('Y-axis'); ax2.set_zlabel('Z-axis')
|
||||
ax2.legend()
|
||||
|
||||
ax3 = fig.add_subplot(223, projection='3d')
|
||||
ax3.view_init(elev=0, azim=90, roll=0)
|
||||
ax3.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
|
||||
ax3.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
|
||||
ax3.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
|
||||
ax3.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
|
||||
ax3.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
|
||||
ax3.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||||
ax3.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||||
ax3.set_xlabel('X-axis'); ax3.set_ylabel('Y-axis'); ax3.set_zlabel('Z-axis')
|
||||
|
||||
ax4 = fig.add_subplot(224, projection='3d')
|
||||
ax4.view_init(elev=0, azim=0, roll=0)
|
||||
ax4.scatter(x_lin1, y_lin1, z_lin1, c='r', marker='o', s=1)
|
||||
ax4.scatter(x_lin2, y_lin2, z_lin2, c='r', marker='o', s=1)
|
||||
ax4.scatter(x_cyl_l1, y_cyl_l1, z_cyl_l1, c='darkcyan', marker='o', label='Cylinder(L)')
|
||||
ax4.scatter(x_cyl_l2, y_cyl_l2, z_cyl_l2, c='pink', marker='o')
|
||||
ax4.scatter(x_cyl_r1, y_cyl_r1, z_cyl_r1, c='blue', marker='o', label='Cylinder(R)')
|
||||
ax4.scatter(x_cyl_r2, y_cyl_r2, z_cyl_r2, c='pink', marker='o')
|
||||
ax4.scatter(x_img, y_img, z_img, c='lightblue', marker='+', alpha=0.04, label='Spine')
|
||||
ax4.set_xlabel('X-axis'); ax4.set_ylabel('Y-axis'); ax4.set_zlabel('Z-axis')
|
||||
|
||||
cyl_points_l = torch.sum(cyl_l).item()
|
||||
cyl_points_r = torch.sum(cyl_r).item()
|
||||
|
||||
overlap_l = ((cortical_tensor == 1) & (cyl_l == 1)).sum().item()
|
||||
overlap_r = ((cortical_tensor == 1) & (cyl_r == 1)).sum().item()
|
||||
overlap_b_l = ((spine_tensor == 1) & (cyl_l == 1)).sum().item()
|
||||
overlap_b_r = ((spine_tensor == 1) & (cyl_r == 1)).sum().item()
|
||||
|
||||
overlap_cortical_l = (overlap_l / cyl_points_l) * 100
|
||||
overlap_cortical_r = (overlap_r / cyl_points_r) * 100
|
||||
overlap_vertebral_l = (overlap_b_l / cyl_points_l) * 100
|
||||
overlap_vertebral_r = (overlap_b_r / cyl_points_r) * 100
|
||||
cb_ratio_l = overlap_cortical_l/overlap_vertebral_l
|
||||
cb_ratio_r = overlap_cortical_r/overlap_vertebral_r
|
||||
user_altitude_l = 90 - best_position_l[4] - alt
|
||||
user_altitude_r = 90 - best_position_r[4] - alt
|
||||
user_azimuth_l = 90 - best_position_l[3] - azi
|
||||
user_azimuth_r = 90 - best_position_r[3] - azi
|
||||
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
patient_id = os.path.basename(os.path.dirname(image2_path))
|
||||
output_folder = os.path.join(base_folder, date_str, patient_id)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
csv_path = os.path.join(output_folder, 'output.csv')
|
||||
|
||||
# 檢查檔案是否存在 (決定是否寫入標題)
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# 欄位標題 (Header)
|
||||
headers = [
|
||||
'Label', 'Side', 'Diameter', 'Length', 'Swarm_Size', 'Max_Iter',
|
||||
'Position_XYZ', 'Raw_Azimuth', 'Azimuth_Diff', 'Raw_Altitude', 'Altitude_Diff',
|
||||
'Intersections', 'Best_Loss', 'Overlap_Cortical', 'Overlap_Bone',
|
||||
'Cortical_Bone_Ratio', 'User_Azimuth', 'User_Altitude', 'Total_Time'
|
||||
]
|
||||
|
||||
try:
|
||||
with open(csv_path, 'a', newline='') as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
|
||||
# 如果是新檔案,寫入 Header
|
||||
if not file_exists:
|
||||
writer.writerow(headers)
|
||||
|
||||
# 寫入 Left 數據
|
||||
writer.writerow([
|
||||
label_str,
|
||||
'L',
|
||||
diameter_l,
|
||||
length_l,
|
||||
swarm_size,
|
||||
max_iter,
|
||||
f"({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f})",
|
||||
f"{best_position_l[3]:.2f}",
|
||||
f"{best_position_l[3]-azi:.2f}",
|
||||
f"{best_position_l[4]:.2f}",
|
||||
f"{best_position_l[4]-alt:.2f}",
|
||||
intersections_l,
|
||||
f"{loss_l:.2f}",
|
||||
f"{overlap_cortical_l:.2f}",
|
||||
f"{overlap_vertebral_l:.2f}",
|
||||
f"{(overlap_cortical_l/overlap_vertebral_l if overlap_vertebral_l!=0 else 0):.2f}",
|
||||
f"{user_azimuth_l:.2f}",
|
||||
f"{user_altitude_l:.2f}",
|
||||
f"{total_time:.2f}"
|
||||
])
|
||||
|
||||
# 寫入 Right 數據
|
||||
writer.writerow([
|
||||
label_str,
|
||||
'R',
|
||||
diameter_r,
|
||||
length_r,
|
||||
swarm_size,
|
||||
max_iter,
|
||||
f"({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f})",
|
||||
f"{best_position_r[3]:.2f}",
|
||||
f"{best_position_r[3]-azi:.2f}",
|
||||
f"{best_position_r[4]:.2f}",
|
||||
f"{best_position_r[4]-alt:.2f}",
|
||||
intersections_r,
|
||||
f"{loss_r:.2f}",
|
||||
f"{overlap_cortical_r:.2f}",
|
||||
f"{overlap_vertebral_r:.2f}",
|
||||
f"{(overlap_cortical_r/overlap_vertebral_r if overlap_vertebral_r!=0 else 0):.2f}",
|
||||
f"{user_azimuth_r:.2f}",
|
||||
f"{user_altitude_r:.2f}",
|
||||
f"{total_time:.2f}"
|
||||
])
|
||||
print(f"[CSV Saved] {csv_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Error] Failed to write CSV: {e}")
|
||||
|
||||
fig.text(0.5, 0.98, f'{label_str} Best Position', ha='center', fontsize=15)
|
||||
fig.text(
|
||||
0.5, 0.44,
|
||||
f'L: Diameter = {diameter_l} mm, {length_l} mm, '
|
||||
f'R: Diameter = {diameter_r} mm, {length_r} mm, '
|
||||
f'Swarm size = {swarm_size}, Iteration = {max_iter}, Total time = {total_time:.2f} s',
|
||||
ha='center', fontsize=12
|
||||
)
|
||||
fig.text(
|
||||
0.5, 0.03,
|
||||
f'Left : Position = ({best_position_l[0]:.2f}, {best_position_l[1]:.2f}, {best_position_l[2]:.2f}), '
|
||||
f'Azimuth = {user_azimuth_l:.2f}, Altitude = {user_altitude_l:.2f}, '
|
||||
f'Intersection = {intersections_l}, Score = {overlap_cortical_l:.2f} / {overlap_vertebral_l:.2f} / {cb_ratio_l:.2f}',
|
||||
ha='center', fontsize=9
|
||||
)
|
||||
fig.text(
|
||||
0.5, 0.01,
|
||||
f'Right : Position = ({best_position_r[0]:.2f}, {best_position_r[1]:.2f}, {best_position_r[2]:.2f}), '
|
||||
f'Azimuth = {user_azimuth_r:.2f}, Altitude = {user_altitude_r:.2f}, '
|
||||
f'Intersection = {intersections_r}, Score = {overlap_cortical_r:.2f} / {overlap_vertebral_r:.2f} / {cb_ratio_r:.2f}',
|
||||
ha='center', fontsize=9
|
||||
)
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
file_name = os.path.basename(image2_path)
|
||||
level = file_name.split('_')[0]
|
||||
output_folder = os.path.join(base_folder, date_str, patient_id)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
if CBT == True:
|
||||
way = 'CBT'
|
||||
|
||||
else:
|
||||
way = 'TPS'
|
||||
|
||||
path = save_with_unique_name(output_folder, label_str, way,
|
||||
diameter_l, length_l, diameter_r, length_r,
|
||||
swarm_size, max_iter)
|
||||
|
||||
fig.savefig(path, dpi=200, bbox_inches="tight")
|
||||
print("[Saved figure]", path)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def eval_overlap_from_position(
|
||||
pos,
|
||||
optimize_size: bool,
|
||||
spine_tensor: torch.Tensor,
|
||||
image_shape,
|
||||
spacing,
|
||||
device: torch.device,
|
||||
grid=None,
|
||||
fixed_diameter: float | None = None,
|
||||
fixed_length: float | None = None,
|
||||
):
|
||||
"""
|
||||
根據 position 生成 cylinder mask,再算 overlap ratio
|
||||
"""
|
||||
|
||||
if optimize_size:
|
||||
d, L = snap_to_discrete_values(pos[5], pos[6])
|
||||
params_5 = pos[:5]
|
||||
else:
|
||||
if fixed_diameter is None or fixed_length is None:
|
||||
raise ValueError("fixed_diameter and fixed_length must be provided when optimize_size=False")
|
||||
d, L = fixed_diameter, fixed_length
|
||||
params_5 = pos
|
||||
|
||||
z, y, x, az, alt = params_5
|
||||
|
||||
cyl_mask = generate_cylinder_n_torch(
|
||||
d, L,
|
||||
z, y, x,
|
||||
az, alt,
|
||||
image_shape, spacing,
|
||||
device=device,
|
||||
grid=grid
|
||||
)
|
||||
|
||||
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
|
||||
return overlap, d, L
|
||||
|
||||
|
||||
Loading…
Reference in a new issue