CBT_project/visualization/res_cyl_to_CT.py
2026-04-10 13:25:27 +08:00

127 lines
No EOL
5.2 KiB
Python

from utils.helpers import get_unique_filepath
import nibabel as nib
from datetime import datetime
import os
import numpy as np
import matplotlib.pyplot as plt
from utils.helpers import pad_rgb_to_shape
def cyl_and_CT(diameter_l, diameter_r, length_l, length_r, roi_image_path, cylinder_1, cylinder_2, base_folder='Output', CBT=True):
bone_nifti = nib.load(roi_image_path)
bone_array = bone_nifti.get_fdata()
cylinder_nifti = nib.load(cylinder_1)
cylinder_nifti2 = nib.load(cylinder_2)
cylinder_array = cylinder_nifti.get_fdata()
cylinder_array2 = cylinder_nifti2.get_fdata()
assert bone_array.shape == cylinder_array.shape, "影像 shape 不匹配,請先進行重採樣!"
date = datetime.now().strftime("%Y%m%d")
patient_id = os.path.basename(os.path.dirname(roi_image_path))
file_name = os.path.basename(roi_image_path)
level = file_name.split('_')[0]
output_folder = os.path.join(base_folder, date, patient_id)
if CBT:
output_file = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_whole.png'
output_axial = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_axial.png'
output_coronal = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_coronal.png'
output_saggital = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_CBT_saggital.png'
else:
output_file = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_whole.png'
output_axial = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_axial.png'
output_coronal = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_coronal.png'
output_saggital = f'{output_folder}/{level}_L{diameter_l}_{length_l}_R{diameter_r}_{length_r}_TPS_saggital.png'
mip_bone_z = np.max(bone_array, axis=0)
mip_cylinder_z = np.max(cylinder_array, axis=0)
mip_cylinder2_z = np.max(cylinder_array2, axis=0)
mip_bone_x = np.max(bone_array, axis=1)
mip_cylinder_x = np.max(cylinder_array, axis=1)
mip_cylinder2_x = np.max(cylinder_array2, axis=1)
mip_bone_y = np.max(bone_array, axis=2)
mip_cylinder_y = np.max(cylinder_array, axis=2)
mip_cylinder2_y = np.max(cylinder_array2, axis=2)
bone_norm_z = mip_bone_z / np.max(mip_bone_z)
bone_norm_x = mip_bone_x / np.max(mip_bone_x)
bone_norm_y = mip_bone_y / np.max(mip_bone_y)
cylinder_mask_z = mip_cylinder_z > 0
cylinder_mask_x = mip_cylinder_x > 0
cylinder_mask_y = mip_cylinder_y > 0
cylinder_mask2_z = mip_cylinder2_z > 0
cylinder_mask2_x = mip_cylinder2_x > 0
cylinder_mask2_y = mip_cylinder2_y > 0
def create_rgb_image(bone_norm, cylinder_mask, cylinder_mask2):
bone_norm = (bone_norm - np.min(bone_norm)) / (np.max(bone_norm) - np.min(bone_norm))
bone_norm = np.clip(bone_norm, 0, 1)
rgb_image = np.zeros((bone_norm.shape[0], bone_norm.shape[1], 3))
rgb_image[..., 0] = bone_norm
rgb_image[..., 1] = bone_norm
rgb_image[..., 2] = bone_norm
rgb_image[cylinder_mask, :] = 1
rgb_image[cylinder_mask2, :] = 1
return rgb_image
def crop_center(img, cropx, cropy):
y, x = img.shape[0], img.shape[1]
startx = x//2 - cropx//2
starty = y//2 - cropy//2
return img[starty:starty+cropy, startx:startx+cropx, :]
rgb_image_z = np.rot90(create_rgb_image(bone_norm_z, cylinder_mask_z, cylinder_mask2_z))
rgb_image_x = np.rot90(create_rgb_image(bone_norm_x, cylinder_mask_x, cylinder_mask2_x))
rgb_image_y = np.rot90(create_rgb_image(bone_norm_y, cylinder_mask_y, cylinder_mask2_y))
H = max(rgb_image_y.shape[0], rgb_image_x.shape[0], rgb_image_z.shape[0])
W = max(rgb_image_y.shape[1], rgb_image_x.shape[1], rgb_image_z.shape[1])
rgb_image_y = pad_rgb_to_shape(rgb_image_y, (H, W), pad_value=0.0)
rgb_image_x = pad_rgb_to_shape(rgb_image_x, (H, W), pad_value=0.0)
rgb_image_z = pad_rgb_to_shape(rgb_image_z, (H, W), pad_value=0.0)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(rgb_image_y)
axs[0].set_title("Axial", pad=10)
axs[0].axis("off")
axs[0].set_aspect("equal")
axs[1].imshow(rgb_image_x)
axs[1].set_title("Coronal", pad=10)
axs[1].axis("off")
axs[1].set_aspect("equal")
axs[2].imshow(rgb_image_z)
axs[2].set_title("Saggital", pad=10)
axs[2].axis("off")
axs[2].set_aspect("equal")
plt.subplots_adjust(left=0.05, right=0.75, top=0.9, bottom=0.1, wspace=0.01)
plt.savefig(get_unique_filepath(output_file), bbox_inches='tight')
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(rgb_image_z)
ax.set_title("Saggital", pad=10)
ax.axis("off")
plt.savefig(get_unique_filepath(output_saggital), bbox_inches='tight')
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(rgb_image_x)
ax.set_title("Coronal", pad=10)
ax.axis("off")
plt.savefig(get_unique_filepath(output_coronal), bbox_inches='tight')
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(rgb_image_y)
ax.set_title("Axial", pad=10)
ax.axis("off")
plt.savefig(get_unique_filepath(output_axial), bbox_inches='tight')