143 lines
4.5 KiB
Python
143 lines
4.5 KiB
Python
|
|
import os
|
||
|
|
|
||
|
|
import SimpleITK as sitk
|
||
|
|
import torch
|
||
|
|
|
||
|
|
# from config.device import get_device
|
||
|
|
from core.cylinder import create_coordinate_grid
|
||
|
|
from core.objective import set_global_context
|
||
|
|
from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch, run_pso_torch_xfr
|
||
|
|
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
|
||
|
|
|
||
|
|
standardized_dir = '/mnt/1248/open/cyrou/CBT/Seg/Resample/standardized-xfr/'
|
||
|
|
|
||
|
|
azimuth_rotation_dir = '/mnt/1248/open/cyrou/azimuth_rotation'
|
||
|
|
tilt_contour_dir = '/mnt/1248/open/cyrou/tilt_contour'
|
||
|
|
Output_dir = '/mnt/1248/open/cyrou/Output'
|
||
|
|
|
||
|
|
def get_device(gpu_id=None):
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
if gpu_id is None:
|
||
|
|
max_free = -1
|
||
|
|
gpu_id = 0
|
||
|
|
for i in range(torch.cuda.device_count()):
|
||
|
|
free_mem, _ = torch.cuda.mem_get_info(i)
|
||
|
|
# print(f'GPU {i}: {torch.cuda.get_device_name(i)} {free_mem}')
|
||
|
|
if free_mem > max_free:
|
||
|
|
max_free = free_mem
|
||
|
|
gpu_id = i
|
||
|
|
device = torch.device(f"cuda:{gpu_id}")
|
||
|
|
print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
|
||
|
|
else:
|
||
|
|
device = torch.device("cpu")
|
||
|
|
print("CUDA not available, using CPU")
|
||
|
|
return device
|
||
|
|
|
||
|
|
def debug_orientation(volume_id, level):
|
||
|
|
|
||
|
|
|
||
|
|
volume_dir = os.path.join(standardized_dir, volume_id)
|
||
|
|
|
||
|
|
cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz')
|
||
|
|
binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz')
|
||
|
|
roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz')
|
||
|
|
|
||
|
|
# azi = azimuth_rotation(binary_path)
|
||
|
|
# res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=False, debug=False)
|
||
|
|
azi = azimuth_rotation(binary_path, show_plt=True, save_plt=True, output_path=f'{azimuth_rotation_dir}/{level}_{volume_id}.png')
|
||
|
|
res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=True, debug=False, save_plt=True, output_path=f'{tilt_contour_dir}/{level}_{volume_id}.png')
|
||
|
|
alt = res['superior']['tilt_angle_deg']
|
||
|
|
|
||
|
|
print(binary_path)
|
||
|
|
# print(f'Azimuth: {azi}, Alt: {alt}')
|
||
|
|
print(f'Alt: {alt}')
|
||
|
|
|
||
|
|
def debug_pso(volume_id, level):
|
||
|
|
# ====== PSO ======
|
||
|
|
swarm_size = 100
|
||
|
|
max_iter = 100
|
||
|
|
|
||
|
|
# ====== DEVICE ======
|
||
|
|
device = get_device()
|
||
|
|
|
||
|
|
# ====== OTHER ======
|
||
|
|
spacing = [0.5, 0.5, 0.5]
|
||
|
|
CBT = True
|
||
|
|
|
||
|
|
volume_dir = os.path.join(standardized_dir, volume_id)
|
||
|
|
|
||
|
|
cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz')
|
||
|
|
binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz')
|
||
|
|
roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz')
|
||
|
|
|
||
|
|
cortical_image = sitk.ReadImage(cortical_path)
|
||
|
|
binary_image = sitk.ReadImage(binary_path)
|
||
|
|
roi_image = sitk.ReadImage(roi_path)
|
||
|
|
|
||
|
|
cortical_array = sitk.GetArrayFromImage(cortical_image)
|
||
|
|
binary_array = sitk.GetArrayFromImage(binary_image)
|
||
|
|
roi_array = sitk.GetArrayFromImage(roi_image)
|
||
|
|
image_shape = binary_array.shape
|
||
|
|
|
||
|
|
cortical_tensor = torch.tensor(cortical_array, device=device)
|
||
|
|
binary_tensor = torch.tensor(binary_array, device=device)
|
||
|
|
|
||
|
|
grid = create_coordinate_grid(image_shape, device)
|
||
|
|
|
||
|
|
set_global_context(
|
||
|
|
cortical=cortical_tensor,
|
||
|
|
spine=binary_tensor,
|
||
|
|
shape=image_shape,
|
||
|
|
spacing_=spacing,
|
||
|
|
device_=device,
|
||
|
|
grid_=grid,
|
||
|
|
use_tip_penalty=False
|
||
|
|
)
|
||
|
|
|
||
|
|
best_l, loss_l, best_r, loss_r, total_time = run_pso_torch_xfr(
|
||
|
|
label_str=level,
|
||
|
|
image1_path=cortical_path,
|
||
|
|
image2_path=binary_path,
|
||
|
|
image3_path=roi_path,
|
||
|
|
folder=Output_dir,
|
||
|
|
swarm_size=swarm_size,
|
||
|
|
max_iter=max_iter,
|
||
|
|
spacing=spacing,
|
||
|
|
CBT=CBT,
|
||
|
|
device=device,
|
||
|
|
optimize_size=True,
|
||
|
|
grid=grid,
|
||
|
|
)
|
||
|
|
|
||
|
|
# exit()
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# level = 'L1'
|
||
|
|
|
||
|
|
# volume_id = '1.3.6.1.4.1.9328.50.4.0001'
|
||
|
|
# volume_id = '1.3.6.1.4.1.9328.50.4.0003'
|
||
|
|
# # volume_id = '1.3.6.1.4.1.9328.50.4.0121'
|
||
|
|
# process_volume(volume_id, level)
|
||
|
|
# exit()
|
||
|
|
|
||
|
|
for volume_id in (
|
||
|
|
'1.3.6.1.4.1.9328.50.4.0001',
|
||
|
|
# '1.3.6.1.4.1.9328.50.4.0002',
|
||
|
|
# '1.3.6.1.4.1.9328.50.4.0003',
|
||
|
|
# '1.3.6.1.4.1.9328.50.4.0004',
|
||
|
|
# '1.3.6.1.4.1.9328.50.4.0005',
|
||
|
|
# '1.3.6.1.4.1.9328.50.4.0006',
|
||
|
|
):
|
||
|
|
# for volume_id in sorted(os.listdir(standardized_dir)):
|
||
|
|
# debug_orientation(volume_id, level)
|
||
|
|
for level in ('L1', 'L2', 'L3', 'L4', 'L5'):
|
||
|
|
# debug_orientation(volume_id, level)
|
||
|
|
debug_pso(volume_id, level)
|
||
|
|
|
||
|
|
exit()
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|