150 lines
No EOL
5.9 KiB
Python
150 lines
No EOL
5.9 KiB
Python
import os
|
||
import SimpleITK as sitk
|
||
from imaging.resample import resample_img
|
||
from imaging.affine import standardize_affine
|
||
from imaging.segmentation import seg_bone
|
||
import json
|
||
import glob
|
||
from config.constant import LABEL_MAP
|
||
from imaging.nifti_io import sitk_to_nibabel, nibabel_to_sitk
|
||
|
||
PROGRESS_FILE = "progress.json"
|
||
|
||
def load_progress():
|
||
if os.path.exists(PROGRESS_FILE):
|
||
with open(PROGRESS_FILE, "r") as f:
|
||
return json.load(f)
|
||
return {}
|
||
|
||
def save_progress(progress):
|
||
with open(PROGRESS_FILE, "w") as f:
|
||
json.dump(progress, f, indent=2)
|
||
|
||
def process_single_image(image_path, label_path, output_dir_base=None):
|
||
|
||
image = sitk.ReadImage(image_path)
|
||
label = sitk.ReadImage(label_path)
|
||
file_name = os.path.basename(image_path)
|
||
name = file_name.replace(".nii.gz", "")
|
||
|
||
# LabelStatisticsImageFilter computes statistics (e.g., mean, minimum, maximum, median) of pixel values in an image, segmented by labels in a corresponding label image.
|
||
lsif = sitk.LabelStatisticsImageFilter()
|
||
lsif.Execute(image, label)
|
||
|
||
# Assume to have some sitk image (itk_image) and label (itk_label)
|
||
resampled_sitk_img = resample_img(image, out_spacing=[0.5, 0.5, 0.5], is_label=False)
|
||
resampled_sitk_lbl = resample_img(label, out_spacing=[0.5, 0.5, 0.5], is_label=True)
|
||
|
||
# 取得現有 label
|
||
lssif = sitk.LabelShapeStatisticsImageFilter()
|
||
lssif.Execute(label)
|
||
existing_labels = lssif.GetLabels() # 這會回傳 list,例如 [1,2,3,20,21]
|
||
|
||
print(f"Existing labels in {os.path.basename(label_path)}: {existing_labels}")
|
||
|
||
# 建立每個檔案的輸出資料夾
|
||
file_name = os.path.basename(image_path)
|
||
name = file_name.replace(".nii.gz", "")
|
||
output_dir = os.path.join(output_dir_base, name)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 存現有 label 到 txt
|
||
txt_path = os.path.join(output_dir, f"{name}_labels.txt")
|
||
with open(txt_path, "w") as f:
|
||
for lab in existing_labels:
|
||
f.write(f"{lab}\t{LABEL_MAP.get(lab, 'Unknown')}\n")
|
||
|
||
# 遍歷現有 label 做分割
|
||
for n in existing_labels:
|
||
try:
|
||
roi_path, binary_path, roi2_path, cortical_path = seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_dir, label_map=LABEL_MAP)
|
||
for path in [roi_path, binary_path, roi2_path, cortical_path]:
|
||
standardize_affine(path, output_dir)
|
||
except RuntimeError as e:
|
||
print(f"Label {n} could not be processed, skipping. Error: {e}")
|
||
|
||
return {
|
||
"processed_labels": [lab for lab in existing_labels],
|
||
"missing_labels": []
|
||
}
|
||
|
||
def process_dataset(image_dir, label_dir, output_dir, labels_to_process=None):
|
||
image_files = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz")))
|
||
total_files = len(image_files)
|
||
print(f"Total files: {total_files}")
|
||
|
||
progress = load_progress()
|
||
all_file_summary = []
|
||
|
||
for idx, image_path in enumerate(image_files, 1):
|
||
file_name = os.path.basename(image_path)
|
||
name = file_name.replace(".nii.gz", "")
|
||
label_path = os.path.join(label_dir, file_name.replace(".nii.gz", "_seg.nii.gz"))
|
||
|
||
file_summary = {
|
||
"file_name": file_name,
|
||
"current_labels": [],
|
||
"missing_labels": []
|
||
}
|
||
|
||
if progress.get(name, {}).get("finished", False):
|
||
print(f"[{idx}/{total_files}] Already finished: {file_name}")
|
||
file_summary["current_labels"] = progress[name].get("processed_labels", [])
|
||
file_summary["missing_labels"] = progress[name].get("missing_labels", [])
|
||
all_file_summary.append(file_summary)
|
||
continue
|
||
|
||
if not os.path.exists(label_path):
|
||
print(f"[{idx}/{total_files}] Warning: label not found for {file_name}")
|
||
file_summary["note"] = "Label file not found"
|
||
all_file_summary.append(file_summary)
|
||
continue
|
||
|
||
try:
|
||
result = process_single_image(image_path, label_path, output_dir_base=output_dir)
|
||
except Exception as e:
|
||
print(f"[{idx}/{total_files}] Error processing {file_name}: {e}")
|
||
file_summary["note"] = f"Error: {e}"
|
||
all_file_summary.append(file_summary)
|
||
continue
|
||
|
||
file_summary["current_labels"] = result["processed_labels"]
|
||
file_summary["missing_labels"] = result["missing_labels"]
|
||
all_file_summary.append(file_summary)
|
||
|
||
progress[name] = {
|
||
"finished": True,
|
||
"processed_labels": result["processed_labels"],
|
||
"missing_labels": result["missing_labels"]
|
||
}
|
||
save_progress(progress)
|
||
|
||
print(f"[{idx}/{total_files}] Finished: {file_name} | Missing labels: {result['missing_labels'] or 'None'}")
|
||
|
||
# --- Summary ---
|
||
summary_path = os.path.join(output_dir, "all_files_label_summary.txt")
|
||
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
|
||
print(f"\nWriting summary to {summary_path}...")
|
||
|
||
with open(summary_path, "w") as f:
|
||
f.write("--- CTSpine1K Dataset Label Summary ---\n")
|
||
f.write(f"Total files processed: {total_files}\n\n")
|
||
|
||
for summary in all_file_summary:
|
||
f.write("================================================\n")
|
||
f.write(f"File: {summary['file_name']}\n")
|
||
|
||
processed_labels_str = ", ".join([str(l) for l in summary['current_labels']])
|
||
f.write(f"Labels processed: {processed_labels_str}\n")
|
||
|
||
if summary['missing_labels']:
|
||
missing_str = ", ".join([f"{l} ({LABEL_MAP.get(l, 'Unknown')})" for l in summary['missing_labels']])
|
||
f.write(f"🚨 Missing Labels: {missing_str}\n")
|
||
else:
|
||
f.write("✅ Missing Labels: None\n")
|
||
|
||
if "note" in summary:
|
||
f.write(f"Note: {summary['note']}\n")
|
||
f.write("================================================\n\n")
|
||
|
||
print("All done! Summary file created.") |