diff options
Diffstat (limited to 'tensorflow/sentencepiece_processor_ops.cc')
-rw-r--r-- | tensorflow/sentencepiece_processor_ops.cc | 652 |
1 files changed, 0 insertions, 652 deletions
diff --git a/tensorflow/sentencepiece_processor_ops.cc b/tensorflow/sentencepiece_processor_ops.cc deleted file mode 100644 index 7cf915f..0000000 --- a/tensorflow/sentencepiece_processor_ops.cc +++ /dev/null @@ -1,652 +0,0 @@ -// Copyright 2018 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License.! - -#include <mutex> -#include <string> -#include <unordered_map> -#include <vector> - -#include "sentencepiece_processor.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/hash/hash.h" - -typedef int int32; -typedef long long int int64; -typedef unsigned long long int uint64; - -namespace sentencepiece { -using ::tensorflow::DEVICE_CPU; -using ::tensorflow::Hash64; -using ::tensorflow::OpKernel; -using ::tensorflow::OpKernelConstruction; -using ::tensorflow::OpKernelContext; -using ::tensorflow::Tensor; -using ::tensorflow::TensorShapeUtils; -using ::tensorflow::tstring; -using ::tensorflow::shape_inference::DimensionHandle; -using ::tensorflow::shape_inference::InferenceContext; -using ::tensorflow::shape_inference::ShapeHandle; - -namespace { - -// A utility function to convert sentencepiece::util::Status to -// ::tensorflow::Status -::tensorflow::Status ToTFStatus(const sentencepiece::util::Status& s) { - if (s.ok()) return ::tensorflow::Status(); - return ::tensorflow::Status(static_cast<::tensorflow::error::Code>(s.code()), - ::tensorflow::string(s.error_message())); -} - -// Global cache to reuse SentencePieceProcessor with the same -// model file or model proto. The instance is managed with shared_ptr so -// the instance is deleted when no client is using it (refcount is zero). -class SentencePieceProcessorCache { - public: - std::shared_ptr<SentencePieceProcessor> get( - const std::string key, bool is_proto, - sentencepiece::util::Status* status) { - std::lock_guard<std::mutex> l(mutex_); - - const uint64 fp = Hash64(key.data(), key.size()); - auto sp = data_[fp].lock(); - - if (sp) { - *status = sp->status(); - return sp; - } - - sp = std::make_shared<SentencePieceProcessor>(); - *status = is_proto ? sp->LoadFromSerializedProto(key) : sp->Load(key); - if (!status->ok()) return nullptr; - - data_[fp] = sp; - return sp; - } - - private: - std::mutex mutex_; - std::unordered_map<uint64, std::weak_ptr<SentencePieceProcessor>> data_; -}; - -class SentencePieceBaseOp : public OpKernel { - public: - explicit SentencePieceBaseOp(OpKernelConstruction* context) - : OpKernel(context) { - std::string model_file_attr, model_proto_attr; - OP_REQUIRES_OK(context, context->GetAttr("model_file", &model_file_attr)); - OP_REQUIRES_OK(context, context->GetAttr("model_proto", &model_proto_attr)); - - // Initializes global cache. - static SentencePieceProcessorCache* cache = new SentencePieceProcessorCache; - sentencepiece::util::Status status; - - OP_REQUIRES(context, - ((model_proto_attr.empty() && !model_file_attr.empty()) || - (!model_proto_attr.empty() && model_file_attr.empty())), - ::tensorflow::errors::InvalidArgument( - "Either `model_proto` or `model_file` must be set.")); - - if (!model_file_attr.empty()) { - sentencepiece_processor_ = cache->get(model_file_attr, false, &status); - } else { - // Loads serialized sentencepiece model proto to enable embedding the - // relatively small sentencepiece model proto into the tensorflow graph - // such that the tensorflow graph is self-contained. - sentencepiece_processor_ = cache->get(model_proto_attr, true, &status); - } - - OP_REQUIRES_OK(context, ToTFStatus(status)); - OP_REQUIRES(context, sentencepiece_processor_, - ::tensorflow::errors::InvalidArgument( - "Failed to initialize SentencePieceProcessor")); - - // Sets extra options to add <s>, </s>. - auto has_attribute = [&context](const std::string& name) { - bool flag = false; - context->GetAttr(name, &flag).IgnoreError(); - return flag; - }; - - if (has_attribute("add_bos")) { - bos_id_ = sentencepiece_processor_->bos_id(); - OP_REQUIRES(context, bos_id_ >= 0, - ::tensorflow::errors::InvalidArgument( - "`bos_id` is not defined in model")); - } - - if (has_attribute("add_eos")) { - eos_id_ = sentencepiece_processor_->eos_id(); - OP_REQUIRES(context, eos_id_ >= 0, - ::tensorflow::errors::InvalidArgument( - "`eos_id` is not defined in model")); - } - - reverse_ = has_attribute("reverse"); - - pad_id_ = sentencepiece_processor_->pad_id(); - if (pad_id_ == -1) pad_id_ = sentencepiece_processor_->unk_id(); - } - - protected: - void GetPad(int32* pad) const { *pad = pad_id_; } - - void GetPad(tstring* pad) const { - pad->clear(); - if (sentencepiece_processor_ && pad_id_ >= 0 && - pad_id_ != sentencepiece_processor_->unk_id()) - *pad = sentencepiece_processor_->IdToPiece(pad_id_); - } - - std::shared_ptr<SentencePieceProcessor> sentencepiece_processor_; - int bos_id_ = -1; - int eos_id_ = -1; - int pad_id_ = -1; - bool reverse_ = false; -}; -} // namespace - -class SentencePieceGetPieceSizeOp : public SentencePieceBaseOp { - public: - explicit SentencePieceGetPieceSizeOp(OpKernelConstruction* context) - : SentencePieceBaseOp(context) {} - - void Compute(OpKernelContext* context) override { - Tensor* vocab_size_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, {}, &vocab_size_tensor)); - vocab_size_tensor->scalar<int32>()() = - sentencepiece_processor_->GetPieceSize(); - } -}; - -template <typename S, typename T> -class SentencePieceConvertPieceOp : public SentencePieceBaseOp { - public: - explicit SentencePieceConvertPieceOp(OpKernelConstruction* context) - : SentencePieceBaseOp(context) {} - - void Compute(OpKernelContext* context) override { - const Tensor* input_tensor = nullptr; - OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - - Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(), - &output_tensor)); - for (int i = 0; i < input_tensor->NumElements(); ++i) - output_tensor->flat<T>()(i) = Convert(input_tensor->flat<S>()(i)); - } - - int32 Convert(const std::string& piece) const { - return sentencepiece_processor_->PieceToId(piece); - } - - std::string Convert(int32 id) const { - if (id >= 0 && id < sentencepiece_processor_->GetPieceSize()) { - return sentencepiece_processor_->IdToPiece(id); - } - return ""; - } -}; - -class SentencePieceGetPieceTypeOp : public SentencePieceBaseOp { - public: - explicit SentencePieceGetPieceTypeOp(OpKernelConstruction* context) - : SentencePieceBaseOp(context) { - OP_REQUIRES_OK(context, context->GetAttr("piece_type", &piece_type_)); - } - - void Compute(OpKernelContext* context) override { - const Tensor* input_tensor = nullptr; - OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - - Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor->shape(), - &output_tensor)); - - for (int i = 0; i < input_tensor->NumElements(); ++i) { - const int id = input_tensor->flat<int32>()(i); - switch (piece_type_) { - case 0: - output_tensor->flat<bool>()(i) = - sentencepiece_processor_->IsUnknown(id); - break; - case 1: - output_tensor->flat<bool>()(i) = - sentencepiece_processor_->IsControl(id); - break; - case 2: - output_tensor->flat<bool>()(i) = - sentencepiece_processor_->IsUnused(id); - break; - default: - break; - } - } - } - - private: - int piece_type_; -}; - -template <typename T, typename U = T> -class SentencePieceEncodeOpBase : public SentencePieceBaseOp { - public: - explicit SentencePieceEncodeOpBase(OpKernelConstruction* context) - : SentencePieceBaseOp(context) {} - - void Compute(OpKernelContext* context) override { - const Tensor* input_tensor = nullptr; - - OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor->shape()), - ::tensorflow::errors::InvalidArgument( - "`input` must be a vector, got shape: ", - input_tensor->shape().DebugString())); - const auto& input_sentences = input_tensor->vec<tstring>(); - const int64 batch_size = input_sentences.size(); - - const Tensor* nbest_size_tensor = nullptr; - OP_REQUIRES_OK(context, context->input("nbest_size", &nbest_size_tensor)); - OP_REQUIRES(context, nbest_size_tensor->dims() <= 1, - ::tensorflow::errors::InvalidArgument( - "`nbest_size` must be a scalar or vector. got shape: ", - nbest_size_tensor->shape().DebugString())); - if (nbest_size_tensor->dims() == 1) { - OP_REQUIRES( - context, batch_size == nbest_size_tensor->dim_size(0), - ::tensorflow::errors::InvalidArgument( - "`nbest_size` must have the same batch size as `input`.")); - } - - const Tensor* alpha_tensor = nullptr; - OP_REQUIRES_OK(context, context->input("alpha", &alpha_tensor)); - OP_REQUIRES(context, alpha_tensor->dims() <= 1, - ::tensorflow::errors::InvalidArgument( - "`alpha` must be a scalar or vector, got shape: ", - alpha_tensor->shape().DebugString())); - if (alpha_tensor->dims() == 1) { - OP_REQUIRES(context, batch_size == alpha_tensor->dim_size(0), - ::tensorflow::errors::InvalidArgument( - "`alpha` must have the same batch size as `input`.")); - } - - std::vector<std::vector<U>> pieces(batch_size); - - for (int64 i = 0; i < batch_size; ++i) { - const int32 nbest_size = nbest_size_tensor->dims() == 1 - ? nbest_size_tensor->vec<int32>()(i) - : nbest_size_tensor->scalar<int32>()(); - if (nbest_size == 0 || nbest_size == 1) { - OP_REQUIRES_OK(context, - ToTFStatus(sentencepiece_processor_->Encode( - absl::string_view(input_sentences(i)), &pieces[i]))); - } else { - const float alpha = alpha_tensor->dims() == 1 - ? alpha_tensor->vec<float>()(i) - : alpha_tensor->scalar<float>()(); - OP_REQUIRES_OK(context, - ToTFStatus(sentencepiece_processor_->SampleEncode( - absl::string_view(input_sentences(i)), nbest_size, - alpha, &pieces[i]))); - } - RewritePieces(&pieces[i]); - } - - MakeOutputTensor(context, pieces); - } - - protected: - void RewritePieces(std::vector<std::string>* pieces) const { - if (reverse_) std::reverse(pieces->begin(), pieces->end()); - if (bos_id_ > 0) - pieces->insert(pieces->begin(), - sentencepiece_processor_->IdToPiece(bos_id_)); - if (eos_id_ > 0) - pieces->push_back(sentencepiece_processor_->IdToPiece(eos_id_)); - } - - void RewritePieces(std::vector<int32>* pieces) const { - if (reverse_) std::reverse(pieces->begin(), pieces->end()); - if (bos_id_ > 0) pieces->insert(pieces->begin(), bos_id_); - if (eos_id_ > 0) pieces->push_back(eos_id_); - } - - virtual void MakeOutputTensor(OpKernelContext* context, - const std::vector<std::vector<U>>& pieces) = 0; -}; - -template <typename T, typename U = T> -class SentencePieceEncodeSparseOp : public SentencePieceEncodeOpBase<T, U> { - public: - explicit SentencePieceEncodeSparseOp(OpKernelConstruction* context) - : SentencePieceEncodeOpBase<T, U>(context) {} - - protected: - void MakeOutputTensor(OpKernelContext* context, - const std::vector<std::vector<U>>& pieces) override { - const int64 batch_size = pieces.size(); - - int64 max_sequence_length = 0; - int64 indices_size = 0; - for (int row = 0; row < batch_size; ++row) { - const int col_size = pieces[row].size(); - max_sequence_length = std::max<int64>(col_size, max_sequence_length); - indices_size += col_size; - } - - // Creates the indices output tensor. - Tensor* indices_tensor = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(0, {indices_size, 2}, - &indices_tensor)); - - auto indices_tensor_output = indices_tensor->matrix<int64>(); - int item_idx = 0; - for (int row = 0; row < batch_size; ++row) { - for (int col = 0; col < pieces[row].size(); ++col) { - indices_tensor_output(item_idx, 0) = row; - indices_tensor_output(item_idx, 1) = col; - ++item_idx; - } - } - - // Creates the values output tensor. - Tensor* values_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(1, {indices_size}, &values_tensor)); - - auto values_tensor_output = values_tensor->flat<T>(); - item_idx = 0; - for (int row = 0; row < batch_size; ++row) { - std::copy(pieces[row].begin(), pieces[row].end(), - &values_tensor_output(item_idx)); - item_idx += pieces[row].size(); - } - - // Creates the shape output tensor. - Tensor* shape_tensor = nullptr; - OP_REQUIRES_OK(context, context->allocate_output(2, {2}, &shape_tensor)); - - auto shape_tensor_output = shape_tensor->flat<int64>(); - shape_tensor_output(0) = batch_size; - shape_tensor_output(1) = max_sequence_length; - } -}; - -template <typename T, typename U = T> -class SentencePieceEncodeDenseOp : public SentencePieceEncodeOpBase<T, U> { - public: - explicit SentencePieceEncodeDenseOp(OpKernelConstruction* context) - : SentencePieceEncodeOpBase<T, U>(context) { - this->GetPad(&pad_); - } - - // protected: - void MakeOutputTensor(OpKernelContext* context, - const std::vector<std::vector<U>>& pieces) override { - const int64 batch_size = pieces.size(); - - int64 max_sequence_length = 0; - for (int row = 0; row < batch_size; ++row) { - max_sequence_length = - std::max<int64>(pieces[row].size(), max_sequence_length); - } - - Tensor* values_tensor = nullptr; - Tensor* length_tensor = nullptr; - - OP_REQUIRES_OK( - context, context->allocate_output(0, {batch_size, max_sequence_length}, - &values_tensor)); - OP_REQUIRES_OK(context, - context->allocate_output(1, {batch_size}, &length_tensor)); - - auto values_tensor_output = values_tensor->matrix<T>(); - auto length_tensor_output = length_tensor->vec<int32>(); - - U pad = pad_; - - for (int row = 0; row < batch_size; ++row) { - for (int col = 0; col < max_sequence_length; ++col) { - values_tensor_output(row, col) = - col < pieces[row].size() ? pieces[row][col] : pad; - } - length_tensor_output(row) = pieces[row].size(); - } - } - - private: - T pad_; -}; - -template <typename T, typename U = T> -class SentencePieceDecodeOp : public SentencePieceBaseOp { - public: - explicit SentencePieceDecodeOp(OpKernelConstruction* context) - : SentencePieceBaseOp(context) {} - - void Compute(OpKernelContext* context) override { - const Tensor* input_tensor = nullptr; - const Tensor* length_tensor = nullptr; - - OP_REQUIRES_OK(context, context->input("input", &input_tensor)); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor->shape()), - ::tensorflow::errors::InvalidArgument( - "`input` must be a 2-D matrix. got shape: ", - input_tensor->shape().DebugString())); - OP_REQUIRES_OK(context, context->input("sequence_length", &length_tensor)); - OP_REQUIRES(context, TensorShapeUtils::IsVector(length_tensor->shape()), - ::tensorflow::errors::InvalidArgument( - "`sequence_length` must be a vector. got shape: ", - length_tensor->shape().DebugString())); - OP_REQUIRES( - context, input_tensor->dim_size(0) == length_tensor->dim_size(0), - ::tensorflow::errors::InvalidArgument( - "`sequence_length` must have the same batch size as `input`.")); - - const auto& input_sentences = input_tensor->matrix<T>(); - const auto& sequence_length = length_tensor->vec<int32>(); - const int64 batch_size = input_tensor->dim_size(0); - const int max_sequence_length = input_tensor->dim_size(1); - - Tensor* values_tensor = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, {batch_size}, &values_tensor)); - auto values_tensor_output = values_tensor->vec<tstring>(); - - for (int64 i = 0; i < batch_size; ++i) { - OP_REQUIRES(context, - (sequence_length(i) >= 0 && - sequence_length(i) <= max_sequence_length), - ::tensorflow::errors::InvalidArgument( - "`sequence_length` is out-of-range.")); - std::vector<U> pieces(&input_sentences(i, 0), - &input_sentences(i, 0) + sequence_length(i)); - if (reverse_) std::reverse(pieces.begin(), pieces.end()); - std::string detokenized_str; - OP_REQUIRES_OK(context, ToTFStatus(sentencepiece_processor_->Decode( - pieces, &detokenized_str))); - values_tensor_output(i) = detokenized_str; - } - } -}; - -namespace { -// The snake case of this variables are used as the function names. -constexpr char kGetPieceSizeOpName[] = "SentencepieceGetPieceSize"; -constexpr char kPieceToIdOpName[] = "SentencepiecePieceToId"; -constexpr char kIdToPieceOpName[] = "SentencepieceIdToPiece"; -constexpr char kGetPieceTypeOpName[] = "SentencepieceGetPieceType"; -constexpr char kEncodeDenseOpName[] = "SentencepieceEncodeDense"; -constexpr char kEncodeSparseOpName[] = "SentencepieceEncodeSparse"; -constexpr char kDecodeOpName[] = "SentencepieceDecode"; -} // namespace - -REGISTER_OP(kGetPieceSizeOpName) - .Output("vocab_size: int32") - .Attr("model_file: string = ''") - .Attr("model_proto: string = ''") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->MakeShape({})); - return ::tensorflow::Status::OK(); - }); - -REGISTER_KERNEL_BUILDER(Name(kGetPieceSizeOpName).Device(DEVICE_CPU), - SentencePieceGetPieceSizeOp); - -REGISTER_OP(kPieceToIdOpName) - .Input("input: string") - .Output("values: int32") - .Attr("model_file: string = ''") - .Attr("model_proto: string = ''") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); - return ::tensorflow::Status::OK(); - }); - -REGISTER_KERNEL_BUILDER(Name(kPieceToIdOpName).Device(DEVICE_CPU), - SentencePieceConvertPieceOp<tstring, int32>); - -REGISTER_OP(kIdToPieceOpName) - .Input("input: int32") - .Output("values: string") - .Attr("model_file: string = ''") - .Attr("model_proto: string = ''") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); - return ::tensorflow::Status::OK(); - }); - -REGISTER_KERNEL_BUILDER(Name(kIdToPieceOpName).Device(DEVICE_CPU), - SentencePieceConvertPieceOp<int32, tstring>); - -REGISTER_OP(kGetPieceTypeOpName) - .Input("input: int32") - .Output("values: bool") - .Attr("model_file: string = ''") - .Attr("model_proto: string = ''") - .Attr("piece_type: int = 0") - .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); - return ::tensorflow::Status::OK(); - }); - -REGISTER_KERNEL_BUILDER(Name(kGetPieceTypeOpName).Device(DEVICE_CPU), - SentencePieceGetPieceTypeOp); - -REGISTER_OP(kEncodeDenseOpName) - .Attr("out_type: {int32, string} = DT_INT32") - .Input("input: string") - .Input("nbest_size: int32") - .Input("alpha: float") - .Output("values: out_type") - .Output("sequence_length: int32") - .Attr("model_file: string = ''") - .Attr("model_proto: string = ''") - .Attr("reverse: bool = false") - .Attr("add_bos: bool = false") - .Attr("add_eos: bool = false") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle input, nbest, alpha; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input)); - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &nbest)); - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &alpha)); - DimensionHandle batch_size = c->Dim(input, 0); - if (c->Rank(nbest) == 1) - TF_RETURN_IF_ERROR(c->Merge(batch_size, c->Dim(nbest, 0), &batch_size)); - if (c->Rank(alpha) == 1) - TF_RETURN_IF_ERROR(c->Merge(batch_size, c->Dim(alpha, 0), &batch_size)); - c->set_output(0, c->MakeShape({batch_size, c->UnknownDim()})); - c->set_output(1, c->MakeShape({batch_size})); - return ::tensorflow::Status::OK(); - }); - -REGISTER_KERNEL_BUILDER(Name(kEncodeDenseOpName) - .Device(DEVICE_CPU) - .TypeConstraint<int32>("out_type"), - SentencePieceEncodeDenseOp<int32>); - -REGISTER_KERNEL_BUILDER(Name(kEncodeDenseOpName) - .Device(DEVICE_CPU) - .TypeConstraint<tstring>("out_type"), - SentencePieceEncodeDenseOp<tstring, std::string>); - -REGISTER_OP(kEncodeSparseOpName) - .Attr("out_type: {int32, string} = DT_INT32") - .Input("input: string") - .Input("nbest_size: int32") - .Input("alpha: float") - .Output("indices: int64") - .Output("values: out_type") - .Output("dense_shape: int64") - .Attr("model_file: string = ''") - .Attr("model_proto: string = ''") - .Attr("reverse: bool = false") - .Attr("add_bos: bool = false") - .Attr("add_eos: bool = false") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle input, nbest, alpha; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input)); - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &nbest)); - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &alpha)); - DimensionHandle batch_size = c->Dim(input, 0); - if (c->Rank(nbest) == 1) - TF_RETURN_IF_ERROR(c->Merge(batch_size, c->Dim(nbest, 0), &batch_size)); - if (c->Rank(alpha) == 1) - TF_RETURN_IF_ERROR(c->Merge(batch_size, c->Dim(alpha, 0), &batch_size)); - c->set_output(0, c->MakeShape({c->UnknownDim(), 2})); - c->set_output(1, c->MakeShape({c->UnknownDim()})); - c->set_output(2, c->MakeShape({2})); - return ::tensorflow::Status::OK(); - }); - -REGISTER_KERNEL_BUILDER(Name(kEncodeSparseOpName) - .Device(DEVICE_CPU) - .TypeConstraint<int32>("out_type"), - SentencePieceEncodeSparseOp<int32>); - -REGISTER_KERNEL_BUILDER(Name(kEncodeSparseOpName) - .Device(DEVICE_CPU) - .TypeConstraint<tstring>("out_type"), - SentencePieceEncodeSparseOp<tstring, std::string>); - -REGISTER_OP(kDecodeOpName) - .Attr("T: {int32, string}") - .Input("input: T") - .Input("sequence_length: int32") - .Output("values: string") - .Attr("model_file: string = ''") - .Attr("model_proto: string = ''") - .Attr("reverse: bool = false") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle input, sequence_length; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sequence_length)); - DimensionHandle batch_size = c->Dim(input, 0); - TF_RETURN_IF_ERROR( - c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size)); - c->set_output(0, c->MakeShape({batch_size})); - return ::tensorflow::Status::OK(); - }); - -REGISTER_KERNEL_BUILDER( - Name(kDecodeOpName).Device(DEVICE_CPU).TypeConstraint<int32>("T"), - SentencePieceDecodeOp<int32>); - -REGISTER_KERNEL_BUILDER( - Name(kDecodeOpName).Device(DEVICE_CPU).TypeConstraint<tstring>("T"), - SentencePieceDecodeOp<tstring, std::string>); -} // namespace sentencepiece |