151 lines
5.1 KiB
Python
151 lines
5.1 KiB
Python
import torch
|
||
from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch
|
||
from config.constant import OVERLAP_THRESH
|
||
|
||
def cl_score_torch(
|
||
cortical_tensor: torch.Tensor,
|
||
spine_tensor: torch.Tensor,
|
||
cylinder_torch: torch.Tensor,
|
||
cylinder_o_torch: torch.Tensor,
|
||
intersections: int,
|
||
diameter: float = None,
|
||
length: float = None,
|
||
cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask
|
||
) -> float:
|
||
"""
|
||
漸進式評分:優先確保找到骨頭,再改善細節
|
||
"""
|
||
cyl_total = cylinder_torch.sum().item()
|
||
overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item()
|
||
null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item()
|
||
null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item()
|
||
|
||
if cyl_total == 0:
|
||
return float(1e9) # 極差的情況
|
||
|
||
overlap_ratio = overlap / cyl_total
|
||
|
||
score = 0
|
||
|
||
# === 階段 1:首要目標是找到骨頭(overlap > 0) ===
|
||
if overlap == 0:
|
||
# 完全沒有 overlap 是最糟糕的情況
|
||
score -= 500000 # 超大懲罰
|
||
# 如果連 spine 都沒穿過,更糟
|
||
if intersections == 0:
|
||
score -= 500000
|
||
return float(-score)
|
||
|
||
# === 階段 2:有找到骨頭了,開始改善品質 ===
|
||
|
||
# 1. Overlap 獎勵(非線性,鼓勵快速提升)
|
||
if overlap_ratio < 0.1:
|
||
# 0-10%:每增加 1% 給大量獎勵(鼓勵探索)
|
||
score += overlap * 5000 # 很高的單位獎勵
|
||
elif overlap_ratio < 0.3:
|
||
# 10-30%:中等獎勵
|
||
score += overlap * 3000
|
||
elif overlap_ratio < 0.5:
|
||
# 30-50%:正常獎勵
|
||
score += overlap * 2000
|
||
else:
|
||
# 50%+:獎勵 + 額外比例獎勵
|
||
score += overlap * 2000
|
||
score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎
|
||
|
||
# 2. Intersection 控制(稍微放寬)
|
||
if intersections == 1:
|
||
score += 20000 # 完美
|
||
elif intersections == 0:
|
||
score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好)
|
||
elif intersections == 2:
|
||
score -= 10000 # 可接受但不理想
|
||
else:
|
||
score -= intersections * 15000
|
||
|
||
# 3. Null voxel 懲罰(漸進式)
|
||
null_ratio = null_vox / cyl_total
|
||
|
||
if overlap_ratio < 0.2:
|
||
# 如果 overlap 還很少,對 null voxel 寬容一點
|
||
score -= null_vox * 300
|
||
elif overlap_ratio < 0.4:
|
||
score -= null_vox * 600
|
||
else:
|
||
# overlap 夠高了,開始嚴格要求
|
||
if null_ratio > 0.5:
|
||
score -= null_vox * 1500
|
||
else:
|
||
score -= null_vox * 800
|
||
|
||
# 4. 反向圓柱懲罰
|
||
score -= null_vox2 * 1000
|
||
|
||
# 5. 尺寸合理性(放寬)
|
||
if diameter is not None and length is not None:
|
||
if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0)
|
||
score -= 3000
|
||
if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60)
|
||
score -= 3000
|
||
|
||
# 6. 尖端 breach 懲罰
|
||
if cylinder_tip_torch is not None:
|
||
tip_total = cylinder_tip_torch.sum().item()
|
||
if tip_total > 0:
|
||
tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item()
|
||
tip_breach_ratio = tip_breach / tip_total
|
||
if tip_breach_ratio > 0:
|
||
score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多
|
||
|
||
return float(-score)
|
||
|
||
def get_overlap_ratio(
|
||
position_params: list,
|
||
diameter: float,
|
||
length: float,
|
||
cortical_tensor: torch.Tensor,
|
||
image_shape: tuple,
|
||
spacing: list,
|
||
device: torch.device,
|
||
grid=None
|
||
) -> float:
|
||
"""
|
||
計算 Cylinder 與 Cortical Bone 的重疊比例 (%)
|
||
"""
|
||
# 生成 Cylinder Mask
|
||
cyl_mask = generate_cylinder_n_torch(
|
||
diameter, length,
|
||
position_params[0], position_params[1], position_params[2],
|
||
position_params[3], position_params[4],
|
||
image_shape, spacing, device, grid
|
||
)
|
||
|
||
# 計算體積 (Voxel count)
|
||
cyl_vol = torch.sum(cyl_mask).item()
|
||
|
||
if cyl_vol == 0:
|
||
return 0.0
|
||
|
||
# 計算重疊部分
|
||
# 注意:這裡使用 cortical_tensor (與 cl_score_torch 邏輯一致)
|
||
overlap_count = ((cortical_tensor == 1) & (cyl_mask == 1)).sum().item()
|
||
|
||
return (overlap_count / cyl_vol) * 100.0
|
||
|
||
def compute_overlap_ratio_from_cylinder_mask(cyl_mask: torch.Tensor,
|
||
spine_mask: torch.Tensor,
|
||
eps: float = 1e-6) -> float:
|
||
"""
|
||
一個常見定義: overlap = intersection / cylinder_volume
|
||
你也可以改成 intersection / spine_volume 或 Dice,依你論文/需求一致即可。
|
||
cyl_mask, spine_mask: uint8/bool tensor, same shape
|
||
"""
|
||
cyl = cyl_mask.bool()
|
||
sp = spine_mask.bool()
|
||
inter = (cyl & sp).sum().item()
|
||
denom = cyl.sum().item()
|
||
return float(inter) / float(denom + eps)
|
||
|
||
def is_solution_ok(loss: float, overlap: float, overlap_thresh: float = OVERLAP_THRESH) -> bool:
|
||
return (loss <= 0) and (overlap >= overlap_thresh)
|
||
|