diff options
-rw-r--r-- | ConcatTable.lua | 9 | ||||
-rw-r--r-- | Criterion.lua | 4 | ||||
-rw-r--r-- | CrossEntropyCriterion.lua | 7 | ||||
-rw-r--r-- | DotProduct.lua | 7 | ||||
-rw-r--r-- | L1HingeEmbeddingCriterion.lua | 6 | ||||
-rw-r--r-- | MarginRankingCriterion.lua | 6 | ||||
-rw-r--r-- | MixtureTable.lua | 19 | ||||
-rw-r--r-- | Module.lua | 13 | ||||
-rw-r--r-- | PairwiseDistance.lua | 16 | ||||
-rw-r--r-- | SpatialContrastiveNormalization.lua | 6 | ||||
-rw-r--r-- | SpatialDivisiveNormalization.lua | 10 | ||||
-rw-r--r-- | SpatialSubtractiveNormalization.lua | 8 | ||||
-rw-r--r-- | init.lua | 2 | ||||
-rw-r--r-- | utils.lua | 12 |
14 files changed, 16 insertions, 109 deletions
diff --git a/ConcatTable.lua b/ConcatTable.lua index 706ee6a..93d9ad5 100644 --- a/ConcatTable.lua +++ b/ConcatTable.lua @@ -86,15 +86,6 @@ function ConcatTable:zeroGradParameters() end end -function ConcatTable:type(type) - parent.type(self, type) - if torch.type(self.gradInput) == 'table' then - for i, gradInput in ipairs(self.gradInput) do - self.gradInput[i] = gradInput:type(type) - end - end -end - function ConcatTable:__tostring__() local tab = ' ' local line = '\n' diff --git a/Criterion.lua b/Criterion.lua index dec514c..0f6e41b 100644 --- a/Criterion.lua +++ b/Criterion.lua @@ -31,9 +31,7 @@ end function Criterion:type(type) -- find all tensors and convert them for key,param in pairs(self) do - if torch.typename(param) and torch.typename(param):find('torch%..+Tensor') then - self[key] = param:type(type) - end + self[key] = nn._utils.recursiveType(param, type) end return self end diff --git a/CrossEntropyCriterion.lua b/CrossEntropyCriterion.lua index 2b3c78c..d4d19e5 100644 --- a/CrossEntropyCriterion.lua +++ b/CrossEntropyCriterion.lua @@ -25,11 +25,4 @@ function CrossEntropyCriterion:updateGradInput(input, target) return self.gradInput end -function CrossEntropyCriterion:type(name) - Criterion.type(self, name) - self.lsm:type(name) - self.nll:type(name) - return self -end - return nn.CrossEntropyCriterion diff --git a/DotProduct.lua b/DotProduct.lua index 7f7b524..bc8e854 100644 --- a/DotProduct.lua +++ b/DotProduct.lua @@ -27,10 +27,3 @@ function DotProduct:updateGradInput(input, gradOutput) return self.gradInput end - -function DotProduct:type(type) - for i, tensor in ipairs(self.gradInput) do - self.gradInput[i] = tensor:type(type) - end - return parent.type(self, type) -end diff --git a/L1HingeEmbeddingCriterion.lua b/L1HingeEmbeddingCriterion.lua index 27f715b..6957278 100644 --- a/L1HingeEmbeddingCriterion.lua +++ b/L1HingeEmbeddingCriterion.lua @@ -39,9 +39,3 @@ function L1HingeEmbeddingCriterion:updateGradInput(input, y) self.gradInput[2]:zero():add(-1, self.gradInput[1]) return self.gradInput end - -function L1HingeEmbeddingCriterion:type(type) - self.gradInput[1] = self.gradInput[1]:type(type) - self.gradInput[2] = self.gradInput[2]:type(type) - return parent.type(self, type) -end diff --git a/MarginRankingCriterion.lua b/MarginRankingCriterion.lua index c0573ac..30c6855 100644 --- a/MarginRankingCriterion.lua +++ b/MarginRankingCriterion.lua @@ -69,9 +69,3 @@ function MarginRankingCriterion:updateGradInput(input, y) end return self.gradInput end - -function MarginRankingCriterion:type(type) - self.gradInput[1] = self.gradInput[1]:type(type) - self.gradInput[2] = self.gradInput[2]:type(type) - return parent.type(self, type) -end diff --git a/MixtureTable.lua b/MixtureTable.lua index 6111a99..77a7d3e 100644 --- a/MixtureTable.lua +++ b/MixtureTable.lua @@ -146,22 +146,3 @@ function MixtureTable:updateGradInput(input, gradOutput) return self.gradInput end - -function MixtureTable:type(type) - self.output = self.output:type(type) - self.gradInput[1] = self.gradInput[1]:type(type) - self._gaterView = self._gaterView:type(type) - self._expert = self._expert:type(type) - self._expertView = self._expertView:type(type) - self._sum = self._sum:type(type) - self._gradInput = self._gradInput:type(type) - self._expert2 = self._expert2:type(type) - self._expertView2 = self._expertView2:type(type) - if torch.type(self.gradInput[2]) == 'table' then - for i,expertGradInput in ipairs(self.gradInput[2]) do - self.gradInput[2][i] = expertGradInput:type(type) - end - else - self.gradInput[2] = self._gradInput - end -end @@ -113,17 +113,6 @@ function Module:clone(...) return clone end -local function recursiveType(param, type_str) - if torch.type(param) == 'table' then - for i = 1, #param do - param[i] = recursiveType(param[i], type_str) - end - elseif torch.isTensor(param) then - param = param:type(type_str) - end - return param -end - function Module:type(type) assert(type, 'Module: must provide a type to convert to') -- find all tensors and convert them @@ -132,7 +121,7 @@ function Module:type(type) -- are table's of tensors. To be general we need to recursively -- cast fields that may be nested tables. if key ~= 'modules' then - self[key] = recursiveType(self[key], type) + self[key] = nn._utils.recursiveType(param, type) end end -- find submodules in classic containers 'modules' diff --git a/PairwiseDistance.lua b/PairwiseDistance.lua index affc2e5..79569c9 100644 --- a/PairwiseDistance.lua +++ b/PairwiseDistance.lua @@ -86,19 +86,3 @@ function PairwiseDistance:updateGradInput(input, gradOutput) self.gradInput[2]:zero():add(-1, self.gradInput[1]) return self.gradInput end - --- save away Module:type(type) for later use. -PairwiseDistance._parent_type = parent.type - --- Fix the bug where tmp = nn.PairwiseDistance:cuda() fails to convert table --- contents. We could, and probably should, change Module.lua to loop over --- and convert all the table elements in a module, but that might have --- repercussions, so this is a safer solution. -function PairwiseDistance:type(type) - self:_parent_type(type) -- Call the parent (Module) type function - -- Now convert the left over table elements - self.gradInput[1] = self.gradInput[1]:type(type) - self.gradInput[2] = self.gradInput[2]:type(type) - return self -end - diff --git a/SpatialContrastiveNormalization.lua b/SpatialContrastiveNormalization.lua index b5a6ce7..0ad251a 100644 --- a/SpatialContrastiveNormalization.lua +++ b/SpatialContrastiveNormalization.lua @@ -34,9 +34,3 @@ function SpatialContrastiveNormalization:updateGradInput(input, gradOutput) self.gradInput = self.normalizer:backward(input, gradOutput) return self.gradInput end - -function SpatialContrastiveNormalization:type(type) - parent.type(self,type) - self.normalizer:type(type) - return self -end diff --git a/SpatialDivisiveNormalization.lua b/SpatialDivisiveNormalization.lua index 23a2c0b..92dfac7 100644 --- a/SpatialDivisiveNormalization.lua +++ b/SpatialDivisiveNormalization.lua @@ -113,13 +113,3 @@ function SpatialDivisiveNormalization:updateGradInput(input, gradOutput) -- done return self.gradInput end - -function SpatialDivisiveNormalization:type(type) - parent.type(self,type) - self.meanestimator:type(type) - self.stdestimator:type(type) - self.divider:type(type) - self.normalizer:type(type) - self.thresholder:type(type) - return self -end diff --git a/SpatialSubtractiveNormalization.lua b/SpatialSubtractiveNormalization.lua index f2c2c31..51bf23a 100644 --- a/SpatialSubtractiveNormalization.lua +++ b/SpatialSubtractiveNormalization.lua @@ -90,11 +90,3 @@ function SpatialSubtractiveNormalization:updateGradInput(input, gradOutput) -- done return self.gradInput end - -function SpatialSubtractiveNormalization:type(type) - parent.type(self,type) - self.meanestimator:type(type) - self.divider:type(type) - self.subtractor:type(type) - return self -end @@ -1,6 +1,8 @@ require('torch') require('libnn') +include('utils.lua') + include('ErrorMessages.lua') include('Module.lua') diff --git a/utils.lua b/utils.lua new file mode 100644 index 0000000..887b82d --- /dev/null +++ b/utils.lua @@ -0,0 +1,12 @@ +nn._utils = {} + +function nn._utils.recursiveType(param, type_str) + if torch.type(param) == 'table' then + for k, v in pairs(param) do + param[k] = nn._utils.recursiveType(v, type_str) + end + elseif torch.isTensor(param) then + param = param:type(type_str) + end + return param +end |