adm18/IMPAX/nb2.ipynb

298 lines
693 KiB
Text
Raw Permalink Normal View History

2025-09-16 05:20:19 +00:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using device: cuda\n",
"\n",
"GeForce RTX 2080 Ti\n",
"Memory Usage:\n",
"Allocated: 0.0 GB\n",
"Cached: 2.6 GB\n",
"[1, 35] trian loss: 1433.191\n",
"[2, 35] trian loss: 355.490\n",
"[3, 35] trian loss: 196.416\n",
"[4, 35] trian loss: 147.377\n",
"[5, 35] trian loss: 135.332\n",
"[6, 35] trian loss: 121.579\n",
"[7, 35] trian loss: 107.172\n",
"[8, 35] trian loss: 95.045\n",
"[9, 35] trian loss: 97.039\n",
"[10, 35] trian loss: 83.606\n",
"[11, 35] trian loss: 85.872\n",
"[12, 35] trian loss: 71.687\n",
"[13, 35] trian loss: 85.657\n",
"[14, 35] trian loss: 80.232\n",
"[15, 35] trian loss: 82.041\n",
"[16, 35] trian loss: 71.219\n",
"[17, 35] trian loss: 60.671\n",
"[18, 35] trian loss: 63.611\n",
"[19, 35] trian loss: 59.995\n",
"[20, 35] trian loss: 58.747\n",
"[21, 35] trian loss: 51.676\n",
"[22, 35] trian loss: 52.033\n",
"[23, 35] trian loss: 64.550\n",
"[24, 35] trian loss: 50.678\n",
"[25, 35] trian loss: 61.712\n",
"[26, 35] trian loss: 67.763\n",
"[27, 35] trian loss: 55.693\n",
"[28, 35] trian loss: 52.792\n",
"[29, 35] trian loss: 49.608\n",
"[30, 35] trian loss: 51.096\n",
"[31, 35] trian loss: 61.687\n",
"[32, 35] trian loss: 58.796\n",
"[33, 35] trian loss: 60.086\n",
"[34, 35] trian loss: 40.097\n",
"[35, 35] trian loss: 53.928\n",
"[36, 35] trian loss: 60.938\n",
"[37, 35] trian loss: 56.628\n",
"[38, 35] trian loss: 52.814\n",
"[39, 35] trian loss: 48.954\n",
"[40, 35] trian loss: 44.150\n",
"[41, 35] trian loss: 47.912\n",
"[42, 35] trian loss: 50.415\n",
"[43, 35] trian loss: 61.354\n",
"[44, 35] trian loss: 48.144\n",
"[45, 35] trian loss: 50.795\n",
"[46, 35] trian loss: 52.000\n",
"[47, 35] trian loss: 43.142\n",
"[48, 35] trian loss: 46.815\n",
"[49, 35] trian loss: 51.042\n",
"[50, 35] trian loss: 42.865\n",
"[51, 35] trian loss: 43.683\n",
"[52, 35] trian loss: 36.757\n",
"[53, 35] trian loss: 59.870\n",
"[54, 35] trian loss: 42.820\n",
"[55, 35] trian loss: 40.697\n",
"[56, 35] trian loss: 45.121\n",
"[57, 35] trian loss: 41.697\n",
"[58, 35] trian loss: 51.218\n",
"[59, 35] trian loss: 45.756\n",
"[60, 35] trian loss: 35.945\n",
"[61, 35] trian loss: 39.103\n",
"[62, 35] trian loss: 46.270\n",
"[63, 35] trian loss: 33.054\n",
"[64, 35] trian loss: 36.228\n",
"[65, 35] trian loss: 40.598\n",
"[66, 35] trian loss: 46.197\n",
"[67, 35] trian loss: 47.982\n",
"[68, 35] trian loss: 45.234\n",
"[69, 35] trian loss: 41.814\n",
"[70, 35] trian loss: 35.128\n",
"[71, 35] trian loss: 43.713\n",
"[72, 35] trian loss: 35.335\n",
"[73, 35] trian loss: 52.862\n",
"[74, 35] trian loss: 45.155\n",
"Finished Training\n",
"test loss: 48.523\n"
]
}
],
"source": [
"import random\n",
"import statistics\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import torch.optim as optim\n",
"\n",
"from dataset import *\n",
"from models import *\n",
"\n",
"BATCH_SIZE = 32\n",
"TEST_STEP = 5\n",
"\n",
"trainset = IMPAXDataset('/shares/Public/IMPAX/train')\n",
"testset = IMPAXDataset('/shares/Public/IMPAX/train')\n",
"\n",
"# print(len(trainset))\n",
"# exit()\n",
"\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,\n",
" shuffle=True, num_workers=6)\n",
"\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=TEST_STEP,\n",
" shuffle=True, num_workers=6)\n",
"\n",
"# setting device on GPU if available, else CPU\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"print('Using device:', device)\n",
"print()\n",
"\n",
"#Additional Info when using cuda\n",
"if device.type == 'cuda':\n",
" print(torch.cuda.get_device_name(0))\n",
" print('Memory Usage:')\n",
" print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')\n",
" print('Cached: ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')\n",
"\n",
"# device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"net = N90().to(device)\n",
"\n",
"# criterion = nn.MSELoss(reduction='sum')\n",
"criterion = nn.MSELoss()\n",
"# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
"# optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)\n",
"\n",
"# optimizer = optim.Adam(net.parameters(), lr=0.01)\n",
"optimizer = optim.Adam(net.parameters())\n",
"\n",
"\n",
"# for epoch in range(3): # 训练所有!整套!数据 3 次\n",
"# for step, (batch_x, batch_y) in enumerate(trainloader): # 每一步 loader 释放一小批数据用来学习\n",
"# # 假设这里就是你训练的地方...\n",
"\n",
"# # 打出来一些数据\n",
"# print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',\n",
"# batch_x.numpy(), '| batch y: ', batch_y.numpy())\n",
"# exit()\n",
"\n",
"\n",
"train_loss = []\n",
"\n",
"for epoch in range(99): # loop over the dataset multiple times\n",
"\n",
" running_loss = 0.0\n",
" for i, data in enumerate(trainloader, 0):\n",
" # get the inputs; data is a list of [inputs, labels]\n",
" inputs, labels = data[0].to(device), data[1].to(device)\n",
"\n",
" # print(inputs[0])\n",
" # print(labels[0])\n",
" # exit()\n",
" # print(inputs)\n",
" # break\n",
" # # continue\n",
"\n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward + backward + optimize\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # print statistics\n",
" running_loss += loss.item()\n",
" \n",
"\n",
" print('[%d, %5d] trian loss: %.3f' % (epoch + 1, i, running_loss/i))\n",
" \n",
" train_loss.append(running_loss/i)\n",
" \n",
"# print(train_loss)\n",
"# print(train_loss[-5:])\n",
" \n",
" if epoch > 20:\n",
" if statistics.mean(train_loss[-10:-1]) > statistics.mean(train_loss[-20: -10]):\n",
" break\n",
"\n",
"print('Finished Training')\n",
"\n",
"\n",
"test_loss = []\n",
"\n",
"with torch.no_grad():\n",
" for data in testloader:\n",
" images, labels = data\n",
" outputs = net(images.to(device))\n",
" test_loss.append(criterion(labels, outputs.cpu()))\n",
"\n",
"mean = torch.mean(torch.stack(test_loss))\n",
"print('test loss: %.3f' % (mean.item()))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABIkAAAYpCAYAAADSMUyfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9yXNcd5bv973zzUyAoqQulaqKVRIHFQeJkzBwlNRVXd3x3JvnlcNv0e6FI2phe+/+Bxzx1t44oh3xop8XtsMbx+tFR7g7ukslThg5SKQ4iOIgiqWhBkpATnf2Avz+cO5lAiSAJJEAzieCAeDmnQDmL8+553zPOVZRFFAURVEURVEURVEURVG2NvZ634CiKIqiKIqiKIqiKIqy/miQSFEURVEURVEURVEURdEgkaIoiqIoiqIoiqIoiqJBIkVRFEVRFEVRFEVRFAUaJFIURVEURVEURVEURVGgQSJFURRFURRFURRFURQFzzFIZFnWv7Ms66ZlWbcty/q753UdRVGeHV2XijKY6NpUlMFE16aiDCa6NhXl+WEVRdH/k1qWA+AWgL8E8CWAaQD/oSiKT/t+MUVRngldl4oymOjaVJTBRNemogwmujYV5fnyvJRE4wBuF0VxpyiKGMD/DeDfP6drKYrybOi6VJTBRNemogwmujYVZTDRtakoz5HnFST6CYAH4ucvH29TFGX90HWpKIOJrk1FGUx0bSrKYKJrU1GeI+56XdiyrF8D+PXjH0ceb0NRFLBtG3meL3UcnkeJnKKsM38oiuIH630TQO+1qShbmIFdm7ZtP2E3aR8ty1q3+1SU583j9/rAvMnVbirKIoO8Nm1bZzYpW5OV2s3nFSR6COCn4ucdj7cZiqL4ewB/DwCWZRWO4yDPcwwPD6PZbAIA6ABnWQbXddFoNNBsNpcMICnKRqUoivsv4DJPXZeP76W0NvVhU9nKDOradBynqNVqyLIMQRAgjmMA5eBQURSo1WqIokiTK8qmotvtvqhLrcpuvphbU5QtzarsZqPRALDwjCntYlEUsCyrJEZQu6lsJlqt1or2f17h1GkAb1mWtdOyLB/AfwvgH5famYEgAJifny8tSsdxYFkW0jTF999/r9lRRVk9K1qXiqK8MFa1NunQygAR/wELDq58mFb7qSgrRu2mogwmq1qbeZ7DsiyjvC2KAnmew3Ec8z3/KcpW5rkoiYqiSC3L+p8A/H8AHAD/qSiKa8vsD9/3kSQJqChyXRdpmhrn1/M8JEmii1ZRVslK16WiKC+G1a5N2kOZAQWALMtKTi+wkDXNskwDRYqyAtRuKspgstq1yedJ2kLP85BlGeI4hm3bpnRby9KUrY41CFI6SnOvXLmCZrOJMAyRpilOnDiBPM8xOztrei6Mjo6q/E/ZdBRFMVsUxeh630cVLTdTtjqDujYdxynCMASwECRiQChJEuPoMngkkysymKQoz5Olekj2o7dkt9tFlmUD+UamT1v9HaXCr7pNUTYTg9STSOK6bhGGIR48eIA0TdFoNDA3N4dDhw6h0+ng1q1bCIIAALBr1y6kaarBImXT0Gq1VmQ3161xdS/iOMYHH3yANE2NExEEAUZGRuA4DrIsw/T0NMbGxozSyLKsZRtdK1sbBhf5/gEWnDLXdRFFUem9s1TjdM/zAABJkjxx/iAIEEWRyUzwfUloXFzXRRzHpdfl9TTwqSjKSuFnFZVDRVGUPq9kppQ/6+eO8ixUfSvLsuA4jnkfye29cBwHAIyfBiwq3hzHQZqmxi5T5SavzXNwv43q4/X6+1T7hmmgSFFeDEVRIMsyzM3N4cCBA3BdF1mWIcsy1Go17N+/3yRabt++jQMHDphnBX7u6XpVXjTrZScGIjxKp4GlZcDiHySKIgALzgKdjkuXLiFNU9TrdeMU8zVFkQRBYB6eWH6RZRmiKEK9Xkee5/B93zw4hWEIy7LADD2w8LCVJAmGh4cBwLwWhiGiKHrCcZZZBz6U0QFO0xSe55Wut1GdX0VR1hd+frmuC8/z4LouXHch9yMbb/KzhoojYCFwrRlSpReu65bePwwWJUlibKrrusaO+b4P27ZNgBKAefBqNBooisLs7/u+yc4zYcJrAOW+WryHLMtKPp68zmZAg7WK8mKwLAue55nPIPrtssWJZVlIkgRxHOPy5cvm+VN+TinKSlhKUbvc61XF6XrYiYF4tzOyOzw8jImJCUxPT5vtvu8DWFAZOY4Dz/NM4Kjb7cK2bXS73ZJ6Q1EIH4qYjUzT1Lyn2u12aT9g4T1VFAWiKMLw8LB54AKAZrOJ4eFhExhiQ1hmQhmwZDAKKJd2yG0Mgna7XQ1wKoqyKhjwZnBIfi6x5MxxHLiui1qtZgJGdII1QK0shcyaS3V3t9stNX0FgCiKkOe5UdVKh7bVapmek2yyzqQN95M/97o+UHaYe6l6BxX+foqirD8M9Lz66qu4fv06Jicnn1AqpmkK13VRr9fxu9/9zjyjrteDurLx6ZUE4XYKZUj1Nf4sv/baLo+v7rPq+17T0X2Cv8ShQ4cwOjqKPM9x7dpC7zE6CWEYIo5jdLtdhGFosll5nqNWq6n8T+lJmqbIsgy+7yPLMvM+kriuazKhQRCYeuT5+flS8LEoCszPzwOAGTvN7CmPoaHJsqwkY2VJyPDwMOI4hud5JoAlZfaKoijPCh+u0zRFnudGLcSAdbWRteu65uGeCkpFqZJlWSlr7vu++bkayKl+jeP4iYAPgzpSUcRytqpDzISLvB79Pb62UbL51UmDiqKsP1mWYefOnTh8+DCiKMLU1BQ8zyslVrIsQ6fTwcsvv2we4LU0VHkWpD2sJglk8IbvKz4fSvWu3MbtvZRHMnDJ1+nj9cNODoSllYqhIAhw8uRJzM/PG/mfHOPr+z6OHDmCKIrMH7nT6ZQUH4pC6FhyakGn0zElZcDC+63dbpuAYxRF6HQ6ABaVP57nmQcubvN937wHi6Iwx1CxRDkrM/x8aJNBps0mmVcU5cXieZ55gI/j2Chu+TAtnYU//elPpYA0ZfWKUkUmO1j2z6SHzLjTgWUyRI6P5r7VQJJUATG4KQOdcmIfz0v1OAOdG42nrTNdh4ry4mAFSrfbxalTp8znDD/DHMcxid/Tp0+b509VEinPQtV+AotJENo5+mzcR9qAqoJIJlP4/qNCnMczICSn2dL+ruV9OxDWltknPszzF0rTFGEYotvt4vLly+h0OhgdXRwywwd1QLM0Sm8YaOT76OOPP0an04HneRgdHTUS+I8++gj1eh2jo6OYmJiA53loNpuo1+tGeeS6LrrdLoaGhnD48GHjrPLB69KlSzh69Cimp6fNdD46xK7r4ty5c7AsC2NjY6UyOEVRlLVCmylLhHzfNw5JEASmNFbLzJTlYCa9OsyBgxqYUa+WWlcdYdnzA0DJaZXlalT78lhZ/iGbrMvg50anV3ZYUVbKUpPyeikNlqPX+5DrfKn36kZ+D1PNXxQFtm/fbrZbloWbN28ijmMcPHjQiBB00IPyrFBQ8M033+Dhw4eo1+vYs2cPfN+H53n45JNP0Gg08MYbb+D27dvwPA/dbtdUvLCHJLCQzNu5c6dRgdMW//73v8drr72Gzz//HACwd+/e0jCJzz77DK7r4s033zQK3tVUrQxEkIgLL8/z0pQoYCHTNDk5iaIocOLECQAoNR8GUGoMrCgSqSQCFrKYdED5mm3bqNVqSNMUaZqaII5t25icnITruhgZGQGw8F5jzywuZun0AuUmnITONGE/orVGeRVF2ZrwgVo2EOZ2YFEJKZvm09Hl59VmeNhW+k8ve8Ree5xExmaudIizLDNJPqpkmQSRZRwy00lbadu26Vsks6ZSPk/7KbOzG5WN/HCtbDyWCh5Vt20V+FnGxIrcdu/ePeR5jrffftu0jaiqNBRlKWQ8otvtYvv27Zibm0Oj0TD9+Ph9lmV46623jGr25s2bcF0XO3fuNO+527dvG98NWFTfttvt0sRuQp+u2uNotb7eQJSbAcDFixcBLDgAU1NTGB8fN9MvbNvG0aNHTe+Y6ihzbQyorJRjx44
"text/plain": [
"<Figure size 1440x2160 with 25 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.rcParams['figure.figsize'] = [20, 30]\n",
"\n",
"dataiter = iter(testloader)\n",
"# dataiter = iter(trainloader)\n",
"images, labels = dataiter.next()\n",
"output = net(images.to(device))\n",
"\n",
"# torch.set_printoptions(profile=\"full\")\n",
"# print(labels[0])\n",
"# torch.set_printoptions(profile=\"default\")\n",
"\n",
"for j in range(TEST_STEP):\n",
" out = output[j]\n",
" plt.subplot(TEST_STEP,5,j*5+1)\n",
" plt.imshow(images[j][0,:,:], cmap='gray')\n",
" \n",
" plt.subplot(TEST_STEP,5,j*5+2)\n",
" plt.imshow(labels[j][0,:,:], cmap='gray')\n",
" plt.subplot(TEST_STEP,5,j*5+3)\n",
" plt.imshow(out[0,:,:].cpu().detach().numpy(), cmap='gray')\n",
" \n",
" plt.subplot(TEST_STEP,5,j*5+4)\n",
" plt.imshow(labels[j][1,:,:], cmap='gray')\n",
" plt.subplot(TEST_STEP,5,j*5+5)\n",
" plt.imshow(out[1,:,:].cpu().detach().numpy(), cmap='gray')\n",
" \n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}