CBT_project/config/device.py

10 lines
308 B
Python
Raw Normal View History

2026-04-10 05:25:27 +00:00
import torch
def get_device(gpu_id=0):
if torch.cuda.is_available():
device = torch.device(f"cuda:{gpu_id}")
print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}")
else:
device = torch.device("cpu")
print("CUDA not available, using CPU")
return device