Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominik Grewe <dominikg@google.com>2015-04-28 18:07:32 +0300
committerDominik Grewe <dominikg@google.com>2015-04-28 18:36:55 +0300
commitd0da63e83825d4631a7f299766a1fa968eb5ccd3 (patch)
treeedab79c4cf664ef0dac24fb851923059027e659e
parent485dd619695c47e49ab56ff518edad52b74475fc (diff)
Make type() truly recursive.
Recursively iterate over the whole table, converting each tensor to the given type. Removes need for many specialized type() functions.
-rw-r--r--ConcatTable.lua9
-rw-r--r--Criterion.lua4
-rw-r--r--CrossEntropyCriterion.lua7
-rw-r--r--DotProduct.lua7
-rw-r--r--L1HingeEmbeddingCriterion.lua6
-rw-r--r--MarginRankingCriterion.lua6
-rw-r--r--MixtureTable.lua19
-rw-r--r--Module.lua13
-rw-r--r--PairwiseDistance.lua16
-rw-r--r--SpatialContrastiveNormalization.lua6
-rw-r--r--SpatialDivisiveNormalization.lua10
-rw-r--r--SpatialSubtractiveNormalization.lua8
-rw-r--r--init.lua2
-rw-r--r--utils.lua12
14 files changed, 16 insertions, 109 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua
index 706ee6a..93d9ad5 100644
--- a/ConcatTable.lua
+++ b/ConcatTable.lua
@@ -86,15 +86,6 @@ function ConcatTable:zeroGradParameters()
end
end
-function ConcatTable:type(type)
- parent.type(self, type)
- if torch.type(self.gradInput) == 'table' then
- for i, gradInput in ipairs(self.gradInput) do
- self.gradInput[i] = gradInput:type(type)
- end
- end
-end
-
function ConcatTable:__tostring__()
local tab = ' '
local line = '\n'
diff --git a/Criterion.lua b/Criterion.lua
index dec514c..0f6e41b 100644
--- a/Criterion.lua
+++ b/Criterion.lua
@@ -31,9 +31,7 @@ end
function Criterion:type(type)
-- find all tensors and convert them
for key,param in pairs(self) do
- if torch.typename(param) and torch.typename(param):find('torch%..+Tensor') then
- self[key] = param:type(type)
- end
+ self[key] = nn._utils.recursiveType(param, type)
end
return self
end
diff --git a/CrossEntropyCriterion.lua b/CrossEntropyCriterion.lua
index 2b3c78c..d4d19e5 100644
--- a/CrossEntropyCriterion.lua
+++ b/CrossEntropyCriterion.lua
@@ -25,11 +25,4 @@ function CrossEntropyCriterion:updateGradInput(input, target)
return self.gradInput
end
-function CrossEntropyCriterion:type(name)
- Criterion.type(self, name)
- self.lsm:type(name)
- self.nll:type(name)
- return self
-end
-
return nn.CrossEntropyCriterion
diff --git a/DotProduct.lua b/DotProduct.lua
index 7f7b524..bc8e854 100644
--- a/DotProduct.lua
+++ b/DotProduct.lua
@@ -27,10 +27,3 @@ function DotProduct:updateGradInput(input, gradOutput)
return self.gradInput
end
-
-function DotProduct:type(type)
- for i, tensor in ipairs(self.gradInput) do
- self.gradInput[i] = tensor:type(type)
- end
- return parent.type(self, type)
-end
diff --git a/L1HingeEmbeddingCriterion.lua b/L1HingeEmbeddingCriterion.lua
index 27f715b..6957278 100644
--- a/L1HingeEmbeddingCriterion.lua
+++ b/L1HingeEmbeddingCriterion.lua
@@ -39,9 +39,3 @@ function L1HingeEmbeddingCriterion:updateGradInput(input, y)
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end
-
-function L1HingeEmbeddingCriterion:type(type)
- self.gradInput[1] = self.gradInput[1]:type(type)
- self.gradInput[2] = self.gradInput[2]:type(type)
- return parent.type(self, type)
-end
diff --git a/MarginRankingCriterion.lua b/MarginRankingCriterion.lua
index c0573ac..30c6855 100644
--- a/MarginRankingCriterion.lua
+++ b/MarginRankingCriterion.lua
@@ -69,9 +69,3 @@ function MarginRankingCriterion:updateGradInput(input, y)
end
return self.gradInput
end
-
-function MarginRankingCriterion:type(type)
- self.gradInput[1] = self.gradInput[1]:type(type)
- self.gradInput[2] = self.gradInput[2]:type(type)
- return parent.type(self, type)
-end
diff --git a/MixtureTable.lua b/MixtureTable.lua
index 6111a99..77a7d3e 100644
--- a/MixtureTable.lua
+++ b/MixtureTable.lua
@@ -146,22 +146,3 @@ function MixtureTable:updateGradInput(input, gradOutput)
return self.gradInput
end
-
-function MixtureTable:type(type)
- self.output = self.output:type(type)
- self.gradInput[1] = self.gradInput[1]:type(type)
- self._gaterView = self._gaterView:type(type)
- self._expert = self._expert:type(type)
- self._expertView = self._expertView:type(type)
- self._sum = self._sum:type(type)
- self._gradInput = self._gradInput:type(type)
- self._expert2 = self._expert2:type(type)
- self._expertView2 = self._expertView2:type(type)
- if torch.type(self.gradInput[2]) == 'table' then
- for i,expertGradInput in ipairs(self.gradInput[2]) do
- self.gradInput[2][i] = expertGradInput:type(type)
- end
- else
- self.gradInput[2] = self._gradInput
- end
-end
diff --git a/Module.lua b/Module.lua
index d3f5a26..d6b16fb 100644
--- a/Module.lua
+++ b/Module.lua
@@ -113,17 +113,6 @@ function Module:clone(...)
return clone
end
-local function recursiveType(param, type_str)
- if torch.type(param) == 'table' then
- for i = 1, #param do
- param[i] = recursiveType(param[i], type_str)
- end
- elseif torch.isTensor(param) then
- param = param:type(type_str)
- end
- return param
-end
-
function Module:type(type)
assert(type, 'Module: must provide a type to convert to')
-- find all tensors and convert them
@@ -132,7 +121,7 @@ function Module:type(type)
-- are table's of tensors. To be general we need to recursively
-- cast fields that may be nested tables.
if key ~= 'modules' then
- self[key] = recursiveType(self[key], type)
+ self[key] = nn._utils.recursiveType(param, type)
end
end
-- find submodules in classic containers 'modules'
diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua
index affc2e5..79569c9 100644
--- a/PairwiseDistance.lua
+++ b/PairwiseDistance.lua
@@ -86,19 +86,3 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end
-
--- save away Module:type(type) for later use.
-PairwiseDistance._parent_type = parent.type
-
--- Fix the bug where tmp = nn.PairwiseDistance:cuda() fails to convert table
--- contents. We could, and probably should, change Module.lua to loop over
--- and convert all the table elements in a module, but that might have
--- repercussions, so this is a safer solution.
-function PairwiseDistance:type(type)
- self:_parent_type(type) -- Call the parent (Module) type function
- -- Now convert the left over table elements
- self.gradInput[1] = self.gradInput[1]:type(type)
- self.gradInput[2] = self.gradInput[2]:type(type)
- return self
-end
-
diff --git a/SpatialContrastiveNormalization.lua b/SpatialContrastiveNormalization.lua
index b5a6ce7..0ad251a 100644
--- a/SpatialContrastiveNormalization.lua
+++ b/SpatialContrastiveNormalization.lua
@@ -34,9 +34,3 @@ function SpatialContrastiveNormalization:updateGradInput(input, gradOutput)
self.gradInput = self.normalizer:backward(input, gradOutput)
return self.gradInput
end
-
-function SpatialContrastiveNormalization:type(type)
- parent.type(self,type)
- self.normalizer:type(type)
- return self
-end
diff --git a/SpatialDivisiveNormalization.lua b/SpatialDivisiveNormalization.lua
index 23a2c0b..92dfac7 100644
--- a/SpatialDivisiveNormalization.lua
+++ b/SpatialDivisiveNormalization.lua
@@ -113,13 +113,3 @@ function SpatialDivisiveNormalization:updateGradInput(input, gradOutput)
-- done
return self.gradInput
end
-
-function SpatialDivisiveNormalization:type(type)
- parent.type(self,type)
- self.meanestimator:type(type)
- self.stdestimator:type(type)
- self.divider:type(type)
- self.normalizer:type(type)
- self.thresholder:type(type)
- return self
-end
diff --git a/SpatialSubtractiveNormalization.lua b/SpatialSubtractiveNormalization.lua
index f2c2c31..51bf23a 100644
--- a/SpatialSubtractiveNormalization.lua
+++ b/SpatialSubtractiveNormalization.lua
@@ -90,11 +90,3 @@ function SpatialSubtractiveNormalization:updateGradInput(input, gradOutput)
-- done
return self.gradInput
end
-
-function SpatialSubtractiveNormalization:type(type)
- parent.type(self,type)
- self.meanestimator:type(type)
- self.divider:type(type)
- self.subtractor:type(type)
- return self
-end
diff --git a/init.lua b/init.lua
index 42946af..c2b2996 100644
--- a/init.lua
+++ b/init.lua
@@ -1,6 +1,8 @@
require('torch')
require('libnn')
+include('utils.lua')
+
include('ErrorMessages.lua')
include('Module.lua')
diff --git a/utils.lua b/utils.lua
new file mode 100644
index 0000000..887b82d
--- /dev/null
+++ b/utils.lua
@@ -0,0 +1,12 @@
+nn._utils = {}
+
+function nn._utils.recursiveType(param, type_str)
+ if torch.type(param) == 'table' then
+ for k, v in pairs(param) do
+ param[k] = nn._utils.recursiveType(v, type_str)
+ end
+ elseif torch.isTensor(param) then
+ param = param:type(type_str)
+ end
+ return param
+end