Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-26 21:01:30 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-26 21:23:34 +0300
commitf95fea7875aa866415d5eae25b65af172082cd84 (patch)
tree5445f374c4425b78b2394924768c6c0669212020
parent42fd9c255ff2cc8783d0c56797831d842fee54f5 (diff)
nn.Kmeans
-rw-r--r--Kmeans.lua215
-rwxr-xr-xdoc/simple.md49
-rwxr-xr-xinit.lua1
-rwxr-xr-xtest.lua75
4 files changed, 340 insertions, 0 deletions
diff --git a/Kmeans.lua b/Kmeans.lua
new file mode 100644
index 0000000..56066b6
--- /dev/null
+++ b/Kmeans.lua
@@ -0,0 +1,215 @@
+-- Online (Hard) Kmeans layer.
+local Kmeans, parent = torch.class('nn.Kmeans', 'nn.Module')
+
+function Kmeans:__init(k, dim, scale)
+ parent.__init(self)
+ self.k = k
+ self.dim = dim
+
+ -- scale for online kmean update
+ self.scale = scale
+
+ assert(k > 0, "Clusters cannot be 0 or negative.")
+ assert(dim > 0, "Dimensionality cannot be 0 or negative.")
+
+ -- Kmeans centers -> self.weight
+ self.weight = torch.Tensor(self.k, self.dim)
+
+ self.gradWeight = torch.Tensor(self.weight:size())
+ self.loss = 0 -- within cluster error of the last forward
+
+ self.clusterSampleCount = torch.Tensor(self.k)
+
+ self:reset()
+end
+
+-- Reset
+function Kmeans:reset(stdev)
+ stdev = stdev or 1
+ self.weight:uniform(-stdev, stdev)
+end
+
+-- Initialize Kmeans weight with random samples from input.
+function Kmeans:initRandom(input)
+ local inputDim = input:nDimension()
+ assert(inputDim == 2, "Incorrect input dimensionality. Expecting 2D.")
+
+ local noOfSamples = input:size(1)
+ local dim = input:size(2)
+ assert(dim == self.dim, "Dimensionality of input and weight don't match.")
+ assert(noOfSamples >= self.k, "Need atleast k samples for initialization.")
+
+ local indices = torch.zeros(self.k)
+ indices:random(1, noOfSamples)
+
+ for i=1, self.k do
+ self.weight[i]:copy(input[indices[i]])
+ end
+end
+
+-- Initialize using Kmeans++
+function Kmeans:initKmeansPlus(input, p)
+ self.p = p or self.p or 0.95
+ assert(self.p>=0 and self.p<=1, "P value should be between 0-1.")
+
+ local inputDim = input:nDimension()
+ assert(inputDim == 2, "Incorrect input dimensionality. Expecting 2D.")
+ local noOfSamples = input:size(1)
+
+ local pcount = math.ceil((1-self.p)*noOfSamples)
+ if pcount <= 0 then pcount = 1 end
+
+ local initializedK = 1
+ self.weight[initializedK]:copy(input[torch.random(noOfSamples)])
+ initializedK = initializedK + 1
+
+ local clusters = self.weight.new()
+ local clusterDistances = self.weight.new()
+ local temp = self.weight.new()
+ local expandedSample = self.weight.new()
+ local distances = self.weight.new()
+ distances:resize(noOfSamples):fill(math.huge)
+ local maxScores = self.weight.new()
+ local maxIndx = self.weight.new()
+
+ for k=initializedK, self.k do
+ clusters = self.weight[{{initializedK-1, initializedK-1}}]
+ for i=1, noOfSamples do
+ temp:expand(input[{{i}}], 1, self.dim)
+ expandedSample:resize(temp:size()):copy(temp)
+
+ -- Squared Euclidean distance
+ expandedSample:add(-1, clusters)
+ clusterDistances:norm(expandedSample, 2, 2)
+ clusterDistances:pow(2)
+ distances[i] = math.min(clusterDistances:min(), distances[i])
+ end
+ maxScores, maxIndx = distances:sort(true)
+ local tempIndx = torch.random(pcount)
+ local indx = maxIndx[tempIndx]
+ self.weight[initializedK]:copy(input[indx])
+ initializedK = initializedK + 1
+ end
+end
+
+local function isCudaTensor(tensor)
+ local typename = torch.typename(tensor)
+ if typename and typename:find('torch.Cuda*Tensor') then
+ return true
+ end
+ return false
+end
+
+-- Kmeans updateOutput (forward)
+function Kmeans:updateOutput(input)
+ local inputDim = input:nDimension()
+ assert(inputDim == 2, "Incorrect input dimensionality. Expecting 2D.")
+
+ local batchSize = input:size(1)
+ local dim = input:size(2)
+ assert(dim == self.dim, "Dimensionality of input and weight don't match.")
+
+ assert(input:isContiguous(), "Input is not contiguous.")
+
+ -- a sample copied k times to compute distance between sample and weight
+ self._expandedSamples = self._expandedSamples or self.weight.new()
+
+ -- distance between a sample and weight
+ self._clusterDistances = self._clusterDistances or self.weight.new()
+
+ self._temp = self._temp or input.new()
+ self._tempExpanded = self._tempExpanded or input.new()
+
+ -- Expanding inputs
+ self._temp:view(input, 1, batchSize, self.dim)
+ self._tempExpanded:expand(self._temp, self.k, batchSize, self.dim)
+ self._expandedSamples:resize(self.k, batchSize, self.dim)
+ :copy(self._tempExpanded)
+
+ -- Expanding weights
+ self._tempWeight = self._tempWeight or self.weight.new()
+ self._tempWeightExp = self._tempWeightExp or self.weight.new()
+ self._expandedWeight = self._expanedWeight or self.weight.new()
+ self._tempWeight:view(self.weight, self.k, 1, self.dim)
+ self._tempWeightExp:expand(self._tempWeight, self._expandedSamples:size())
+ self._expandedWeight:resize(self.k, batchSize, self.dim)
+ :copy(self._tempWeightExp)
+
+ -- x-c
+ self._expandedSamples:add(-1, self._expandedWeight)
+ -- Squared Euclidean distance
+ self._clusterDistances:norm(self._expandedSamples, 2, 3)
+ self._clusterDistances:pow(2)
+ self._clusterDistances:resize(self.k, batchSize)
+
+ self._minScore = self._minScore or self.weight.new()
+ self._minIndx = self._minIndx or (isCudaTensor(input) and torch.CudaLongTensor() or torch.LongTensor())
+ self._minScore:min(self._minIndx, self._clusterDistances, 1)
+ self._minIndx:resize(batchSize)
+
+ self.output:resize(batchSize):copy(self._minIndx)
+ self.loss = self._minScore:sum()
+
+ return self.output
+end
+
+-- Kmeans has its own criterion hence gradInput are zeros
+function Kmeans:updateGradInput(input, gradOuput)
+ self.gradInput:resize(input:size()):zero()
+
+ return self.gradInput
+end
+
+-- We define kmeans update rule as c -> c + scale * 1/n * sum_i (x-c).
+-- n is no. of x's belonging to c.
+-- With this update rule and gradient descent will be negative the gradWeights.
+function Kmeans:accGradParameters(input, gradOutput, scale)
+ local scale = self.scale or scale or 1
+ assert(scale > 0 , " Scale has to be positive.")
+
+ -- Update cluster sample count
+ local batchSize = input:size(1)
+ self._cscAdder = self._cscAdder or self.weight.new()
+ self._cscAdder:resize(batchSize):fill(1)
+ self.clusterSampleCount:zero()
+ self.clusterSampleCount:indexAdd(1, self._minIndx, self._cscAdder)
+
+ -- scale * (x[k]-c[k]) where k is nearest cluster to x
+ self._gradWeight = self._gradWeight or self.gradWeight.new()
+ self._gradWeight:index(self.weight, 1, self._minIndx)
+ self._gradWeight:mul(-1)
+ self._gradWeight:add(input)
+ self._gradWeight:mul(-scale)
+
+ self._gradWeight2 = self._gradWeight2 or self.gradWeight.new()
+ self._gradWeight2:resizeAs(self.gradWeight):zero()
+ self._gradWeight2:indexAdd(1, self._minIndx, self._gradWeight)
+
+ -- scale/n * sum_i (x-c)
+ self._ccounts = self._ccounts or self.clusterSampleCount.new()
+ self._ccounts:resize(self.k):copy(self.clusterSampleCount)
+ self._ccounts:add(0.0000001) -- prevent division by zero errors
+
+ self._gradWeight2:cdiv(self._ccounts:view(self.k,1):expandAs(self.gradWeight))
+
+ self.gradWeight:add(self._gradWeight2)
+end
+
+function Kmeans:clearState()
+ -- prevent premature memory allocations
+ self._expandedSamples = nil
+ self._clusterDistances = nil
+ self._temp = nil
+ self._tempExpanded = nil
+ self._tempWeight = nil
+ self._tempWeightExp = nil
+ self._expandedWeight = nil
+ self._minScore = nil
+ self._minIndx = nil
+ self._cscAdder = nil
+end
+
+function Kmeans:type(type, tensorCache)
+ self:clearState()
+ return parent.type(self, type, tensorCache)
+end
diff --git a/doc/simple.md b/doc/simple.md
index 849d9b5..0fa467b 100755
--- a/doc/simple.md
+++ b/doc/simple.md
@@ -16,6 +16,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [Euclidean](#nn.Euclidean) : the euclidean distance of the input to `k` mean centers ;
* [WeightedEuclidean](#nn.WeightedEuclidean) : similar to [Euclidean](#nn.Euclidean), but additionally learns a diagonal covariance matrix ;
* [Cosine](#nn.Cosine) : the cosine similarity of the input to `k` mean centers ;
+ * [Kmeans](#nn.Kmeans) : [Kmeans](https://en.wikipedia.org/wiki/K-means_clustering) clustering layer;
* Modules that adapt basic Tensor methods :
* [Copy](#nn.Copy) : a [copy](https://github.com/torch/torch7/blob/master/doc/tensor.md#torch.Tensor.copy) of the input with [type](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-or-string-typetype) casting ;
* [Narrow](#nn.Narrow) : a [narrow](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-narrowdim-index-size) operation over a given dimension ;
@@ -682,6 +683,54 @@ Outputs the [cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity)
The distance `y_j` between center `j` and input `x` is formulated as `y_j = (x ยท w_j) / ( || w_j || * || x || )`.
+<a name='nn.Kmeans'></a>
+## Kmeans ##
+
+```lua
+km = nn.Kmeans(k, dim)
+```
+
+`k` is the number of centroids and `dim` is the dimensionality of samples.
+The `forward` pass computes distances with respect to centroids and returns index of closest centroid.
+Centroids can be updated using gradient descent.
+Centroids can be initialized randomly or by using [kmeans++](https://en.wikipedia.org/wiki/K-means%2B%2B) algoirthm:
+
+```lua
+km:initRandom(samples) -- Randomly initialize centroids from input samples.
+km:initKmeansPlus(samples) -- Use Kmeans++ to initialize centroids.
+```
+
+Example showing how to use Kmeans module to do standard Kmeans clustering.
+
+```lua
+attempts = 10
+iter = 100 -- Number of iterations
+bestKm = nil
+bestLoss = math.huge
+learningRate = 1
+for j=1, attempts do
+ local km = nn.Kmeans(k, dim)
+ km:initKmeansPlus(samples)
+ for i=1, iter do
+ km:zeroGradParameters()
+ km:forward(samples) -- sets km.loss
+ km:backward(samples, gradOutput) -- gradOutput is ignored
+
+ -- Gradient Descent weight/centroids update
+ km:updateParameters(learningRate)
+ end
+
+ if km.loss < bestLoss then
+ bestLoss = km.loss
+ bestKm = km:clone()
+ end
+end
+```
+`nn.Kmeans()` module maintains loss only for the latest forward. If you want to maintain loss over the whole dataset then you who would need do it my adding the module loss for every forward.
+
+You can also use `nn.Kmeans()` as an auxillary layer in your network.
+A call to `forward` will generate an `output` containing the index of the nearest cluster for each sample in the batch.
+The `gradInput` generated by `updateGradInput` will be zero.
<a name="nn.Identity"></a>
## Identity ##
diff --git a/init.lua b/init.lua
index 97485f0..7e7deb3 100755
--- a/init.lua
+++ b/init.lua
@@ -80,6 +80,7 @@ require('nn.CosineDistance')
require('nn.DotProduct')
require('nn.Normalize')
require('nn.Cosine')
+require('nn.Kmeans')
require('nn.Exp')
require('nn.Log')
diff --git a/test.lua b/test.lua
index 44390ae..dbac512 100755
--- a/test.lua
+++ b/test.lua
@@ -8668,6 +8668,81 @@ function nntest.CAddTensorTable()
mytester:assertTensorEq(output[1]+output[2]+output[3], gradInput[1], 0.000001, "CAddTensorTable gradInput1")
end
+-- Unit Test Kmeans layer
+function nntest.Kmeans()
+ local k = 3
+ local dim = 5
+ local batchSize = 200
+ local input = torch.Tensor(batchSize, dim)
+ for i=1, batchSize do
+ input[i]:fill(torch.random(1, k))
+ end
+
+ local verbose = false
+
+ local attempts = 10
+ local iter = 100
+ local bestLoss = 100000000
+ local bestKm = nil
+ local tempLoss = 0
+ local learningRate = 1
+
+ local initTypes = {'random', 'kmeans++'}
+ local useCudas = {false}
+ if pcall(function() require 'cunn' end) then
+ useCudas[2] = true
+ end
+ for _, initType in pairs(initTypes) do
+ for _, useCuda in pairs(useCudas) do
+
+ if useCuda then
+ input = input:cuda()
+ else
+ input = input:double()
+ end
+
+ local timer = torch.Timer()
+ for j=1, attempts do
+ local km = nn.Kmeans(k, dim)
+ if useCuda then km:cuda() end
+
+ if initType == 'kmeans++' then
+ km:initKmeansPlus(input)
+ else
+ km:initRandom(input)
+ end
+
+ for i=1, iter do
+ km:zeroGradParameters()
+
+ km:forward(input)
+ km:backward(input, gradOutput)
+
+ -- Gradient descent
+ km.weight:add(-learningRate, km.gradWeight)
+ tempLoss = km.loss
+ end
+ if verbose then print("Attempt Loss " .. j ..": " .. tempLoss) end
+ if tempLoss < bestLoss then
+ bestLoss = tempLoss
+ end
+ if (initType == 'kmeans++' and bestLoss < 0.00001) or (initType == 'random' and bestLoss < 500) then
+ break
+ end
+ end
+ if verbose then
+ print("InitType: " .. initType .. " useCuda: " .. tostring(useCuda))
+ print("Best Loss: " .. bestLoss)
+ print("Total time: " .. timer:time().real)
+ end
+ if initType == 'kmeans++' then
+ mytester:assert(bestLoss < 0.00001, "Kmeans++ error ("..(useCuda and 'cuda' or 'double')..")")
+ else
+ mytester:assert(bestLoss < 500, "Kmeans error ("..(useCuda and 'cuda' or 'double')..")")
+ end
+ end
+ end
+end
mytester:add(nntest)