CBT_project/core/scoring.py
Xiao Furen b76f0708f3 1. correct the bounding box and cortical mask
2. make the plot isometric
3. now it should work after tuning the parameters
2026-04-17 00:03:10 +08:00

265 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch
from config.constant import OVERLAP_THRESH
def cl_score_torch_xfr(
cortical_tensor: torch.Tensor,
spine_tensor: torch.Tensor,
cylinder_torch: torch.Tensor,
cylinder_o_torch: torch.Tensor,
intersections: int,
diameter: float = None,
length: float = None,
cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask
) -> float:
"""
漸進式評分:優先確保找到骨頭,再改善細節
"""
cyl_total = cylinder_torch.sum().item()
overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item()
null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item()
null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item()
in_bone= ((spine_tensor == 1) & (cylinder_torch == 1)).sum().item()
not_in_bone= ((spine_tensor == 0) & (cylinder_torch == 1)).sum().item()
# if cyl_total == 0:
# return float(1000*1000)
# return float(1e9) # 極差的情況
overlap_ratio = overlap / cyl_total
# if cyl_total < 1000:
# return float((1000 - cyl_total)*10000)
score = cyl_total
# if in_bone == 0:
# return float(not_in_bone*200)
score += overlap*30
score += in_bone*10
score -= not_in_bone*1000
score -= null_vox2*1000
return float(-score)
# === 階段 1首要目標是找到骨頭overlap > 0 ===
if overlap == 0:
# 完全沒有 overlap 是最糟糕的情況
score -= 500000 # 超大懲罰
# 如果連 spine 都沒穿過,更糟
if intersections == 0:
score -= 500000
return float(-score)
# === 階段 2有找到骨頭了開始改善品質 ===
# 1. Overlap 獎勵(非線性,鼓勵快速提升)
if overlap_ratio < 0.1:
# 0-10%:每增加 1% 給大量獎勵(鼓勵探索)
score += overlap * 5000 # 很高的單位獎勵
elif overlap_ratio < 0.3:
# 10-30%:中等獎勵
score += overlap * 3000
elif overlap_ratio < 0.5:
# 30-50%:正常獎勵
score += overlap * 2000
else:
# 50%+:獎勵 + 額外比例獎勵
score += overlap * 2000
score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎
# 2. Intersection 控制(稍微放寬)
if intersections == 1:
score += 20000 # 完美
elif intersections == 0:
score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好)
elif intersections == 2:
score -= 10000 # 可接受但不理想
else:
score -= intersections * 15000
# 3. Null voxel 懲罰(漸進式)
null_ratio = null_vox / cyl_total
if overlap_ratio < 0.2:
# 如果 overlap 還很少,對 null voxel 寬容一點
score -= null_vox * 300
elif overlap_ratio < 0.4:
score -= null_vox * 600
else:
# overlap 夠高了,開始嚴格要求
if null_ratio > 0.5:
score -= null_vox * 1500
else:
score -= null_vox * 800
# 4. 反向圓柱懲罰
score -= null_vox2 * 1000
# 5. 尺寸合理性(放寬)
if diameter is not None and length is not None:
if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0)
score -= 3000
if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60)
score -= 3000
# 6. 尖端 breach 懲罰
if cylinder_tip_torch is not None:
tip_total = cylinder_tip_torch.sum().item()
if tip_total > 0:
tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item()
tip_breach_ratio = tip_breach / tip_total
if tip_breach_ratio > 0:
score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多
return float(-score)
def cl_score_torch(
cortical_tensor: torch.Tensor,
spine_tensor: torch.Tensor,
cylinder_torch: torch.Tensor,
cylinder_o_torch: torch.Tensor,
intersections: int,
diameter: float = None,
length: float = None,
cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask
) -> float:
"""
漸進式評分:優先確保找到骨頭,再改善細節
"""
cyl_total = cylinder_torch.sum().item()
overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item()
null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item()
null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item()
if cyl_total == 0:
return float(1e9) # 極差的情況
overlap_ratio = overlap / cyl_total
score = 0
# === 階段 1首要目標是找到骨頭overlap > 0 ===
if overlap == 0:
# 完全沒有 overlap 是最糟糕的情況
score -= 500000 # 超大懲罰
# 如果連 spine 都沒穿過,更糟
if intersections == 0:
score -= 500000
return float(-score)
# === 階段 2有找到骨頭了開始改善品質 ===
# 1. Overlap 獎勵(非線性,鼓勵快速提升)
if overlap_ratio < 0.1:
# 0-10%:每增加 1% 給大量獎勵(鼓勵探索)
score += overlap * 5000 # 很高的單位獎勵
elif overlap_ratio < 0.3:
# 10-30%:中等獎勵
score += overlap * 3000
elif overlap_ratio < 0.5:
# 30-50%:正常獎勵
score += overlap * 2000
else:
# 50%+:獎勵 + 額外比例獎勵
score += overlap * 2000
score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎
# 2. Intersection 控制(稍微放寬)
if intersections == 1:
score += 20000 # 完美
elif intersections == 0:
score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好)
elif intersections == 2:
score -= 10000 # 可接受但不理想
else:
score -= intersections * 15000
# 3. Null voxel 懲罰(漸進式)
null_ratio = null_vox / cyl_total
if overlap_ratio < 0.2:
# 如果 overlap 還很少,對 null voxel 寬容一點
score -= null_vox * 300
elif overlap_ratio < 0.4:
score -= null_vox * 600
else:
# overlap 夠高了,開始嚴格要求
if null_ratio > 0.5:
score -= null_vox * 1500
else:
score -= null_vox * 800
# 4. 反向圓柱懲罰
score -= null_vox2 * 1000
# 5. 尺寸合理性(放寬)
if diameter is not None and length is not None:
if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0)
score -= 3000
if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60)
score -= 3000
# 6. 尖端 breach 懲罰
if cylinder_tip_torch is not None:
tip_total = cylinder_tip_torch.sum().item()
if tip_total > 0:
tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item()
tip_breach_ratio = tip_breach / tip_total
if tip_breach_ratio > 0:
score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多
return float(-score)
def get_overlap_ratio(
position_params: list,
diameter: float,
length: float,
cortical_tensor: torch.Tensor,
image_shape: tuple,
spacing: list,
device: torch.device,
grid=None
) -> float:
"""
計算 Cylinder 與 Cortical Bone 的重疊比例 (%)
"""
# 生成 Cylinder Mask
cyl_mask = generate_cylinder_n_torch(
diameter, length,
position_params[0], position_params[1], position_params[2],
position_params[3], position_params[4],
image_shape, spacing, device, grid
)
# 計算體積 (Voxel count)
cyl_vol = torch.sum(cyl_mask).item()
if cyl_vol == 0:
return 0.0
# 計算重疊部分
# 注意:這裡使用 cortical_tensor (與 cl_score_torch 邏輯一致)
overlap_count = ((cortical_tensor == 1) & (cyl_mask == 1)).sum().item()
return (overlap_count / cyl_vol) * 100.0
def compute_overlap_ratio_from_cylinder_mask(cyl_mask: torch.Tensor,
spine_mask: torch.Tensor,
eps: float = 1e-6) -> float:
"""
一個常見定義: overlap = intersection / cylinder_volume
你也可以改成 intersection / spine_volume 或 Dice依你論文/需求一致即可。
cyl_mask, spine_mask: uint8/bool tensor, same shape
"""
cyl = cyl_mask.bool()
sp = spine_mask.bool()
inter = (cyl & sp).sum().item()
denom = cyl.sum().item()
return float(inter) / float(denom + eps)
def is_solution_ok(loss: float, overlap: float, overlap_thresh: float = OVERLAP_THRESH) -> bool:
return (loss <= 0) and (overlap >= overlap_thresh)