CBT_project/imaging/preprocessing.py
2026-04-10 13:25:27 +08:00

150 lines
No EOL
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import SimpleITK as sitk
from imaging.resample import resample_img
from imaging.affine import standardize_affine
from imaging.segmentation import seg_bone
import json
import glob
from config.constant import LABEL_MAP
from imaging.nifti_io import sitk_to_nibabel, nibabel_to_sitk
PROGRESS_FILE = "progress.json"
def load_progress():
if os.path.exists(PROGRESS_FILE):
with open(PROGRESS_FILE, "r") as f:
return json.load(f)
return {}
def save_progress(progress):
with open(PROGRESS_FILE, "w") as f:
json.dump(progress, f, indent=2)
def process_single_image(image_path, label_path, output_dir_base=None):
image = sitk.ReadImage(image_path)
label = sitk.ReadImage(label_path)
file_name = os.path.basename(image_path)
name = file_name.replace(".nii.gz", "")
# LabelStatisticsImageFilter computes statistics (e.g., mean, minimum, maximum, median) of pixel values in an image, segmented by labels in a corresponding label image.
lsif = sitk.LabelStatisticsImageFilter()
lsif.Execute(image, label)
# Assume to have some sitk image (itk_image) and label (itk_label)
resampled_sitk_img = resample_img(image, out_spacing=[0.5, 0.5, 0.5], is_label=False)
resampled_sitk_lbl = resample_img(label, out_spacing=[0.5, 0.5, 0.5], is_label=True)
# 取得現有 label
lssif = sitk.LabelShapeStatisticsImageFilter()
lssif.Execute(label)
existing_labels = lssif.GetLabels() # 這會回傳 list例如 [1,2,3,20,21]
print(f"Existing labels in {os.path.basename(label_path)}: {existing_labels}")
# 建立每個檔案的輸出資料夾
file_name = os.path.basename(image_path)
name = file_name.replace(".nii.gz", "")
output_dir = os.path.join(output_dir_base, name)
os.makedirs(output_dir, exist_ok=True)
# 存現有 label 到 txt
txt_path = os.path.join(output_dir, f"{name}_labels.txt")
with open(txt_path, "w") as f:
for lab in existing_labels:
f.write(f"{lab}\t{LABEL_MAP.get(lab, 'Unknown')}\n")
# 遍歷現有 label 做分割
for n in existing_labels:
try:
roi_path, binary_path, roi2_path, cortical_path = seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_dir, label_map=LABEL_MAP)
for path in [roi_path, binary_path, roi2_path, cortical_path]:
standardize_affine(path, output_dir)
except RuntimeError as e:
print(f"Label {n} could not be processed, skipping. Error: {e}")
return {
"processed_labels": [lab for lab in existing_labels],
"missing_labels": []
}
def process_dataset(image_dir, label_dir, output_dir, labels_to_process=None):
image_files = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz")))
total_files = len(image_files)
print(f"Total files: {total_files}")
progress = load_progress()
all_file_summary = []
for idx, image_path in enumerate(image_files, 1):
file_name = os.path.basename(image_path)
name = file_name.replace(".nii.gz", "")
label_path = os.path.join(label_dir, file_name.replace(".nii.gz", "_seg.nii.gz"))
file_summary = {
"file_name": file_name,
"current_labels": [],
"missing_labels": []
}
if progress.get(name, {}).get("finished", False):
print(f"[{idx}/{total_files}] Already finished: {file_name}")
file_summary["current_labels"] = progress[name].get("processed_labels", [])
file_summary["missing_labels"] = progress[name].get("missing_labels", [])
all_file_summary.append(file_summary)
continue
if not os.path.exists(label_path):
print(f"[{idx}/{total_files}] Warning: label not found for {file_name}")
file_summary["note"] = "Label file not found"
all_file_summary.append(file_summary)
continue
try:
result = process_single_image(image_path, label_path, output_dir_base=output_dir)
except Exception as e:
print(f"[{idx}/{total_files}] Error processing {file_name}: {e}")
file_summary["note"] = f"Error: {e}"
all_file_summary.append(file_summary)
continue
file_summary["current_labels"] = result["processed_labels"]
file_summary["missing_labels"] = result["missing_labels"]
all_file_summary.append(file_summary)
progress[name] = {
"finished": True,
"processed_labels": result["processed_labels"],
"missing_labels": result["missing_labels"]
}
save_progress(progress)
print(f"[{idx}/{total_files}] Finished: {file_name} | Missing labels: {result['missing_labels'] or 'None'}")
# --- Summary ---
summary_path = os.path.join(output_dir, "all_files_label_summary.txt")
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
print(f"\nWriting summary to {summary_path}...")
with open(summary_path, "w") as f:
f.write("--- CTSpine1K Dataset Label Summary ---\n")
f.write(f"Total files processed: {total_files}\n\n")
for summary in all_file_summary:
f.write("================================================\n")
f.write(f"File: {summary['file_name']}\n")
processed_labels_str = ", ".join([str(l) for l in summary['current_labels']])
f.write(f"Labels processed: {processed_labels_str}\n")
if summary['missing_labels']:
missing_str = ", ".join([f"{l} ({LABEL_MAP.get(l, 'Unknown')})" for l in summary['missing_labels']])
f.write(f"🚨 Missing Labels: {missing_str}\n")
else:
f.write("✅ Missing Labels: None\n")
if "note" in summary:
f.write(f"Note: {summary['note']}\n")
f.write("================================================\n\n")
print("All done! Summary file created.")