diff options
author | Soumith Chintala <soumith@gmail.com> | 2015-01-08 22:58:52 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2015-01-08 22:58:52 +0300 |
commit | c596c339786cf0674ef31b43a6c243678bb000e2 (patch) | |
tree | 4cf6b285e8daa62bd1cb68ce684e0d5d4d6ec377 | |
parent | 27acf6315e30181936a309e9831e18baec1a3f28 (diff) | |
parent | 608ac7c2bbf52677a38c6f2d7e45b8817e4ba3af (diff) |
Merge pull request #139 from jjh42/prettyprint
Added more informative pretty-printing.
-rw-r--r-- | Linear.lua | 6 | ||||
-rw-r--r-- | Reshape.lua | 6 | ||||
-rw-r--r-- | SpatialZeroPadding.lua | 7 | ||||
-rw-r--r-- | test.lua | 16 |
4 files changed, 35 insertions, 0 deletions
@@ -99,3 +99,9 @@ end -- we do not need to accumulate parameters when sharing Linear.sharedAccUpdateGradParameters = Linear.accUpdateGradParameters + + +function Linear:__tostring__() + return torch.type(self) .. + string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1)) +end diff --git a/Reshape.lua b/Reshape.lua index 0baab17..dcc787b 100644 --- a/Reshape.lua +++ b/Reshape.lua @@ -61,3 +61,9 @@ function Reshape:updateGradInput(input, gradOutput) self.gradInput:viewAs(gradOutput, input) return self.gradInput end + + +function Reshape:__tostring__() + return torch.type(self) .. '(' .. + table.concat(self.size:totable(), 'x') .. ')' +end diff --git a/SpatialZeroPadding.lua b/SpatialZeroPadding.lua index 8e3756d..72d1c63 100644 --- a/SpatialZeroPadding.lua +++ b/SpatialZeroPadding.lua @@ -95,3 +95,10 @@ function SpatialZeroPadding:updateGradInput(input, gradOutput) end return self.gradInput end + + +function SpatialZeroPadding:__tostring__() + return torch.type(self) .. + string.format('(l=%d,r=%d,t=%d,b=%d)', self.pad_l, self.pad_r, + self.pad_t, self.pad_b) +end @@ -22,6 +22,22 @@ local function equal(t1, t2, msg) end +--[[ Generate tests to exercise the tostring component of modules. ]] +local tostringTestModules = { + nnLinear = nn.Linear(1, 2), + nnReshape = nn.Reshape(10), + nnSpatialZeroPadding = nn.SpatialZeroPadding(1, 1, 1, 1)} +for test_name, component in pairs(tostringTestModules) do + nntest['tostring' .. test_name] = + function () + mytester:assert(tostring(component):find(torch.type(component) .. '(', + 1, true), + 'nn components should have a descriptive tostring' .. + ' beginning with the classname') + end +end + + function nntest.Add() local inj_vals = {math.random(3,5), 1} -- Also test the inj = 1 spatial case local ini = math.random(3,5) |