diff options
author | Adam Lerer <alerer@fb.com> | 2017-02-10 20:18:11 +0300 |
---|---|---|
committer | Adam Lerer <alerer@fb.com> | 2017-02-28 21:23:36 +0300 |
commit | d46ff6fdeb3f3284e49da4348927c1ce8d7e758e (patch) | |
tree | 8ff0f406252ea499346bd3ada90290bec3e44292 | |
parent | c8f3c0b17582ab5f85a507c82d6a37c1d08f1dee (diff) |
address comments and add tests
-rw-r--r-- | lib/TH/generic/THTensorMath.c | 35 | ||||
-rw-r--r-- | test/test.lua | 114 |
2 files changed, 76 insertions, 73 deletions
diff --git a/lib/TH/generic/THTensorMath.c b/lib/TH/generic/THTensorMath.c index 6b483dd..bc48c6a 100644 --- a/lib/TH/generic/THTensorMath.c +++ b/lib/TH/generic/THTensorMath.c @@ -1538,6 +1538,9 @@ void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int *tempValues__data = *t_data; *tempIndices__data = tempIndices__counter[dimension]; }); + + THTensor_(free)(tempValues_); + THLongTensor_free(tempIndices_); } } @@ -1643,7 +1646,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension) temp_->size[dimension] = t->size[dimension]; temp_->stride[dimension] = 0; - THTensor_(cadd)(temp_, temp_, 1, t); + TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data + *t_data;); THTensor_(free)(temp_); } } @@ -1675,7 +1678,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension) temp_->size[dimension] = t->size[dimension]; temp_->stride[dimension] = 0; - THTensor_(cmul)(temp_, temp_, t); + TH_TENSOR_APPLY2(real, temp_, real, t, *temp__data = *temp__data * *t_data;); THTensor_(free)(temp_); } } @@ -2749,35 +2752,11 @@ void THTensor_(lerp)(THTensor *r_, THTensor *a, THTensor *b, real weight) void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension) { - THLongStorage *dim; - THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "invalid dimension %d", dimension + TH_INDEX_BASE); - dim = THTensor_(newSizeOf)(t); - THLongStorage_set(dim, dimension, 1); - THTensor_(resize)(r_, dim, NULL); - THLongStorage_free(dim); - - // two implementations optimized for data locality - if (t->stride[dimension] == 1) { - TH_TENSOR_DIM_APPLY2(real, t, real, r_, dimension, - accreal sum = 0; - long i; - for(i = 0; i < t_size; i++) - sum += t_data[i*t_stride]; - *r__data = (real)sum/t_size;); - } else { - THTensor_(zero)(r_); - THTensor *temp_ = THTensor_(newWithTensor)(r_); - // r_.expand_as(t) - temp_->size[dimension] = t->size[dimension]; - temp_->stride[dimension] = 0; - - THTensor_(cadd)(temp_, temp_, 1, t); - THTensor_(free)(temp_); - THTensor_(div)(r_, r_, t->size[dimension]); - } + THTensor_(sum)(r_, t, dimension); + THTensor_(div)(r_, r_, t->size[dimension]); } void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int flag) diff --git a/test/test.lua b/test/test.lua index ed00c88..943aaa1 100644 --- a/test/test.lua +++ b/test/test.lua @@ -343,6 +343,7 @@ function torchtest.round() end function torchtest.max() -- torch.max([resval, resind,] x [,dim]) + -- torch.max( x ) -- contiguous local m1 = torch.randn(100,100) @@ -357,6 +358,7 @@ function torchtest.max() -- torch.max([resval, resind,] x [,dim]) end local err = res1 - res2 mytester:assertlt(err, precision, 'error in torch.max - contiguous') + -- non-contiguous local m1 = torch.randn(10,10,10) local m2 = m1[{{}, 4, {}}] @@ -371,33 +373,34 @@ function torchtest.max() -- torch.max([resval, resind,] x [,dim]) end local err = res1 - res2 mytester:assertlt(err, precision, 'error in torch.max - non-contiguous') + -- torch.max([resval, resind,] x ,dim]) - local m1 = torch.randn(100,100) - local res1val, res1ind = torch.max(m1, 2) - local res2val = res1val:clone():zero() - local res2ind = res1ind:clone():zero() - for i=1, m1:size(1) do - res2val[i] = m1[i][1] - res2ind[i] = 1 - for j=1, m1:size(2) do - if m1[i][j] > res2val[i][1] then - res2val[i] = m1[i][j] - res2ind[i] = j + function lua_max(t, dim) + assert(t:nDimension() == 2) + max_val = t:narrow(dim, 1, 1):clone() + max_ind = t:narrow(dim, 1, 1):clone():long():fill(1) + other = 3 - dim + for i = 1, t:size(other) do + for j = 1, t:size(dim) do + val = t:select(other, i):select(dim, j) + max = max_val:select(other, i):select(dim, 1) + if val > max then + max_val:select(other, i):fill(val) + max_ind:select(other, i):fill(j) + end end end + return max_val, max_ind end - local errval = res1val:clone():zero() - for i = 1, res1val:size(1) do - errval[i] = math.abs(res1val[i][1] - res2val[i][1]) - mytester:asserteq(res1ind[i][1], res2ind[i][1], 'error in torch.max - non-contiguous') - end - local maxerr = 0 - for i = 1, errval:size(1) do - if errval[i][1] > maxerr then - maxerr = errval[i] - end + + local m1 = torch.randn(100,100) + for dim = 1,2 do + local res1val, res1ind = torch.max(m1, dim) + local res2val, res2ind = lua_max(m1, dim) + mytester:asserteq((res1val-res2val):abs():max(), 0, 'error in torch.max') + mytester:asserteq((res1ind-res2ind):abs():max(), 0, 'error in torch.max') end - mytester:assertlt(maxerr, precision, 'error in torch.max - non-contiguous') + -- NaNs for index in pairs{1, 5, 100} do local m1 = torch.randn(100) @@ -439,33 +442,34 @@ function torchtest.min() -- torch.min([resval, resind,] x [,dim]) end local err = res1 - res2 mytester:assertlt(err, precision, 'error in torch.min - non-contiguous') - -- torch.min([resval, resind,] x ,dim]) - local m1 = torch.randn(100,100) - local res1val, res1ind = torch.min(m1, 2) - local res2val = res1val:clone():zero() - local res2ind = res1ind:clone():zero() - for i=1, m1:size(1) do - res2val[i] = m1[i][1] - res2ind[i] = 1 - for j=1, m1:size(2) do - if m1[i][j] < res2val[i][1] then - res2val[i] = m1[i][j] - res2ind[i] = j + + -- torch.max([resval, resind,] x ,dim]) + function lua_min(t, dim) + assert(t:nDimension() == 2) + max_val = t:narrow(dim, 1, 1):clone() + max_ind = t:narrow(dim, 1, 1):clone():long():fill(1) + other = 3 - dim + for i = 1, t:size(other) do + for j = 1, t:size(dim) do + val = t:select(other, i):select(dim, j) + max = max_val:select(other, i):select(dim, 1) + if val < max then + max_val:select(other, i):fill(val) + max_ind:select(other, i):fill(j) + end end end + return max_val, max_ind end - local errval = res1val:clone():zero() - for i = 1, res1val:size(1) do - errval[i] = math.abs(res1val[i][1] - res2val[i][1]) - mytester:asserteq(res1ind[i][1], res2ind[i][1], 'error in torch.min - non-contiguous') - end - local minerr = 0 - for i = 1, errval:size(1) do - if errval[i][1] < minerr then - minerr = errval[i] - end + + local m1 = torch.randn(100,100) + for dim = 1,2 do + local res1val, res1ind = torch.min(m1, dim) + local res2val, res2ind = lua_min(m1, dim) + mytester:asserteq((res1val-res2val):abs():max(), 0, 'error in torch.max') + mytester:asserteq((res1ind-res2ind):abs():max(), 0, 'error in torch.max') end - mytester:assertlt(minerr, precision, 'error in torch.min - non-contiguous') + -- NaNs for index in pairs{1, 5, 100} do local m1 = torch.randn(100) @@ -1533,6 +1537,16 @@ function torchtest.sum() local mxx = torch.Tensor() torch.sum(mxx,x,2) mytester:asserteq(maxdiff(mx,mxx),0,'torch.sum value') + + local y = torch.rand(5, 5, 5) + for i=1,3 do + local a = y:sum(i) + local b = y:narrow(i, 1, 1):clone():zero() + for j = 1, 5 do + b:add(y:narrow(i, j, 1)) + end + mytester:asserteq(maxdiff(a, b), 0, 'torch.sum value') + end end function torchtest.prod() local x = torch.rand(msize,msize) @@ -1540,6 +1554,16 @@ function torchtest.prod() local mxx = torch.Tensor() torch.prod(mxx,x,2) mytester:asserteq(maxdiff(mx,mxx),0,'torch.prod value') + + local y = torch.rand(5, 5, 5) + for i=1,3 do + local a = y:prod(i) + local b = y:narrow(i, 1, 1):clone():fill(1) + for j = 1, 5 do + b:cmul(y:narrow(i, j, 1)) + end + mytester:asserteq(maxdiff(a, b), 0, 'torch.sum value') + end end function torchtest.cumsum() local x = torch.rand(msize,msize) |