Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Leonard <nleonard@twitter.com>2017-05-09 22:26:36 +0300
committerNicholas Leonard <nleonard@twitter.com>2017-05-10 17:38:51 +0300
commit17bbce696980189e27f0cd12d5f219c8cd8ffbc0 (patch)
treeb02b37bc0daf5e65b43f745d15f654591825000e /DontCast.lua
parent3752f2426b55bc32cbd0ef112649d47dc674baa8 (diff)
Decorator modules
Diffstat (limited to 'DontCast.lua')
-rw-r--r--DontCast.lua124
1 files changed, 124 insertions, 0 deletions
diff --git a/DontCast.lua b/DontCast.lua
new file mode 100644
index 0000000..b89f543
--- /dev/null
+++ b/DontCast.lua
@@ -0,0 +1,124 @@
+local DontCast, parent = torch.class("nn.DontCast", "nn.Decorator")
+
+-- utility functions
+
+local function recursiveTypeCopy(dst, src, type_str)
+ if torch.type(src) == 'table' then
+ dst = (torch.type(dst) == 'table') and dst or {}
+ for k, v in pairs(src) do
+ dst[k] = recursiveTypeCopy(dst[k], v, type_str)
+ end
+ elseif torch.isTensor(src) then
+ dst = (torch.type(dst) == type_str) and dst or torch.getmetatable(type_str).new()
+ dst:resize(src:size())
+ if src:nElement() > 0 then
+ dst:copy(src)
+ end
+ end
+ return dst
+end
+
+local function tableTensorType(src)
+ if type(src) == 'table' then
+ local type_str, found
+ for k,v in pairs(src) do
+ type_str, found = tableTensorType(v)
+ if found then
+ return type_str, true
+ end
+ end
+ return type_str, found
+ else
+ return torch.type(src), torch.isTensor(src)
+ end
+end
+
+-- DontCast methods and constructor
+
+function DontCast:__init(module, castin, castout, moduleType)
+ parent.__init(self, module)
+ self.castin = castin
+ self.castout = (castout == nil) and castin or castout
+ self.moduleType = moduleType
+ if (self.castin or self.castout) and not self.moduleType then
+ local moduleType, found = tableTensorType(module.output)
+ if found then
+ self.moduleType = moduleType
+ else
+ moduleType, found = tableTensorType(module:parameters())
+ if found then
+ self.moduleType = moduleType
+ else
+ error"Cannot extrapolate moduleType. Provide constructor argument 4"
+ end
+ end
+ end
+end
+
+function DontCast:updateOutput(input)
+ if self.castin and tableTensorType(input) ~= self.moduleType then
+ self._input = recursiveTypeCopy(self._input, input, self.moduleType)
+ input = self._input
+ end
+
+ local output = self.modules[1]:updateOutput(input)
+
+ if self.castout then
+ self.output = recursiveTypeCopy(self.output, output, tableTensorType(self.output))
+ else
+ self.output = output
+ end
+ return self.output
+end
+
+function DontCast:updateGradInput(input, gradOutput)
+ if self.castin and tableTensorType(input) ~= self.moduleType then
+ input = self._input
+ end
+ if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
+ self._gradOutput = recursiveTypeCopy(self._gradOutput, gradOutput, self.moduleType)
+ gradOutput = self._gradOutput
+ end
+
+ local gradInput = self.modules[1]:updateGradInput(input, gradOutput)
+
+ if self.castin then
+ self.gradInput = recursiveTypeCopy(self.gradInput, gradInput, tableTensorType(self.gradInput))
+ else
+ self.gradInput = gradInput
+ end
+ return self.gradInput
+end
+
+function DontCast:accGradParameters(input, gradOutput, scale)
+ if self.castin and tableTensorType(input) ~= self.moduleType then
+ input = self._input
+ end
+ if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
+ gradOutput = self._gradOutput
+ end
+
+ self.modules[1]:accGradParameters(input, gradOutput, scale)
+end
+
+function DontCast:accUpdateGradParameters(input, gradOutput, lr)
+ if self.castin and tableTensorType(input) ~= self.moduleType then
+ input = self._input
+ end
+ if self.castout and tableTensorType(gradOutput) ~= self.moduleType then
+ gradOutput = self._gradOutput
+ end
+
+ self.modules[1]:accUpdateGradParameters(input, gradOutput, lr)
+end
+
+-- dont cast (the essence thereof)
+function DontCast:type(type)
+ if self.castout and tableTensorType(self.output) ~= type then
+ self.output = recursiveTypeCopy(nil, self.output, type)
+ end
+ if self.castin and tableTensorType(self.gradInput) ~= type then
+ self.gradInput = recursiveTypeCopy(nil, self.gradInput, type)
+ end
+ return self
+end