192 lines
693 KiB
Text
192 lines
693 KiB
Text
|
|
{
|
||
|
|
"cells": [
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": 1,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [
|
||
|
|
{
|
||
|
|
"name": "stdout",
|
||
|
|
"output_type": "stream",
|
||
|
|
"text": [
|
||
|
|
"[1, 23] loss: 20962.588\n",
|
||
|
|
"[2, 23] loss: 5299.285\n",
|
||
|
|
"[3, 23] loss: 3746.179\n",
|
||
|
|
"[4, 23] loss: 3133.387\n",
|
||
|
|
"[5, 23] loss: 3026.136\n",
|
||
|
|
"[6, 23] loss: 2714.172\n",
|
||
|
|
"[7, 23] loss: 2451.235\n",
|
||
|
|
"[8, 23] loss: 2314.406\n",
|
||
|
|
"[9, 23] loss: 2621.939\n",
|
||
|
|
"[10, 23] loss: 2520.008\n",
|
||
|
|
"[11, 23] loss: 2269.186\n",
|
||
|
|
"[12, 23] loss: 2110.618\n",
|
||
|
|
"[13, 23] loss: 2268.559\n",
|
||
|
|
"[14, 23] loss: 2196.927\n",
|
||
|
|
"[15, 23] loss: 2362.194\n",
|
||
|
|
"[16, 23] loss: 2588.966\n",
|
||
|
|
"[17, 23] loss: 2107.856\n",
|
||
|
|
"[18, 23] loss: 2017.883\n",
|
||
|
|
"[19, 23] loss: 2223.861\n",
|
||
|
|
"[20, 23] loss: 1975.168\n",
|
||
|
|
"Finished Training\n"
|
||
|
|
]
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"source": [
|
||
|
|
"import random\n",
|
||
|
|
"\n",
|
||
|
|
"import matplotlib.pyplot as plt\n",
|
||
|
|
"import torch.optim as optim\n",
|
||
|
|
"\n",
|
||
|
|
"from dataset import *\n",
|
||
|
|
"from models import *\n",
|
||
|
|
"\n",
|
||
|
|
"TEST_STEP = 4\n",
|
||
|
|
"\n",
|
||
|
|
"trainset = IMPAXDataset('/shares/Public/IMPAX/train')\n",
|
||
|
|
"testset = IMPAXDataset('/shares/Public/IMPAX/train')\n",
|
||
|
|
"\n",
|
||
|
|
"# print(len(trainset))\n",
|
||
|
|
"# exit()\n",
|
||
|
|
"\n",
|
||
|
|
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,\n",
|
||
|
|
" shuffle=True, num_workers=2)\n",
|
||
|
|
"\n",
|
||
|
|
"testloader = torch.utils.data.DataLoader(testset, batch_size=TEST_STEP,\n",
|
||
|
|
" shuffle=True, num_workers=2)\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
|
"net = N90().to(device)\n",
|
||
|
|
"\n",
|
||
|
|
"# criterion = nn.MSELoss(reduction='sum')\n",
|
||
|
|
"criterion = nn.MSELoss()\n",
|
||
|
|
"# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
|
||
|
|
"# optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.5)\n",
|
||
|
|
"\n",
|
||
|
|
"optimizer = optim.Adam(net.parameters(), lr=0.01)\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"# for epoch in range(3): # 训练所有!整套!数据 3 次\n",
|
||
|
|
"# for step, (batch_x, batch_y) in enumerate(trainloader): # 每一步 loader 释放一小批数据用来学习\n",
|
||
|
|
"# # 假设这里就是你训练的地方...\n",
|
||
|
|
"\n",
|
||
|
|
"# # 打出来一些数据\n",
|
||
|
|
"# print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',\n",
|
||
|
|
"# batch_x.numpy(), '| batch y: ', batch_y.numpy())\n",
|
||
|
|
"# exit()\n",
|
||
|
|
"\n",
|
||
|
|
"\n",
|
||
|
|
"for epoch in range(20): # loop over the dataset multiple times\n",
|
||
|
|
"\n",
|
||
|
|
" running_loss = 0.0\n",
|
||
|
|
" for i, data in enumerate(trainloader, 0):\n",
|
||
|
|
" # get the inputs; data is a list of [inputs, labels]\n",
|
||
|
|
" inputs, labels = data[0].to(device), data[1].to(device)\n",
|
||
|
|
"\n",
|
||
|
|
" # print(inputs[0])\n",
|
||
|
|
" # print(labels[0])\n",
|
||
|
|
" # exit()\n",
|
||
|
|
" # print(inputs)\n",
|
||
|
|
" # break\n",
|
||
|
|
" # # continue\n",
|
||
|
|
"\n",
|
||
|
|
" # zero the parameter gradients\n",
|
||
|
|
" optimizer.zero_grad()\n",
|
||
|
|
"\n",
|
||
|
|
" # forward + backward + optimize\n",
|
||
|
|
" outputs = net(inputs)\n",
|
||
|
|
" loss = criterion(outputs, labels)\n",
|
||
|
|
" loss.backward()\n",
|
||
|
|
" optimizer.step()\n",
|
||
|
|
"\n",
|
||
|
|
" # print statistics\n",
|
||
|
|
" running_loss += loss.item()\n",
|
||
|
|
"\n",
|
||
|
|
" if i % 2000 == 1999: # print every 2000 mini-batches\n",
|
||
|
|
" print('[%d, %5d] loss: %.3f' %\n",
|
||
|
|
" (epoch + 1, i + 1, running_loss / 2000))\n",
|
||
|
|
" running_loss = 0.0\n",
|
||
|
|
"\n",
|
||
|
|
" print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss))\n",
|
||
|
|
"\n",
|
||
|
|
"print('Finished Training')\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": 3,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [
|
||
|
|
{
|
||
|
|
"data": {
|
||
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABIYAAAZnCAYAAAAWaHwOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOy9eZAd1XUG/vXbt3kzEkJiMXacVCqrWARaEKAFJMQisRiwkTBLjJ3grQAn5TgxLsex40qc2MQ2dmzAFhgQm20hsSMhoX0FISin8vs59XPFC5uk0Sxv3/r3x/N35vSd7n5vRgIL635VU29ev+6739vnfPeccx3XdWFhYWFhYWFhYWFhYWFhYWFhcfQh8rsugIWFhYWFhYWFhYWFhYWFhYXF7waWGLKwsLCwsLCwsLCwsLCwsLA4SmGJIQsLCwsLCwsLCwsLCwsLC4ujFJYYsrCwsLCwsLCwsLCwsLCwsDhKYYkhCwsLCwsLCwsLCwsLCwsLi6MUlhiysLCwsLCwsLCwsLCwsLCwOErxthFDjuNc4DjO/+M4zv86jvO5tysfCwsLC4t3H+w7wsLCwsIiDPY9YWFhYfHOwXFd9/An6jhRAP8vgIUAfg1gF4Clruv+92HPzMLCwsLiXQX7jrCwsLCwCIN9T1hYWFi8s3i7LIZmAPhf13X/P9d1awAeAnDp25SXhYWFhcW7C/YdYWFhYWERBvuesLCwsHgHEXub0j0RwK/U918DmBl0s+M4ruM4h7UAruvCTLPba4czz7HmZVpwHe52CcvvUPIy63Yo7coydfv8WOvAsrmui0ikzY22Wi1EIhG0Wi3PvbxPf3ccR+57u/vnd4mx9uF47ifG09djee5IgTm2u10TXNd9d1W0M8b0jgDa74m3tUQWFhYW72LY9wQQiURcynUWFhYWFiNotVpotVqh74m3ixjqCMdx/hrAX/N7IpHQv/kqS7xGZd1PseQ113URi8XgOA4ajYZc0y8MkgEdyun5ZDpmGV3XRavVQjQa9S1jq9UapdTyd7968ZofWREEP7fAIKUzrA3NegYRavrTcZxRbWum22w2R7W3zj8SiUg7ApC2HEu9/GCOjVarhWQyCcdxUKvVAADJZBLlchnxeFzKWC6XkUgkPGWORqOeurVarVHjwFT+SSTpPvUb467rotlsjipzEBnCe8Pqqv9nPSKRiPzxGT1PWC+zzLyu89L9xfml66bT7FSvbhC2Lpjfg8aI2TbmOAtKr5vyOo4j/WKuE2wrPVfCiC7+VqlUOub7+wrzPWFhYWFhYaGh3xORSAS9vb1+98j/WqY1/9f3mvJxmAwwls2tbjZRu7kvrHxBsvvhxFg3cf2eC9LjguQuv34IuuZXxvGWWZdd52fqb/p/nYepKwBtfcKU44Pq7icbBl0PategfMw0zeeC0g2CX/276edu7/frw7HoE2MZd+PdHO/0vF/du9EH/NaETmNc5zUwMNCxHm8XMfQbACep7+/57TWB67p3ArgTaDP8b1M5LCwsLCyOPHR8RwDe94S1GLKwsLA4qjDm90QsFrPvCQsLC4tx4u0ihnYB+GPHcd6P9iJ+NYBl3TyoLRkIP8sa/Zu+xu/RaFTS0RYBtGgA4LF0MNMzofMMskiIRqO+9wX9ry0LeJ1WM9oiZzyWFZ3gxz76WV35WbaYlkF+uxIm+2mmRSsafb/ZrubneNuBz8Vi7eEejUZRKpU81yqVilgD0UIjk8mgXq+L1Us0GkWlUsGECRPknjBrFf1/s9kcda9f/cMs2Mw2TCaToyzROF6YNq3i9HNBFkOmNZRZH79rQRZN3VjxjIftD7K0CcrXzCfsmU73dZOmHvd+FkO6DwB/S8K3c4fvCMK43xEWFhYWFkcFDst74lAsC/Qz3VimdMq/W1kiLA19zbTeD/r97USQNVOn/M3fwnSx8VglhT3TTZ92k7Zf/2u5m/pDLBZDo9EQfZRW+alUCtVqNbDMfrpUWH061SGszYPq2cmC6XChm/qb94+nLGEWR4erjmN51m8sdVonwqzGDlVnfluIIdd1G47jfArAswCiAH7ouu7PQu73raQmbsx7zGsmyaNjwOhBYCrI9Xrdt0xmw3aKPUNos8BOndLNgO/WjUyXOeh7p7z9JoZ5TZMJ+tPPTNJ8GYa9nM2+8SuX33NB8LuPbakXat3O8Xgc5XIZqVRK7nMcR8ijRqOBY445Bvv27ZM0E4lEYLmDyuK6bZcxPVYikQii0ajvvUF1JhHZbDZD24j36TFt5s/fea3RaPgSlM1mU9rDr55+RFUnBAk03VzrBp1efOYcCxIournmt075PWNe1+tVGPn1+4KxviMsLCwsLI4uHI73RJCSFPae9nv3dqOsdUI3G89jlZ+6qd87gW6IID+MlbTz0xW7lSFNpTmovJ1IPv07w33ocCjcCG42m6JPAG1ZnPeVSiXROcLK1KlNxlp+874gUkGnY7Z5mF4ynnKMJQ2/svB/v/vDrnfTtmHkcCd9tRM6kXndklSHe+6/LcfVj7kQjuNyssRisVGxZjgIqZBqMiJoUJhMJ+/RFhlUdE2ySIN5d3oB8HdtBdSp0x3HazVDhlnXq5u8g9IP+42LWZCFiIY5OM2Fw4/YCyqLzs9cXMmk+8XZMZ8xnw0rs3lfs9lEPB5HIpHA8PCw/JZIJFCr1RCPxwFArIWy2SyANllSqVSQTCaRyWQAAIVCwVP2oHqbbdJsNj3xd6LRKGKx2CiyJigtfrIPwxYqs2x+pKPruh6LoUaj4YnBxTnJtjPb0y89fvrNw25f5J1+M+dH0O5GGLoVZvzq4QczrlLY/ZyHYfOnXC6j2WweFWZEYXCsK5mFhYVFINzfv+DTY0YsFnMZY6ibd7ufInyoO+6d8ghTLrslWMKUSP27HxlyuOBXhyDF3Q9+7RzUHn75hiFIoe+GTPIjAYLKR90hHo+LoUE0GkWxWEQ0GkU+nwfQluN6enpQLpcBtPWLarXq2aCm7NjNmDxcY7SbPtBl6ZRvt2P3cNwTdu9YxnwQ2Tge4tccZ0F5dSLkzDp1qzcFzRdeGxgYQKPRCK3Y7yz4tEYQkaAVYPPTbFxzsHKymQ1pki581nSr8SMgTPgNRDNIbxjzG+RG0i1L2Kk8h+M5kgrm/bp9zEHsV8+gfHX/mUy02bd+z42HKSWTf/DgQezZswdAe3GfNm0aMpmMuInt2LED8XgcxWIRALBo0SK4rotcLof9+/cDaLtz+aXvN6HDSErt4sjvnQKjm5ZCQfU32595aRdL8zk/y7CwF4G+R6dnpqOfCSp7GDkZJoiY46fbk+NoUdVp7IT1qd889/vNb33wI1q7KbeFhcU7g06Ki4adtxYWRw6ClEbznd3NvB6rjB2msAblF6YL+ZV5rDpKt/CTYzopu355hJEcJrRc1c29fvkHlSuorEFKud//+jndj9FoFLVaDWvWrAHQ1gumTp2KyZMnY3BwEACwbt06DA0NSfDfW2+9VYgkP0LIr55hCBu/neo5HnQaF93Ol/Hocp363m/8hNU/jBwbazt2085mOmPp47GuQePhBeyZjhYWFhYWFhYWFhYWFhYWFhZHKY4IiyGTBfezcOjm2HZzZ97vftPaxXRbM1lCHVw5aGefz+l7/Ophpu3HkJrBnFm+bhC2O9Dt7oef5YNpzWK2F9My3eIIv/gpZhl123Sqr8kGdxoXTJuWQNlsFq7rIp/Py7P1el0CT9OS5swzz8SmTZswa9YsAEA8Hkc8Hsfg4KD4DYdZp5hui7o9zSDjQe5EfmNOjx+2Ka81m83AgNQE3dbMNjfz9+sr3faMK2Vax4XtGHWDbnbAguak/mT5Ou2M+I25Tjtt5q5W2Jw209HXzNhcftZWFhYWRy66WWMsLCzeOZiW6FomYixHv/mqAwNr2UYf4NFJD/B7v5vW4n6WP2ZsQZ2/KecGWVlojMWKJgx+OkG39wMYVXdgtH6gr/sd1BKkrwTpFlo+DLM86taKJkyu4yEwiUQCxxxzDIB2/KBoNIpqtSq6wpIlS/DQQw/h4x//uDzL+ppl96unX9uZz3SysgpKn2B
|
||
|
|
"text/plain": [
|
||
|
|
"<Figure size 1440x2160 with 12 Axes>"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
"metadata": {
|
||
|
|
"needs_background": "light"
|
||
|
|
},
|
||
|
|
"output_type": "display_data"
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"source": [
|
||
|
|
"plt.rcParams['figure.figsize'] = [20, 30]\n",
|
||
|
|
"\n",
|
||
|
|
"# dataiter = iter(testloader)\n",
|
||
|
|
"dataiter = iter(trainloader)\n",
|
||
|
|
"images, labels = dataiter.next()\n",
|
||
|
|
"output = net(images.to(device))\n",
|
||
|
|
"\n",
|
||
|
|
"for j in range(TEST_STEP):\n",
|
||
|
|
" plt.subplot(TEST_STEP,3,j*3+1)\n",
|
||
|
|
" plt.imshow(images[j][0,:,:], cmap='gray')\n",
|
||
|
|
" plt.subplot(TEST_STEP,3,j*3+2)\n",
|
||
|
|
" plt.imshow(labels[j][0,:,:], cmap='gray')\n",
|
||
|
|
" plt.subplot(TEST_STEP,3,j*3+3)\n",
|
||
|
|
" out = output[j]\n",
|
||
|
|
"# print(out)\n",
|
||
|
|
" plt.imshow(out[0,:,:].cpu().detach().numpy(), cmap='gray')\n",
|
||
|
|
"plt.show()\n"
|
||
|
|
]
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": null,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [],
|
||
|
|
"source": []
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"cell_type": "code",
|
||
|
|
"execution_count": null,
|
||
|
|
"metadata": {},
|
||
|
|
"outputs": [],
|
||
|
|
"source": []
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"metadata": {
|
||
|
|
"kernelspec": {
|
||
|
|
"display_name": "Python 3",
|
||
|
|
"language": "python",
|
||
|
|
"name": "python3"
|
||
|
|
},
|
||
|
|
"language_info": {
|
||
|
|
"codemirror_mode": {
|
||
|
|
"name": "ipython",
|
||
|
|
"version": 3
|
||
|
|
},
|
||
|
|
"file_extension": ".py",
|
||
|
|
"mimetype": "text/x-python",
|
||
|
|
"name": "python",
|
||
|
|
"nbconvert_exporter": "python",
|
||
|
|
"pygments_lexer": "ipython3",
|
||
|
|
"version": "3.7.4"
|
||
|
|
}
|
||
|
|
},
|
||
|
|
"nbformat": 4,
|
||
|
|
"nbformat_minor": 2
|
||
|
|
}
|