Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/nn.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSoumith Chintala <soumith@gmail.com>2015-01-08 22:58:52 +0300
committerSoumith Chintala <soumith@gmail.com>2015-01-08 22:58:52 +0300
commitc596c339786cf0674ef31b43a6c243678bb000e2 (patch)
tree4cf6b285e8daa62bd1cb68ce684e0d5d4d6ec377
parent27acf6315e30181936a309e9831e18baec1a3f28 (diff)
parent608ac7c2bbf52677a38c6f2d7e45b8817e4ba3af (diff)
Merge pull request #139 from jjh42/prettyprint
Added more informative pretty-printing.
-rw-r--r--Linear.lua6
-rw-r--r--Reshape.lua6
-rw-r--r--SpatialZeroPadding.lua7
-rw-r--r--test.lua16
4 files changed, 35 insertions, 0 deletions
diff --git a/Linear.lua b/Linear.lua
index 5e05c2f..246d86b 100644
--- a/Linear.lua
+++ b/Linear.lua
@@ -99,3 +99,9 @@ end
-- we do not need to accumulate parameters when sharing
Linear.sharedAccUpdateGradParameters = Linear.accUpdateGradParameters
+
+
+function Linear:__tostring__()
+ return torch.type(self) ..
+ string.format('(%d -> %d)', self.weight:size(2), self.weight:size(1))
+end
diff --git a/Reshape.lua b/Reshape.lua
index 0baab17..dcc787b 100644
--- a/Reshape.lua
+++ b/Reshape.lua
@@ -61,3 +61,9 @@ function Reshape:updateGradInput(input, gradOutput)
self.gradInput:viewAs(gradOutput, input)
return self.gradInput
end
+
+
+function Reshape:__tostring__()
+ return torch.type(self) .. '(' ..
+ table.concat(self.size:totable(), 'x') .. ')'
+end
diff --git a/SpatialZeroPadding.lua b/SpatialZeroPadding.lua
index 8e3756d..72d1c63 100644
--- a/SpatialZeroPadding.lua
+++ b/SpatialZeroPadding.lua
@@ -95,3 +95,10 @@ function SpatialZeroPadding:updateGradInput(input, gradOutput)
end
return self.gradInput
end
+
+
+function SpatialZeroPadding:__tostring__()
+ return torch.type(self) ..
+ string.format('(l=%d,r=%d,t=%d,b=%d)', self.pad_l, self.pad_r,
+ self.pad_t, self.pad_b)
+end
diff --git a/test.lua b/test.lua
index 27c3dde..65ff3b1 100644
--- a/test.lua
+++ b/test.lua
@@ -22,6 +22,22 @@ local function equal(t1, t2, msg)
end
+--[[ Generate tests to exercise the tostring component of modules. ]]
+local tostringTestModules = {
+ nnLinear = nn.Linear(1, 2),
+ nnReshape = nn.Reshape(10),
+ nnSpatialZeroPadding = nn.SpatialZeroPadding(1, 1, 1, 1)}
+for test_name, component in pairs(tostringTestModules) do
+ nntest['tostring' .. test_name] =
+ function ()
+ mytester:assert(tostring(component):find(torch.type(component) .. '(',
+ 1, true),
+ 'nn components should have a descriptive tostring' ..
+ ' beginning with the classname')
+ end
+end
+
+
function nntest.Add()
local inj_vals = {math.random(3,5), 1} -- Also test the inj = 1 spatial case
local ini = math.random(3,5)