import os import SimpleITK as sitk import torch # from config.device import get_device from core.cylinder import create_coordinate_grid from core.objective import set_global_context from core.optimizer import run_pso_torch, run_de_torch, run_nm_torch, run_pso_torch_xfr from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour standardized_dir = '/mnt/1248/open/cyrou/CBT/Seg/Resample/standardized-xfr/' azimuth_rotation_dir = '/mnt/1248/open/cyrou/azimuth_rotation' tilt_contour_dir = '/mnt/1248/open/cyrou/tilt_contour' Output_dir = '/mnt/1248/open/cyrou/Output' def get_device(gpu_id=None): if torch.cuda.is_available(): if gpu_id is None: max_free = -1 gpu_id = 0 for i in range(torch.cuda.device_count()): free_mem, _ = torch.cuda.mem_get_info(i) # print(f'GPU {i}: {torch.cuda.get_device_name(i)} {free_mem}') if free_mem > max_free: max_free = free_mem gpu_id = i device = torch.device(f"cuda:{gpu_id}") print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}") else: device = torch.device("cpu") print("CUDA not available, using CPU") return device def debug_orientation(volume_id, level): volume_dir = os.path.join(standardized_dir, volume_id) cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz') binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz') roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz') # azi = azimuth_rotation(binary_path) # res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=False, debug=False) azi = azimuth_rotation(binary_path, show_plt=True, save_plt=True, output_path=f'{azimuth_rotation_dir}/{level}_{volume_id}.png') res = analyze_vertebral_tilt_contour(binary_path, edge_type='superior', show_plot=True, debug=False, save_plt=True, output_path=f'{tilt_contour_dir}/{level}_{volume_id}.png') alt = res['superior']['tilt_angle_deg'] print(binary_path) # print(f'Azimuth: {azi}, Alt: {alt}') print(f'Alt: {alt}') def debug_pso(volume_id, level): # ====== PSO ====== swarm_size = 100 max_iter = 100 # ====== DEVICE ====== device = get_device() # ====== OTHER ====== spacing = [0.5, 0.5, 0.5] CBT = True volume_dir = os.path.join(standardized_dir, volume_id) cortical_path = os.path.join(volume_dir, f'{level}_cortical.nii.gz') binary_path = os.path.join(volume_dir, f'{level}_binary.nii.gz') roi_path = os.path.join(volume_dir, f'{level}_roi2.nii.gz') cortical_image = sitk.ReadImage(cortical_path) binary_image = sitk.ReadImage(binary_path) roi_image = sitk.ReadImage(roi_path) cortical_array = sitk.GetArrayFromImage(cortical_image) binary_array = sitk.GetArrayFromImage(binary_image) roi_array = sitk.GetArrayFromImage(roi_image) image_shape = binary_array.shape cortical_tensor = torch.tensor(cortical_array, device=device) binary_tensor = torch.tensor(binary_array, device=device) grid = create_coordinate_grid(image_shape, device) set_global_context( cortical=cortical_tensor, spine=binary_tensor, shape=image_shape, spacing_=spacing, device_=device, grid_=grid, use_tip_penalty=False ) best_l, loss_l, best_r, loss_r, total_time = run_pso_torch_xfr( label_str=level, image1_path=cortical_path, image2_path=binary_path, image3_path=roi_path, folder=Output_dir, swarm_size=swarm_size, max_iter=max_iter, spacing=spacing, CBT=CBT, device=device, optimize_size=True, grid=grid, ) # exit() def main(): # level = 'L1' # volume_id = '1.3.6.1.4.1.9328.50.4.0001' # volume_id = '1.3.6.1.4.1.9328.50.4.0003' # # volume_id = '1.3.6.1.4.1.9328.50.4.0121' # process_volume(volume_id, level) # exit() for volume_id in ( '1.3.6.1.4.1.9328.50.4.0001', # '1.3.6.1.4.1.9328.50.4.0002', # '1.3.6.1.4.1.9328.50.4.0003', # '1.3.6.1.4.1.9328.50.4.0004', # '1.3.6.1.4.1.9328.50.4.0005', # '1.3.6.1.4.1.9328.50.4.0006', ): # for volume_id in sorted(os.listdir(standardized_dir)): # debug_orientation(volume_id, level) for level in ('L1', 'L2', 'L3', 'L4', 'L5'): # debug_orientation(volume_id, level) debug_pso(volume_id, level) exit() if __name__ == '__main__': main()