CBT_project/xfr_debug.py
Xiao Furen b76f0708f3 1. correct the bounding box and cortical mask
2. make the plot isometric
3. now it should work after tuning the parameters
2026-04-17 00:03:10 +08:00

143 lines
No EOL
4.5 KiB
Python

import os
import SimpleITK as sitk
import torch
# from config.device import get_device
from core.cylinder import create_coordinate_grid
from core.objective import set_global_context
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch, run_pso_torch_xfr
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
standardized_dir = '/mnt/1248/open/cyrou/CBT/Seg/Resample/standardized-xfr/'
azimuth_rotation_dir = '/mnt/1248/open/cyrou/azimuth_rotation'
tilt_contour_dir = '/mnt/1248/open/cyrou/tilt_contour'
Output_dir = '/mnt/1248/open/cyrou/Output'
def get_device(gpu_id=None):
if torch.cuda.is_available():
if gpu_id is None:
max_free = -1
gpu_id = 0
for i in range(torch.cuda.device_count()):
free_mem, _ = torch.cuda.mem_get_info(i)
# print(f'GPU {i}: {torch.cuda.get_device_name(i)} {free_mem}')
if free_mem > max_free:
max_free = free_mem
gpu_id = i
device = torch.device(f"cuda:{gpu_id}")
print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
else:
device = torch.device("cpu")
print("CUDA not available, using CPU")
return device
def debug_orientation(volume_id, level):
volume_dir = os.path.join(standardized_dir, volume_id)
cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz')
binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz')
roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz')
# azi = azimuth_rotation(binary_path)
# res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=False, debug=False)
azi = azimuth_rotation(binary_path, show_plt=True, save_plt=True, output_path=f'{azimuth_rotation_dir}/{level}_{volume_id}.png')
res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=True, debug=False, save_plt=True, output_path=f'{tilt_contour_dir}/{level}_{volume_id}.png')
alt = res['superior']['tilt_angle_deg']
print(binary_path)
# print(f'Azimuth: {azi}, Alt: {alt}')
print(f'Alt: {alt}')
def debug_pso(volume_id, level):
# ====== PSO ======
swarm_size = 100
max_iter = 100
# ====== DEVICE ======
device = get_device()
# ====== OTHER ======
spacing = [0.5, 0.5, 0.5]
CBT = True
volume_dir = os.path.join(standardized_dir, volume_id)
cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz')
binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz')
roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz')
cortical_image = sitk.ReadImage(cortical_path)
binary_image = sitk.ReadImage(binary_path)
roi_image = sitk.ReadImage(roi_path)
cortical_array = sitk.GetArrayFromImage(cortical_image)
binary_array = sitk.GetArrayFromImage(binary_image)
roi_array = sitk.GetArrayFromImage(roi_image)
image_shape = binary_array.shape
cortical_tensor = torch.tensor(cortical_array, device=device)
binary_tensor = torch.tensor(binary_array, device=device)
grid = create_coordinate_grid(image_shape, device)
set_global_context(
cortical=cortical_tensor,
spine=binary_tensor,
shape=image_shape,
spacing_=spacing,
device_=device,
grid_=grid,
use_tip_penalty=False
)
best_l, loss_l, best_r, loss_r, total_time = run_pso_torch_xfr(
label_str=level,
image1_path=cortical_path,
image2_path=binary_path,
image3_path=roi_path,
folder=Output_dir,
swarm_size=swarm_size,
max_iter=max_iter,
spacing=spacing,
CBT=CBT,
device=device,
optimize_size=True,
grid=grid,
)
# exit()
def main():
# level = 'L1'
# volume_id = '1.3.6.1.4.1.9328.50.4.0001'
# volume_id = '1.3.6.1.4.1.9328.50.4.0003'
# # volume_id = '1.3.6.1.4.1.9328.50.4.0121'
# process_volume(volume_id, level)
# exit()
for volume_id in (
'1.3.6.1.4.1.9328.50.4.0001',
# '1.3.6.1.4.1.9328.50.4.0002',
# '1.3.6.1.4.1.9328.50.4.0003',
# '1.3.6.1.4.1.9328.50.4.0004',
# '1.3.6.1.4.1.9328.50.4.0005',
# '1.3.6.1.4.1.9328.50.4.0006',
):
# for volume_id in sorted(os.listdir(standardized_dir)):
# debug_orientation(volume_id, level)
for level in ('L1', 'L2', 'L3', 'L4', 'L5'):
# debug_orientation(volume_id, level)
debug_pso(volume_id, level)
exit()
if __name__ == '__main__':
main()