diff options
author | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 04:43:14 +0300 |
---|---|---|
committer | Nicholas Leonard <nleonard@twitter.com> | 2017-05-25 05:00:20 +0300 |
commit | b9ccf3af37e237211b24c99336823c673a08f3ca (patch) | |
tree | 84d584961d9bf360e7a06572b1120b6515c7d52d | |
parent | 6714cebc861db18ede18e3d9d56e05340669998c (diff) | |
parent | c6f1da5e02436ad9aeba97b537681f406116f3f1 (diff) |
Merge branch 'master' into Convert
-rw-r--r-- | Convert.lua | 19 | ||||
-rwxr-xr-x | test.lua | 2 |
2 files changed, 11 insertions, 10 deletions
diff --git a/Convert.lua b/Convert.lua index 308bdac..855338d 100644 --- a/Convert.lua +++ b/Convert.lua @@ -54,7 +54,6 @@ function Convert:buildConverter(input) assert(torch.isTensor(self.output), "Expecting Tensor output") self.converter:type(torch.type(self.output)) - self.converter:serialMode(self.dpnn_serialEmpty, self.dpnn_serialType) self.modules[1] = self.converter end @@ -231,14 +230,16 @@ function Convert:findAxis(axis_char, shape, silent) return axis_pos end +function Convert:clearState() + self._input = nil + self._gradInput = nil + self.__input = nil + self.__output = nil + self.__gradInput = nil + self.__gradOutput = nil +end + function Convert:type(type) - if not torch.isTypeOf(self.output, type) then - self._input = nil - self._gradInput = nil - self.__input = nil - self.__output = nil - self.__gradInput = nil - self.__gradOutput = nil - end + self:clearState() return parent.type(self, type) end @@ -4708,7 +4708,7 @@ end function nntest.TemporalRowConvolution() - + if true then return end -- until this unit test is fixed... local from = math.random(1,5) local ki = math.random(1,5) local si = math.random(1,2) |