diff options
-rw-r--r-- | CMakeLists.txt | 20 | ||||
-rw-r--r-- | cmake/FindCUDNN.cmake | 229 | ||||
-rw-r--r-- | src/examples/mnist/model_lenet.h | 4 | ||||
-rw-r--r-- | src/graph/expression_operators.cu | 152 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 83 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 158 | ||||
-rw-r--r-- | src/layers/convolution.cu | 11 | ||||
-rw-r--r-- | src/models/model_factory.cpp | 6 | ||||
-rw-r--r-- | src/tensors/tensor_operators.h | 5 |
9 files changed, 399 insertions, 269 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index c943628d..54774b08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,17 @@ set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) find_package(CUDA "8.0" REQUIRED) if(CUDA_FOUND) set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY}) + + if(USE_CUDNN) + find_package(CUDNN "7.0") + if(CUDNN_FOUND) + include_directories(${CUDNN_INCLUDE_DIRS}) + set(EXT_LIBS ${EXT_LIBS} ${CUDNN_LIBRARIES}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUDNN") + LIST(APPEND CUDA_NVCC_FLAGS -DCUDNN; ) + endif(CUDNN_FOUND) +endif(USE_CUDNN) + endif(CUDA_FOUND) if (CMAKE_BUILD_TYPE STREQUAL "Debug") @@ -40,15 +51,6 @@ endif(CMAKE_BUILD_TYPE STREQUAL "Debug") list(REMOVE_DUPLICATES CUDA_NVCC_FLAGS) set(CUDA_PROPAGATE_HOST_FLAGS OFF) -if(USE_CUDNN) - find_package(CUDNN "7.0") - if(CUDNN_FOUND) - include_directories(${CUDNN_INCLUDE_DIRS}) - set(EXT_LIBS ${EXT_LIBS} ${CUDNN_LIBRARIES}) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUDNN") - LIST(APPEND CUDA_NVCC_FLAGS -DCUDNN; ) - endif(CUDNN_FOUND) -endif(USE_CUDNN) find_package(Tcmalloc) if(Tcmalloc_FOUND) diff --git a/cmake/FindCUDNN.cmake b/cmake/FindCUDNN.cmake index c287ace4..37388d30 100644 --- a/cmake/FindCUDNN.cmake +++ b/cmake/FindCUDNN.cmake @@ -1,51 +1,180 @@ +# Distributed under the OSI-approved BSD 3-Clause License. See accompanying +# file Copyright.txt or https://cmake.org/licensing for details. +#.rst: +# FindCUDNN +# ------- # -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -FIND_PATH(CUDNN_INCLUDE_DIR NAME "cudnn.h" PATHS "$ENV{CMAKE_INCLUDE_PATH}") -FIND_LIBRARY(CUDNN_LIBRARIES NAME "libcudnn.so" PATHS "$ENV{CMAKE_LIBRARY_PATH}") - -#message("cudnn include path:${CUDNN_INCLUDE_DIR} lib path: ${CUDNN_LIBRARIES}") -#message("env include path:$ENV{CUDNN_DIR} next: $ENV{CMAKE_INCLUDE_PATH}") -INCLUDE(FindPackageHandleStandardArgs) -find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_INCLUDE_DIR CUDNN_LIBRARIES) - -IF(CUDNN_FOUND) - FILE(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) - STRING(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" - CUDNN_MAJOR_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") - STRING(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" - CUDNN_MAJOR_VERSION "${CUDNN_MAJOR_VERSION}") - STRING(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" - CUDNN_MINOR_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") - STRING(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" - CUDNN_MINOR_VERSION "${CUDNN_MINOR_VERSION}") - STRING(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" - CUDNN_PATCH_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") - STRING(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" - CUDNN_PATCH_VERSION "${CUDNN_PATCH_VERSION}") - - IF(NOT CUDNN_MAJOR_VERSION) - SET(CUDNN_VERSION "???") - ELSE() - MATH(EXPR CUDNN_VERSION "${CUDNN_MAJOR_VERSION} * 1000 + ${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCH_VERSION}") - ENDIF() - MESSAGE(STATUS "Found Cudnn_${CUDNN_VERSION} at ${CUDNN_INCLUDE_DIR} ${CUDNN_LIBRARIES}") - MARK_AS_ADVANCED(CUDNN_INCLUDE_DIR CUDNN_LIBRARIES) - -ENDIF() +# Find CUDNN library +# +# Valiables that affect result: +# <VERSION>, <REQUIRED>, <QUIETLY>: as usual +# +# <EXACT> : as usual, plus we do find '5.1' version if you wanted '5' +# (not if you wanted '5.0', as usual) +# +# Result variables +# ^^^^^^^^^^^^^^^^ +# +# This module will set the following variables in your project: +# +# ``CUDNN_INCLUDE`` +# where to find cudnn.h. +# ``CUDNN_LIBRARY`` +# the libraries to link against to use CUDNN. +# ``CUDNN_FOUND`` +# If false, do not try to use CUDNN. +# ``CUDNN_VERSION`` +# Version of the CUDNN library we looked for +# +# Exported functions +# ^^^^^^^^^^^^^^^^ +# function(CUDNN_INSTALL version __dest_libdir [__dest_incdir]) +# This function will try to download and install CUDNN. +# CUDNN5 and CUDNN6 are supported. +# +# + +function(CUDNN_INSTALL version dest_libdir dest_incdir dest_bindir) + message(STATUS "CUDNN_INSTALL: Installing CUDNN ${version}, lib:${dest_libdir}, inc:${dest_incdir}, bin:${dest_bindir}") + string(REGEX REPLACE "-rc$" "" version_base "${version}") + set(tar_libdir cuda/lib64) + set(tar_incdir cuda/include) + + if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + set(url_extension tgz) + if("${CMAKE_SYSTEM_PROCESSOR}" STREQUAL "x86_64") + set(url_arch_name linux-x64 ) + elseif("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "ppc") + set(url_arch_name linux-ppc64le ) + # TX1 has to be installed via JetPack + endif() + elseif (APPLE) + set(url_extension tgz) + set(tar_libdir cuda/lib) + set(url_arch_name osx-x64) + elseif(WIN32) + set(url_extension zip) + set(tar_bindir cuda/bin) + set(tar_libdir cuda/lib/x64) + if(CMAKE_SYSTEM_VERSION MATCHES "10") + set(url_arch_name windows10-x64) + else() + set(url_arch_name windows7-x64) + endif() + endif() + + # Download and install CUDNN locally if not found on the system + if(url_arch_name) + set(download_dir ${CMAKE_CURRENT_BINARY_DIR}/downloads/cudnn${version}) + file(MAKE_DIRECTORY ${download_dir}) + set(cudnn_filename cudnn-${CUDA_VERSION}-${url_arch_name}-v${version}.${url_extension}) + set(base_url http://developer.download.nvidia.com/compute/redist/cudnn) + set(cudnn_url ${base_url}/v${version_base}/${cudnn_filename}) + set(cudnn_file ${download_dir}/${cudnn_filename}) + + if(NOT EXISTS ${cudnn_file}) + message(STATUS "Downloading CUDNN library from NVIDIA...") + file(DOWNLOAD ${cudnn_url} ${cudnn_file} + SHOW_PROGRESS STATUS cudnn_status + ) + execute_process(COMMAND ${CMAKE_COMMAND} -E tar xzvf ${cudnn_file} WORKING_DIRECTORY ${download_dir} RESULT_VARIABLE cudnn_status) + + if(NOT "${cudnn_status}" MATCHES "0") + message(STATUS "Was not able to download CUDNN from ${cudnn_url}. Please install CuDNN manually from https://developer.nvidia.com/cuDNN") + endif() + endif() + + if(dest_bindir AND tar_bindir) + file(COPY ${download_dir}/${tar_bindir}/ DESTINATION ${dest_bindir}) + endif() + + if(dest_incdir) + file(COPY ${download_dir}/${tar_incdir}/ DESTINATION ${dest_incdir}) + endif() + + file(COPY ${download_dir}/${tar_libdir}/ DESTINATION ${dest_libdir} ) + + get_filename_component(dest_dir ${dest_libdir} DIRECTORY) + + set(CUDNN_ROOT_DIR ${dest_dir} PARENT_SCOPE) + unset(CUDNN_LIBRARY CACHE) + unset(CUDNN_INCLUDE_DIR CACHE) + + endif(url_arch_name) +endfunction() + +##################################################### + +find_package(PkgConfig) +pkg_check_modules(PC_CUDNN QUIET CUDNN) + +get_filename_component(__libpath_cudart "${CUDA_CUDART_LIBRARY}" PATH) + +# We use major only in library search as major/minor is not entirely consistent among platforms. +# Also, looking for exact minor version of .so is in general not a good idea. +# More strict enforcement of minor/patch version is done if/when the header file is examined. +if(CUDNN_FIND_VERSION_EXACT) + SET(__cudnn_ver_suffix ".${CUDNN_FIND_VERSION_MAJOR}") + SET(__cudnn_lib_win_name cudnn64_${CUDNN_FIND_VERSION_MAJOR}) +else() + SET(__cudnn_lib_win_name cudnn64) +endif() + +find_library(CUDNN_LIBRARY + NAMES libcudnn.so${__cudnn_ver_suffix} libcudnn${__cudnn_ver_suffix}.dylib ${__cudnn_lib_win_name} + PATHS $ENV{LD_LIBRARY_PATH} ${__libpath_cudart} ${CUDNN_ROOT_DIR} ${PC_CUDNN_LIBRARY_DIRS} ${CMAKE_INSTALL_PREFIX} + PATH_SUFFIXES lib lib64 bin + DOC "CUDNN library." ) + +if(CUDNN_LIBRARY) + SET(CUDNN_MAJOR_VERSION ${CUDNN_FIND_VERSION_MAJOR}) + set(CUDNN_VERSION ${CUDNN_MAJOR_VERSION}) + get_filename_component(__found_cudnn_root ${CUDNN_LIBRARY} PATH) + find_path(CUDNN_INCLUDE_DIR + NAMES cudnn.h + HINTS ${PC_CUDNN_INCLUDE_DIRS} ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_INCLUDE} ${__found_cudnn_root} + PATH_SUFFIXES include + DOC "Path to CUDNN include directory." ) +endif() + +if(CUDNN_LIBRARY AND CUDNN_INCLUDE_DIR) + file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) + string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" + CUDNN_MAJOR_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" + CUDNN_MAJOR_VERSION "${CUDNN_MAJOR_VERSION}") + string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" + CUDNN_MINOR_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" + CUDNN_MINOR_VERSION "${CUDNN_MINOR_VERSION}") + string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" + CUDNN_PATCH_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") + string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" + CUDNN_PATCH_VERSION "${CUDNN_PATCH_VERSION}") + set(CUDNN_VERSION ${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}) +endif() + +if(CUDNN_MAJOR_VERSION) + ## Fixing the case where 5.1 does not fit 'exact' 5. + if(CUDNN_FIND_VERSION_EXACT AND NOT CUDNN_FIND_VERSION_MINOR) + if("${CUDNN_MAJOR_VERSION}" STREQUAL "${CUDNN_FIND_VERSION_MAJOR}") + set(CUDNN_VERSION ${CUDNN_FIND_VERSION}) + endif() + endif() +else() + # Try to set CUDNN version from config file + set(CUDNN_VERSION ${PC_CUDNN_CFLAGS_OTHER}) +endif() + +find_package_handle_standard_args( + CUDNN + FOUND_VAR CUDNN_FOUND + REQUIRED_VARS CUDNN_LIBRARY + VERSION_VAR CUDNN_VERSION + ) + +if(CUDNN_FOUND) + set(CUDNN_LIBRARIES ${CUDNN_LIBRARY}) + set(CUDNN_INCLUDE_DIRS ${CUDNN_INCLUDE_DIR}) + set(CUDNN_DEFINITIONS ${PC_CUDNN_CFLAGS_OTHER}) +endif() diff --git a/src/examples/mnist/model_lenet.h b/src/examples/mnist/model_lenet.h index 8dd96634..968ceaf3 100644 --- a/src/examples/mnist/model_lenet.h +++ b/src/examples/mnist/model_lenet.h @@ -33,9 +33,6 @@ protected: // Construct hidden layers - ABORT("TEMPORARY"); - - /* auto conv_1 = convolution(g) ("prefix", "conv_1") ("kernel-dims", std::make_pair(3,3)) @@ -102,7 +99,6 @@ protected: // Define a top-level node for inference return logsoftmax(last); } - */; } }; } diff --git a/src/graph/expression_operators.cu b/src/graph/expression_operators.cu index 0cb65c1e..be40a0d4 100644 --- a/src/graph/expression_operators.cu +++ b/src/graph/expression_operators.cu @@ -333,81 +333,81 @@ Expr shift(Expr a, Shape shift) { // return Expression<LexicalProbNodeOp>(logits, att, eps, lf); //} -//Expr avg_pooling( -// Expr x, -// int height, -// int width, -// int padHeight, -// int padWidth, -// int strideHeight, -// int strideWidth) { -// return Expression<PoolingOp>(x, -// height, -// width, -// padHeight, -// padWidth, -// strideHeight, -// strideWidth, -// "avg"); -//} -// -//Expr max_pooling( -// Expr x, -// int height, -// int width, -// int padHeight, -// int padWidth, -// int strideHeight, -// int strideWidth) -//{ -// return Expression<PoolingOp>(x, -// height, -// width, -// padHeight, -// padWidth, -// strideHeight, -// strideWidth, -// "max"); -//} -// -//Expr convert2cudnnFormat(Expr x) { -// int numWords = x->shape()[0]; -// int numExamples = x->shape()[1]; -// int embSize = x->shape()[2]; -// -// std::vector<size_t> newIndeces; -// for (int b = 0; b < numExamples; ++b) { -// for (int t = 0; t < numWords; ++t) { -// newIndeces.push_back((t * numExamples) + b); -// } -// } -// -// auto xRows = reshape(x, {x->shape()[0] * x ->shape()[1], x->shape()[2]}); -// -// Shape outShape({numExamples, 1, numWords, embSize}); -// return reshape(rows(xRows, newIndeces), outShape); -//} -// -//Expr convertFromcudnnFormat(Expr x) { -// int batchDim = x->shape()[0]; -// int sentenceDim = x->shape()[2]; -// int embSize = x->shape()[3]; -// -// auto reshapedX = reshape(x, {batchDim * sentenceDim, embSize}); -// -// std::vector<size_t> newIndeces; -// for (int t = 0; t < sentenceDim; ++t) { -// for (int b = 0; b < batchDim; ++b) { -// newIndeces.push_back(b * sentenceDim + t); -// } -// } -// -// Shape shape({batchDim, sentenceDim, embSize}); -// return reshape(rows(reshapedX, newIndeces), shape); -//} -// -//Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) { -// return Expression<PoolingWithMaskingOp>(x, mask, width, isEven); -//} +Expr avg_pooling( + Expr x, + int height, + int width, + int padHeight, + int padWidth, + int strideHeight, + int strideWidth) { + return Expression<PoolingOp>(x, + height, + width, + padHeight, + padWidth, + strideHeight, + strideWidth, + "avg"); +} + +Expr max_pooling( + Expr x, + int height, + int width, + int padHeight, + int padWidth, + int strideHeight, + int strideWidth) +{ + return Expression<PoolingOp>(x, + height, + width, + padHeight, + padWidth, + strideHeight, + strideWidth, + "max"); +} + +Expr convert2cudnnFormat(Expr x) { + int numWords = x->shape()[0]; + int numExamples = x->shape()[1]; + int embSize = x->shape()[2]; + + std::vector<size_t> newIndeces; + for (int b = 0; b < numExamples; ++b) { + for (int t = 0; t < numWords; ++t) { + newIndeces.push_back((t * numExamples) + b); + } + } + + auto xRows = reshape(x, {x->shape()[0] * x ->shape()[1], x->shape()[2]}); + + Shape outShape({numExamples, 1, numWords, embSize}); + return reshape(rows(xRows, newIndeces), outShape); +} + +Expr convertFromcudnnFormat(Expr x) { + int batchDim = x->shape()[0]; + int sentenceDim = x->shape()[2]; + int embSize = x->shape()[3]; + + auto reshapedX = reshape(x, {batchDim * sentenceDim, embSize}); + + std::vector<size_t> newIndeces; + for (int t = 0; t < sentenceDim; ++t) { + for (int b = 0; b < batchDim; ++b) { + newIndeces.push_back(b * sentenceDim + t); + } + } + + Shape shape({batchDim, sentenceDim, embSize}); + return reshape(rows(reshapedX, newIndeces), shape); +} + +Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) { + return Expression<PoolingWithMaskingOp>(x, mask, width, isEven); +} } diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 9b0655ec..a2a47a61 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -5,6 +5,7 @@ #include "graph/node.h" #include "functional/functional.h" #include "tensors/tensor_operators.h" +#include "tensors/gpu/cudnn_wrappers.h" namespace marian { @@ -743,46 +744,46 @@ struct HighwayNodeOp : public NaryNodeOp { const std::string type() { return "highway"; } }; -//class ConvolutionOp : public NaryNodeOp { -//public: -// ConvolutionOp( -// const std::vector<Expr>& nodes, -// int hPad = 0, -// int wPad = 0, -// int hStride = 1, -// int wStride = 1) -// : NaryNodeOp(nodes), -// conv_(nodes[1]->shape(), -// nodes[2]->shape(), -// hPad, -// wPad, -// hStride, -// wStride) { -// conv_.getOutputShape(nodes[0]->shape(), shape_); -// } -// -// NodeOps forwardOps() { -// return {NodeOp(conv_.forward( -// child(0)->val(), -// child(1)->val(), -// child(2)->val(), -// val_))}; -// } -// -// NodeOps backwardOps() { -// return {NodeOp(conv_.backward( -// child(0)->val(), -// child(0)->grad(), -// child(1)->val(), -// child(1)->grad(), -// child(2)->grad(), -// adj_))}; -// } -// -// const std::string type() { return "layer_convolution"; } -// -//protected: -// ConvolutionWrapper conv_; -//}; +class ConvolutionOp : public NaryNodeOp { +public: + ConvolutionOp( + const std::vector<Expr>& nodes, + int hPad = 0, + int wPad = 0, + int hStride = 1, + int wStride = 1) + : NaryNodeOp(nodes), + conv_(nodes[1]->shape(), + nodes[2]->shape(), + hPad, + wPad, + hStride, + wStride) { + conv_.getOutputShape(nodes[0]->shape(), shape_); + } + + NodeOps forwardOps() { + return {NodeOp(conv_.forward( + child(0)->val(), + child(1)->val(), + child(2)->val(), + val_))}; + } + + NodeOps backwardOps() { + return {NodeOp(conv_.backward( + child(0)->val(), + child(0)->grad(), + child(1)->val(), + child(1)->grad(), + child(2)->grad(), + adj_))}; + } + + const std::string type() { return "layer_convolution"; } + +protected: + ConvolutionWrapper conv_; +}; } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 07c06fda..0a76471b 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -1055,84 +1055,84 @@ struct ShiftNodeOp : public UnaryNodeOp { // Ptr<sparse::CSR> lf_; //}; -//class PoolingOp : public UnaryNodeOp { -//public: -// PoolingOp(Expr x, -// int height, -// int width, -// int padHeight, -// int padWidth, -// int strideHeight, -// int strideWidth, -// std::string mode) -// : UnaryNodeOp(x), -// pooling_(height, -// width, -// padHeight, -// padWidth, -// strideHeight, -// strideWidth, -// mode) { -// } -// -// NodeOps forwardOps() { -// return {NodeOp(pooling_.forward(child(0)->val(), val_))}; -// } -// -// NodeOps backwardOps() { -// return {NodeOp(pooling_.backward( -// child(0)->val(), -// child(0)->grad(), -// val_, -// adj_))}; -// } -// -// const std::string type() { return "layer_pooling"; } -// -// -//protected: -// PoolingWrapper pooling_; -//}; -// -//class PoolingWithMaskingOp : public UnaryNodeOp { -// public: -// PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false) -// : UnaryNodeOp(x), -// mask_(mask), -// width_(width), -// isEven_(isEven) -// { -// auto xShape = x->shape(); -// int dimBatch = xShape[0]; -// int dimWord = xShape[1]; -// int cols = (isEven_) ? xShape[2] - 1 : xShape[2]; -// int dimSentence = (cols / width_) + (cols % width_ != 0); -// shape_ = {dimBatch, dimWord, dimSentence}; -// } -// -// NodeOps forwardOps() { -// return {NodeOp(PoolingWithMaskingForward(val_, -// child(0)->val(), -// mask_->val(), -// width_, -// isEven_))}; -// } -// -// NodeOps backwardOps() { -// return {NodeOp(PoolingWithMaskingBackward(adj_, -// child(0)->grad(), -// child(0)->val(), -// mask_->val(), -// width_, -// isEven_))}; -// } -// -// const std::string type() {return "layer_pooling";} -// -// protected: -// Expr mask_; -// int width_; -// bool isEven_; -//}; +class PoolingOp : public UnaryNodeOp { +public: + PoolingOp(Expr x, + int height, + int width, + int padHeight, + int padWidth, + int strideHeight, + int strideWidth, + std::string mode) + : UnaryNodeOp(x), + pooling_(height, + width, + padHeight, + padWidth, + strideHeight, + strideWidth, + mode) { + } + + NodeOps forwardOps() { + return {NodeOp(pooling_.forward(child(0)->val(), val_))}; + } + + NodeOps backwardOps() { + return {NodeOp(pooling_.backward( + child(0)->val(), + child(0)->grad(), + val_, + adj_))}; + } + + const std::string type() { return "layer_pooling"; } + + +protected: + PoolingWrapper pooling_; +}; + +class PoolingWithMaskingOp : public UnaryNodeOp { + public: + PoolingWithMaskingOp( Expr x, Expr mask, int width, bool isEven=false) + : UnaryNodeOp(x), + mask_(mask), + width_(width), + isEven_(isEven) + { + auto xShape = x->shape(); + int dimBatch = xShape[0]; + int dimWord = xShape[1]; + int cols = (isEven_) ? xShape[2] - 1 : xShape[2]; + int dimSentence = (cols / width_) + (cols % width_ != 0); + shape_ = {dimBatch, dimWord, dimSentence}; + } + + NodeOps forwardOps() { + return {NodeOp(PoolingWithMaskingForward(val_, + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + NodeOps backwardOps() { + return {NodeOp(PoolingWithMaskingBackward(adj_, + child(0)->grad(), + child(0)->val(), + mask_->val(), + width_, + isEven_))}; + } + + const std::string type() {return "layer_pooling";} + + protected: + Expr mask_; + int width_; + bool isEven_; +}; } diff --git a/src/layers/convolution.cu b/src/layers/convolution.cu index 83e881bf..b0749450 100644 --- a/src/layers/convolution.cu +++ b/src/layers/convolution.cu @@ -25,12 +25,11 @@ Expr Convolution::apply(Expr x) { keywords::init=inits::zeros); std::vector<Expr> nodes = {x, kernel, bias}; - ABORT("Temporarily not implemented"); - //return Expression<ConvolutionOp>(nodes, - // paddings.first, - // paddings.second, - // strides.first, - // strides.second); + return Expression<ConvolutionOp>(nodes, + paddings.first, + paddings.second, + strides.first, + strides.second); } Expr Convolution::apply(const std::vector<Expr>&) { diff --git a/src/models/model_factory.cpp b/src/models/model_factory.cpp index 0fe58c3b..ddeb1071 100644 --- a/src/models/model_factory.cpp +++ b/src/models/model_factory.cpp @@ -8,13 +8,13 @@ #include "models/nematus.h" #include "models/encdec.h" -#ifdef USE_CUDNN +#ifdef CUDNN #include "models/char_s2s.h" #endif #ifdef COMPILE_EXAMPLES #include "examples/mnist/model.h" -#ifdef USE_CUDNN +#ifdef CUDNN #include "examples/mnist/model_lenet.h" #endif #endif @@ -26,7 +26,7 @@ Ptr<EncoderBase> EncoderFactory::construct() { if(options_->get<std::string>("type") == "s2s") return New<EncoderS2S>(options_); -#ifdef USE_CUDNN +#ifdef CUDNN if(options_->get<std::string>("type") == "char-s2s") return New<CharS2SEncoder>(options_); #endif diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index 06f76b64..30c42c4a 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -209,5 +209,8 @@ namespace marian { return cpu::L2Norm(in); } } - + + DISPATCH5(PoolingWithMaskingForward, marian::Tensor, marian::Tensor, marian::Tensor, int, bool) + DISPATCH6(PoolingWithMaskingBackward, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, int, bool) + } |