Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch.github.io.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornicholas-leonard <nick@nikopia.org>2016-07-22 22:13:27 +0300
committernicholas-leonard <nick@nikopia.org>2016-07-22 22:13:27 +0300
commitd20bca37c2e99dc7fc4f8e906324483ebc050c8b (patch)
tree90172f7b408ecd698215386dab27f741229e59f3
parentd68e87389470e2c3e1a739169e95d9e62544d617 (diff)
added figures
-rw-r--r--blog/_posts/2016-05-11-nce.md73
-rw-r--r--blog/_posts/images/LM-Linear.pngbin0 -> 10672 bytes
-rw-r--r--blog/_posts/images/LM-NCE.pngbin0 -> 5671 bytes
-rw-r--r--blog/_posts/images/LM-params.pngbin0 -> 8196 bytes
4 files changed, 53 insertions, 20 deletions
diff --git a/blog/_posts/2016-05-11-nce.md b/blog/_posts/2016-05-11-nce.md
index 8492d70..12c94fc 100644
--- a/blog/_posts/2016-05-11-nce.md
+++ b/blog/_posts/2016-05-11-nce.md
@@ -9,7 +9,7 @@ picture: https://raw.githubusercontent.com/torch/torch.github.io/master/blog/_po
<!---# Language modeling a billion words -->
- * [Word versus character level language models](#nce.char)
+ * [Word versus character language models](#nce.char)
* [Recurrent neural network language models](#nce.rnnlm)
* [Loading the Google billion words dataset](#nce.gbw)
* [Building a multi-layer LSTM](#nce.lstm)
@@ -26,7 +26,7 @@ The enormity of the dataset caused us to contribute some novel open-source Torch
We also provide scripts so that you can train and evaluate your own language models.
<a name='nce.char'></a>
-## Word versus character level language models
+## Word versus character language models
In recent months you may have noticed increased interest in generative character-level
RNNLMs like [char-rnn](https://github.com/karpathy/char-rnn)
@@ -62,15 +62,16 @@ Progress isn't made by early risers. It's made by lazy men trying to find easier
After tokenization, the word-level model might view this sequence as containing 22 tokens.
On the other hand, the char-level will view this sequence as containing 102 tokens.
-This longer sequence makes the task of the char-level model harder than the word-level model, as it
-must take into account dependencies between more tokens over more time-steps.[[8]](#nce.ref)
+This longer sequence makes the task of the character model harder than the word model, as it
+must take into account dependencies between more tokens over more time-steps.
Another issue with character language models is that they need to learn spelling in
addition to syntax, semantics, etc.
+In any case, word language models will typically have lower error than character models.[[8]](#nce.ref)
-The main advantage of character over word level language models is that they
+The main advantage of character over word language models is that they
have a really small vocabulary. For example, the GBW dataset will contain approximately 800 characters
-compared to 800,000 words (after pruning low-frequency tokens). In practice this means that char-level models will
-require less memory and have faster inference than their word-level counterparts.
+compared to 800,000 words (after pruning low-frequency tokens). In practice this means that character models will
+require less memory and have faster inference than their word counterparts.
Another advantage is that they do not require tokenization as a preprocessing step.
<a name='nce.rnnlm'></a>
@@ -125,6 +126,22 @@ The main advantage is that LSTMs can learn dependencies between words seperated
It isn't as prone to the problems of vanishing gradients as the different gates can preserve the gradients during back-propagation.
To create a LM, the word embeddings (`W[x->h]x[t]` in eq.1) would be fed to the LSTM and the resulting hidden state would be fed to eq. 2.
+The error of language model is traditionally measured using perplexity.
+Perplexity is a measure of how surprised the model is to see a sequence of text.
+If you feed it in a sequence of words, and for each successive word the model is able to
+predict with high likelihood what word comes next, it will have low perplexity.
+If the next word in the sequence `s` of length `T` is indexed by `s[t]` and the model-inferred likelihood is `y[t]` such that
+the likelihood of that word is `y[t][s[t]]`, then the perplexity of that sequence of words is:
+
+```
+ log(y[1][s[1]) + log(y[2][s[2]) + ... + log(y[T][s[T])
+PPL(s,y) = exp( -------------------------------------------------------- )
+ -T
+```
+
+The lower the perplexity, the better.
+
+
<a name='nce.gbw'></a>
## Loading the Google billion words dataset
@@ -303,7 +320,7 @@ lm:add(lookup) -- input is seqlen x batchsize
A sub-class of `LookupTable`, we use the [LookupTableMaskZero](https://github.com/Element-Research/rnn#rnn.LookupTableMaskZero)
to learn word embeddings. The main difference is that it supports zero-indexes, which are forwarded as zero-tensors.
-Then we have the actual multi-layer LSTM implementation, which uses the fast [SeqLSTM](https://github.com/Element-Research/rnn#rnn.SeqLSTM) module:
+Then we have the actual multi-layer LSTM implementation, which uses the [SeqLSTM](https://github.com/Element-Research/rnn#rnn.SeqLSTM) module:
```lua
local inputsize = opt.inputsize
@@ -318,7 +335,7 @@ for i,hiddensize in ipairs(opt.hiddensize) do
end
```
-The `SeqLSTM` implemention is very fast and it benchmarked by the [rnn-benchmarks](https://github.com/glample/rnn-benchmarks#lstm) repository.
+As demonstrated in the [rnn-benchmarks](https://github.com/glample/rnn-benchmarks#lstm) repository, the `SeqLSTM` implemention is very fast.
Next we split the output of the SeqLSTM (which is a `seqlen x batchsize x outputsize` Tensor) into a table containing a `batchsize x outputsize` tensor for
each time-step:
@@ -340,17 +357,23 @@ outputlayer = nn.Sequential()
However, when training with large vocabularies, like the 793471 words that makes up the GBW dataset ,
the output layer quickly becomes a bottleneck.
-For example, if you are training your model with a `batchsize = 32` (number of sequences per batch) and a `seqlen = 100`
+For example, if you are training your model with a `batchsize = 128` (number of sequences per batch) and a `seqlen = 50`
(size of sequence to backpropagate through time),
-the output of that layer will have shape `seqlen x batchsize x vocabsize`, or `32 x 100 x 793471`.
-For a `FloatTensor` or `CudaTensor`, that single tensor will take up 10.156GB of memory!
-The number can be double for gradients, and doubled again as both Linear and SoftMax store a copy for the output.
-If somehow you can find a way to put >40GB on a GPU (or distribute it over many), you then run in the problem of
+the output of that layer will have shape `seqlen x batchsize x vocabsize`, or `128 x 50 x 793471`.
+For a `FloatTensor` or `CudaTensor`, that single tensor will take up 20GB of memory!
+The number can be double for `gradInput` (i.e. gradients with respect to input),
+and double again as both `Linear` and `SoftMax` store a copy for the `output`.
+
+![Scale of output layer buffers with Linear](images/LM-Linear.png)
+
+Excluding parameters and their gradients, the above figure outlines the approximate memory consumption of a 4-layer LSTM with 2048 units with a `seqlen=50`.
+Even if somehow you can find a way to put 80GB on a GPU (or distribute it over many), you still run into the problem of
forward/backward propagating through that `outputlayer` in a reasonable time-frame.
+<a name='nce.nce'></a>
### The solution: noise contrastive estimation
-The output layer of the LM uses Noise Contrastive Estimation (NCE) to speed up training and reduce memory consumption:
+The output layer of the LM uses NCE to speed up training and reduce memory consumption:
```lua
local unigram = trainset.wordfreq:float()
@@ -372,6 +395,11 @@ The [NCEModule](https://github.com/Element-Research/dpnn#nn.NCEModule) is a more
nn.Sequential():add(nn.Linear(inputsize, #trainset.ivocab)):add(nn.LogSoftMax())
```
+For evaluating perplexity, the model still implements `Linear` + `SoftMax`.
+NCE is useful for reducing the memory consumption during training (compare to the figure above):
+
+![Scale of output layer buffers with NCE](images/LM-NCE.png)
+
Along with the [NCECriterion](https://github.com/Element-Research/dpnn#nn.NCECriterion),
the `NCEModule` implements the algorithm is described in [[1]](#nce.ref).
I won't go into the details of the algorithm as it involves a lot of math which is more appropriately detailed in the reference papers.
@@ -509,10 +537,15 @@ The `--temperature` flag can be reduced to make the sampling more deterministic.
As can be observed in the previous section, training a 2-layer LSTM with only 250 hidden units will not yield the best
generated samples. The model needs much more capacity than what can fit on a 12GB GPU.
-The [multigpu-nce-rnnlm.lua](https://github.com/Element-Research/rnn/blob/master/examples/multigpu-nce-rnnlm.lua) script can be used
-to train a model on four GPUs.
+For parameters and their gradients, a 4x2048 LSTM model requires the following:
+
+![LM parameter memory consumption](images/LM-params.png)
+
+This doesn't include all the intermediate buffers required for the different modules (outlined in [NCE section](#nce.nce)).
+The solution was of course to distribution the model over more GPUs.
+The [multigpu-nce-rnnlm.lua](https://github.com/Element-Research/rnn/blob/master/examples/multigpu-nce-rnnlm.lua) script is thus provided to train a language model on four GPUs.
-It uses the [GPU](https://github.com/torch/nn/blob/master/doc/simple.md#nn.GPU) to decorate modules such that
+It uses the [GPU](https://github.com/torch/nn/blob/master/doc/simple.md#nn.GPU) (which we contributed it to the [nn](https://github.com/torch/nn)) to decorate modules such that
all their operations and memory are hosted on a specified device.
The `GPU` module won't parallelize kernel execution over different GPU-devices.
But it does allow us to distribute large models over devices.
@@ -536,7 +569,7 @@ end
Basically, the embedding space is split into two tables.
For a 2048 unit embedding space, half, i.e. 1024 units, are located on each of two devices.
-We use `Concat` to concatenate them back together after a `forward`.
+We use [Concat](https://github.com/torch/nn/blob/master/doc/containers.md#nn.Concat) to concatenate them back together after a `forward`.
For the hidden layers (i.e. `SeqLSTM`), we just distribute them on the devices used by the input layer.
The hidden layers use up little memory (approximately 1GB each) so they aren't the problem.
@@ -919,7 +952,7 @@ For those not familiar with this game, terms like
[exotics](http://destiny.wikia.com/wiki/Exotic) and [Crota raid](http://destiny.wikia.com/wiki/Crota%27s_End)
may seem odd. But these are all part of the game vocabulary.
-The particular model (a 4x1500 LSTM with dropout) only backpropagates through 50 time-steps.
+The particular model (a 4x1572 LSTM with dropout) only backpropagates through 50 time-steps.
What I would like to see is for the `COMMENTS` to actually answer the question posed by the `TITLE` and `SELFTEXT`.
This is a very difficult semantic problem which I hope the Reddit dataset will help solve.
More to follow in my next Torch blog post.
diff --git a/blog/_posts/images/LM-Linear.png b/blog/_posts/images/LM-Linear.png
new file mode 100644
index 0000000..46c92d9
--- /dev/null
+++ b/blog/_posts/images/LM-Linear.png
Binary files differ
diff --git a/blog/_posts/images/LM-NCE.png b/blog/_posts/images/LM-NCE.png
new file mode 100644
index 0000000..39b6fad
--- /dev/null
+++ b/blog/_posts/images/LM-NCE.png
Binary files differ
diff --git a/blog/_posts/images/LM-params.png b/blog/_posts/images/LM-params.png
new file mode 100644
index 0000000..1ae0e05
--- /dev/null
+++ b/blog/_posts/images/LM-params.png
Binary files differ