Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Lerer <alerer@fb.com>2017-02-10 20:18:11 +0300
committerAdam Lerer <alerer@fb.com>2017-02-28 21:23:36 +0300
commitd46ff6fdeb3f3284e49da4348927c1ce8d7e758e (patch)
tree8ff0f406252ea499346bd3ada90290bec3e44292
parentc8f3c0b17582ab5f85a507c82d6a37c1d08f1dee (diff)
address comments and add tests
-rw-r--r--lib/TH/generic/THTensorMath.c35
-rw-r--r--test/test.lua114
2 files changed, 76 insertions, 73 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c
index 6b483dd..bc48c6a 100644
--- a/lib/TH/generic/THTensorMath.c
+++ b/lib/TH/generic/THTensorMath.c
@@ -1538,6 +1538,9 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
*tempValues__data = *t_data;
*tempIndices__data = tempIndices__counter[dimension];
});
+
+ THTensor_(free)(tempValues_);
+ THLongTensor_free(tempIndices_);
}
}
@@ -1643,7 +1646,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension)
temp_->size[dimension] = t->size[dimension];
temp_->stride[dimension] = 0;
- THTensor_(cadd)(temp_, temp_, 1, t);
+ TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data + *t_data;);
THTensor_(free)(temp_);
}
}
@@ -1675,7 +1678,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension)
temp_->size[dimension] = t->size[dimension];
temp_->stride[dimension] = 0;
- THTensor_(cmul)(temp_, temp_, t);
+ TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data * *t_data;);
THTensor_(free)(temp_);
}
}
@@ -2749,35 +2752,11 @@ void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight)
void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension)
{
- THLongStorage *dim;
-
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d",
dimension + TH_INDEX_BASE);
- dim = THTensor_(newSizeOf)(t);
- THLongStorage_set(dim, dimension, 1);
- THTensor_(resize)(r_, dim, NULL);
- THLongStorage_free(dim);
-
- // two implementations optimized for data locality
- if (t->stride[dimension] == 1) {
- TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension,
- accreal sum = 0;
- long i;
- for(i = 0; i < t_size; i++)
- sum += t_data[i*t_stride];
- *r__data = (real)sum/t_size;);
- } else {
- THTensor_(zero)(r_);
- THTensor *temp_ = THTensor_(newWithTensor)(r_);
- // r_.expand_as(t)
- temp_->size[dimension] = t->size[dimension];
- temp_->stride[dimension] = 0;
-
- THTensor_(cadd)(temp_, temp_, 1, t);
- THTensor_(free)(temp_);
- THTensor_(div)(r_, r_, t->size[dimension]);
- }
+ THTensor_(sum)(r_, t, dimension);
+ THTensor_(div)(r_, r_, t->size[dimension]);
}
void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag)
diff --git a/test/test.lua b/test/test.lua
index ed00c88..943aaa1 100644
--- a/test/test.lua
+++ b/test/test.lua
@@ -343,6 +343,7 @@ function torchtest.round()
end
function torchtest.max() -- torch.max([resval, resind,] x [,dim])
+
-- torch.max( x )
-- contiguous
local m1 = torch.randn(100,100)
@@ -357,6 +358,7 @@ function torchtest.max() -- torch.max([resval, resind,] x [,dim])
end
local err = res1 - res2
mytester:assertlt(err, precision, 'error in torch.max - contiguous')
+
-- non-contiguous
local m1 = torch.randn(10,10,10)
local m2 = m1[{{}, 4, {}}]
@@ -371,33 +373,34 @@ function torchtest.max() -- torch.max([resval, resind,] x [,dim])
end
local err = res1 - res2
mytester:assertlt(err, precision, 'error in torch.max - non-contiguous')
+
-- torch.max([resval, resind,] x ,dim])
- local m1 = torch.randn(100,100)
- local res1val, res1ind = torch.max(m1, 2)
- local res2val = res1val:clone():zero()
- local res2ind = res1ind:clone():zero()
- for i=1, m1:size(1) do
- res2val[i] = m1[i][1]
- res2ind[i] = 1
- for j=1, m1:size(2) do
- if m1[i][j] > res2val[i][1] then
- res2val[i] = m1[i][j]
- res2ind[i] = j
+ function lua_max(t, dim)
+ assert(t:nDimension() == 2)
+ max_val = t:narrow(dim, 1, 1):clone()
+ max_ind = t:narrow(dim, 1, 1):clone():long():fill(1)
+ other = 3 - dim
+ for i = 1, t:size(other) do
+ for j = 1, t:size(dim) do
+ val = t:select(other, i):select(dim, j)
+ max = max_val:select(other, i):select(dim, 1)
+ if val > max then
+ max_val:select(other, i):fill(val)
+ max_ind:select(other, i):fill(j)
+ end
end
end
+ return max_val, max_ind
end
- local errval = res1val:clone():zero()
- for i = 1, res1val:size(1) do
- errval[i] = math.abs(res1val[i][1] - res2val[i][1])
- mytester:asserteq(res1ind[i][1], res2ind[i][1], 'error in torch.max - non-contiguous')
- end
- local maxerr = 0
- for i = 1, errval:size(1) do
- if errval[i][1] > maxerr then
- maxerr = errval[i]
- end
+
+ local m1 = torch.randn(100,100)
+ for dim = 1,2 do
+ local res1val, res1ind = torch.max(m1, dim)
+ local res2val, res2ind = lua_max(m1, dim)
+ mytester:asserteq((res1val-res2val):abs():max(), 0, 'error in torch.max')
+ mytester:asserteq((res1ind-res2ind):abs():max(), 0, 'error in torch.max')
end
- mytester:assertlt(maxerr, precision, 'error in torch.max - non-contiguous')
+
-- NaNs
for index in pairs{1, 5, 100} do
local m1 = torch.randn(100)
@@ -439,33 +442,34 @@ function torchtest.min() -- torch.min([resval, resind,] x [,dim])
end
local err = res1 - res2
mytester:assertlt(err, precision, 'error in torch.min - non-contiguous')
- -- torch.min([resval, resind,] x ,dim])
- local m1 = torch.randn(100,100)
- local res1val, res1ind = torch.min(m1, 2)
- local res2val = res1val:clone():zero()
- local res2ind = res1ind:clone():zero()
- for i=1, m1:size(1) do
- res2val[i] = m1[i][1]
- res2ind[i] = 1
- for j=1, m1:size(2) do
- if m1[i][j] < res2val[i][1] then
- res2val[i] = m1[i][j]
- res2ind[i] = j
+
+ -- torch.max([resval, resind,] x ,dim])
+ function lua_min(t, dim)
+ assert(t:nDimension() == 2)
+ max_val = t:narrow(dim, 1, 1):clone()
+ max_ind = t:narrow(dim, 1, 1):clone():long():fill(1)
+ other = 3 - dim
+ for i = 1, t:size(other) do
+ for j = 1, t:size(dim) do
+ val = t:select(other, i):select(dim, j)
+ max = max_val:select(other, i):select(dim, 1)
+ if val < max then
+ max_val:select(other, i):fill(val)
+ max_ind:select(other, i):fill(j)
+ end
end
end
+ return max_val, max_ind
end
- local errval = res1val:clone():zero()
- for i = 1, res1val:size(1) do
- errval[i] = math.abs(res1val[i][1] - res2val[i][1])
- mytester:asserteq(res1ind[i][1], res2ind[i][1], 'error in torch.min - non-contiguous')
- end
- local minerr = 0
- for i = 1, errval:size(1) do
- if errval[i][1] < minerr then
- minerr = errval[i]
- end
+
+ local m1 = torch.randn(100,100)
+ for dim = 1,2 do
+ local res1val, res1ind = torch.min(m1, dim)
+ local res2val, res2ind = lua_min(m1, dim)
+ mytester:asserteq((res1val-res2val):abs():max(), 0, 'error in torch.max')
+ mytester:asserteq((res1ind-res2ind):abs():max(), 0, 'error in torch.max')
end
- mytester:assertlt(minerr, precision, 'error in torch.min - non-contiguous')
+
-- NaNs
for index in pairs{1, 5, 100} do
local m1 = torch.randn(100)
@@ -1533,6 +1537,16 @@ function torchtest.sum()
local mxx = torch.Tensor()
torch.sum(mxx,x,2)
mytester:asserteq(maxdiff(mx,mxx),0,'torch.sum value')
+
+ local y = torch.rand(5, 5, 5)
+ for i=1,3 do
+ local a = y:sum(i)
+ local b = y:narrow(i, 1, 1):clone():zero()
+ for j = 1, 5 do
+ b:add(y:narrow(i, j, 1))
+ end
+ mytester:asserteq(maxdiff(a, b), 0, 'torch.sum value')
+ end
end
function torchtest.prod()
local x = torch.rand(msize,msize)
@@ -1540,6 +1554,16 @@ function torchtest.prod()
local mxx = torch.Tensor()
torch.prod(mxx,x,2)
mytester:asserteq(maxdiff(mx,mxx),0,'torch.prod value')
+
+ local y = torch.rand(5, 5, 5)
+ for i=1,3 do
+ local a = y:prod(i)
+ local b = y:narrow(i, 1, 1):clone():fill(1)
+ for j = 1, 5 do
+ b:cmul(y:narrow(i, j, 1))
+ end
+ mytester:asserteq(maxdiff(a, b), 0, 'torch.sum value')
+ end
end
function torchtest.cumsum()
local x = torch.rand(msize,msize)