Skip to content

Commit

Permalink
edit weight_drop.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tonghuikang committed Jun 7, 2019
1 parent 668b94d commit 2ca28cf
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 24 deletions.
190 changes: 168 additions & 22 deletions logs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
Expand All @@ -151,36 +153,180 @@
"Model total parameters: 13787650\n",
"/home/hkmac/.local/lib/python3.6/site-packages/torch/nn/modules/rnn.py:522: RuntimeWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().\n",
" self.dropout, self.training, self.bidirectional, self.batch_first)\n",
"| epoch 1 | 200/ 261 batches | lr 0.00200 | ms/batch 162.74 | loss 2.98 | ppl 19.62 | bpc 4.294\n",
"| epoch 1 | 200/ 261 batches | lr 0.00200 | ms/batch 159.46 | loss 2.47 | ppl 11.84 | bpc 3.566\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 1 | time: 49.75s | valid loss 2.14 | valid ppl 8.52 | valid bpc 3.091\n",
"| end of epoch 1 | time: 49.22s | valid loss 1.81 | valid ppl 6.10 | valid bpc 2.610\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"Traceback (most recent call last):\n",
" File \"main.py\", line 240, in <module>\n",
" train()\n",
" File \"main.py\", line 196, in train\n",
" output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)\n",
" File \"/home/hkmac/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\", line 493, in __call__\n",
" result = self.forward(*input, **kwargs)\n",
" File \"/home/hkmac/awd-lstm-lm/model.py\", line 81, in forward\n",
" raw_output, new_h = rnn(raw_output, hidden[l])\n",
" File \"/home/hkmac/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\", line 493, in __call__\n",
" result = self.forward(*input, **kwargs)\n",
" File \"/home/hkmac/awd-lstm-lm/weight_drop.py\", line 46, in forward\n",
" self._setweights()\n",
" File \"/home/hkmac/awd-lstm-lm/weight_drop.py\", line 43, in _setweights\n",
" setattr(self.module, name_w, w)\n",
" File \"/home/hkmac/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\", line 558, in __setattr__\n",
" .format(torch.typename(value), name))\n",
"TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight_hh_l0' (torch.nn.Parameter or None expected)\n"
"| epoch 2 | 200/ 261 batches | lr 0.00200 | ms/batch 157.20 | loss 1.67 | ppl 5.31 | bpc 2.409\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 2 | time: 49.51s | valid loss 1.36 | valid ppl 3.88 | valid bpc 1.956\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 3 | 200/ 261 batches | lr 0.00200 | ms/batch 157.01 | loss 1.39 | ppl 4.00 | bpc 1.999\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 3 | time: 49.49s | valid loss 1.24 | valid ppl 3.46 | valid bpc 1.793\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 4 | 200/ 261 batches | lr 0.00200 | ms/batch 158.76 | loss 1.29 | ppl 3.64 | bpc 1.865\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 4 | time: 49.22s | valid loss 1.19 | valid ppl 3.29 | valid bpc 1.719\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 5 | 200/ 261 batches | lr 0.00200 | ms/batch 160.15 | loss 1.24 | ppl 3.47 | bpc 1.795\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 5 | time: 49.51s | valid loss 1.16 | valid ppl 3.18 | valid bpc 1.671\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 6 | 200/ 261 batches | lr 0.00200 | ms/batch 158.86 | loss 1.21 | ppl 3.36 | bpc 1.748\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 6 | time: 49.46s | valid loss 1.14 | valid ppl 3.11 | valid bpc 1.639\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 7 | 200/ 261 batches | lr 0.00200 | ms/batch 159.45 | loss 1.19 | ppl 3.28 | bpc 1.714\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 7 | time: 49.53s | valid loss 1.12 | valid ppl 3.06 | valid bpc 1.616\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 8 | 200/ 261 batches | lr 0.00200 | ms/batch 158.21 | loss 1.17 | ppl 3.22 | bpc 1.687\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 8 | time: 49.42s | valid loss 1.11 | valid ppl 3.02 | valid bpc 1.597\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 9 | 200/ 261 batches | lr 0.00200 | ms/batch 161.16 | loss 1.15 | ppl 3.17 | bpc 1.666\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 9 | time: 49.61s | valid loss 1.10 | valid ppl 2.99 | valid bpc 1.580\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 10 | 200/ 261 batches | lr 0.00200 | ms/batch 160.91 | loss 1.14 | ppl 3.13 | bpc 1.648\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 10 | time: 49.20s | valid loss 1.09 | valid ppl 2.96 | valid bpc 1.567\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 11 | 200/ 261 batches | lr 0.00200 | ms/batch 156.48 | loss 1.13 | ppl 3.10 | bpc 1.634\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 11 | time: 49.59s | valid loss 1.08 | valid ppl 2.94 | valid bpc 1.557\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 12 | 200/ 261 batches | lr 0.00200 | ms/batch 158.18 | loss 1.12 | ppl 3.08 | bpc 1.622\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 12 | time: 49.47s | valid loss 1.07 | valid ppl 2.92 | valid bpc 1.546\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 13 | 200/ 261 batches | lr 0.00200 | ms/batch 158.63 | loss 1.11 | ppl 3.05 | bpc 1.608\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 13 | time: 49.32s | valid loss 1.07 | valid ppl 2.90 | valid bpc 1.538\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 14 | 200/ 261 batches | lr 0.00200 | ms/batch 160.47 | loss 1.11 | ppl 3.03 | bpc 1.597\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 14 | time: 49.60s | valid loss 1.06 | valid ppl 2.89 | valid bpc 1.529\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 15 | 200/ 261 batches | lr 0.00200 | ms/batch 159.30 | loss 1.10 | ppl 3.01 | bpc 1.588\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 15 | time: 49.45s | valid loss 1.06 | valid ppl 2.87 | valid bpc 1.522\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 16 | 200/ 261 batches | lr 0.00200 | ms/batch 158.45 | loss 1.09 | ppl 2.99 | bpc 1.578\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 16 | time: 49.56s | valid loss 1.05 | valid ppl 2.86 | valid bpc 1.514\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 17 | 200/ 261 batches | lr 0.00200 | ms/batch 157.30 | loss 1.09 | ppl 2.97 | bpc 1.570\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 17 | time: 49.59s | valid loss 1.04 | valid ppl 2.84 | valid bpc 1.507\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 18 | 200/ 261 batches | lr 0.00200 | ms/batch 158.68 | loss 1.08 | ppl 2.95 | bpc 1.562\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 18 | time: 49.52s | valid loss 1.04 | valid ppl 2.83 | valid bpc 1.502\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 19 | 200/ 261 batches | lr 0.00200 | ms/batch 158.27 | loss 1.08 | ppl 2.94 | bpc 1.554\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 19 | time: 49.46s | valid loss 1.04 | valid ppl 2.82 | valid bpc 1.498\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 20 | 200/ 261 batches | lr 0.00200 | ms/batch 160.60 | loss 1.07 | ppl 2.93 | bpc 1.549\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 20 | time: 49.14s | valid loss 1.03 | valid ppl 2.81 | valid bpc 1.492\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 21 | 200/ 261 batches | lr 0.00200 | ms/batch 158.56 | loss 1.07 | ppl 2.91 | bpc 1.543\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 21 | time: 49.30s | valid loss 1.03 | valid ppl 2.81 | valid bpc 1.489\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 22 | 200/ 261 batches | lr 0.00200 | ms/batch 160.03 | loss 1.07 | ppl 2.90 | bpc 1.537\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 22 | time: 49.15s | valid loss 1.03 | valid ppl 2.80 | valid bpc 1.484\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 23 | 200/ 261 batches | lr 0.00200 | ms/batch 161.16 | loss 1.06 | ppl 2.89 | bpc 1.533\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 23 | time: 49.62s | valid loss 1.03 | valid ppl 2.79 | valid bpc 1.480\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 24 | 200/ 261 batches | lr 0.00200 | ms/batch 159.20 | loss 1.06 | ppl 2.88 | bpc 1.527\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 24 | time: 49.54s | valid loss 1.02 | valid ppl 2.78 | valid bpc 1.477\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 25 | 200/ 261 batches | lr 0.00200 | ms/batch 160.36 | loss 1.06 | ppl 2.88 | bpc 1.524\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 25 | time: 49.49s | valid loss 1.02 | valid ppl 2.77 | valid bpc 1.472\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 26 | 200/ 261 batches | lr 0.00200 | ms/batch 161.67 | loss 1.05 | ppl 2.87 | bpc 1.519\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 26 | time: 49.16s | valid loss 1.02 | valid ppl 2.77 | valid bpc 1.470\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n",
"| epoch 27 | 200/ 261 batches | lr 0.00200 | ms/batch 159.73 | loss 1.05 | ppl 2.86 | bpc 1.514\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 27 | time: 49.44s | valid loss 1.02 | valid ppl 2.76 | valid bpc 1.467\n",
"-----------------------------------------------------------------------------------------\n",
"Saving model (new best validation)\n"
]
}
],
"source": [
"!python3 -u main.py --epochs 500 --nlayers 3 --emsize 200 --nhid 1000 --alpha 0 --beta 0 --dropoute 0 --dropouth 0.25 --dropouti 0.1 --dropout 0.1 --wdrop 0.5 --wdecay 1.2e-6 --bptt 150 --batch_size 128 --optimizer adam --lr 2e-3 --data data/pennchar --save PTBC.pt --when 300 400"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing WeightDrop\n",
"=-=-=-=-=-=-=-=-=-=\n",
"Testing WeightDrop with Linear\n",
"Applying weight drop of 0.9 to weight\n",
"All items should be different\n",
"Run 1: [tensor(1.2150, device='cuda:0'), tensor(11.1652, device='cuda:0')]\n",
"Run 2: [tensor(8.8544, device='cuda:0'), tensor(-7.6523, device='cuda:0')]\n",
"---\n",
"Testing WeightDrop with LSTM\n",
"Applying weight drop of 0.9 to weight_hh_l0\n",
"/home/hkmac/.local/lib/python3.6/site-packages/torch/nn/modules/rnn.py:522: RuntimeWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().\n",
" self.dropout, self.training, self.bidirectional, self.batch_first)\n",
"First timesteps should be equal, all others should differ\n",
"Run 1: [tensor(0.3290, device='cuda:0'), tensor(-0.0117, device='cuda:0')]\n",
"Run 2: [tensor(0.3290, device='cuda:0'), tensor(0.0147, device='cuda:0')]\n",
"---\n"
]
}
],
"source": [
"!python3 weight_drop.py"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
4 changes: 2 additions & 2 deletions weight_drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def _setweights(self):
mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
if raw_w.is_cuda: mask = mask.cuda()
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
w = mask.expand_as(raw_w) * raw_w
w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w)
else:
w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
w = torch.nn.Parameter(torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training))

This comment has been minimized.

Copy link
@vcramach

vcramach Nov 19, 2019

This really is necessary. I was running the LSTM implementation on Colab and got the same error. Made this change and it worked like a charm!

setattr(self.module, name_w, w)

def forward(self, *args):
Expand Down

0 comments on commit 2ca28cf

Please sign in to comment.