CBT_project/imaging/preprocessing.py

152 lines
5.9 KiB
Python
Raw Permalink Normal View History

2026-04-10 05:25:27 +00:00
import os
import SimpleITK as sitk
from imaging.resample import resample_img
from imaging.affine import standardize_affine
from imaging.segmentation import seg_bone
import json
import glob
from config.constant import LABEL_MAP
from imaging.nifti_io import sitk_to_nibabel, nibabel_to_sitk
PROGRESS_FILE = "progress.json"
def load_progress():
if os.path.exists(PROGRESS_FILE):
with open(PROGRESS_FILE, "r") as f:
return json.load(f)
return {}
def save_progress(progress):
with open(PROGRESS_FILE, "w") as f:
json.dump(progress, f, indent=2)
def process_single_image(image_path, label_path, output_dir_base=None):
image = sitk.ReadImage(image_path)
label = sitk.ReadImage(label_path)
file_name = os.path.basename(image_path)
name = file_name.replace(".nii.gz", "")
# LabelStatisticsImageFilter computes statistics (e.g., mean, minimum, maximum, median) of pixel values in an image, segmented by labels in a corresponding label image.
lsif = sitk.LabelStatisticsImageFilter()
lsif.Execute(image, label)
# Assume to have some sitk image (itk_image) and label (itk_label)
resampled_sitk_img = resample_img(image, out_spacing=[0.5, 0.5, 0.5], is_label=False)
resampled_sitk_lbl = resample_img(label, out_spacing=[0.5, 0.5, 0.5], is_label=True)
# 取得現有 label
lssif = sitk.LabelShapeStatisticsImageFilter()
lssif.Execute(label)
existing_labels = lssif.GetLabels() # 這會回傳 list例如 [1,2,3,20,21]
print(f"Existing labels in {os.path.basename(label_path)}: {existing_labels}")
# 建立每個檔案的輸出資料夾
file_name = os.path.basename(image_path)
name = file_name.replace(".nii.gz", "")
output_dir = os.path.join(output_dir_base, name)
os.makedirs(output_dir, exist_ok=True)
# 存現有 label 到 txt
txt_path = os.path.join(output_dir, f"{name}_labels.txt")
with open(txt_path, "w") as f:
for lab in existing_labels:
f.write(f"{lab}\t{LABEL_MAP.get(lab, 'Unknown')}\n")
# 遍歷現有 label 做分割
for n in existing_labels:
try:
roi_path, binary_path, roi2_path, cortical_path = seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_dir, label_map=LABEL_MAP)
for path in [roi_path, binary_path, roi2_path, cortical_path]:
standardize_affine(path, output_dir)
except RuntimeError as e:
print(f"Label {n} could not be processed, skipping. Error: {e}")
return {
"processed_labels": [lab for lab in existing_labels],
"missing_labels": []
}
def process_dataset(image_dir, label_dir, output_dir, labels_to_process=None):
image_files = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz")))
total_files = len(image_files)
print(f"Total files: {total_files}")
progress = load_progress()
all_file_summary = []
for idx, image_path in enumerate(image_files, 1):
file_name = os.path.basename(image_path)
name = file_name.replace(".nii.gz", "")
label_path = os.path.join(label_dir, file_name.replace(".nii.gz", "_seg.nii.gz"))
file_summary = {
"file_name": file_name,
"current_labels": [],
"missing_labels": []
}
if progress.get(name, {}).get("finished", False):
print(f"[{idx}/{total_files}] Already finished: {file_name}")
file_summary["current_labels"] = progress[name].get("processed_labels", [])
file_summary["missing_labels"] = progress[name].get("missing_labels", [])
all_file_summary.append(file_summary)
continue
if not os.path.exists(label_path):
print(f"[{idx}/{total_files}] Warning: label not found for {file_name}")
file_summary["note"] = "Label file not found"
all_file_summary.append(file_summary)
continue
try:
result = process_single_image(image_path, label_path, output_dir_base=output_dir)
# print(result)
# exit()
2026-04-10 05:25:27 +00:00
except Exception as e:
print(f"[{idx}/{total_files}] Error processing {file_name}: {e}")
file_summary["note"] = f"Error: {e}"
all_file_summary.append(file_summary)
continue
file_summary["current_labels"] = result["processed_labels"]
file_summary["missing_labels"] = result["missing_labels"]
all_file_summary.append(file_summary)
progress[name] = {
"finished": True,
"processed_labels": result["processed_labels"],
"missing_labels": result["missing_labels"]
}
save_progress(progress)
print(f"[{idx}/{total_files}] Finished: {file_name} | Missing labels: {result['missing_labels'] or 'None'}")
# --- Summary ---
summary_path = os.path.join(output_dir, "all_files_label_summary.txt")
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
print(f"\nWriting summary to {summary_path}...")
with open(summary_path, "w") as f:
f.write("--- CTSpine1K Dataset Label Summary ---\n")
f.write(f"Total files processed: {total_files}\n\n")
for summary in all_file_summary:
f.write("================================================\n")
f.write(f"File: {summary['file_name']}\n")
processed_labels_str = ", ".join([str(l) for l in summary['current_labels']])
f.write(f"Labels processed: {processed_labels_str}\n")
if summary['missing_labels']:
missing_str = ", ".join([f"{l} ({LABEL_MAP.get(l, 'Unknown')})" for l in summary['missing_labels']])
f.write(f"🚨 Missing Labels: {missing_str}\n")
else:
f.write("✅ Missing Labels: None\n")
if "note" in summary:
f.write(f"Note: {summary['note']}\n")
f.write("================================================\n\n")
print("All done! Summary file created.")