From d34673d67bc8cefd7055edf55a2c88d28b3ab158 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Tue, 9 May 2017 13:50:15 -0700 Subject: Add Infer Size N, for expansion of fused operations. --- lib/TH/THStorage.c | 43 +++++++++++++++++++++++++++++++++++++++++++ lib/TH/THStorage.h | 1 + 2 files changed, 44 insertions(+) diff --git a/lib/TH/THStorage.c b/lib/TH/THStorage.c index 7636e4b..b7d4a5d 100644 --- a/lib/TH/THStorage.c +++ b/lib/TH/THStorage.c @@ -104,6 +104,49 @@ TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long di return 0; } +TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors) { + THArgCheck(n > 0, 2, "n must be greater than 0"); + THArgCheck(sizes != NULL, 1, "sizesA must not be null"); + THArgCheck(dims != NULL, 1, "dims must not be null"); + + ptrdiff_t ndim = 0; + for (int j = 0; j < n; ++j) { + THArgCheck(sizes[ j ] != NULL, 1, "size %d must not be null", j); + THArgCheck(dims[ j ], 1, "Can't expand empty tensor %d", j); + ptrdiff_t ndim = dims[ j ] > ndim ? dims[ j ] : ndim; + } + + long *expandedSizes = THAlloc(sizeof(long)*ndim); + + for (long i = ndim - 1; i >= 0; --i) { + long max_dim_size = 1; + long offset = ndim - 1 - i; + for (int j = 0; j < n; ++j) { + long dim = dims[ j ] - 1 - offset; + long size = (dim >= 0) ? sizes[ i ][ dim ] : 1; + if (size != max_dim_size) { + if (max_dim_size == 1){ + max_dim_size = size; + } else if (size == 1) { + // we'll expand, nothing to do + } else { + THFree(expandedSizes); + if (raiseErrors) { + THError("The size of tensor %i (%ld) must match the expanded size of tensor (%ld) at " + "non-singleton dimension %ld.", j, size, max_dim_size, i); + } + return -1; + } + } + } + expandedSizes[ i ] = max_dim_size; + } + THLongStorage_resize(output, ndim); + memcpy(THLongStorage_data(output), expandedSizes, sizeof(long)*ndim); + THFree(expandedSizes); + return 0; +} + TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors) { ptrdiff_t ndim = THLongStorage_size(sizes); diff --git a/lib/TH/THStorage.h b/lib/TH/THStorage.h index 9cbc452..7aab593 100644 --- a/lib/TH/THStorage.h +++ b/lib/TH/THStorage.h @@ -31,6 +31,7 @@ typedef struct { TH_API THDescBuff THLongStorage_sizeDesc(const THLongStorage *size); TH_API THLongStorage *THLongStorage_newInferSize(THLongStorage *size, ptrdiff_t nElement); TH_API int THLongStorage_inferSize2(THLongStorage *output, long *sizesA, long dimsA, long *sizesB, long dimsB, int raiseErrors); +TH_API int THLongStorage_inferSizeN(THLongStorage *output, int n, long **sizes, long *dims, int raiseErrors); TH_API int THLongStorage_inferExpandGeometry(long *tensorSizes, long *tensorStrides, long tensorDim, THLongStorage *sizes, long **esz, long **est, int raiseErrors); #endif -- cgit v1.2.3