Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2020-03-10 21:32:14 +0300
committerMarcin Junczys-Dowmunt <marcinjd@microsoft.com>2020-03-10 21:32:14 +0300
commit4b23fe76ff5db11f7ef60a735d06d9acd09efff9 (patch)
tree8cbc19fe038e2087c63a54337d0f53e958e0a4a2
parent9fd5ba99bf5cf081a17ad5fc8e96f327ae730e5c (diff)
parent8640031437103a8b78806a55e666440a9540958c (diff)
update to marian-dev
-rw-r--r--[-rwxr-xr-x].gitignore3
-rw-r--r--.gitmodules7
-rw-r--r--CHANGELOG.md94
-rw-r--r--CMakeLists.txt341
-rw-r--r--Doxyfile.in2
-rw-r--r--README.md59
-rw-r--r--VERSION2
-rw-r--r--cmake/FindCBLAS.cmake186
-rw-r--r--cmake/FindMKL.cmake15
-rw-r--r--cmake/FindSSE.cmake33
-rw-r--r--cmake/GetCacheVariables.cmake52
-rw-r--r--[-rwxr-xr-x]contrib/autoformat.sh0
m---------examples0
m---------regression-tests0
-rwxr-xr-xscripts/bert/bert4marian.py154
-rwxr-xr-xscripts/checkpoints/average.py55
-rwxr-xr-x[-rw-r--r--]scripts/contrib/fix_hard.py0
-rwxr-xr-x[-rw-r--r--]scripts/contrib/inject_ctt.py0
-rwxr-xr-x[-rw-r--r--]scripts/contrib/inject_model_params.py0
-rwxr-xr-x[-rw-r--r--]scripts/contrib/model_info.py0
-rwxr-xr-xscripts/embeddings/export_embeddings.py31
-rw-r--r--scripts/shortlist/.gitignore3
-rw-r--r--scripts/shortlist/README.md8
-rwxr-xr-xscripts/shortlist/generate_shortlists.pl97
-rw-r--r--scripts/shortlist/install.sh25
-rw-r--r--src/3rd_party/CLI/App.hpp21
-rw-r--r--src/3rd_party/CMakeLists.txt117
-rw-r--r--[-rwxr-xr-x]src/3rd_party/ExceptionWithCallStack.cpp0
-rw-r--r--src/3rd_party/ExceptionWithCallStack.h2
-rw-r--r--[-rwxr-xr-x]src/3rd_party/any_type.h0
-rw-r--r--src/3rd_party/avx_mathfun.h726
-rw-r--r--[-rwxr-xr-x]src/3rd_party/catch.hpp22316
-rw-r--r--[-rwxr-xr-x]src/3rd_party/cnpy/cnpy.cpp0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/cnpy/cnpy.h0
m---------src/3rd_party/fbgemm0
-rwxr-xr-xsrc/3rd_party/half_float/HalfPrecisionFloatTest.cpp115
-rw-r--r--src/3rd_party/half_float/Readme.md43
-rw-r--r--src/3rd_party/half_float/stdint.h222
-rwxr-xr-xsrc/3rd_party/half_float/umHalf.h294
-rw-r--r--src/3rd_party/half_float/umHalf.inl495
-rw-r--r--src/3rd_party/mio/LICENSE21
-rw-r--r--src/3rd_party/mio/README.md337
-rw-r--r--src/3rd_party/mio/mio.hpp1748
m---------src/3rd_party/nccl0
-rw-r--r--src/3rd_party/pathie-cpp/src/entry_iterator.cpp4
-rw-r--r--src/3rd_party/pathie-cpp/src/path.cpp19
-rw-r--r--src/3rd_party/pathie-cpp/src/pathie.cpp8
-rw-r--r--src/3rd_party/phf/LICENSE19
-rw-r--r--src/3rd_party/phf/README.md182
-rw-r--r--src/3rd_party/phf/phf.cc1478
-rw-r--r--src/3rd_party/phf/phf.h299
-rw-r--r--src/3rd_party/reduce_all.h525
m---------src/3rd_party/sentencepiece0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/spdlog/astyle.sh0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/spdlog/bench/latency/compare.sh0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/spdlog/details/format.h0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/spdlog/details/logger_impl.h0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/spdlog/logger.h0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/spdlog/tests/catch.hpp0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/spdlog/tests/install_libcxx.sh0
-rw-r--r--src/3rd_party/sse_mathfun.h711
-rw-r--r--[-rwxr-xr-x]src/3rd_party/threadpool.h1
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/binary_renamed.cpp0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/collectionstack.h0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/dll.h0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/emitterstate.cpp0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/emitterstate.h0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/node/convert.h0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/node/node.h0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/node_data.cpp0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/scanner.cpp0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/scantoken.cpp0
-rw-r--r--[-rwxr-xr-x]src/3rd_party/yaml-cpp/singledocparser.cpp0
-rw-r--r--src/3rd_party/zstr/strict_fstream.hpp4
-rw-r--r--src/CMakeLists.txt56
-rw-r--r--[-rwxr-xr-x]src/command/marian_conv.cpp31
-rw-r--r--src/command/marian_decoder.cpp1
-rw-r--r--[-rwxr-xr-x]src/command/marian_main.cpp0
-rw-r--r--src/command/marian_server.cpp4
-rw-r--r--[-rwxr-xr-x]src/command/marian_train.cpp15
-rw-r--r--[-rwxr-xr-x]src/command/marian_vocab.cpp6
-rw-r--r--src/common/aliases.cpp150
-rw-r--r--src/common/authors.h65
-rw-r--r--src/common/binary.cpp47
-rw-r--r--src/common/binary.h6
-rw-r--r--src/common/build_info.cpp.in18
-rw-r--r--src/common/build_info.h13
-rw-r--r--[-rwxr-xr-x]src/common/cli_helper.h1
-rw-r--r--[-rwxr-xr-x]src/common/cli_wrapper.cpp131
-rw-r--r--[-rwxr-xr-x]src/common/cli_wrapper.h281
-rw-r--r--[-rwxr-xr-x]src/common/compile_time_crc32.h0
-rw-r--r--[-rwxr-xr-x]src/common/config.cpp94
-rw-r--r--[-rwxr-xr-x]src/common/config.h12
-rwxr-xr-xsrc/common/config_parser.cpp451
-rw-r--r--[-rwxr-xr-x]src/common/config_parser.h69
-rw-r--r--[-rwxr-xr-x]src/common/config_validator.cpp46
-rw-r--r--src/common/config_validator.h7
-rwxr-xr-xsrc/common/definitions.h67
-rw-r--r--src/common/fastopt.cpp112
-rw-r--r--src/common/fastopt.h379
-rwxr-xr-xsrc/common/file_stream.cpp172
-rw-r--r--[-rwxr-xr-x]src/common/file_stream.h337
-rw-r--r--src/common/filesystem.cpp31
-rw-r--r--[-rwxr-xr-x]src/common/filesystem.h27
-rw-r--r--src/common/intrusive_ptr.h225
-rw-r--r--[-rwxr-xr-x]src/common/io.cpp7
-rwxr-xr-x[-rw-r--r--]src/common/io_item.h67
-rw-r--r--[-rwxr-xr-x]src/common/logging.cpp84
-rw-r--r--[-rwxr-xr-x]src/common/logging.h68
-rw-r--r--src/common/options.cpp101
-rwxr-xr-xsrc/common/options.h141
-rw-r--r--[-rwxr-xr-x]src/common/project_version.h.in4
-rw-r--r--src/common/regex.h6
-rw-r--r--[-rwxr-xr-x]src/common/shape.h36
-rw-r--r--[-rwxr-xr-x]src/common/timer.h27
-rw-r--r--src/common/types.cpp38
-rw-r--r--[-rwxr-xr-x]src/common/types.h473
-rwxr-xr-xsrc/common/utils.cpp331
-rw-r--r--[-rwxr-xr-x]src/common/utils.h40
-rw-r--r--[-rwxr-xr-x]src/common/version.cpp0
-rw-r--r--[-rwxr-xr-x]src/common/version.h0
-rw-r--r--[-rwxr-xr-x]src/data/alignment.cpp4
-rw-r--r--[-rwxr-xr-x]src/data/alignment.h4
-rwxr-xr-xsrc/data/batch.h6
-rw-r--r--[-rwxr-xr-x]src/data/batch_generator.h75
-rw-r--r--[-rwxr-xr-x]src/data/batch_stats.h16
-rwxr-xr-xsrc/data/corpus.cpp125
-rw-r--r--[-rwxr-xr-x]src/data/corpus.h51
-rwxr-xr-xsrc/data/corpus_base.cpp183
-rwxr-xr-xsrc/data/corpus_base.h193
-rw-r--r--src/data/corpus_nbest.cpp10
-rw-r--r--[-rwxr-xr-x]src/data/corpus_nbest.h1
-rw-r--r--src/data/corpus_sqlite.cpp3
-rw-r--r--[-rwxr-xr-x]src/data/corpus_sqlite.h2
-rw-r--r--[-rwxr-xr-x]src/data/dataset.h3
-rw-r--r--[-rwxr-xr-x]src/data/default_vocab.cpp200
-rwxr-xr-xsrc/data/factored_vocab.cpp778
-rwxr-xr-xsrc/data/factored_vocab.h137
-rw-r--r--[-rwxr-xr-x]src/data/rng_engine.h0
-rw-r--r--[-rwxr-xr-x]src/data/sentencepiece_vocab.cpp66
-rw-r--r--[-rwxr-xr-x]src/data/shortlist.h177
-rw-r--r--[-rwxr-xr-x]src/data/text_input.cpp52
-rw-r--r--[-rwxr-xr-x]src/data/text_input.h5
-rw-r--r--src/data/types.h41
-rwxr-xr-xsrc/data/vocab.cpp47
-rwxr-xr-xsrc/data/vocab.h32
-rwxr-xr-x[-rw-r--r--]src/data/vocab_base.h26
-rw-r--r--src/examples/CMakeLists.txt2
-rw-r--r--src/examples/iris/iris.cpp6
-rw-r--r--[-rwxr-xr-x]src/examples/mnist/dataset.h17
-rw-r--r--[-rwxr-xr-x]src/examples/mnist/download.sh0
-rwxr-xr-xsrc/examples/mnist/model.h58
-rw-r--r--src/examples/mnist/model_lenet.h12
-rw-r--r--src/examples/mnist/training.h11
-rw-r--r--src/examples/mnist/validator.h12
-rw-r--r--[-rwxr-xr-x]src/functional/approx.h20
-rw-r--r--src/functional/array.h31
-rw-r--r--[-rwxr-xr-x]src/functional/defs.h26
-rw-r--r--src/functional/floats.h2
-rw-r--r--[-rwxr-xr-x]src/functional/functional.h3
-rw-r--r--[-rwxr-xr-x]src/functional/operands.h28
-rwxr-xr-xsrc/functional/operators.h606
-rw-r--r--[-rwxr-xr-x]src/functional/predicates.h143
-rwxr-xr-xsrc/functional/shape.h235
-rwxr-xr-x[-rw-r--r--]src/functional/tensor.h206
-rwxr-xr-x[-rw-r--r--]src/functional/tmp.h191
-rwxr-xr-x[-rw-r--r--]src/graph/auto_tuner.h25
-rw-r--r--[-rwxr-xr-x]src/graph/chainable.h13
-rwxr-xr-xsrc/graph/expression_graph.cpp270
-rwxr-xr-xsrc/graph/expression_graph.h392
-rwxr-xr-xsrc/graph/expression_operators.cpp519
-rwxr-xr-x[-rw-r--r--]src/graph/expression_operators.h109
-rwxr-xr-xsrc/graph/node.cpp23
-rw-r--r--[-rwxr-xr-x]src/graph/node.h101
-rwxr-xr-xsrc/graph/node_initializers.cpp252
-rwxr-xr-xsrc/graph/node_initializers.h201
-rw-r--r--src/graph/node_operators.cpp31
-rw-r--r--src/graph/node_operators.h23
-rwxr-xr-xsrc/graph/node_operators_binary.h324
-rwxr-xr-x[-rw-r--r--]src/graph/node_operators_unary.h358
-rw-r--r--[-rwxr-xr-x]src/graph/parameters.h54
-rwxr-xr-x[-rw-r--r--]src/layers/constructors.h155
-rw-r--r--src/layers/convolution.cpp12
-rwxr-xr-x[-rw-r--r--]src/layers/factory.h94
-rwxr-xr-xsrc/layers/generic.cpp566
-rwxr-xr-xsrc/layers/generic.h396
-rwxr-xr-xsrc/layers/guided_alignment.h102
-rwxr-xr-xsrc/layers/loss.cpp136
-rwxr-xr-x[-rw-r--r--]src/layers/loss.h455
-rw-r--r--[-rwxr-xr-x]src/layers/weight.cpp14
-rw-r--r--[-rwxr-xr-x]src/layers/weight.h1
-rw-r--r--[-rwxr-xr-x]src/layers/word2vec_reader.h9
-rw-r--r--[-rwxr-xr-x]src/marian.h0
-rwxr-xr-xsrc/microsoft/quicksand.cpp148
-rwxr-xr-xsrc/microsoft/quicksand.h46
-rwxr-xr-xsrc/models/amun.h55
-rwxr-xr-xsrc/models/bert.h375
-rwxr-xr-x[-rw-r--r--]src/models/char_s2s.h21
-rwxr-xr-xsrc/models/classifier.h42
-rwxr-xr-xsrc/models/costs.h255
-rwxr-xr-xsrc/models/decoder.h116
-rwxr-xr-xsrc/models/encoder.h91
-rw-r--r--src/models/encoder_classifier.h227
-rwxr-xr-xsrc/models/encoder_decoder.cpp120
-rwxr-xr-x[-rw-r--r--]src/models/encoder_decoder.h80
-rw-r--r--src/models/model_base.h34
-rwxr-xr-x[-rw-r--r--]src/models/model_factory.cpp273
-rwxr-xr-x[-rw-r--r--]src/models/model_factory.h53
-rw-r--r--src/models/model_task.h3
-rwxr-xr-xsrc/models/nematus.h22
-rwxr-xr-xsrc/models/s2s.h164
-rwxr-xr-xsrc/models/states.h86
-rwxr-xr-xsrc/models/transformer.h443
-rwxr-xr-xsrc/models/transformer_factory.h8
-rwxr-xr-x[-rw-r--r--]src/models/transformer_stub.cpp18
-rw-r--r--src/optimizers/clippers.cpp2
-rw-r--r--src/optimizers/clippers.h1
-rwxr-xr-xsrc/optimizers/optimizers.cpp119
-rw-r--r--[-rwxr-xr-x]src/optimizers/optimizers.h94
-rwxr-xr-x[-rw-r--r--]src/rescorer/rescorer.h84
-rw-r--r--src/rescorer/score_collector.cpp5
-rw-r--r--src/rescorer/score_collector.h3
-rwxr-xr-x[-rw-r--r--]src/rnn/attention.h32
-rw-r--r--src/rnn/attention_constructors.h6
-rw-r--r--src/rnn/cells.cpp17
-rwxr-xr-x[-rw-r--r--]src/rnn/cells.h242
-rwxr-xr-xsrc/rnn/constructors.h105
-rw-r--r--[-rwxr-xr-x]src/rnn/rnn.h7
-rw-r--r--[-rwxr-xr-x]src/rnn/types.h3
-rw-r--r--[-rwxr-xr-x]src/tensors/allocator.h44
-rw-r--r--src/tensors/backend.h11
-rw-r--r--[-rwxr-xr-x]src/tensors/cpu/add.h37
-rw-r--r--src/tensors/cpu/backend.h7
-rw-r--r--src/tensors/cpu/device.cpp57
-rw-r--r--src/tensors/cpu/element.h88
-rw-r--r--src/tensors/cpu/fbgemm/expanded_gemm.h407
-rw-r--r--src/tensors/cpu/fbgemm/expression_graph_packable.h153
-rw-r--r--src/tensors/cpu/fbgemm/packed_gemm.cpp550
-rw-r--r--src/tensors/cpu/fbgemm/packed_gemm.h141
-rw-r--r--src/tensors/cpu/int16.h4
-rwxr-xr-xsrc/tensors/cpu/prod.cpp71
-rw-r--r--src/tensors/cpu/sharp/avx_gemm.cpp6
-rw-r--r--src/tensors/cpu/sharp/int_gemm.cpp11
-rw-r--r--[-rwxr-xr-x]src/tensors/cpu/sharp/int_gemm.h0
-rwxr-xr-xsrc/tensors/cpu/tensor_operators.cpp410
-rw-r--r--[-rwxr-xr-x]src/tensors/device.h2
-rwxr-xr-x[-rw-r--r--]src/tensors/gpu/add.cu127
-rwxr-xr-x[-rw-r--r--]src/tensors/gpu/add.h3
-rwxr-xr-x[-rw-r--r--]src/tensors/gpu/add.inc15
-rw-r--r--src/tensors/gpu/add_all.cu116
-rw-r--r--src/tensors/gpu/add_all.h87
-rw-r--r--src/tensors/gpu/add_all.inc71
-rw-r--r--[-rwxr-xr-x]src/tensors/gpu/algorithm.cu43
-rw-r--r--[-rwxr-xr-x]src/tensors/gpu/backend.h50
-rwxr-xr-xsrc/tensors/gpu/cuda_helpers.h24
-rw-r--r--[-rwxr-xr-x]src/tensors/gpu/device.cu8
-rwxr-xr-x[-rw-r--r--]src/tensors/gpu/element.cu42
-rwxr-xr-x[-rw-r--r--]src/tensors/gpu/element.inc26
-rwxr-xr-xsrc/tensors/gpu/prod.cpp495
-rw-r--r--src/tensors/gpu/prod.cu200
-rw-r--r--src/tensors/gpu/prod.h19
-rwxr-xr-x[-rw-r--r--]src/tensors/gpu/tensor_operators.cu1910
-rw-r--r--src/tensors/memory_piece.h17
-rwxr-xr-xsrc/tensors/rand.cpp8
-rw-r--r--src/tensors/rand.h6
-rwxr-xr-xsrc/tensors/tensor.cpp142
-rwxr-xr-xsrc/tensors/tensor.h313
-rw-r--r--[-rwxr-xr-x]src/tensors/tensor_allocator.h26
-rw-r--r--[-rwxr-xr-x]src/tensors/tensor_operators.h93
-rw-r--r--src/tests/CMakeLists.txt60
-rw-r--r--src/tests/cli.cpp (renamed from src/tests/cli_test.cpp)15
-rw-r--r--src/tests/conv.cu (renamed from src/tests/conv_test.cu)0
-rw-r--r--src/tests/conv_char.cu (renamed from src/tests/conv_char_test.cu)0
-rw-r--r--src/tests/dropout.cpp (renamed from src/tests/dropout_test.cpp)4
-rw-r--r--src/tests/graph_tests.cpp107
-rw-r--r--src/tests/logger.cpp (renamed from src/tests/logger_test.cpp)0
-rw-r--r--src/tests/operator_tests.cpp619
-rw-r--r--src/tests/pooling.cpp (renamed from src/tests/pooling_test.cpp)4
-rw-r--r--src/tests/prod.cpp37
-rw-r--r--src/tests/sqlite.cpp (renamed from src/tests/sqlite_test.cpp)2
-rw-r--r--src/tests/tensor.cu (renamed from src/tests/tensor_test.cu)0
-rw-r--r--src/tests/units/CMakeLists.txt20
-rw-r--r--src/tests/units/attention_tests.cpp (renamed from src/tests/attention_tests.cpp)54
-rw-r--r--src/tests/units/fastopt_tests.cpp82
-rw-r--r--src/tests/units/graph_tests.cpp123
-rw-r--r--src/tests/units/operator_tests.cpp847
-rw-r--r--src/tests/units/rnn_tests.cpp (renamed from src/tests/rnn_tests.cpp)116
-rw-r--r--src/tests/units/run_tests.cpp (renamed from src/tests/run_tests.cpp)0
-rw-r--r--[-rwxr-xr-x]src/training/communicator.cpp17
-rw-r--r--[-rwxr-xr-x]src/training/communicator.h28
-rw-r--r--[-rwxr-xr-x]src/training/communicator_nccl.h86
-rw-r--r--[-rwxr-xr-x]src/training/exponential_smoothing.h32
-rw-r--r--[-rwxr-xr-x]src/training/gradient_dropping/sparse_tensor.h2
-rw-r--r--[-rwxr-xr-x]src/training/graph_group.h64
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_async.cpp35
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_async.h11
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_async_drop.cpp0
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_async_drop.h4
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_multinode.cpp21
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_multinode.h14
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_multinode_sync.cpp18
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_multinode_sync.h15
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_singleton.cpp6
-rw-r--r--[-rwxr-xr-x]src/training/graph_group_singleton.h18
-rwxr-xr-xsrc/training/graph_group_sync.cpp414
-rwxr-xr-xsrc/training/graph_group_sync.h24
-rw-r--r--src/training/scheduler.cpp43
-rwxr-xr-xsrc/training/scheduler.h397
-rw-r--r--[-rwxr-xr-x]src/training/training.h53
-rw-r--r--[-rwxr-xr-x]src/training/training_state.h144
-rw-r--r--src/training/validator.cpp604
-rwxr-xr-xsrc/training/validator.h510
-rwxr-xr-xsrc/translator/beam_search.h659
-rw-r--r--[-rwxr-xr-x]src/translator/helpers.cpp10
-rw-r--r--src/translator/helpers.cu27
-rw-r--r--[-rwxr-xr-x]src/translator/helpers.h6
-rw-r--r--[-rwxr-xr-x]src/translator/history.h50
-rw-r--r--[-rwxr-xr-x]src/translator/hypothesis.h103
-rw-r--r--[-rwxr-xr-x]src/translator/nth_element.cpp120
-rw-r--r--[-rwxr-xr-x]src/translator/nth_element.cu187
-rw-r--r--[-rwxr-xr-x]src/translator/nth_element.h4
-rw-r--r--[-rwxr-xr-x]src/translator/output_collector.cpp2
-rw-r--r--[-rwxr-xr-x]src/translator/output_collector.h7
-rw-r--r--[-rwxr-xr-x]src/translator/output_printer.cpp21
-rwxr-xr-xsrc/translator/output_printer.h49
-rw-r--r--[-rwxr-xr-x]src/translator/scorers.cpp30
-rwxr-xr-xsrc/translator/scorers.h43
-rwxr-xr-xsrc/translator/translator.h100
-rwxr-xr-xvs/Marian.sln376
-rwxr-xr-xvs/Marian.vcxproj1189
-rwxr-xr-xvs/Marian.vcxproj.filters1286
331 files changed, 44873 insertions, 16918 deletions
diff --git a/.gitignore b/.gitignore
index 80080441..956ce684 100755..100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
# Config files from CMake
src/common/project_version.h
src/common/git_revision.h
+src/common/build_info.cpp
*.vcxproj.user
/vs/x64
@@ -61,4 +62,4 @@ examples/mnist/*ubyte
.vs
.vscode
-
+
diff --git a/.gitmodules b/.gitmodules
index 5c3c00f1..b7c67bef 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,9 +1,16 @@
[submodule "examples"]
path = examples
url = https://github.com/marian-nmt/marian-examples
+[submodule "regression-tests"]
+ path = regression-tests
+ url = https://github.com/marian-nmt/marian-regression-tests
[submodule "src/3rd_party/sentencepiece"]
path = src/3rd_party/sentencepiece
url = https://github.com/marian-nmt/sentencepiece
[submodule "src/3rd_party/nccl"]
path = src/3rd_party/nccl
url = https://github.com/marian-nmt/nccl
+[submodule "src/3rd_party/fbgemm"]
+ path = src/3rd_party/fbgemm
+ url = https://github.com/marian-nmt/FBGEMM
+ branch = master
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 02df1217..487c07a1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,23 +5,109 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
+
## [Unreleased]
### Added
-- Automatic detection of CPU intrisics when building with -arch=native
+- An option to print cached variables from CMake
+- Add support for compiling on Mac (and clang)
+- An option for resetting stalled validation metrics
+- Add CMAKE options to disable compilation for specific GPU SM types
+- An option to print word-level translation scores
+- An option to turn off automatic detokenization from SentencePiece
+- Separate quantization types for 8-bit FBGEMM for AVX2 and AVX512
+- Sequence-level unliklihood training
+- Allow file name templated valid-translation-output files
+- Support for lexical shortlists in marian-server
+- Support for 8-bit matrix multiplication with FBGEMM
+- CMakeLists.txt now looks for SSE 4.2
+- Purging of finished hypotheses during beam-search. A lot faster for large batches.
+- Faster option look-up, up to 20-30% faster translation
+- Added --cite and --authors flag
+- Added optional support for ccache
+- Switch to change abort to exception, only to be used in library mode
+- Support for 16-bit packed models with FBGEMM
+- Multiple separated parameter types in ExpressionGraph, currently inference-only
+- Safe handling of sigterm signal
+- Automatic vectorization of elementwise operations on CPU for tensors dims that
+ are divisible by 4 (AVX) and 8 (AVX2)
+- Replacing std::shared_ptr<T> with custom IntrusivePtr<T> for small objects like
+ Tensors, Hypotheses and Expressions.
+- Fp16 inference working for translation
+- Gradient-checkpointing
+
+### Fixed
+- Replace value for INVALID_PATH_SCORE with std::numer_limits<float>::lowest()
+ to avoid overflow with long sequences
+- Break up potential circular references for GraphGroup*
+- Fix empty source batch entries with batch purging
+- Clear RNN chache in transformer model, add correct hash functions to nodes
+- Gather-operation for all index sizes
+- Fix word weighting with max length cropping
+- Fixed compilation on CPUs without support for AVX
+- FastOpt now reads "n" and "y" values as strings, not as boolean values
+- Fixed multiple reduction kernels on GPU
+- Fixed guided-alignment training with cross-entropy
+- Replace IntrusivePtr with std::uniq_ptr in FastOpt, fixes random segfaults
+ due to thread-non-safty of reference counting.
+- Make sure that items are 256-byte aligned during saving
+- Make explicit matmul functions respect setting of cublasMathMode
+- Fix memory mapping for mixed paramter models
+- Removed naked pointer and potential memory-leak from file_stream.{cpp,h}
+- Compilation for GCC >= 7 due to exception thrown in destructor
+- Sort parameters by lexicographical order during allocation to ensure consistent
+ memory-layout during allocation, loading, saving.
+- Output empty line when input is empty line. Previous behavior might result in
+ hallucinated outputs.
+- Compilation with CUDA 10.1
+
+### Changed
+- Combine two for-loops in nth_element.cpp on CPU
+- Revert LayerNorm eps to old position, i.e. sigma' = sqrt(sigma^2 + eps)
+- Downgrade NCCL to 2.3.7 as 2.4.2 is buggy (hangs with larger models)
+- Return error signal on SIGTERM
+- Dropped support for CUDA 8.0, CUDA 9.0 is now minimal requirement
+- Removed autotuner for now, will be switched back on later
+- Boost depdendency is now optional and only required for marian_server
+- Dropped support for g++-4.9
+- Simplified file stream and temporary file handling
+- Unified node intializers, same function API.
+- Remove overstuff/understuff code
+
+## [1.8.0] - 2019-09-04
+
+### Added
+- Alias options and new --task option
+- Automatic detection of CPU intrisics when building with -arch=native
+- First version of BERT-training and BERT-classifier, currently not compatible with TF models
+- New reduction operators
+- Use Cmake's ExternalProject to build NCCL and potentially other external libs
+- Code for Factored Vocabulary, currently not usable yet without outside tools
### Fixed
+- Issue with relative paths in automatically generated decoder config files
+- Bug with overlapping CXX flags and building spm_train executable
+- Compilation with gcc 8
+- Overwriting and unsetting vector options
- Windows build with recent changes
- Bug with read-ahead buffer
-- Fixed handling of "dump-config: false" in YAML config
+- Handling of "dump-config: false" in YAML config
- Errors due to warnings
-- Fixed issue concerning failed saving with single GPU training and --sync-sgd option.
+- Issue concerning failed saving with single GPU training and --sync-sgd option.
+- NaN problem when training with Tensor Cores on Volta GPUs
+- Fix pipe-handling
+- Fix compilation with GCC 9.1
+- Fix CMake build types
### Changed
+- Error message when using left-to-right and right-to-left models together in ensembles
+- Regression tests included as a submodule
+- Update NCCL to 2.4.2
- Add zlib source to Marian's source tree, builds now as object lib
- -DUSE_STATIC_LIBS=on now also looks for static versions of CUDA libraries
- Include NCCL build from github.com/marian-nmt/nccl and compile within source tree
-- Set nearly all warnings as errors for Marian's own targets. Disable warnings for 3rd party.
+- Set nearly all warnings as errors for Marian's own targets. Disable warnings for 3rd party
+- Refactored beam search
## [1.7.0] - 2018-11-27
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 28e648aa..46d9c6c9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -5,7 +5,6 @@ if (POLICY CMP0074)
cmake_policy(SET CMP0074 NEW) # CMake 3.12
endif ()
-
project(marian CXX C)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
@@ -14,14 +13,33 @@ set(BUILD_ARCH native CACHE STRING "Compile for this CPU architecture.")
# Custom CMake options
option(COMPILE_CPU "Compile CPU version" ON)
option(COMPILE_CUDA "Compile GPU version" ON)
-option(USE_SENTENCEPIECE "Download and compile SentencePiece" OFF)
-option(USE_STATIC_LIBS "Link statically against non-system libs" OFF)
-option(USE_CUDNN "Use CUDNN library" OFF)
-option(USE_NCCL "Use NCCL library" ON)
-option(USE_MPI "Use MPI library" OFF)
+option(COMPILE_CUDA_SM35 "Compile GPU version with SM35 support" ON)
+option(COMPILE_CUDA_SM50 "Compile GPU version with SM50 support" ON)
+option(COMPILE_CUDA_SM60 "Compile GPU version with SM60 support" ON)
+option(COMPILE_CUDA_SM70 "Compile GPU version with SM70 support" ON)
option(COMPILE_EXAMPLES "Compile examples" OFF)
+option(COMPILE_SERVER "Compile marian-server" OFF)
option(COMPILE_TESTS "Compile tests" OFF)
-option(COMPILE_SERVER "Compile marian-server" ON)
+option(USE_CCACHE "Use ccache compiler cache (https://ccache.dev)" OFF)
+option(USE_CUDNN "Use CUDNN library" OFF)
+option(USE_DOXYGEN "Build documentation with Doxygen" ON)
+option(USE_FBGEMM "Use FBGEMM" OFF)
+option(USE_MKL "Compile with MKL support" ON)
+option(USE_MPI "Use MPI library" OFF)
+option(USE_NCCL "Use NCCL library" ON)
+option(USE_SENTENCEPIECE "Download and compile SentencePiece" OFF)
+option(USE_STATIC_LIBS "Link statically against non-system libs" OFF)
+
+# use ccache (https://ccache.dev) for faster compilation if requested and available
+if(USE_CCACHE)
+find_program(CCACHE_PROGRAM ccache)
+if(CCACHE_PROGRAM)
+ message(STATUS "Will be using ccache for faster repeat compilation (use cmake -DUSE_CCACHE=off to disable).")
+ set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CCACHE_PROGRAM}")
+else(CCACHE_PROGRAM)
+ message(WARNING "Compilation with ccache requested but no ccache found.")
+endif(CCACHE_PROGRAM)
+endif(USE_CCACHE)
# Project versioning
find_package(Git QUIET)
@@ -32,7 +50,13 @@ message(STATUS "Project version: ${PROJECT_VERSION_STRING_FULL}")
execute_process(COMMAND git submodule update --init --recursive --no-fetch
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
-
+
+if(NOT CMAKE_BUILD_TYPE)
+ message(WARNING "CMAKE_BUILD_TYPE not set; setting to Release")
+ set(CMAKE_BUILD_TYPE "Release")
+endif()
+
+###############################################################################
# Set compilation flags
if(MSVC)
# These are used in src/CMakeLists.txt on a per-target basis
@@ -42,7 +66,7 @@ if(MSVC)
# C4310: cast truncates constant value
# C4324: 'marian::cpu::int16::`anonymous-namespace'::ScatterPut': structure was padded due to alignment specifier
set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\"")
-
+
set(INTRINSICS "/arch:AVX")
# Or maybe use these?
@@ -57,62 +81,112 @@ if(MSVC)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /NODEFAULTLIB:MSVCRT /ignore:4049")
set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental")
- find_library(SHLWAPI Shlwapi.lib)
+ find_library(SHLWAPI Shlwapi.lib)
set(EXT_LIBS ${EXT_LIBS} SHLWAPI)
-else()
+else(MSVC)
-# Detect support CPU instrinsics for the current platform. This will
-# only by used with BUILD_ARCH=native. For overridden BUILD_ARCH we
-# minimally use -msse4.1. This seems to work with MKL.
-set(INTRINSICS "")
-if(BUILD_ARCH STREQUAL "native")
- message(STATUS "Checking support for CPU intrinsics")
- include(FindSSE)
- if(SSE2_FOUND)
- message(STATUS "SSE2 support found")
- set(INTRINSICS "${INTRINSICS} -msse2")
- endif(SSE2_FOUND)
- if(SSE3_FOUND)
- message(STATUS "SSE3 support found")
- set(INTRINSICS "${INTRINSICS} -msse3")
- endif(SSE3_FOUND)
- if(SSE4_1_FOUND)
- message(STATUS "SSE4.1 support found")
- set(INTRINSICS "${INTRINSICS} -msse4.1")
- endif(SSE4_1_FOUND)
- if(AVX_FOUND)
- message(STATUS "AVX support found")
- set(INTRINSICS "${INTRINSICS} -mavx")
- endif(AVX_FOUND)
- if(AVX2_FOUND)
- message(STATUS "AVX2 support found")
- set(INTRINSICS "${INTRINSICS} -mavx2")
- endif(AVX2_FOUND)
-else()
- set(INTRINSICS "-msse4.1")
-endif()
+ # Check we are using at least g++ 5.0
+ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 5.0)
+ message(FATAL_ERROR "FATAL ERROR: Compiling Marian requires at least g++ 5.0, your version is ${CMAKE_CXX_COMPILER_VERSION}")
+ endif()
-set(DISABLE_GLOBALLY "-Wno-unused-result")
+ # Detect support CPU instrinsics for the current platform. This will
+ # only by used with BUILD_ARCH=native. For overridden BUILD_ARCH we
+ # minimally use -msse4.1. This seems to work with MKL.
+ set(INTRINSICS "")
+ list(APPEND INTRINSICS_NVCC)
+
+ if(BUILD_ARCH STREQUAL "native")
+ message(STATUS "Checking support for CPU intrinsics")
+ include(FindSSE)
+ if(SSE2_FOUND)
+ message(STATUS "SSE2 support found")
+ set(INTRINSICS "${INTRINSICS} -msse2")
+ list(APPEND INTRINSICS_NVCC -Xcompiler\ -msse2)
+ endif(SSE2_FOUND)
+ if(SSE3_FOUND)
+ message(STATUS "SSE3 support found")
+ set(INTRINSICS "${INTRINSICS} -msse3")
+ list(APPEND INTRINSICS_NVCC -Xcompiler\ -msse3)
+ endif(SSE3_FOUND)
+ if(SSE4_1_FOUND)
+ message(STATUS "SSE4.1 support found")
+ set(INTRINSICS "${INTRINSICS} -msse4.1")
+ list(APPEND INTRINSICS_NVCC -Xcompiler\ -msse4.1)
+ endif(SSE4_1_FOUND)
+ if(SSE4_2_FOUND)
+ message(STATUS "SSE4.2 support found")
+ set(INTRINSICS "${INTRINSICS} -msse4.2")
+ list(APPEND INTRINSICS_NVCC -Xcompiler\ -msse4.2)
+ endif(SSE4_2_FOUND)
+ if(AVX_FOUND)
+ message(STATUS "AVX support found")
+ set(INTRINSICS "${INTRINSICS} -mavx")
+ list(APPEND INTRINSICS_NVCC -Xcompiler\ -mavx)
+ endif(AVX_FOUND)
+ if(AVX2_FOUND)
+ message(STATUS "AVX2 support found")
+ set(INTRINSICS "${INTRINSICS} -mavx2")
+ list(APPEND INTRINSICS_NVCC -Xcompiler\ -mavx2)
+ endif(AVX2_FOUND)
+ if(AVX512_FOUND)
+ message(STATUS "AVX512 support found")
+ set(INTRINSICS "${INTRINSICS} -mavx512f")
+ list(APPEND INTRINSICS_NVCC -Xcompiler\ -mavx512f)
+ endif(AVX512_FOUND)
+ else()
+ set(INTRINSICS "-msse4.1")
+ endif()
-# These are used in src/CMakeLists.txt on a per-target basis
-list(APPEND ALL_WARNINGS -Wall; -Werror; -Wno-unused-result; -Wno-deprecated; -Wno-pragmas; -Wno-unused-parameter; -Wextra; -Wno-unused-function;
- -Wno-unused-value; -Wno-unknown-pragmas; -Wno-sign-compare; -Wno-missing-field-initializers;)
+ if(USE_FBGEMM)
+ set(EXT_LIBS ${EXT_LIBS} fbgemm dl)
+ add_definitions(-DUSE_FBGEMM=1)
+ endif(USE_FBGEMM)
+
+ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
+ # Clang-10.0.0 complains when CUDA is newer than 10.1
+ set(CLANG_IGNORE_UNKNOWN_CUDA "-Wno-unknown-cuda-version")
+ endif()
+ set(DISABLE_GLOBALLY "-Wno-unused-result -Wno-unknown-warning-option ${CLANG_IGNORE_UNKNOWN_CUDA}")
+
+ # These are used in src/CMakeLists.txt on a per-target basis
+ list(APPEND ALL_WARNINGS -Wall; -Werror; -Wextra; -Wno-unused-result; -Wno-deprecated;
+ -Wno-pragmas; -Wno-unused-parameter; -Wno-unused-function;
+ -Wno-unused-value; -Wno-unknown-pragmas; -Wno-sign-compare;
+ -Wno-missing-field-initializers;)
# This warning does not exist prior to gcc 5.0
if(CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 5.0)
- list(APPEND ALL_WARNINGS -Wsuggest-override)
+ list(APPEND ALL_WARNINGS -Wsuggest-override -Wno-int-in-bool-context)
endif()
- set(CMAKE_CXX_FLAGS "-std=c++11 -O3 -Ofast -m64 -pthread -march=${BUILD_ARCH} ${INTRINSICS} -Wl,--no-as-needed -funroll-loops -ffinite-math-only -fPIC ${DISABLE_GLOBALLY}")
- set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} -g -rdynamic")
- set(CMAKE_CXX_FLAGS_DEBUG "-std=c++11 -g -rdynamic -O0 -pthread -Wl,--no-as-needed -fPIC -Wno-unused-result -Wno-deprecated -Wno-pragmas")
- set(CMAKE_CXX_FLAGS_SLIM "${CMAKE_CXX_FLAGS} -DNDEBUG")
- set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS} -g -rdynamic")
- set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg -g -rdynamic")
+ if(CMAKE_COMPILER_IS_GNUCC)
+ # these flags are not known to clang
+ set(CMAKE_GCC_FLAGS "-Wl,--no-as-needed")
+ set(CMAKE_RDYNAMIC_FLAG "-rdynamic")
+ endif(CMAKE_COMPILER_IS_GNUCC)
+
+ set(CMAKE_CXX_FLAGS "-std=c++11 -pthread ${CMAKE_GCC_FLAGS} -fPIC ${DISABLE_GLOBALLY} -march=${BUILD_ARCH} ${INTRINSICS}")
+ set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -m64 -funroll-loops -ffinite-math-only -g ${CMAKE_RDYNAMIC_FLAG}")
+ set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g ${CMAKE_RDYNAMIC_FLAG}")
+ set(CMAKE_CXX_FLAGS_SLIM "-Ofast -m64 -funroll-loops -ffinite-math-only -DNDEBUG")
+ set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELEASE}")
+ set(CMAKE_CXX_FLAGS_PROFILE "${CMAKE_CXX_FLAGS_RELEASE} -pg")
set(CMAKE_CXX_FLAGS_PROFGEN "${CMAKE_CXX_FLAGS_RELEASE} -fprofile-generate -fprofile-correction")
set(CMAKE_CXX_FLAGS_PROFUSE "${CMAKE_CXX_FLAGS_RELEASE} -fprofile-use -fprofile-correction")
- endif()
+ # these need to be set separately
+ set(CMAKE_C_FLAGS "-pthread ${CMAKE_GCC_FLAGS} -fPIC ${DISABLE_GLOBALLY} -march=${BUILD_ARCH} ${INTRINSICS}")
+ set(CMAKE_C_FLAGS_RELEASE "-O3 -m64 -funroll-loops -ffinite-math-only -g ${CMAKE_RDYNAMIC_FLAG}")
+ set(CMAKE_C_FLAGS_DEBUG "-O0 -g ${CMAKE_RDYNAMIC_FLAG}")
+ set(CMAKE_C_FLAGS_SLIM "-O3 -m64 -funroll-loops -ffinite-math-only -DNDEBUG")
+ set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELEASE}")
+ set(CMAKE_C_FLAGS_PROFILE "${CMAKE_C_FLAGS_RELEASE} -pg")
+ set(CMAKE_C_FLAGS_PROFGEN "${CMAKE_C_FLAGS_RELEASE} -fprofile-generate -fprofile-correction")
+ set(CMAKE_C_FLAGS_PROFUSE "${CMAKE_C_FLAGS_RELEASE} -fprofile-use -fprofile-correction")
+endif(MSVC)
+
+###############################################################################
# Downloading SentencePiece if requested and set to compile with it.
# Requires all the dependencies imposed by SentencePiece
if(USE_SENTENCEPIECE)
@@ -121,10 +195,10 @@ if(USE_SENTENCEPIECE)
set(EXT_LIBS ${EXT_LIBS} sentencepiece sentencepiece_train)
endif()
-
# Find packages
set(EXT_LIBS ${EXT_LIBS} ${CMAKE_DL_LIBS})
+###############################################################################
if(COMPILE_CUDA)
if(USE_STATIC_LIBS)
@@ -140,16 +214,41 @@ if(USE_STATIC_LIBS)
endif()
endif()
-find_package(CUDA "8.0")
+find_package(CUDA "9.0") # TODO: only enable FP16-related options for compute_70 and higher.
if(CUDA_FOUND)
+ # CUDA >= 10.0 requires CMake >= 3.12.2
+ if((CUDA_VERSION VERSION_EQUAL "10.0" OR CUDA_VERSION VERSION_GREATER "10.0") AND (CMAKE_VERSION VERSION_LESS "3.12.2"))
+ message(WARNING "On some Unix systems CUDA 10.0+ requires CMake 3.12.2+; you use CMake ${CMAKE_VERSION}")
+ endif()
+
+ if(COMPILE_CUDA_SM35)
+ LIST(APPEND COMPUTE -arch=sm_35; -gencode=arch=compute_35,code=sm_35;) # Tesla K40 and above
+ endif(COMPILE_CUDA_SM35)
+ if(COMPILE_CUDA_SM50)
+ LIST(APPEND COMPUTE -gencode=arch=compute_50,code=sm_50; -gencode=arch=compute_52,code=sm_52;) # Maxwell GPUs
+ endif(COMPILE_CUDA_SM50)
+ if(COMPILE_CUDA_SM60)
+ LIST(APPEND COMPUTE -gencode=arch=compute_60,code=sm_60; -gencode=arch=compute_61,code=sm_61;) # Pascal GPUs
+ endif(COMPILE_CUDA_SM60)
+ if(COMPILE_CUDA_SM70)
+ LIST(APPEND COMPUTE -gencode=arch=compute_70,code=sm_70; -gencode=arch=compute_70,code=compute_70) # Volta GPUs
+ endif(COMPILE_CUDA_SM70)
+
if(USE_STATIC_LIBS)
find_library(CUDA_culibos_LIBRARY NAMES culibos PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
- message(STATUS "Found CUDA libraries: ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}")
+ set(CUDA_LIBS ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_culibos_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
+ # CUDA 10.1 introduces cublasLt library that is required on static build
+ if ((CUDA_VERSION VERSION_EQUAL "10.1" OR CUDA_VERSION VERSION_GREATER "10.1"))
+ find_library(CUDA_cublasLt_LIBRARY NAMES cublasLt PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
+ set(EXT_LIBS ${EXT_LIBS} ${CUDA_cublasLt_LIBRARY})
+ set(CUDA_LIBS ${CUDA_LIBS} ${CUDA_cublasLt_LIBRARY})
+ endif()
+ message(STATUS "Found CUDA libraries: ${CUDA_LIBS}")
else(USE_STATIC_LIBS)
set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES})
message(STATUS "Found CUDA libraries: ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} ${CUDA_CUBLAS_LIBRARIES}")
-endif(USE_STATIC_LIBS)
+ endif(USE_STATIC_LIBS)
if(USE_CUDNN)
find_package(CUDNN "7.0")
@@ -168,63 +267,42 @@ endif(USE_STATIC_LIBS)
list(APPEND CUDA_NVCC_FLAGS -DBOOST_PP_VARIADICS=0; )
endif()
- # We compile NCCL ourselves, using the NVidia Makefile rather than CMake, this requires to pass a couple of parameters from
- # Cmake. This is also fairly untested, let's hope it does not explode.
- # @TODO: Make sure it does not use pre-installed NCCL headers
if(USE_NCCL)
- # define and set the include dir for the generated nccl.h header
- set(NCCL_HEADER_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/nccl/include")
- include_directories(${NCCL_HEADER_LOCATION})
-
- # set the path for the generated static lib
- set(NCCL_LIB_STATIC "${CMAKE_CURRENT_BINARY_DIR}/nccl/lib/libnccl_static.a")
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_NCCL")
-
- LIST(APPEND CUDA_NVCC_FLAGS -DUSE_NCCL; )
-
- # disables compilation for sm_30 to avoid ptxas warning... that's general Kepler support. But K80s are supported for instance by sm_35
- set(GENCODE "-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61")
-
- # We build using NVidia's custom makefile, for that we pass a number of variables from CMake.
- # Sets output to the chosen build folder, i.e. where the binaries and objects are generated.
- # Also passes CUDA location from FindCUDA, sets c++ compiler to the same one CMake uses.
- add_custom_command(OUTPUT ${NCCL_LIB_STATIC}
- COMMAND ${CMAKE_MAKE_PROGRAM} src.build
- BUILDDIR=${CMAKE_CURRENT_BINARY_DIR}/nccl
- CUDA_HOME=${CUDA_TOOLKIT_ROOT_DIR}
- CUDA8_GENCODE=${GENCODE}
- CXX=${CMAKE_CXX_COMPILER}
- WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/src/3rd_party/nccl)
- add_custom_target(nccl_target DEPENDS ${NCCL_LIB_STATIC})
add_library(nccl STATIC IMPORTED)
- set_target_properties(nccl PROPERTIES IMPORTED_LOCATION ${NCCL_LIB_STATIC})
- add_dependencies(nccl nccl_target)
set(EXT_LIBS ${EXT_LIBS} nccl)
-
- # adds the resulting files to be removed by `make clean`
- set_directory_properties(PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_CURRENT_BINARY_DIR}/nccl)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_NCCL")
+ LIST(APPEND CUDA_NVCC_FLAGS -DUSE_NCCL; )
endif(USE_NCCL)
-if(USE_STATIC_LIBS)
- set(CMAKE_FIND_LIBRARY_SUFFIXES ${_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES})
-endif()
+ if(USE_STATIC_LIBS)
+ set(CMAKE_FIND_LIBRARY_SUFFIXES ${_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES})
+ endif()
else(CUDA_FOUND)
- message(FATAL_ERROR "CUDA has not been found, set -DCOMPILE_CUDA=off to avoid this check and to compile the CPU version only")
+ message("
+Cannot find suitable CUDA libraries. Specify the path explicitly with
+ -DCUDA_TOOLKIT_ROOT_DIR=/path/to/appropriate/cuda/installation
+ (hint: try /usr/local/$(readlink /usr/local/cuda))
+OR compile the CPU-only version of Marian with
+ -DCOMPILE_CUDA=off
+")
+ message(FATAL_ERROR "FATAL ERROR: No suitable CUDA library found.")
endif(CUDA_FOUND)
else(COMPILE_CUDA)
message(WARNING "COMPILE_CUDA=off : Building only CPU version")
endif(COMPILE_CUDA)
+# TODO: make compatible with older CUDA versions
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
- list(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -O0; -g; -arch=sm_30; -gencode=arch=compute_30,code=sm_30; -gencode=arch=compute_50,code=sm_50; -gencode=arch=compute_52,code=sm_52; -gencode=arch=compute_60,code=sm_60; -gencode=arch=compute_61,code=sm_61; -gencode=arch=compute_61,code=compute_61 ;)
+ list(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -O0; -g; --use_fast_math; ${COMPUTE})
else(CMAKE_BUILD_TYPE STREQUAL "Debug")
- list(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -O3; -g; --use_fast_math; -arch=sm_30; -gencode=arch=compute_30,code=sm_30; -gencode=arch=compute_50,code=sm_50; -gencode=arch=compute_52,code=sm_52; -gencode=arch=compute_60,code=sm_60; -gencode=arch=compute_61,code=sm_61; -gencode=arch=compute_61,code=compute_61 ;)
+ list(APPEND CUDA_NVCC_FLAGS --default-stream per-thread; -O3; -g; --use_fast_math; ${COMPUTE})
endif(CMAKE_BUILD_TYPE STREQUAL "Debug")
if(NOT MSVC)
# @TODO: add warnings here too
- list(APPEND CUDA_NVCC_FLAGS -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;)
+ list(APPEND CUDA_NVCC_FLAGS -ccbin ${CMAKE_C_COMPILER}; -std=c++11; -Xcompiler\ -fPIC; -Xcompiler\ -Wno-unused-result; -Xcompiler\ -Wno-deprecated; -Xcompiler\ -Wno-pragmas; -Xcompiler\ -Wno-unused-value; -Xcompiler\ -Werror;)
+ list(APPEND CUDA_NVCC_FLAGS ${INTRINSICS_NVCC})
else()
list(APPEND CUDA_NVCC_FLAGS -Xcompiler\ /FS; )
endif()
@@ -241,6 +319,8 @@ if(USE_STATIC_LIBS)
endif()
endif()
+###############################################################################
+# Find Tcmalloc
if(NOT WIN32)
find_package(Tcmalloc)
if(Tcmalloc_FOUND)
@@ -251,6 +331,8 @@ if(NOT WIN32)
endif(Tcmalloc_FOUND)
endif()
+###############################################################################
+# Find MPI
if(USE_MPI)
find_package(MPI 2.0)
if(MPI_FOUND)
@@ -260,38 +342,40 @@ if(USE_MPI)
endif(MPI_FOUND)
endif(USE_MPI)
+###############################################################################
+# Find MKL
if(COMPILE_CPU)
- find_package(MKL)
+ if(USE_MKL)
+ find_package(MKL)
+ endif(USE_MKL)
if(MKL_FOUND)
include_directories(${MKL_INCLUDE_DIR})
set(EXT_LIBS ${EXT_LIBS} ${MKL_LIBRARIES})
add_definitions(-DBLAS_FOUND=1 -DMKL_FOUND=1)
else(MKL_FOUND)
- set(BLA_VENDOR "OpenBLAS")
+ set(BLAS_VENDOR "OpenBLAS")
find_package(BLAS)
if(BLAS_FOUND)
- include_directories(${BLAS_INCLUDE_DIR})
- set(EXT_LIBS ${EXT_LIBS} ${BLAS_LIBRARIES})
- add_definitions(-DBLAS_FOUND=1)
+ include(FindCBLAS)
+ if(CBLAS_FOUND)
+ include_directories(${BLAS_INCLUDE_DIR} ${CBLAS_INCLUDE_DIR})
+ set(EXT_LIBS ${EXT_LIBS} ${BLAS_LIBRARIES} ${CBLAS_LIBRARIES})
+ add_definitions(-DBLAS_FOUND=1)
+ endif(CBLAS_FOUND)
endif(BLAS_FOUND)
endif(MKL_FOUND)
endif(COMPILE_CPU)
-set(BOOST_COMPONENTS timer iostreams filesystem system chrono)
-if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 4.9)
- add_definitions(-DUSE_BOOST_REGEX=1)
- set(BOOST_COMPONENTS ${BOOST_COMPONENTS} regex)
- message(STATUS "Using boost::regex")
-else()
- message(STATUS "Using std::regex")
-endif()
-
+###############################################################################
+# Find OpenSSL
+set(BOOST_COMPONENTS "")
if(COMPILE_SERVER)
find_package(OpenSSL)
if(OpenSSL_FOUND)
message(STATUS "Found OpenSSL")
include_directories(${OPENSSL_INCLUDE_DIR})
set(EXT_LIBS ${EXT_LIBS} ${OPENSSL_CRYPTO_LIBRARY})
+ set(BOOST_COMPONENTS ${BOOST_COMPONENTS} system)
else(OpenSSL_FOUND)
message(WARNING "Cannot find OpenSSL library. Not compiling server.")
set(COMPILE_SERVER "off")
@@ -302,19 +386,25 @@ if(USE_STATIC_LIBS)
set(CMAKE_FIND_LIBRARY_SUFFIXES ${_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES})
endif()
+# TODO: move inside if(BOOST_COMPONENTS) ?
if(USE_STATIC_LIBS)
set(Boost_USE_STATIC_LIBS ON)
endif()
-find_package(Boost COMPONENTS ${BOOST_COMPONENTS})
-if(Boost_FOUND)
- include_directories(${Boost_INCLUDE_DIRS})
- set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES})
- set(EXT_LIBS ${EXT_LIBS} ${ZLIB_LIBRARIES}) # hack for static compilation
-else(Boost_FOUND)
- message(SEND_ERROR "Cannot find Boost libraries. Terminating.")
-endif(Boost_FOUND)
-
+###############################################################################
+# Find Boost if required
+if(BOOST_COMPONENTS)
+ find_package(Boost COMPONENTS ${BOOST_COMPONENTS})
+ if(Boost_FOUND)
+ include_directories(${Boost_INCLUDE_DIRS})
+ set(EXT_LIBS ${EXT_LIBS} ${Boost_LIBRARIES})
+ set(EXT_LIBS ${EXT_LIBS} ${ZLIB_LIBRARIES}) # hack for static compilation
+ else(Boost_FOUND)
+ message(SEND_ERROR "Cannot find Boost libraries. Terminating.")
+ endif(Boost_FOUND)
+endif(BOOST_COMPONENTS)
+
+###############################################################################
if(COMPILE_TESTS)
enable_testing()
endif(COMPILE_TESTS)
@@ -327,11 +417,18 @@ endif(COMPILE_EXAMPLES)
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/project_version.h.in
${CMAKE_CURRENT_SOURCE_DIR}/src/common/project_version.h @ONLY)
+# Generate build_info.cpp with CMake cache variables
+include(GetCacheVariables)
+
+configure_file(${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp.in
+ ${CMAKE_CURRENT_SOURCE_DIR}/src/common/build_info.cpp @ONLY)
+
# Compile source files
include_directories(${marian_SOURCE_DIR}/src)
add_subdirectory(src)
-
+###############################################################################
+if(USE_DOXYGEN)
# Add a target to generate API documentation with Doxygen
find_package(Doxygen)
if(DOXYGEN_FOUND)
@@ -340,7 +437,7 @@ if(DOXYGEN_FOUND)
add_custom_target(doc
${DOXYGEN_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
- COMMENT "Generating API documentation with Doxygen" VERBATIM
+ COMMENT "Generating API documentation with Doxygen" VERBATIM
)
endif(DOXYGEN_FOUND)
-
+endif(USE_DOXYGEN)
diff --git a/Doxyfile.in b/Doxyfile.in
index 1761a228..ba2fec09 100644
--- a/Doxyfile.in
+++ b/Doxyfile.in
@@ -1592,7 +1592,7 @@ PAPER_TYPE = a4
# If left blank no extra packages will be included.
# This tag requires that the tag GENERATE_LATEX is set to YES.
-EXTRA_PACKAGES =
+EXTRA_PACKAGES = amsmath
# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the
# generated LaTeX document. The header should contain everything until the first
diff --git a/README.md b/README.md
index d60bcff2..17a33728 100644
--- a/README.md
+++ b/README.md
@@ -1,40 +1,26 @@
Marian
======
+[![Build Status CUDA 9](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-dev-cuda-9.2.svg?label=CUDA%209)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev-cuda-9.2/)
[![Build Status CUDA 10](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-dev-cuda-10.1.svg?label=CUDA%2010)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev-cuda-10.1/)
-[![CPU Build Status](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-dev-cpu.svg?label=CPU)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev-cpu/)
+[![Build Status CPU](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-dev-cpu.svg?label=CPU)](http://vali.inf.ed.ac.uk/jenkins/job/marian-dev-cpu/)
[![Tests Status](https://img.shields.io/jenkins/s/http/vali.inf.ed.ac.uk/jenkins/view/marian/job/marian-regression-tests.svg?label=tests)](http://vali.inf.ed.ac.uk/jenkins/job/marian-regression-tests/)
[![Latest release](https://img.shields.io/github/release/marian-nmt/marian.svg?label=release)](https://github.com/marian-nmt/marian/releases)
[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE.md)
[![Twitter](https://img.shields.io/twitter/follow/marian_nmt.svg?style=social)](https://twitter.com/intent/follow?screen_name=marian_nmt)
- <p>
- <b>Marian</b> is an efficient Neural Machine Translation framework written
- in pure C++ with minimal dependencies.
-
- Named in honour of Marian Rejewski, a Polish mathematician and cryptologist.
-
- <!--It has mainly been developed at the
- Adam Mickiewicz University in Poznań (AMU) and at the University of Edinburgh.-->
- </p>
-
- <!--p>
- It is currently being deployed in
- multiple European projects and is the main translation and training engine
- behind the neural MT launch at the
- <a href="http://www.wipo.int/pressroom/en/articles/2016/article_0014.html">World Intellectual Property Organization</a>.
- </p-->
-
- <p>
- Main features:
- <ul>
- <li> Fast multi-gpu training and translation </li>
- <li> Compatible with Nematus and DL4MT </li>
- <li> Efficient pure C++ implementation </li>
- <li> Permissive open source license (MIT) </li>
- <li> <a href="https://marian-nmt.github.io/features/"> more details... </a> </li>
- </ul>
- </p>
+*Marian* is an efficient Neural Machine Translation framework written in pure
+C++ with minimal dependencies.
+
+Named in honour of Marian Rejewski, a Polish mathematician and cryptologist.
+
+Main features:
+
+- Efficient pure C++ implementation
+- Fast multi-GPU training and GPU/CPU translation
+- State-of-the-art NMT architectures: deep RNN and transformer
+- Permissive open source license (MIT)
+- [more detail...](https://marian-nmt.github.io/features)
If you use this, please cite:
@@ -59,20 +45,11 @@ Machine Translation in C++ (http://www.aclweb.org/anthology/P18-4020)
url = {http://www.aclweb.org/anthology/P18-4020}
}
-<!--
-## Compilation
-
-```
-cd marian-dev
-mkdir -p build
-cd build
-cmake .. -DCMAKE_BUILD_TYPE=Release
-make -j
-```
--->
-
## Amun
-The handwritten decoder for RNN models compatible with Marian and Nematus has been superseded by the Marian decoder. The code is available in a separate repository: https://github.com/marian-nmt/amun
+
+The handwritten decoder for RNN models compatible with Marian and Nematus has
+been superseded by the Marian decoder. The code is available in a separate
+repository: https://github.com/marian-nmt/amun
## Website
diff --git a/VERSION b/VERSION
index 12751ca7..6959dfca 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-v1.7.6
+v1.8.52
diff --git a/cmake/FindCBLAS.cmake b/cmake/FindCBLAS.cmake
new file mode 100644
index 00000000..97b0d3f8
--- /dev/null
+++ b/cmake/FindCBLAS.cmake
@@ -0,0 +1,186 @@
+# - Find CBLAS library
+#
+# This module finds an installed fortran library that implements the CBLAS
+# linear-algebra interface (see http://www.netlib.org/blas/), with CBLAS
+# interface.
+#
+# This module sets the following variables:
+# CBLAS_FOUND - set to true if a library implementing the CBLAS interface
+# is found
+# CBLAS_LINKER_FLAGS - uncached list of required linker flags (excluding -l
+# and -L).
+# CBLAS_LIBRARIES - uncached list of libraries (using full path name) to
+# link against to use CBLAS
+# CBLAS_INCLUDE_DIR - path to includes
+# CBLAS_INCLUDE_FILE - the file to be included to use CBLAS
+#
+
+## Based on https://github.com/Eyescale/CMake/blob/master/FindCBLAS.cmake
+
+INCLUDE(CheckFunctionExists)
+INCLUDE(CheckIncludeFile)
+
+MACRO(CHECK_ALL_LIBRARIES LIBRARIES INCLUDE _prefix _name _flags _list _include _search_include)
+ # This macro checks for the existence of the combination of fortran libraries
+ # given by _list. If the combination is found, this macro checks (using the
+ # Check_Fortran_Function_Exists macro) whether can link against that library
+ # combination using the name of a routine given by _name using the linker
+ # flags given by _flags. If the combination of libraries is found and passes
+ # the link test, LIBRARIES is set to the list of complete library paths that
+ # have been found. Otherwise, LIBRARIES is set to FALSE.
+
+ # N.B. _prefix is the prefix applied to the names of all cached variables that
+ # are generated internally and marked advanced by this macro.
+
+ SET(__list)
+ FOREACH(_elem ${_list})
+ IF(__list)
+ SET(__list "${__list} - ${_elem}")
+ ELSE(__list)
+ SET(__list "${_elem}")
+ ENDIF(__list)
+ ENDFOREACH(_elem)
+ MESSAGE(STATUS "Checking for [${__list}]")
+ SET(_libraries_work TRUE)
+ SET(${LIBRARIES})
+ SET(_combined_name)
+ SET(_paths)
+ FOREACH(_library ${_list})
+ SET(_combined_name ${_combined_name}_${_library})
+
+ # did we find all the libraries in the _list until now?
+ # (we stop at the first unfound one)
+ IF(_libraries_work)
+ IF(APPLE)
+ FIND_LIBRARY(${_prefix}_${_library}_LIBRARY
+ NAMES ${_library}
+ PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 /usr/local/opt/openblas/lib ENV
+ DYLD_LIBRARY_PATH
+ )
+ ELSE(APPLE)
+ FIND_LIBRARY(${_prefix}_${_library}_LIBRARY
+ NAMES ${_library}
+ PATHS /usr/local/lib /usr/lib /usr/local/lib64 /usr/lib64 ENV
+ LD_LIBRARY_PATH
+ )
+ ENDIF(APPLE)
+ MARK_AS_ADVANCED(${_prefix}_${_library}_LIBRARY)
+ IF(${_prefix}_${_library}_LIBRARY)
+ GET_FILENAME_COMPONENT(_path ${${_prefix}_${_library}_LIBRARY} PATH)
+ LIST(APPEND _paths ${_path}/../include ${_path}/../../include)
+ ENDIF(${_prefix}_${_library}_LIBRARY)
+ SET(${LIBRARIES} ${${LIBRARIES}} ${${_prefix}_${_library}_LIBRARY})
+ SET(_libraries_work ${${_prefix}_${_library}_LIBRARY})
+ ENDIF(_libraries_work)
+ ENDFOREACH(_library ${_list})
+
+ # Test include
+ SET(_bug_search_include ${_search_include}) #CMAKE BUG!!! SHOULD NOT BE THAT
+ IF(_bug_search_include)
+ FIND_PATH(${_prefix}${_combined_name}_INCLUDE ${_include} ${_paths})
+ MARK_AS_ADVANCED(${_prefix}${_combined_name}_INCLUDE)
+ IF(${_prefix}${_combined_name}_INCLUDE)
+ MESSAGE(STATUS "Checking for [${__list}] -- includes found")
+ SET(${_prefix}_INCLUDE_DIR ${${_prefix}${_combined_name}_INCLUDE})
+ SET(${_prefix}_INCLUDE_FILE ${_include})
+ SET(${INCLUDE} ${${_prefix}_INCLUDE_DIR})
+ ELSE(${_prefix}${_combined_name}_INCLUDE)
+ MESSAGE(STATUS "Checking for [${__list}] -- includes not found")
+ SET(_libraries_work FALSE)
+ ENDIF(${_prefix}${_combined_name}_INCLUDE)
+ ELSE(_bug_search_include)
+ SET(${_prefix}_INCLUDE_DIR)
+ SET(${_prefix}_INCLUDE_FILE ${_include})
+ ENDIF(_bug_search_include)
+
+ IF(_libraries_work)
+ # Test this combination of libraries.
+ SET(CMAKE_REQUIRED_LIBRARIES ${_flags} ${${LIBRARIES}})
+ CHECK_FUNCTION_EXISTS(${_name} ${_prefix}${_combined_name}_WORKS)
+ SET(CMAKE_REQUIRED_LIBRARIES)
+ MARK_AS_ADVANCED(${_prefix}${_combined_name}_WORKS)
+ SET(_libraries_work ${${_prefix}${_combined_name}_WORKS})
+
+ IF(_libraries_work)
+ MESSAGE(STATUS "Checking for [${__list}] -- libraries found")
+ ENDIF(_libraries_work)
+
+ ENDIF(_libraries_work)
+
+
+ IF(NOT _libraries_work)
+ SET(${LIBRARIES} FALSE)
+ ENDIF(NOT _libraries_work)
+
+ENDMACRO(CHECK_ALL_LIBRARIES)
+
+SET(CBLAS_LINKER_FLAGS)
+SET(CBLAS_LIBRARIES)
+SET(CBLAS_INCLUDE_DIR)
+
+# CBLAS in openBLAS
+IF(NOT CBLAS_LIBRARIES)
+ CHECK_ALL_LIBRARIES(
+ CBLAS_LIBRARIES
+ CBLAS_INCLUDE_DIR
+ cblas
+ cblas_sgemm
+ ""
+ "openblas"
+ "cblas.h"
+ TRUE
+ )
+ENDIF(NOT CBLAS_LIBRARIES)
+
+#MESSAGE(STATUS ${openblas_INCLUDE_DIR})
+
+# CBLAS in CBLAS
+IF(NOT CBLAS_LIBRARIES)
+ CHECK_ALL_LIBRARIES(
+ CBLAS_LIBRARIES
+ CBLAS_INCLUDE_DIR
+ cblas
+ cblas_sgemm
+ ""
+ "cblas"
+ "cblas.h"
+ TRUE
+ )
+ENDIF(NOT CBLAS_LIBRARIES)
+
+#MESSAGE(STATUS ${cblas_INCLUDE_DIR})
+
+# CBLAS in lapacke
+IF(NOT CBLAS_LIBRARIES)
+ CHECK_ALL_LIBRARIES(
+ CBLAS_LIBRARIES
+ CBLAS_INCLUDE_DIR
+ cblas
+ cblas_sgemm
+ ""
+ "lapacke"
+ "cblas.h"
+ TRUE
+ )
+ENDIF(NOT CBLAS_LIBRARIES)
+
+#MESSAGE(STATUS ${lapacke_INCLUDE_DIR})
+
+IF(CBLAS_LIBRARIES)
+ SET(CBLAS_FOUND TRUE)
+ELSE(CBLAS_LIBRARIES)
+ SET(CBLAS_FOUND FALSE)
+ENDIF(CBLAS_LIBRARIES)
+
+IF(NOT CBLAS_FOUND AND CBLAS_FIND_REQUIRED)
+ MESSAGE(FATAL_ERROR "CBLAS library not found. Please specify library location")
+ENDIF(NOT CBLAS_FOUND AND CBLAS_FIND_REQUIRED)
+
+IF(NOT CBLAS_FIND_QUIETLY)
+ IF(CBLAS_FOUND)
+ MESSAGE(STATUS "CBLAS library found: " ${CBLAS_LIBRARIES})
+ MESSAGE(STATUS "cblas.h include directory: " ${CBLAS_INCLUDE_DIR})
+ ELSE(CBLAS_FOUND)
+ MESSAGE(STATUS "CBLAS library not found. Please specify library location")
+ ENDIF(CBLAS_FOUND)
+ENDIF(NOT CBLAS_FIND_QUIETLY)
diff --git a/cmake/FindMKL.cmake b/cmake/FindMKL.cmake
index 028161e3..4e8a99ee 100644
--- a/cmake/FindMKL.cmake
+++ b/cmake/FindMKL.cmake
@@ -53,11 +53,11 @@ else()
set(COR_LIB "mkl_core")
endif()
-if(MSVC)
- set(ProgramFilesx86 "ProgramFiles(x86)")
- set(INTEL_ROOT_DEFAULT $ENV{${ProgramFilesx86}}/IntelSWTools/compilers_and_libraries/windows)
-else()
- set(INTEL_ROOT_DEFAULT "/opt/intel")
+if(MSVC)
+ set(ProgramFilesx86 "ProgramFiles(x86)")
+ set(INTEL_ROOT_DEFAULT $ENV{${ProgramFilesx86}}/IntelSWTools/compilers_and_libraries/windows)
+else()
+ set(INTEL_ROOT_DEFAULT "/opt/intel")
endif()
set(INTEL_ROOT ${INTEL_ROOT_DEFAULT} CACHE PATH "Folder contains intel libs")
find_path(MKL_ROOT include/mkl.h PATHS $ENV{MKLROOT} ${INTEL_ROOT}/mkl
@@ -89,7 +89,10 @@ find_library(MKL_CORE_LIBRARY
NO_DEFAULT_PATH)
set(MKL_INCLUDE_DIRS ${MKL_INCLUDE_DIR})
-set(MKL_LIBRARIES ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY})
+# Added -Wl block to avoid circular dependencies.
+# https://stackoverflow.com/questions/5651869/what-are-the-start-group-and-end-group-command-line-options
+# https://software.intel.com/en-us/articles/intel-mkl-link-line-advisor
+set(MKL_LIBRARIES -Wl,--start-group ${MKL_INTERFACE_LIBRARY} ${MKL_SEQUENTIAL_LAYER_LIBRARY} ${MKL_CORE_LIBRARY} -Wl,--end-group)
# message("1 ${MKL_INCLUDE_DIR}")
# message("2 ${MKL_INTERFACE_LIBRARY}")
diff --git a/cmake/FindSSE.cmake b/cmake/FindSSE.cmake
index c152dd74..82ee7f3e 100644
--- a/cmake/FindSSE.cmake
+++ b/cmake/FindSSE.cmake
@@ -41,6 +41,14 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
ENDIF (SSE41_TRUE)
+ STRING(REGEX REPLACE "^.*(sse4_2).*$" "\\1" SSE_THERE ${CPUINFO})
+ STRING(COMPARE EQUAL "sse4_2" "${SSE_THERE}" SSE42_TRUE)
+ IF (SSE42_TRUE)
+ set(SSE4_2_FOUND true CACHE BOOL "SSE4.2 available on host")
+ ELSE (SSE42_TRUE)
+ set(SSE4_2_FOUND false CACHE BOOL "SSE4.2 available on host")
+ ENDIF (SSE42_TRUE)
+
STRING(REGEX REPLACE "^.*(avx).*$" "\\1" SSE_THERE ${CPUINFO})
STRING(COMPARE EQUAL "avx" "${SSE_THERE}" AVX_TRUE)
IF (AVX_TRUE)
@@ -48,7 +56,7 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
ELSE (AVX_TRUE)
set(AVX_FOUND false CACHE BOOL "AVX available on host")
ENDIF (AVX_TRUE)
-
+
STRING(REGEX REPLACE "^.*(avx2).*$" "\\1" SSE_THERE ${CPUINFO})
STRING(COMPARE EQUAL "avx2" "${SSE_THERE}" AVX2_TRUE)
IF (AVX2_TRUE)
@@ -57,6 +65,14 @@ IF(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
ENDIF (AVX2_TRUE)
+ STRING(REGEX REPLACE "^.*(avx512).*$" "\\1" SSE_THERE ${CPUINFO})
+ STRING(COMPARE EQUAL "avx512" "${SSE_THERE}" AVX512_TRUE)
+ IF (AVX512_TRUE)
+ set(AVX512_FOUND true CACHE BOOL "AVX512 available on host")
+ ELSE (AVX512_TRUE)
+ set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
+ ENDIF (AVX512_TRUE)
+
ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin")
EXEC_PROGRAM("/usr/sbin/sysctl -n machdep.cpu.features" OUTPUT_VARIABLE
CPUINFO)
@@ -109,6 +125,14 @@ ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Darwin")
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
ENDIF (AVX2_TRUE)
+ STRING(REGEX REPLACE "^.*(avx512).*$" "\\1" SSE_THERE ${CPUINFO})
+ STRING(COMPARE EQUAL "avx512" "${SSE_THERE}" AVX512_TRUE)
+ IF (AVX512_TRUE)
+ set(AVX512_FOUND true CACHE BOOL "AVX512 available on host")
+ ELSE (AVX512_TRUE)
+ set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
+ ENDIF (AVX512_TRUE)
+
ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows")
# TODO
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
@@ -117,6 +141,7 @@ ELSEIF(CMAKE_SYSTEM_NAME MATCHES "Windows")
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
set(AVX_FOUND false CACHE BOOL "AVX available on host")
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
+ set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(SSE2_FOUND true CACHE BOOL "SSE2 available on host")
set(SSE3_FOUND false CACHE BOOL "SSE3 available on host")
@@ -124,6 +149,7 @@ ELSE(CMAKE_SYSTEM_NAME MATCHES "Linux")
set(SSE4_1_FOUND false CACHE BOOL "SSE4.1 available on host")
set(AVX_FOUND false CACHE BOOL "AVX available on host")
set(AVX2_FOUND false CACHE BOOL "AVX2 available on host")
+ set(AVX512_FOUND false CACHE BOOL "AVX512 available on host")
ENDIF(CMAKE_SYSTEM_NAME MATCHES "Linux")
if(NOT SSE2_FOUND)
@@ -144,5 +170,8 @@ endif(NOT AVX_FOUND)
if(NOT AVX2_FOUND)
MESSAGE(STATUS "Could not find hardware support for AVX2 on this machine.")
endif(NOT AVX2_FOUND)
+if(NOT AVX512_FOUND)
+ MESSAGE(STATUS "Could not find hardware support for AVX512 on this machine.")
+endif(NOT AVX512_FOUND)
-mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND, AVX_FOUND, AVX2_FOUND)
+mark_as_advanced(SSE2_FOUND SSE3_FOUND SSSE3_FOUND SSE4_1_FOUND, AVX_FOUND, AVX2_FOUND, AVX512_FOUND)
diff --git a/cmake/GetCacheVariables.cmake b/cmake/GetCacheVariables.cmake
new file mode 100644
index 00000000..563ade79
--- /dev/null
+++ b/cmake/GetCacheVariables.cmake
@@ -0,0 +1,52 @@
+##
+# This module extracts CMake cached variables into a variable.
+#
+# Author: snukky
+#
+# This module sets the following variables:
+# * PROJECT_CMAKE_CACHE - to the output of "cmake -L" - an uncached list of
+# non-advanced cached variables
+# * PROJECT_CMAKE_CACHE_ADVANCED - to the output of "cmake -LA" - an uncached
+# list of advanced cached variables
+#
+
+set(PROJECT_CMAKE_CACHE "")
+set(PROJECT_CMAKE_CACHE_ADVANCED "")
+
+# Get all CMake variables
+get_cmake_property(_variableNames VARIABLES)
+list(SORT _variableNames)
+list(REMOVE_DUPLICATES _variableNames)
+
+foreach(_variableName ${_variableNames})
+ # If it is a cache variable
+ get_property(_cachePropIsSet CACHE "${_variableName}" PROPERTY VALUE SET)
+ if(_cachePropIsSet)
+ # Get the variable's type
+ get_property(_variableType CACHE ${_variableName} PROPERTY TYPE)
+
+ # Get the variable's value
+ set(_variableValue "${${_variableName}}")
+
+ # Skip static or internal cached variables, cmake -L[A] does not print them, see
+ # https://github.com/Kitware/CMake/blob/master/Source/cmakemain.cxx#L282
+ if( (NOT "${_variableType}" STREQUAL "STATIC") AND
+ (NOT "${_variableType}" STREQUAL "INTERNAL") AND
+ (NOT "${_variableValue}" STREQUAL "") )
+
+
+ set(PROJECT_CMAKE_CACHE_ADVANCED "${PROJECT_CMAKE_CACHE_ADVANCED} \"${_variableName}=${_variableValue}\\n\"\n")
+
+ # Get the variable's advanced flag
+ get_property(_isAdvanced CACHE ${_variableName} PROPERTY ADVANCED SET)
+ if(NOT _isAdvanced)
+ set(PROJECT_CMAKE_CACHE "${PROJECT_CMAKE_CACHE} \"${_variableName}=${_variableValue}\\n\"\n")
+ endif()
+
+ # Print variables for debugging
+ #message(STATUS "${_variableName}=${${_variableName}}")
+ #message(STATUS " Type=${_variableType}")
+ #message(STATUS " Advanced=${_isAdvanced}")
+ endif()
+ endif(_cachePropIsSet)
+endforeach()
diff --git a/contrib/autoformat.sh b/contrib/autoformat.sh
index 18a5dbac..18a5dbac 100755..100644
--- a/contrib/autoformat.sh
+++ b/contrib/autoformat.sh
diff --git a/examples b/examples
-Subproject 336740065d9c23e53e912a1befff18981d9d27a
+Subproject c19b7814d71febf1053bd93af6ac314b4620409
diff --git a/regression-tests b/regression-tests
new file mode 160000
+Subproject 6a08849b23f6c14eefbe12f4eb73dc638b96258
diff --git a/scripts/bert/bert4marian.py b/scripts/bert/bert4marian.py
new file mode 100755
index 00000000..8070c0fe
--- /dev/null
+++ b/scripts/bert/bert4marian.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python3
+"""
+This script takes a Tensorflow BERT checkpoint and a model description in a JSON file and converts
+it to a Marian weight file with numpy weights and an internal YAML description.
+
+This works with checkpoints from https://github.com/google-research/bert
+
+Assmung a BERT checkpoint like this:
+drwxr-xr-x 2 marcinjd marcinjd 4.0K Nov 23 16:39 .
+-rw-r--r-- 1 marcinjd marcinjd 521 Nov 23 16:38 bert_config.json
+-rw-r--r-- 1 marcinjd marcinjd 682M Nov 23 16:39 bert_model.ckpt.data-00000-of-00001
+-rw-r--r-- 1 marcinjd marcinjd 8.5K Nov 23 16:39 bert_model.ckpt.index
+-rw-r--r-- 1 marcinjd marcinjd 888K Nov 23 16:39 bert_model.ckpt.meta
+-rw-r--r-- 1 marcinjd marcinjd 973K Nov 23 16:37 vocab.txt
+
+usage:
+
+./bert.py --bert_prefix bert_model.ckpt --bert_config bert_config.json --marian bert.npz
+"""
+
+import tensorflow as tf
+import numpy as np
+import sys
+import yaml
+import argparse
+
+parser = argparse.ArgumentParser(description='Convert Tensorflow BERT model to Marian weight file.')
+parser.add_argument('--bert_prefix', help='Prefix for Tensorflow BERT checkpoint', required=True)
+parser.add_argument('--bert_config', help='Path to Tensorflow BERT JSON config', required=True)
+parser.add_argument('--marian', help='Output path for Marian weight file', required=True)
+args = parser.parse_args()
+
+print("Loading TensorFlow config from %s" % (args.bert_config,))
+bertConfig = yaml.load(open(args.bert_config))
+bertConfigYamlStr = yaml.dump(bertConfig, default_flow_style=False)
+print(bertConfigYamlStr)
+
+print("Loading TensorFlow model from %s" % (args.bert_prefix,))
+
+# Collect tensors from TF model as numpy matrices
+tfModel = dict()
+with tf.Session() as sess:
+ preloader = tf.train.import_meta_graph(args.bert_prefix + ".meta")
+ preloader.restore(sess, args.bert_prefix)
+ vars = tf.global_variables()
+ for v in vars:
+ if len(v.shape) > 0:
+ if "adam" not in v.name: # ignore adam parameters
+ print(v.name, v.shape)
+ tfModel[v.name] = sess.run(v.name) # get numpy matrix
+
+# Prepare Marian model config
+config = dict()
+config["type"] = "bert"
+config["input-types"] = ["sequence", "class"]
+config["tied-embeddings-all"] = True
+config["dim-emb"] = tfModel["bert/embeddings/word_embeddings:0"].shape[-1]
+config["dim-vocabs"] = [ tfModel["bert/embeddings/word_embeddings:0"].shape[0],
+ tfModel["cls/seq_relationship/output_weights:0"].shape[0] ]
+
+config["transformer-dim-ffn"] = tfModel["bert/encoder/layer_0/intermediate/dense/kernel:0"].shape[-1]
+config["transformer-ffn-activation"] = bertConfig["hidden_act"]
+config["transformer-ffn-depth"] = 2
+config["transformer-heads"] = bertConfig["num_attention_heads"]
+config["transformer-train-position-embeddings"] = True
+config["transformer-preprocess"] = ""
+config["transformer-postprocess"] = "dan"
+config["transformer-postprocess-emb"] = "nd"
+config["bert-train-type-embeddings"] = True
+config["bert-type-vocab-size"] = tfModel["bert/embeddings/token_type_embeddings:0"].shape[0]
+config["version"] = "bert4marian.py conversion"
+
+# check number of layers
+found = True
+config["enc-depth"] = 0;
+while found:
+ found = False
+ for key in tfModel:
+ if "bert/encoder/layer_" + str(config["enc-depth"]) in key:
+ config["enc-depth"] += 1
+ found = True
+ break
+
+if config["enc-depth"] != bertConfig["num_hidden_layers"]:
+ sys.exit("Number of layers in JSON config (%s) and number of layers found in checkpoint (%s) do not match!" % (config["enc-depth"], bertConfig["num_hidden_layers"]))
+
+configYamlStr = yaml.dump(config, default_flow_style=False)
+desc = list(configYamlStr)
+npDesc = np.chararray((len(desc),))
+npDesc[:] = desc
+npDesc.dtype = np.int8
+
+marianModel = dict()
+marianModel["special:model.yml"] = npDesc
+
+# Map model weights here #
+# Embedding layers
+marianModel["Wemb"] = tfModel["bert/embeddings/word_embeddings:0"]
+marianModel["Wpos"] = tfModel["bert/embeddings/position_embeddings:0"]
+marianModel["Wtype"] = tfModel["bert/embeddings/token_type_embeddings:0"]
+marianModel["encoder_emb_ln_scale_pre"] = tfModel["bert/embeddings/LayerNorm/gamma:0"]
+marianModel["encoder_emb_ln_bias_pre"] = tfModel["bert/embeddings/LayerNorm/beta:0"]
+
+for layer in range(config["enc-depth"]):
+ marianPrefix = "encoder_l%s" % (layer + 1,)
+ tfPrefix = "bert/encoder/layer_%s" % (layer,)
+
+ # Attention
+ marianModel[marianPrefix + "_self_Wq"] = tfModel[tfPrefix + "/attention/self/query/kernel:0"]
+ marianModel[marianPrefix + "_self_bq"] = tfModel[tfPrefix + "/attention/self/query/bias:0"]
+
+ marianModel[marianPrefix + "_self_Wk"] = tfModel[tfPrefix + "/attention/self/key/kernel:0"]
+ marianModel[marianPrefix + "_self_bk"] = tfModel[tfPrefix + "/attention/self/key/bias:0"]
+
+ marianModel[marianPrefix + "_self_Wv"] = tfModel[tfPrefix + "/attention/self/value/kernel:0"]
+ marianModel[marianPrefix + "_self_bv"] = tfModel[tfPrefix + "/attention/self/value/bias:0"]
+
+ marianModel[marianPrefix + "_self_Wo"] = tfModel[tfPrefix + "/attention/output/dense/kernel:0"]
+ marianModel[marianPrefix + "_self_bo"] = tfModel[tfPrefix + "/attention/output/dense/bias:0"]
+
+ marianModel[marianPrefix + "_self_Wo_ln_scale"] = tfModel[tfPrefix + "/attention/output/LayerNorm/gamma:0"]
+ marianModel[marianPrefix + "_self_Wo_ln_bias"] = tfModel[tfPrefix + "/attention/output/LayerNorm/beta:0"]
+
+ # FFN
+ marianModel[marianPrefix + "_ffn_W1"] = tfModel[tfPrefix + "/intermediate/dense/kernel:0"]
+ marianModel[marianPrefix + "_ffn_b1"] = tfModel[tfPrefix + "/intermediate/dense/bias:0"]
+
+ marianModel[marianPrefix + "_ffn_W2"] = tfModel[tfPrefix + "/output/dense/kernel:0"]
+ marianModel[marianPrefix + "_ffn_b2"] = tfModel[tfPrefix + "/output/dense/bias:0"]
+
+ marianModel[marianPrefix + "_ffn_ffn_ln_scale"] = tfModel[tfPrefix + "/output/LayerNorm/gamma:0"]
+ marianModel[marianPrefix + "_ffn_ffn_ln_bias"] = tfModel[tfPrefix + "/output/LayerNorm/beta:0"]
+
+ # Training objectives
+ # Masked-LM output layer
+ marianModel["masked-lm_ff_logit_l1_W"] = tfModel["cls/predictions/transform/dense/kernel:0"]
+ marianModel["masked-lm_ff_logit_l1_b"] = tfModel["cls/predictions/transform/dense/bias:0"]
+
+ marianModel["masked-lm_ff_ln_scale"] = tfModel["cls/predictions/transform/LayerNorm/gamma:0"]
+ marianModel["masked-lm_ff_ln_bias"] = tfModel["cls/predictions/transform/LayerNorm/beta:0"]
+
+ marianModel["masked-lm_ff_logit_l2_b"] = tfModel["cls/predictions/output_bias:0"]
+
+ # Next Sentence classifier
+ marianModel["next-sentence_ff_logit_l1_W"] = tfModel["bert/pooler/dense/kernel:0"]
+ marianModel["next-sentence_ff_logit_l1_b"] = tfModel["bert/pooler/dense/bias:0"]
+
+ marianModel["next-sentence_ff_logit_l2_W"] = np.transpose(tfModel["cls/seq_relationship/output_weights:0"]) # transpose?!
+ marianModel["next-sentence_ff_logit_l2_b"] = tfModel["cls/seq_relationship/output_bias:0"]
+
+print("\nMarian config:")
+print(configYamlStr)
+print("Saving Marian model to %s" % (args.marian,))
+np.savez(args.marian, **marianModel)
diff --git a/scripts/checkpoints/average.py b/scripts/checkpoints/average.py
new file mode 100755
index 00000000..53bff186
--- /dev/null
+++ b/scripts/checkpoints/average.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python
+"""
+This script takes multiple Marian *.npz model files and outputs an elementwise average of the model,
+meant to do check-point averaging from:
+
+https://www.aclweb.org/anthology/W16-2316
+
+usage:
+
+./average.py -m model.1.npz model.2.npz --output model.avg.npz
+"""
+
+from __future__ import print_function
+
+import os
+import sys
+import argparse
+
+import numpy as np
+
+# Parse arguments
+parser = argparse.ArgumentParser()
+parser.add_argument('-m', '--model', nargs='+', required=True,
+ help="models to average")
+parser.add_argument('-o', '--output', required=True,
+ help="output path")
+args = parser.parse_args()
+
+# *average* holds the model matrix
+average = dict()
+# No. of models.
+n = len(args.model)
+
+for filename in args.model:
+ print("Loading {}".format(filename))
+ with open(filename, "rb") as mfile:
+ # Loads matrix from model file
+ m = np.load(mfile)
+ for k in m:
+ if k != "history_errs":
+ # Initialize the key
+ if k not in average:
+ average[k] = m[k]
+ # Add to the appropriate value
+ elif average[k].shape == m[k].shape and "special" not in k:
+ average[k] += m[k]
+
+# Actual averaging
+for k in average:
+ if "special" not in k:
+ average[k] /= n
+
+# Save averaged model to file
+print("Saving to {}".format(args.output))
+np.savez(args.output, **average)
diff --git a/scripts/contrib/fix_hard.py b/scripts/contrib/fix_hard.py
index 60043c27..60043c27 100644..100755
--- a/scripts/contrib/fix_hard.py
+++ b/scripts/contrib/fix_hard.py
diff --git a/scripts/contrib/inject_ctt.py b/scripts/contrib/inject_ctt.py
index 751ee1c6..751ee1c6 100644..100755
--- a/scripts/contrib/inject_ctt.py
+++ b/scripts/contrib/inject_ctt.py
diff --git a/scripts/contrib/inject_model_params.py b/scripts/contrib/inject_model_params.py
index 46096eb8..46096eb8 100644..100755
--- a/scripts/contrib/inject_model_params.py
+++ b/scripts/contrib/inject_model_params.py
diff --git a/scripts/contrib/model_info.py b/scripts/contrib/model_info.py
index 9e9647ef..9e9647ef 100644..100755
--- a/scripts/contrib/model_info.py
+++ b/scripts/contrib/model_info.py
diff --git a/scripts/embeddings/export_embeddings.py b/scripts/embeddings/export_embeddings.py
index 1476e52c..3b4f3314 100755
--- a/scripts/embeddings/export_embeddings.py
+++ b/scripts/embeddings/export_embeddings.py
@@ -1,4 +1,4 @@
-#!/usr/bin/env python
+#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import print_function
@@ -9,18 +9,22 @@ import numpy as np
def main():
- desc = """Export word embedding from model"""
+ desc = """Export word embeddings from model"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter, description=desc)
- parser.add_argument("-m", "--model", help="Model file", required=True)
- parser.add_argument(
- "-o", "--output-prefix", help="Output files prefix", required=True)
+ parser.add_argument("-m", "--model", help="path to model.npz file", required=True)
+ parser.add_argument("-o", "--output-prefix", help="prefix for output files", required=True)
args = parser.parse_args()
print("Loading model")
model = np.load(args.model)
special = yaml.load(model["special:model.yml"][:-1].tobytes())
+ if special["tied-embeddings-all"] or special["tied-embeddings-src"]:
+ all_emb = model["Wemb"]
+ export_emb(args.output_prefix + ".all", all_emb)
+ exit()
+
if special["type"] == "amun":
enc_emb = model["Wemb"]
dec_emb = model["Wemb_dec"]
@@ -28,16 +32,15 @@ def main():
enc_emb = model["encoder_Wemb"]
dec_emb = model["decoder_Wemb"]
- with open(args.output_prefix + ".src", "w") as out:
- out.write("{0} {1}\n".format(*enc_emb.shape))
- for i in range(enc_emb.shape[0]):
- vec = " ".join("{0:.8f}".format(v) for v in enc_emb[i])
- out.write("{0} {1}\n".format(i, vec))
+ export_emb(args.output_prefix + ".src", enc_emb)
+ export_emb(args.output_prefix + ".trg", dec_emb)
+
- with open(args.output_prefix + ".trg", "w") as out:
- out.write("{0} {1}\n".format(*dec_emb.shape))
- for i in range(dec_emb.shape[0]):
- vec = " ".join("{0:.8f}".format(v) for v in dec_emb[i])
+def export_emb(filename, emb):
+ with open(filename, "w") as out:
+ out.write("{0} {1}\n".format(*emb.shape))
+ for i in range(emb.shape[0]):
+ vec = " ".join("{0:.8f}".format(v) for v in emb[i])
out.write("{0} {1}\n".format(i, vec))
diff --git a/scripts/shortlist/.gitignore b/scripts/shortlist/.gitignore
new file mode 100644
index 00000000..bf0d379e
--- /dev/null
+++ b/scripts/shortlist/.gitignore
@@ -0,0 +1,3 @@
+bin
+fast_align
+extract-lex
diff --git a/scripts/shortlist/README.md b/scripts/shortlist/README.md
new file mode 100644
index 00000000..30bf1015
--- /dev/null
+++ b/scripts/shortlist/README.md
@@ -0,0 +1,8 @@
+`install.sh` is a helper script that downloads and compiles fastalign and extract-lex, and copies
+required binaries into _./bin_.
+
+Shortlist files (_lex.s2t_ and _lex.t2s_) can be created using `generate_shortlists.pl`, for
+example:
+
+ perl generate_shortlists.pl --bindir ./bin -s corpus.bpe.src -t corpus.bpe.tgt
+
diff --git a/scripts/shortlist/generate_shortlists.pl b/scripts/shortlist/generate_shortlists.pl
new file mode 100755
index 00000000..309eeef8
--- /dev/null
+++ b/scripts/shortlist/generate_shortlists.pl
@@ -0,0 +1,97 @@
+#!/usr/bin/env perl
+
+use strict;
+use Getopt::Long;
+use FindBin qw($Bin);
+use File::Temp qw(tempdir tempfile);
+use POSIX;
+
+my $PID = $$;
+$SIG{TERM} = $SIG{INT} = $SIG{QUIT} = sub { die; };
+
+my $BINDIR = "$Bin/bin";
+my $SRC;
+my $TRG;
+my $OUTPUT = "lex";
+my $THREADS = 8;
+my $PARALLEL = 0;
+my $HELP;
+
+GetOptions(
+ "b|bindir=s" => \$BINDIR,
+ "s|source=s" => \$SRC,
+ "t|target=s" => \$TRG,
+ "o|output=s" => \$OUTPUT,
+ "threads=i" => \$THREADS,
+ "parallel" => \$PARALLEL,
+ "h|help" => \$HELP,
+);
+
+if($HELP) {
+ print "Usage: perl $0 -b bindir -s corpus.src -t corpus.tgt [-o outputprefix] [--threads 8] [--parallel]\n";
+ exit 0;
+}
+
+die "--bindir arg is required" if not defined $BINDIR;
+die "-s|--source arg is required" if not defined $SRC;
+die "-t|--target arg is required" if not defined $TRG;
+die "-o|--output arg is required" if not defined $OUTPUT;
+
+for my $app (qw(fast_align atools extract_lex)) {
+ die "Could not find $app in $BINDIR" if not -e "$BINDIR/$app";
+}
+
+my $TEMPDIR = tempdir(CLEANUP => 1);
+
+my (undef, $CORPUS) = tempfile(DIR => $TEMPDIR);
+my (undef, $ALN_S2T) = tempfile(DIR => $TEMPDIR);
+my (undef, $ALN_T2S) = tempfile(DIR => $TEMPDIR);
+my (undef, $ALN_GDF) = tempfile(DIR => $TEMPDIR);
+
+execute("paste $SRC $TRG | sed 's/\\t/ ||| /' > $CORPUS");
+
+my @COMMANDS = (
+ "OMP_NUM_THREADS=$THREADS $BINDIR/fast_align -vdo -i $CORPUS > $ALN_S2T",
+ "OMP_NUM_THREADS=$THREADS $BINDIR/fast_align -vdor -i $CORPUS > $ALN_T2S"
+);
+
+my @PIDS;
+for my $c (@COMMANDS) {
+ if ($PARALLEL) {
+ my $pid = fork();
+ if (!$pid) {
+ execute($c);
+ exit(0);
+ } else {
+ push(@PIDS, $pid);
+ print "Forked process $pid\n";
+ }
+ } else {
+ execute($c);
+ }
+}
+if ($PARALLEL) {
+ waitpid($_, 0) foreach(@PIDS);
+}
+
+execute("$BINDIR/atools -c grow-diag-final -i $ALN_S2T -j $ALN_T2S > $ALN_GDF");
+execute("$BINDIR/extract_lex $TRG $SRC $ALN_GDF $OUTPUT.s2t $OUTPUT.t2s");
+
+sub execute {
+ my $command = shift;
+ logMessage("Executing:\t$command");
+ my $ret = system($command);
+ if ($ret != 0) {
+ logMessage("Command '$command' finished with return status $ret");
+ logMessage("Aborting and killing parent process");
+ kill(2, $PID);
+ die;
+ }
+}
+
+sub logMessage {
+ my $message = shift;
+ my $time = POSIX::strftime("%m/%d/%Y %H:%M:%S", localtime());
+ my $log_message = $time."\t$message\n";
+ print STDERR $log_message;
+}
diff --git a/scripts/shortlist/install.sh b/scripts/shortlist/install.sh
new file mode 100644
index 00000000..49b2171a
--- /dev/null
+++ b/scripts/shortlist/install.sh
@@ -0,0 +1,25 @@
+#!/bin/bash -v
+
+mkdir -p bin
+
+# download and compile fast_align
+if [ ! -e bin/fast_align ]; then
+ git clone https://github.com/clab/fast_align
+ mkdir -p fast_align/build
+ cd fast_align/build
+ cmake ..
+ make -j4
+ cp fast_align atools ../../bin
+ cd ../../
+fi
+
+# download and compile extract-lex
+if [ ! -e bin/extract_lex ]; then
+ git clone https://github.com/marian-nmt/extract-lex
+ mkdir -p extract-lex/build
+ cd extract-lex/build
+ cmake ..
+ make -j4
+ cp extract_lex ../../bin
+ cd ../../
+fi
diff --git a/src/3rd_party/CLI/App.hpp b/src/3rd_party/CLI/App.hpp
index b943ef1a..14ddd1e7 100644
--- a/src/3rd_party/CLI/App.hpp
+++ b/src/3rd_party/CLI/App.hpp
@@ -1590,7 +1590,12 @@ class App {
}
// Unlimited vector parser
+ // RG: A negative number for the total number of expected values means that the option is a
+ // vector and accepts an unlimited number of values
if(num < 0) {
+ // RG: We need to keep track if the vector option is empty and handle this separately as
+ // otherwise the parser will mark the command-line option as not set
+ bool emptyVectorArgs = true;
while(!args.empty() && _recognize(args.back()) == detail::Classifer::NONE) {
if(collected >= -num) {
// We could break here for allow extras, but we don't
@@ -1603,12 +1608,28 @@ class App {
parse_order_.push_back(op.get());
args.pop_back();
collected++;
+ emptyVectorArgs = false;
}
// Allow -- to end an unlimited list and "eat" it
if(!args.empty() && _recognize(args.back()) == detail::Classifer::POSITIONAL_MARK)
args.pop_back();
+ // RG: Handle empty vector-like options
+ if(emptyVectorArgs) {
+ // RG: Set implicit value(s) if the option has it (them)
+ if(op->get_implicit()) {
+ for(const auto& ival : detail::split_up(op->get_implicitval())) {
+ op->add_result(ival);
+ parse_order_.push_back(op.get());
+ }
+ // RG: Abort if there is a minimum number of values expected. Note: get_expected()
+ // equals to -N means at least N values are expected
+ } else if (op->get_expected() < 0) {
+ parse_order_.push_back(op.get());
+ throw ArgumentMismatch(op->get_name(), op->get_expected(), 0);
+ }
+ }
} else {
while(num > 0 && !args.empty()) {
num--;
diff --git a/src/3rd_party/CMakeLists.txt b/src/3rd_party/CMakeLists.txt
index 8548d9b8..9f3981af 100644
--- a/src/3rd_party/CMakeLists.txt
+++ b/src/3rd_party/CMakeLists.txt
@@ -6,6 +6,33 @@ add_subdirectory(./SQLiteCpp)
add_subdirectory(./pathie-cpp)
add_subdirectory(./zlib)
+if(USE_FBGEMM)
+ # @TODO: find out if this is somehow harmful. This is supppressing CMake warnings for CMAKE_SUPPRESS_DEVELOPER_WARNINGS
+ # meant to silence CMakeFiles of 3rd_party tools.
+ if(NOT DEFINED CMAKE_SUPPRESS_DEVELOPER_WARNINGS)
+ set(CMAKE_SUPPRESS_DEVELOPER_WARNINGS 1 CACHE INTERNAL "No dev warnings")
+ endif()
+
+ if(NOT MSVC)
+ # only locally disabled for the 3rd_party folder
+ # set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-value -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused")
+ endif()
+
+ set(FBGEMM_BUILD_TESTS OFF CACHE BOOL "Disable fbgemm tests")
+ set(FBGEMM_BUILD_BENCHMARKS OFF CACHE BOOL "Disable fbgemm benchmark")
+ add_subdirectory(./fbgemm)
+
+ # asmjit (3rd-party submodule of fbgemm) sets -Wall -Wextra near the end of
+ # the compile options, invalidating any -Wno-... flags that we may have set
+ # earlier. Let's remove them.
+ get_property(ASMJIT_COMPILE_OPTIONS TARGET asmjit PROPERTY COMPILE_OPTIONS)
+ list(REMOVE_ITEM ASMJIT_COMPILE_OPTIONS -Wall -Wextra)
+ set_property(TARGET asmjit PROPERTY COMPILE_OPTIONS ${ASMJIT_COMPILE_OPTIONS})
+ message(" ASMJIT COMPILE FLAGS: ${ASMJIT_COMPILE_OPTIONS}")
+
+endif(USE_FBGEMM)
+
if(USE_SENTENCEPIECE)
if(USE_STATIC_LIBS)
set(_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES ${CMAKE_FIND_LIBRARY_SUFFIXES})
@@ -16,16 +43,37 @@ if(USE_SENTENCEPIECE)
endif()
endif()
- set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
set(SPM_ENABLE_TCMALLOC ON CACHE BOOL "Enable TCMalloc if available." FORCE)
- set(SPM_TCMALLOC_STATIC ON CACHE BOOL "Link static library of TCMALLOC." FORCE)
+
+ if(USE_STATIC_LIBS)
+ message(WARNING "You are compiling SentencePiece binaries with -DUSE_STATIC_LIBS=on. \
+ This will cause spm_train to segfault. No need to worry if you do not intend to use that binary. \
+ Marian support for SentencePiece will work fine.")
+
+ set(SPM_ENABLE_SHARED OFF CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
+ set(SPM_TCMALLOC_STATIC ON CACHE BOOL "Link static library of TCMALLOC." FORCE)
+ else(USE_STATIC_LIBS)
+ set(SPM_ENABLE_SHARED ON CACHE BOOL "Builds shared libaries in addition to static libraries." FORCE)
+ set(SPM_TCMALLOC_STATIC OFF CACHE BOOL "Link static library of TCMALLOC." FORCE)
+ endif(USE_STATIC_LIBS)
add_subdirectory(./sentencepiece)
include_directories(./sentencepiece)
set_target_properties(spm_encode spm_decode spm_train spm_normalize spm_export_vocab
- PROPERTIES
- RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
+ PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
+
+ if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+ foreach(t sentencepiece sentencepiece_train sentencepiece_train-static
+ spm_decode spm_encode spm_export_vocab spm_normalize spm_train)
+ set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-tautological-compare -Wno-unused")
+ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
+ set_property(TARGET ${t} APPEND_STRING PROPERTY COMPILE_FLAGS " -Wno-range-loop-construct")
+ endif()
+ # get_property(SENTENCEPIECE_COMPILE_FLAGS TARGET ${t} PROPERTY COMPILE_FLAGS)
+ # message("-- SENTENCPIECE: compile flags for target ${t}: ${SENTENCEPIECE_COMPILE_FLAGS}")
+ endforeach(t)
+ endif()
if(USE_STATIC_LIBS)
set(CMAKE_FIND_LIBRARY_SUFFIXES ${_ORIG_CMAKE_FIND_LIBRARY_SUFFIXES})
@@ -36,5 +84,66 @@ include_directories(./SQLiteCpp/include)
include_directories(./CLI)
include_directories(./pathie-cpp/include)
+if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
+ #set_target_properties(SQLiteCpp PROPERTIES COMPILE_FLAGS
+ set_property(TARGET SQLiteCpp APPEND_STRING PROPERTY COMPILE_FLAGS
+ " -Wno-parentheses-equality -Wno-unused-value")
+ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0)
+ set_property(TARGET SQLiteCpp APPEND_STRING PROPERTY COMPILE_FLAGS
+ " -Wno-implicit-int-float-conversion")
+ endif()
+ set_property(TARGET libyaml-cpp APPEND_STRING PROPERTY COMPILE_FLAGS
+ " -fPIC -Wno-unused-value")
+ set_property(TARGET pathie-cpp APPEND_STRING PROPERTY COMPILE_FLAGS
+ " -fPIC -Wno-unused-value")
+endif()
+
+
+
include_directories(./zlib)
+include(ExternalProject)
+
+set(INSTALLS "") # this will contain a list of 3rd part dependencies that we install locally
+if(CUDA_FOUND)
+ if(USE_NCCL)
+
+ # disables compilation for sm_30 to avoid ptxas warning... that is general Kepler support. But K80s are supported for instance by sm_35
+
+ set(GENCODE "")
+ if(COMPILE_CUDA_SM35)
+ set(GENCODE "${GENCODE} -gencode=arch=compute_35,code=sm_35")
+ endif(COMPILE_CUDA_SM35)
+ if(COMPILE_CUDA_SM50)
+ set(GENCODE "${GENCODE} -gencode=arch=compute_50,code=sm_50")
+ endif(COMPILE_CUDA_SM50)
+ if(COMPILE_CUDA_SM60)
+ set(GENCODE "${GENCODE} -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_61,code=sm_61")
+ endif(COMPILE_CUDA_SM60)
+ if(COMPILE_CUDA_SM70)
+ set(GENCODE "${GENCODE} -gencode=arch=compute_70,code=sm_70")
+ endif(COMPILE_CUDA_SM70)
+
+ # install nccl in ${CMAKE_BINARY_DIR}/local similar to /usr/local linux installation
+ ExternalProject_Add(nccl_install
+ SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/nccl
+ BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/nccl
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND
+ $(MAKE) -f ${CMAKE_CURRENT_SOURCE_DIR}/nccl/Makefile src.build
+ BUILDDIR=${CMAKE_BINARY_DIR}/local CUDA_HOME=${CUDA_TOOLKIT_ROOT_DIR}
+ CUDA8_GENCODE=${GENCODE} CXX=${CMAKE_CXX_COMPILER}
+ INSTALL_COMMAND "")
+
+ set_target_properties(nccl PROPERTIES IMPORTED_LOCATION ${CMAKE_BINARY_DIR}/local/lib/libnccl_static.a)
+ add_dependencies(nccl nccl_install)
+ set(INSTALLS ${INSTALLS} nccl_install)
+
+ endif(USE_NCCL)
+endif(CUDA_FOUND)
+
+# @TODO: do the same for SentencePiece, Protobuf etc.
+# make clean will clean "${CMAKE_BINARY_DIR}/local"
+set_directory_properties(PROPERTY ADDITIONAL_MAKE_CLEAN_FILES ${CMAKE_BINARY_DIR}/local)
+
+add_custom_target(3rd_party_installs DEPENDS ${INSTALLS})
diff --git a/src/3rd_party/ExceptionWithCallStack.cpp b/src/3rd_party/ExceptionWithCallStack.cpp
index 3ce38a2a..3ce38a2a 100755..100644
--- a/src/3rd_party/ExceptionWithCallStack.cpp
+++ b/src/3rd_party/ExceptionWithCallStack.cpp
diff --git a/src/3rd_party/ExceptionWithCallStack.h b/src/3rd_party/ExceptionWithCallStack.h
index 5b961bd9..488b1277 100644
--- a/src/3rd_party/ExceptionWithCallStack.h
+++ b/src/3rd_party/ExceptionWithCallStack.h
@@ -5,6 +5,8 @@
// ExceptionWithCallStack.h - debug util functions
//
+#pragma once
+
#include <string>
namespace Microsoft { namespace MSR { namespace CNTK {
diff --git a/src/3rd_party/any_type.h b/src/3rd_party/any_type.h
index b397e053..b397e053 100755..100644
--- a/src/3rd_party/any_type.h
+++ b/src/3rd_party/any_type.h
diff --git a/src/3rd_party/avx_mathfun.h b/src/3rd_party/avx_mathfun.h
new file mode 100644
index 00000000..6840478c
--- /dev/null
+++ b/src/3rd_party/avx_mathfun.h
@@ -0,0 +1,726 @@
+/*
+ AVX implementation of sin, cos, sincos, exp and log
+
+ Based on "sse_mathfun.h", by Julien Pommier
+ http://gruntthepeon.free.fr/ssemath/
+
+ Copyright (C) 2012 Giovanni Garberoglio
+ Interdisciplinary Laboratory for Computational Science (LISC)
+ Fondazione Bruno Kessler and University of Trento
+ via Sommarive, 18
+ I-38123 Trento (Italy)
+
+ This software is provided 'as-is', without any express or implied
+ warranty. In no event will the authors be held liable for any damages
+ arising from the use of this software.
+
+ Permission is granted to anyone to use this software for any purpose,
+ including commercial applications, and to alter it and redistribute it
+ freely, subject to the following restrictions:
+
+ 1. The origin of this software must not be misrepresented; you must not
+ claim that you wrote the original software. If you use this software
+ in a product, an acknowledgment in the product documentation would be
+ appreciated but is not required.
+ 2. Altered source versions must be plainly marked as such, and must not be
+ misrepresented as being the original software.
+ 3. This notice may not be removed or altered from any source distribution.
+
+ (this is the zlib license)
+*/
+
+#include <immintrin.h>
+
+/* yes I know, the top of this file is quite ugly */
+#ifdef _MSC_VER
+# define ALIGN32_BEG __declspec(align(32))
+# define ALIGN32_END
+#else /* gcc or icc */
+# define ALIGN32_BEG
+# define ALIGN32_END __attribute__((aligned(32)))
+#endif
+
+/* __m128 is ugly to write */
+typedef __m256 v8sf; // vector of 8 float (avx)
+typedef __m256i v8si; // vector of 8 int (avx)
+typedef __m128i v4si; // vector of 8 int (avx)
+
+#define _PI32AVX_CONST(Name, Val) \
+ static const ALIGN32_BEG int _pi32avx_##Name[4] ALIGN32_END = { Val, Val, Val, Val }
+
+_PI32AVX_CONST(1, 1);
+_PI32AVX_CONST(inv1, ~1);
+_PI32AVX_CONST(2, 2);
+_PI32AVX_CONST(4, 4);
+
+
+/* declare some AVX constants -- why can't I figure a better way to do that? */
+#define _PS256_CONST(Name, Val) \
+ static const ALIGN32_BEG float _ps256_##Name[8] ALIGN32_END = { (float)Val, (float)Val, (float)Val, (float)Val, (float)Val, (float)Val, (float)Val, (float)Val }
+#define _PI32_CONST256(Name, Val) \
+ static const ALIGN32_BEG int _pi32_256_##Name[8] ALIGN32_END = { Val, Val, Val, Val, Val, Val, Val, Val }
+#define _PS256_CONST_TYPE(Name, Type, Val) \
+ static const ALIGN32_BEG Type _ps256_##Name[8] ALIGN32_END = { Val, Val, Val, Val, Val, Val, Val, Val }
+
+_PS256_CONST(1 , 1.0f);
+_PS256_CONST(0p5, 0.5f);
+/* the smallest non denormalized float number */
+_PS256_CONST_TYPE(min_norm_pos, int, 0x00800000);
+_PS256_CONST_TYPE(mant_mask, int, 0x7f800000);
+_PS256_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
+
+_PS256_CONST_TYPE(sign_mask, int, (int)0x80000000);
+_PS256_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
+
+_PI32_CONST256(0, 0);
+_PI32_CONST256(1, 1);
+_PI32_CONST256(inv1, ~1);
+_PI32_CONST256(2, 2);
+_PI32_CONST256(4, 4);
+_PI32_CONST256(0x7f, 0x7f);
+
+_PS256_CONST(cephes_SQRTHF, 0.707106781186547524);
+_PS256_CONST(cephes_log_p0, 7.0376836292E-2);
+_PS256_CONST(cephes_log_p1, - 1.1514610310E-1);
+_PS256_CONST(cephes_log_p2, 1.1676998740E-1);
+_PS256_CONST(cephes_log_p3, - 1.2420140846E-1);
+_PS256_CONST(cephes_log_p4, + 1.4249322787E-1);
+_PS256_CONST(cephes_log_p5, - 1.6668057665E-1);
+_PS256_CONST(cephes_log_p6, + 2.0000714765E-1);
+_PS256_CONST(cephes_log_p7, - 2.4999993993E-1);
+_PS256_CONST(cephes_log_p8, + 3.3333331174E-1);
+_PS256_CONST(cephes_log_q1, -2.12194440e-4);
+_PS256_CONST(cephes_log_q2, 0.693359375);
+
+#ifndef __AVX2__
+
+typedef union imm_xmm_union {
+ v8si imm;
+ v4si xmm[2];
+} imm_xmm_union;
+
+#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) { \
+ ALIGN32_BEG imm_xmm_union u ALIGN32_END; \
+ u.imm = imm_; \
+ xmm0_ = u.xmm[0]; \
+ xmm1_ = u.xmm[1]; \
+}
+
+#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) { \
+ ALIGN32_BEG imm_xmm_union u ALIGN32_END; \
+ u.xmm[0]=xmm0_; u.xmm[1]=xmm1_; imm_ = u.imm; \
+}
+
+
+#define AVX2_BITOP_USING_SSE2(fn) \
+static inline v8si avx2_mm256_##fn(v8si x, int a) \
+{ \
+ /* use SSE2 instruction to perform the bitop AVX2 */ \
+ v4si x1, x2; \
+ v8si ret; \
+ COPY_IMM_TO_XMM(x, x1, x2); \
+ x1 = _mm_##fn(x1,a); \
+ x2 = _mm_##fn(x2,a); \
+ COPY_XMM_TO_IMM(x1, x2, ret); \
+ return(ret); \
+}
+
+//#warning "Using SSE2 to perform AVX2 bitshift ops"
+AVX2_BITOP_USING_SSE2(slli_epi32)
+AVX2_BITOP_USING_SSE2(srli_epi32)
+
+#define AVX2_INTOP_USING_SSE2(fn) \
+static inline v8si avx2_mm256_##fn(v8si x, v8si y) \
+{ \
+ /* use SSE2 instructions to perform the AVX2 integer operation */ \
+ v4si x1, x2; \
+ v4si y1, y2; \
+ v8si ret; \
+ COPY_IMM_TO_XMM(x, x1, x2); \
+ COPY_IMM_TO_XMM(y, y1, y2); \
+ x1 = _mm_##fn(x1,y1); \
+ x2 = _mm_##fn(x2,y2); \
+ COPY_XMM_TO_IMM(x1, x2, ret); \
+ return(ret); \
+}
+
+//#warning "Using SSE2 to perform AVX2 integer ops"
+AVX2_INTOP_USING_SSE2(and_si128)
+AVX2_INTOP_USING_SSE2(andnot_si128)
+AVX2_INTOP_USING_SSE2(cmpeq_epi32)
+AVX2_INTOP_USING_SSE2(sub_epi32)
+AVX2_INTOP_USING_SSE2(add_epi32)
+#define avx2_mm256_and_si256 avx2_mm256_and_si128
+#define avx2_mm256_andnot_si256 avx2_mm256_andnot_si128
+#else
+#define avx2_mm256_slli_epi32 _mm256_slli_epi32
+#define avx2_mm256_srli_epi32 _mm256_srli_epi32
+#define avx2_mm256_and_si256 _mm256_and_si256
+#define avx2_mm256_andnot_si256 _mm256_andnot_si256
+#define avx2_mm256_cmpeq_epi32 _mm256_cmpeq_epi32
+#define avx2_mm256_sub_epi32 _mm256_sub_epi32
+#define avx2_mm256_add_epi32 _mm256_add_epi32
+#endif /* __AVX2__ */
+
+
+/* natural logarithm computed for 8 simultaneous float
+ return NaN for x <= 0
+*/
+static inline v8sf log256_ps(v8sf x) {
+ v8si imm0;
+ v8sf one = *(v8sf*)_ps256_1;
+
+ //v8sf invalid_mask = _mm256_cmple_ps(x, _mm256_setzero_ps());
+ v8sf invalid_mask = _mm256_cmp_ps(x, _mm256_setzero_ps(), _CMP_LE_OS);
+
+ x = _mm256_max_ps(x, *(v8sf*)_ps256_min_norm_pos); /* cut off denormalized stuff */
+
+ // can be done with AVX2
+ imm0 = avx2_mm256_srli_epi32(_mm256_castps_si256(x), 23);
+
+ /* keep only the fractional part */
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_mant_mask);
+ x = _mm256_or_ps(x, *(v8sf*)_ps256_0p5);
+
+ // this is again another AVX2 instruction
+ imm0 = avx2_mm256_sub_epi32(imm0, *(v8si*)_pi32_256_0x7f);
+ v8sf e = _mm256_cvtepi32_ps(imm0);
+
+ e = _mm256_add_ps(e, one);
+
+ /* part2:
+ if( x < SQRTHF ) {
+ e -= 1;
+ x = x + x - 1.0;
+ } else { x = x - 1.0; }
+ */
+ //v8sf mask = _mm256_cmplt_ps(x, *(v8sf*)_ps256_cephes_SQRTHF);
+ v8sf mask = _mm256_cmp_ps(x, *(v8sf*)_ps256_cephes_SQRTHF, _CMP_LT_OS);
+ v8sf tmp = _mm256_and_ps(x, mask);
+ x = _mm256_sub_ps(x, one);
+ e = _mm256_sub_ps(e, _mm256_and_ps(one, mask));
+ x = _mm256_add_ps(x, tmp);
+
+ v8sf z = _mm256_mul_ps(x,x);
+
+ v8sf y = *(v8sf*)_ps256_cephes_log_p0;
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p1);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p2);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p3);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p4);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p5);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p6);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p7);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_log_p8);
+ y = _mm256_mul_ps(y, x);
+
+ y = _mm256_mul_ps(y, z);
+
+ tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q1);
+ y = _mm256_add_ps(y, tmp);
+
+
+ tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
+ y = _mm256_sub_ps(y, tmp);
+
+ tmp = _mm256_mul_ps(e, *(v8sf*)_ps256_cephes_log_q2);
+ x = _mm256_add_ps(x, y);
+ x = _mm256_add_ps(x, tmp);
+ x = _mm256_or_ps(x, invalid_mask); // negative arg will be NAN
+ return x;
+}
+
+_PS256_CONST(exp_hi, 88.3762626647949f);
+_PS256_CONST(exp_lo, -88.3762626647949f);
+
+_PS256_CONST(cephes_LOG2EF, 1.44269504088896341);
+_PS256_CONST(cephes_exp_C1, 0.693359375);
+_PS256_CONST(cephes_exp_C2, -2.12194440e-4);
+
+_PS256_CONST(cephes_exp_p0, 1.9875691500E-4);
+_PS256_CONST(cephes_exp_p1, 1.3981999507E-3);
+_PS256_CONST(cephes_exp_p2, 8.3334519073E-3);
+_PS256_CONST(cephes_exp_p3, 4.1665795894E-2);
+_PS256_CONST(cephes_exp_p4, 1.6666665459E-1);
+_PS256_CONST(cephes_exp_p5, 5.0000001201E-1);
+
+static inline v8sf exp256_ps(v8sf x) {
+ v8sf tmp = _mm256_setzero_ps(), fx;
+ v8si imm0;
+ v8sf one = *(v8sf*)_ps256_1;
+
+ x = _mm256_min_ps(x, *(v8sf*)_ps256_exp_hi);
+ x = _mm256_max_ps(x, *(v8sf*)_ps256_exp_lo);
+
+ /* express exp(x) as exp(g + n*log(2)) */
+ fx = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_LOG2EF);
+ fx = _mm256_add_ps(fx, *(v8sf*)_ps256_0p5);
+
+ /* how to perform a floorf with SSE: just below */
+ //imm0 = _mm256_cvttps_epi32(fx);
+ //tmp = _mm256_cvtepi32_ps(imm0);
+
+ tmp = _mm256_floor_ps(fx);
+
+ /* if greater, substract 1 */
+ //v8sf mask = _mm256_cmpgt_ps(tmp, fx);
+ v8sf mask = _mm256_cmp_ps(tmp, fx, _CMP_GT_OS);
+ mask = _mm256_and_ps(mask, one);
+ fx = _mm256_sub_ps(tmp, mask);
+
+ tmp = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C1);
+ v8sf z = _mm256_mul_ps(fx, *(v8sf*)_ps256_cephes_exp_C2);
+ x = _mm256_sub_ps(x, tmp);
+ x = _mm256_sub_ps(x, z);
+
+ z = _mm256_mul_ps(x,x);
+
+ v8sf y = *(v8sf*)_ps256_cephes_exp_p0;
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p1);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p2);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p3);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p4);
+ y = _mm256_mul_ps(y, x);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_cephes_exp_p5);
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_add_ps(y, x);
+ y = _mm256_add_ps(y, one);
+
+ /* build 2^n */
+ imm0 = _mm256_cvttps_epi32(fx);
+ // another two AVX2 instructions
+ imm0 = avx2_mm256_add_epi32(imm0, *(v8si*)_pi32_256_0x7f);
+ imm0 = avx2_mm256_slli_epi32(imm0, 23);
+ v8sf pow2n = _mm256_castsi256_ps(imm0);
+ y = _mm256_mul_ps(y, pow2n);
+ return y;
+}
+
+_PS256_CONST(minus_cephes_DP1, -0.78515625);
+_PS256_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
+_PS256_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
+_PS256_CONST(sincof_p0, -1.9515295891E-4);
+_PS256_CONST(sincof_p1, 8.3321608736E-3);
+_PS256_CONST(sincof_p2, -1.6666654611E-1);
+_PS256_CONST(coscof_p0, 2.443315711809948E-005);
+_PS256_CONST(coscof_p1, -1.388731625493765E-003);
+_PS256_CONST(coscof_p2, 4.166664568298827E-002);
+_PS256_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
+
+
+/* evaluation of 8 sines at onces using AVX intrisics
+
+ The code is the exact rewriting of the cephes sinf function.
+ Precision is excellent as long as x < 8192 (I did not bother to
+ take into account the special handling they have for greater values
+ -- it does not return garbage for arguments over 8192, though, but
+ the extra precision is missing).
+
+ Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
+ surprising but correct result.
+
+*/
+static inline v8sf sin256_ps(v8sf x) { // any x
+ v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y;
+ v8si imm0, imm2;
+
+#ifndef __AVX2__
+ v4si imm0_1, imm0_2;
+ v4si imm2_1, imm2_2;
+#endif
+
+ sign_bit = x;
+ /* take the absolute value */
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
+ /* extract the sign bit (upper one) */
+ sign_bit = _mm256_and_ps(sign_bit, *(v8sf*)_ps256_sign_mask);
+
+ /* scale by 4/Pi */
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
+
+ /*
+ Here we start a series of integer operations, which are in the
+ realm of AVX2.
+ If we don't have AVX, let's perform them using SSE2 directives
+ */
+
+#ifdef __AVX2__
+ /* store the integer part of y in mm0 */
+ imm2 = _mm256_cvttps_epi32(y);
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ // another two AVX2 instruction
+ imm2 = avx2_mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
+ imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
+ y = _mm256_cvtepi32_ps(imm2);
+
+ /* get the swap sign flag */
+ imm0 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
+ imm0 = avx2_mm256_slli_epi32(imm0, 29);
+ /* get the polynom selection mask
+ there is one polynom for 0 <= x <= Pi/4
+ and another one for Pi/4<x<=Pi/2
+
+ Both branches will be computed.
+ */
+ imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
+ imm2 = avx2_mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0);
+#else
+ /* we use SSE2 routines to perform the integer ops */
+ COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
+
+ imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
+ imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
+
+ imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
+ imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
+
+ COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
+ y = _mm256_cvtepi32_ps(imm2);
+
+ imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
+ imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);
+
+ imm0_1 = _mm_slli_epi32(imm0_1, 29);
+ imm0_2 = _mm_slli_epi32(imm0_2, 29);
+
+ COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
+
+ imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
+ imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
+
+ imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
+ imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
+
+ COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
+#endif
+
+ v8sf swap_sign_bit = _mm256_castsi256_ps(imm0);
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
+ sign_bit = _mm256_xor_ps(sign_bit, swap_sign_bit);
+
+ /* The magic pass: "Extended precision modular arithmetic"
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
+ xmm1 = _mm256_mul_ps(y, xmm1);
+ xmm2 = _mm256_mul_ps(y, xmm2);
+ xmm3 = _mm256_mul_ps(y, xmm3);
+ x = _mm256_add_ps(x, xmm1);
+ x = _mm256_add_ps(x, xmm2);
+ x = _mm256_add_ps(x, xmm3);
+
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
+ y = *(v8sf*)_ps256_coscof_p0;
+ v8sf z = _mm256_mul_ps(x,x);
+
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_mul_ps(y, z);
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
+ y = _mm256_sub_ps(y, tmp);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
+
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
+
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_mul_ps(y2, x);
+ y2 = _mm256_add_ps(y2, x);
+
+ /* select the correct result from the two polynoms */
+ xmm3 = poly_mask;
+ y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
+ y = _mm256_andnot_ps(xmm3, y);
+ y = _mm256_add_ps(y,y2);
+ /* update the sign */
+ y = _mm256_xor_ps(y, sign_bit);
+
+ return y;
+}
+
+/* almost the same as sin_ps */
+static inline v8sf cos256_ps(v8sf x) { // any x
+ v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y;
+ v8si imm0, imm2;
+
+#ifndef __AVX2__
+ v4si imm0_1, imm0_2;
+ v4si imm2_1, imm2_2;
+#endif
+
+ /* take the absolute value */
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
+
+ /* scale by 4/Pi */
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
+
+#ifdef __AVX2__
+ /* store the integer part of y in mm0 */
+ imm2 = _mm256_cvttps_epi32(y);
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ imm2 = avx2_mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
+ imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
+ y = _mm256_cvtepi32_ps(imm2);
+ imm2 = avx2_mm256_sub_epi32(imm2, *(v8si*)_pi32_256_2);
+
+ /* get the swap sign flag */
+ imm0 = avx2_mm256_andnot_si256(imm2, *(v8si*)_pi32_256_4);
+ imm0 = avx2_mm256_slli_epi32(imm0, 29);
+ /* get the polynom selection mask */
+ imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
+ imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
+#else
+
+ /* we use SSE2 routines to perform the integer ops */
+ COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
+
+ imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
+ imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
+
+ imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
+ imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
+
+ COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
+ y = _mm256_cvtepi32_ps(imm2);
+
+ imm2_1 = _mm_sub_epi32(imm2_1, *(v4si*)_pi32avx_2);
+ imm2_2 = _mm_sub_epi32(imm2_2, *(v4si*)_pi32avx_2);
+
+ imm0_1 = _mm_andnot_si128(imm2_1, *(v4si*)_pi32avx_4);
+ imm0_2 = _mm_andnot_si128(imm2_2, *(v4si*)_pi32avx_4);
+
+ imm0_1 = _mm_slli_epi32(imm0_1, 29);
+ imm0_2 = _mm_slli_epi32(imm0_2, 29);
+
+ COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
+
+ imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
+ imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
+
+ imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
+ imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
+
+ COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
+#endif
+
+ v8sf sign_bit = _mm256_castsi256_ps(imm0);
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
+
+ /* The magic pass: "Extended precision modular arithmetic"
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
+ xmm1 = _mm256_mul_ps(y, xmm1);
+ xmm2 = _mm256_mul_ps(y, xmm2);
+ xmm3 = _mm256_mul_ps(y, xmm3);
+ x = _mm256_add_ps(x, xmm1);
+ x = _mm256_add_ps(x, xmm2);
+ x = _mm256_add_ps(x, xmm3);
+
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
+ y = *(v8sf*)_ps256_coscof_p0;
+ v8sf z = _mm256_mul_ps(x,x);
+
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_mul_ps(y, z);
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
+ y = _mm256_sub_ps(y, tmp);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
+
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
+
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_mul_ps(y2, x);
+ y2 = _mm256_add_ps(y2, x);
+
+ /* select the correct result from the two polynoms */
+ xmm3 = poly_mask;
+ y2 = _mm256_and_ps(xmm3, y2); //, xmm3);
+ y = _mm256_andnot_ps(xmm3, y);
+ y = _mm256_add_ps(y,y2);
+ /* update the sign */
+ y = _mm256_xor_ps(y, sign_bit);
+
+ return y;
+}
+
+/* since sin256_ps and cos256_ps are almost identical, sincos256_ps could replace both of them..
+ it is almost as fast, and gives you a free cosine with your sine */
+static inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) {
+
+ v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y;
+ v8si imm0, imm2, imm4;
+
+#ifndef __AVX2__
+ v4si imm0_1, imm0_2;
+ v4si imm2_1, imm2_2;
+ v4si imm4_1, imm4_2;
+#endif
+
+ sign_bit_sin = x;
+ /* take the absolute value */
+ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask);
+ /* extract the sign bit (upper one) */
+ sign_bit_sin = _mm256_and_ps(sign_bit_sin, *(v8sf*)_ps256_sign_mask);
+
+ /* scale by 4/Pi */
+ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI);
+
+#ifdef __AVX2__
+ /* store the integer part of y in imm2 */
+ imm2 = _mm256_cvttps_epi32(y);
+
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ imm2 = avx2_mm256_add_epi32(imm2, *(v8si*)_pi32_256_1);
+ imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_inv1);
+
+ y = _mm256_cvtepi32_ps(imm2);
+ imm4 = imm2;
+
+ /* get the swap sign flag for the sine */
+ imm0 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_4);
+ imm0 = avx2_mm256_slli_epi32(imm0, 29);
+ //v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
+
+ /* get the polynom selection mask for the sine*/
+ imm2 = avx2_mm256_and_si256(imm2, *(v8si*)_pi32_256_2);
+ imm2 = avx2_mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0);
+ //v8sf poly_mask = _mm256_castsi256_ps(imm2);
+#else
+ /* we use SSE2 routines to perform the integer ops */
+ COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2);
+
+ imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1);
+ imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1);
+
+ imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1);
+ imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1);
+
+ COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2);
+ y = _mm256_cvtepi32_ps(imm2);
+
+ imm4_1 = imm2_1;
+ imm4_2 = imm2_2;
+
+ imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4);
+ imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4);
+
+ imm0_1 = _mm_slli_epi32(imm0_1, 29);
+ imm0_2 = _mm_slli_epi32(imm0_2, 29);
+
+ COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0);
+
+ imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2);
+ imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2);
+
+ imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128());
+ imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128());
+
+ COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2);
+#endif
+ v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0);
+ v8sf poly_mask = _mm256_castsi256_ps(imm2);
+
+ /* The magic pass: "Extended precision modular arithmetic"
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
+ xmm1 = *(v8sf*)_ps256_minus_cephes_DP1;
+ xmm2 = *(v8sf*)_ps256_minus_cephes_DP2;
+ xmm3 = *(v8sf*)_ps256_minus_cephes_DP3;
+ xmm1 = _mm256_mul_ps(y, xmm1);
+ xmm2 = _mm256_mul_ps(y, xmm2);
+ xmm3 = _mm256_mul_ps(y, xmm3);
+ x = _mm256_add_ps(x, xmm1);
+ x = _mm256_add_ps(x, xmm2);
+ x = _mm256_add_ps(x, xmm3);
+
+#ifdef __AVX2__
+ imm4 = avx2_mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2);
+ imm4 = avx2_mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4);
+ imm4 = avx2_mm256_slli_epi32(imm4, 29);
+#else
+ imm4_1 = _mm_sub_epi32(imm4_1, *(v4si*)_pi32avx_2);
+ imm4_2 = _mm_sub_epi32(imm4_2, *(v4si*)_pi32avx_2);
+
+ imm4_1 = _mm_andnot_si128(imm4_1, *(v4si*)_pi32avx_4);
+ imm4_2 = _mm_andnot_si128(imm4_2, *(v4si*)_pi32avx_4);
+
+ imm4_1 = _mm_slli_epi32(imm4_1, 29);
+ imm4_2 = _mm_slli_epi32(imm4_2, 29);
+
+ COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4);
+#endif
+
+ v8sf sign_bit_cos = _mm256_castsi256_ps(imm4);
+
+ sign_bit_sin = _mm256_xor_ps(sign_bit_sin, swap_sign_bit_sin);
+
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
+ v8sf z = _mm256_mul_ps(x,x);
+ y = *(v8sf*)_ps256_coscof_p0;
+
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p1);
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_coscof_p2);
+ y = _mm256_mul_ps(y, z);
+ y = _mm256_mul_ps(y, z);
+ v8sf tmp = _mm256_mul_ps(z, *(v8sf*)_ps256_0p5);
+ y = _mm256_sub_ps(y, tmp);
+ y = _mm256_add_ps(y, *(v8sf*)_ps256_1);
+
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
+
+ v8sf y2 = *(v8sf*)_ps256_sincof_p0;
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p1);
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_add_ps(y2, *(v8sf*)_ps256_sincof_p2);
+ y2 = _mm256_mul_ps(y2, z);
+ y2 = _mm256_mul_ps(y2, x);
+ y2 = _mm256_add_ps(y2, x);
+
+ /* select the correct result from the two polynoms */
+ xmm3 = poly_mask;
+ v8sf ysin2 = _mm256_and_ps(xmm3, y2);
+ v8sf ysin1 = _mm256_andnot_ps(xmm3, y);
+ y2 = _mm256_sub_ps(y2,ysin2);
+ y = _mm256_sub_ps(y, ysin1);
+
+ xmm1 = _mm256_add_ps(ysin1,ysin2);
+ xmm2 = _mm256_add_ps(y,y2);
+
+ /* update the sign */
+ *s = _mm256_xor_ps(xmm1, sign_bit_sin);
+ *c = _mm256_xor_ps(xmm2, sign_bit_cos);
+}
+
diff --git a/src/3rd_party/catch.hpp b/src/3rd_party/catch.hpp
index 621f6e66..5d104bc4 100755..100644
--- a/src/3rd_party/catch.hpp
+++ b/src/3rd_party/catch.hpp
@@ -1,17 +1,21 @@
/*
- * Catch v1.9.4
- * Generated: 2017-05-16 13:51:55.506519
+ * Catch v2.10.1
+ * Generated: 2019-10-20 20:52:21.372334
* ----------------------------------------------------------
* This file has been merged from multiple headers. Please don't edit it directly
- * Copyright (c) 2012 Two Blue Cubes Ltd. All rights reserved.
+ * Copyright (c) 2019 Two Blue Cubes Ltd. All rights reserved.
*
* Distributed under the Boost Software License, Version 1.0. (See accompanying
* file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
*/
#ifndef TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED
#define TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED
+// start catch.hpp
-#define TWOBLUECUBES_CATCH_HPP_INCLUDED
+
+#define CATCH_VERSION_MAJOR 2
+#define CATCH_VERSION_MINOR 10
+#define CATCH_VERSION_PATCH 1
#ifdef __clang__
# pragma clang system_header
@@ -19,36 +23,66 @@
# pragma GCC system_header
#endif
-// #included from: internal/catch_suppress_warnings.h
+// start catch_suppress_warnings.h
#ifdef __clang__
# ifdef __ICC // icpc defines the __clang__ macro
# pragma warning(push)
# pragma warning(disable: 161 1682)
# else // __ICC
-# pragma clang diagnostic ignored "-Wglobal-constructors"
-# pragma clang diagnostic ignored "-Wvariadic-macros"
-# pragma clang diagnostic ignored "-Wc99-extensions"
-# pragma clang diagnostic ignored "-Wunused-variable"
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wpadded"
-# pragma clang diagnostic ignored "-Wc++98-compat"
-# pragma clang diagnostic ignored "-Wc++98-compat-pedantic"
# pragma clang diagnostic ignored "-Wswitch-enum"
# pragma clang diagnostic ignored "-Wcovered-switch-default"
# endif
#elif defined __GNUC__
-# pragma GCC diagnostic ignored "-Wvariadic-macros"
-# pragma GCC diagnostic ignored "-Wunused-variable"
-# pragma GCC diagnostic ignored "-Wparentheses"
+ // Because REQUIREs trigger GCC's -Wparentheses, and because still
+ // supported version of g++ have only buggy support for _Pragmas,
+ // Wparentheses have to be suppressed globally.
+# pragma GCC diagnostic ignored "-Wparentheses" // See #674 for details
# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wunused-variable"
# pragma GCC diagnostic ignored "-Wpadded"
#endif
+// end catch_suppress_warnings.h
#if defined(CATCH_CONFIG_MAIN) || defined(CATCH_CONFIG_RUNNER)
# define CATCH_IMPL
+# define CATCH_CONFIG_ALL_PARTS
+#endif
+
+// In the impl file, we want to have access to all parts of the headers
+// Can also be used to sanely support PCHs
+#if defined(CATCH_CONFIG_ALL_PARTS)
+# define CATCH_CONFIG_EXTERNAL_INTERFACES
+# if defined(CATCH_CONFIG_DISABLE_MATCHERS)
+# undef CATCH_CONFIG_DISABLE_MATCHERS
+# endif
+# if !defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER)
+# define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER
+# endif
+#endif
+
+#if !defined(CATCH_CONFIG_IMPL_ONLY)
+// start catch_platform.h
+
+#ifdef __APPLE__
+# include <TargetConditionals.h>
+# if TARGET_OS_OSX == 1
+# define CATCH_PLATFORM_MAC
+# elif TARGET_OS_IPHONE == 1
+# define CATCH_PLATFORM_IPHONE
+# endif
+
+#elif defined(linux) || defined(__linux) || defined(__linux__)
+# define CATCH_PLATFORM_LINUX
+
+#elif defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) || defined(__MINGW32__)
+# define CATCH_PLATFORM_WINDOWS
#endif
+// end catch_platform.h
+
#ifdef CATCH_IMPL
# ifndef CLARA_CONFIG_MAIN
# define CLARA_CONFIG_MAIN_NOT_DEFINED
@@ -56,74 +90,59 @@
# endif
#endif
-// #included from: internal/catch_notimplemented_exception.h
-#define TWOBLUECUBES_CATCH_NOTIMPLEMENTED_EXCEPTION_H_INCLUDED
+// start catch_user_interfaces.h
+
+namespace Catch {
+ unsigned int rngSeed();
+}
-// #included from: catch_common.h
-#define TWOBLUECUBES_CATCH_COMMON_H_INCLUDED
+// end catch_user_interfaces.h
+// start catch_tag_alias_autoregistrar.h
-// #included from: catch_compiler_capabilities.h
-#define TWOBLUECUBES_CATCH_COMPILER_CAPABILITIES_HPP_INCLUDED
+// start catch_common.h
-// Detect a number of compiler features - mostly C++11/14 conformance - by compiler
+// start catch_compiler_capabilities.h
+
+// Detect a number of compiler features - by compiler
// The following features are defined:
//
-// CATCH_CONFIG_CPP11_NULLPTR : is nullptr supported?
-// CATCH_CONFIG_CPP11_NOEXCEPT : is noexcept supported?
-// CATCH_CONFIG_CPP11_GENERATED_METHODS : The delete and default keywords for compiler generated methods
-// CATCH_CONFIG_CPP11_IS_ENUM : std::is_enum is supported?
-// CATCH_CONFIG_CPP11_TUPLE : std::tuple is supported
-// CATCH_CONFIG_CPP11_LONG_LONG : is long long supported?
-// CATCH_CONFIG_CPP11_OVERRIDE : is override supported?
-// CATCH_CONFIG_CPP11_UNIQUE_PTR : is unique_ptr supported (otherwise use auto_ptr)
-// CATCH_CONFIG_CPP11_SHUFFLE : is std::shuffle supported?
-// CATCH_CONFIG_CPP11_TYPE_TRAITS : are type_traits and enable_if supported?
-
-// CATCH_CONFIG_CPP11_OR_GREATER : Is C++11 supported?
-
-// CATCH_CONFIG_VARIADIC_MACROS : are variadic macros supported?
// CATCH_CONFIG_COUNTER : is the __COUNTER__ macro supported?
// CATCH_CONFIG_WINDOWS_SEH : is Windows SEH supported?
// CATCH_CONFIG_POSIX_SIGNALS : are POSIX signals supported?
+// CATCH_CONFIG_DISABLE_EXCEPTIONS : Are exceptions enabled?
// ****************
// Note to maintainers: if new toggles are added please document them
// in configuration.md, too
// ****************
// In general each macro has a _NO_<feature name> form
-// (e.g. CATCH_CONFIG_CPP11_NO_NULLPTR) which disables the feature.
+// (e.g. CATCH_CONFIG_NO_POSIX_SIGNALS) which disables the feature.
// Many features, at point of detection, define an _INTERNAL_ macro, so they
// can be combined, en-mass, with the _NO_ forms later.
-// All the C++11 features can be disabled with CATCH_CONFIG_NO_CPP11
-
#ifdef __cplusplus
-# if __cplusplus >= 201103L
-# define CATCH_CPP11_OR_GREATER
+# if (__cplusplus >= 201402L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 201402L)
+# define CATCH_CPP14_OR_GREATER
# endif
-# if __cplusplus >= 201402L
-# define CATCH_CPP14_OR_GREATER
+# if (__cplusplus >= 201703L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
+# define CATCH_CPP17_OR_GREATER
# endif
#endif
-#ifdef __clang__
-
-# if __has_feature(cxx_nullptr)
-# define CATCH_INTERNAL_CONFIG_CPP11_NULLPTR
-# endif
+#if defined(CATCH_CPP17_OR_GREATER)
+# define CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS
+#endif
-# if __has_feature(cxx_noexcept)
-# define CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT
-# endif
+#ifdef __clang__
-# if defined(CATCH_CPP11_OR_GREATER)
-# define CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
+# define CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
_Pragma( "clang diagnostic push" ) \
- _Pragma( "clang diagnostic ignored \"-Wexit-time-destructors\"" )
-# define CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \
+ _Pragma( "clang diagnostic ignored \"-Wexit-time-destructors\"" ) \
+ _Pragma( "clang diagnostic ignored \"-Wglobal-constructors\"")
+# define CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
_Pragma( "clang diagnostic pop" )
# define CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \
@@ -131,238 +150,298 @@
_Pragma( "clang diagnostic ignored \"-Wparentheses\"" )
# define CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \
_Pragma( "clang diagnostic pop" )
-# endif
-#endif // __clang__
+# define CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \
+ _Pragma( "clang diagnostic push" ) \
+ _Pragma( "clang diagnostic ignored \"-Wunused-variable\"" )
+# define CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS \
+ _Pragma( "clang diagnostic pop" )
-////////////////////////////////////////////////////////////////////////////////
-// We know some environments not to support full POSIX signals
-#if defined(__CYGWIN__) || defined(__QNX__)
+# define CATCH_INTERNAL_SUPPRESS_ZERO_VARIADIC_WARNINGS \
+ _Pragma( "clang diagnostic push" ) \
+ _Pragma( "clang diagnostic ignored \"-Wgnu-zero-variadic-macro-arguments\"" )
+# define CATCH_INTERNAL_UNSUPPRESS_ZERO_VARIADIC_WARNINGS \
+ _Pragma( "clang diagnostic pop" )
-# if !defined(CATCH_CONFIG_POSIX_SIGNALS)
-# define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS
-# endif
+# define CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ _Pragma( "clang diagnostic push" ) \
+ _Pragma( "clang diagnostic ignored \"-Wunused-template\"" )
+# define CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ _Pragma( "clang diagnostic pop" )
+#endif // __clang__
+////////////////////////////////////////////////////////////////////////////////
+// Assume that non-Windows platforms support posix signals by default
+#if !defined(CATCH_PLATFORM_WINDOWS)
+ #define CATCH_INTERNAL_CONFIG_POSIX_SIGNALS
#endif
////////////////////////////////////////////////////////////////////////////////
-// Cygwin
-#ifdef __CYGWIN__
-
-// Required for some versions of Cygwin to declare gettimeofday
-// see: http://stackoverflow.com/questions/36901803/gettimeofday-not-declared-in-this-scope-cygwin
-# define _BSD_SOURCE
+// We know some environments not to support full POSIX signals
+#if defined(__CYGWIN__) || defined(__QNX__) || defined(__EMSCRIPTEN__) || defined(__DJGPP__)
+ #define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS
+#endif
-#endif // __CYGWIN__
+#ifdef __OS400__
+# define CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS
+# define CATCH_CONFIG_COLOUR_NONE
+#endif
////////////////////////////////////////////////////////////////////////////////
-// Borland
-#ifdef __BORLANDC__
-
-#endif // __BORLANDC__
+// Android somehow still does not support std::to_string
+#if defined(__ANDROID__)
+# define CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING
+# define CATCH_INTERNAL_CONFIG_ANDROID_LOGWRITE
+#endif
////////////////////////////////////////////////////////////////////////////////
-// EDG
-#ifdef __EDG_VERSION__
-
-#endif // __EDG_VERSION__
+// Not all Windows environments support SEH properly
+#if defined(__MINGW32__)
+# define CATCH_INTERNAL_CONFIG_NO_WINDOWS_SEH
+#endif
////////////////////////////////////////////////////////////////////////////////
-// Digital Mars
-#ifdef __DMC__
-
-#endif // __DMC__
+// PS4
+#if defined(__ORBIS__)
+# define CATCH_INTERNAL_CONFIG_NO_NEW_CAPTURE
+#endif
////////////////////////////////////////////////////////////////////////////////
-// GCC
-#ifdef __GNUC__
+// Cygwin
+#ifdef __CYGWIN__
-# if __GNUC__ == 4 && __GNUC_MINOR__ >= 6 && defined(__GXX_EXPERIMENTAL_CXX0X__)
-# define CATCH_INTERNAL_CONFIG_CPP11_NULLPTR
-# endif
+// Required for some versions of Cygwin to declare gettimeofday
+// see: http://stackoverflow.com/questions/36901803/gettimeofday-not-declared-in-this-scope-cygwin
+# define _BSD_SOURCE
+// some versions of cygwin (most) do not support std::to_string. Use the libstd check.
+// https://gcc.gnu.org/onlinedocs/gcc-4.8.2/libstdc++/api/a01053_source.html line 2812-2813
+# if !((__cplusplus >= 201103L) && defined(_GLIBCXX_USE_C99) \
+ && !defined(_GLIBCXX_HAVE_BROKEN_VSWPRINTF))
-// - otherwise more recent versions define __cplusplus >= 201103L
-// and will get picked up below
+# define CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING
-#endif // __GNUC__
+# endif
+#endif // __CYGWIN__
////////////////////////////////////////////////////////////////////////////////
// Visual C++
#ifdef _MSC_VER
-#define CATCH_INTERNAL_CONFIG_WINDOWS_SEH
+# if _MSC_VER >= 1900 // Visual Studio 2015 or newer
+# define CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS
+# endif
-#if (_MSC_VER >= 1600)
-# define CATCH_INTERNAL_CONFIG_CPP11_NULLPTR
-# define CATCH_INTERNAL_CONFIG_CPP11_UNIQUE_PTR
-#endif
+// Universal Windows platform does not support SEH
+// Or console colours (or console at all...)
+# if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP)
+# define CATCH_CONFIG_COLOUR_NONE
+# else
+# define CATCH_INTERNAL_CONFIG_WINDOWS_SEH
+# endif
-#if (_MSC_VER >= 1900 ) // (VC++ 13 (VS2015))
-#define CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT
-#define CATCH_INTERNAL_CONFIG_CPP11_GENERATED_METHODS
-#define CATCH_INTERNAL_CONFIG_CPP11_SHUFFLE
-#define CATCH_INTERNAL_CONFIG_CPP11_TYPE_TRAITS
-#endif
+// MSVC traditional preprocessor needs some workaround for __VA_ARGS__
+// _MSVC_TRADITIONAL == 0 means new conformant preprocessor
+// _MSVC_TRADITIONAL == 1 means old traditional non-conformant preprocessor
+# if !defined(_MSVC_TRADITIONAL) || (defined(_MSVC_TRADITIONAL) && _MSVC_TRADITIONAL)
+# define CATCH_INTERNAL_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+# endif
+#endif // _MSC_VER
+#if defined(_REENTRANT) || defined(_MSC_VER)
+// Enable async processing, as -pthread is specified or no additional linking is required
+# define CATCH_INTERNAL_CONFIG_USE_ASYNC
#endif // _MSC_VER
////////////////////////////////////////////////////////////////////////////////
-
-// Use variadic macros if the compiler supports them
-#if ( defined _MSC_VER && _MSC_VER > 1400 && !defined __EDGE__) || \
- ( defined __WAVE__ && __WAVE_HAS_VARIADICS ) || \
- ( defined __GNUC__ && __GNUC__ >= 3 ) || \
- ( !defined __cplusplus && __STDC_VERSION__ >= 199901L || __cplusplus >= 201103L )
-
-#define CATCH_INTERNAL_CONFIG_VARIADIC_MACROS
-
+// Check if we are compiled with -fno-exceptions or equivalent
+#if defined(__EXCEPTIONS) || defined(__cpp_exceptions) || defined(_CPPUNWIND)
+# define CATCH_INTERNAL_CONFIG_EXCEPTIONS_ENABLED
#endif
-// Use __COUNTER__ if the compiler supports it
-#if ( defined _MSC_VER && _MSC_VER >= 1300 ) || \
- ( defined __GNUC__ && __GNUC__ >= 4 && __GNUC_MINOR__ >= 3 ) || \
- ( defined __clang__ && __clang_major__ >= 3 )
-
-#define CATCH_INTERNAL_CONFIG_COUNTER
+////////////////////////////////////////////////////////////////////////////////
+// DJGPP
+#ifdef __DJGPP__
+# define CATCH_INTERNAL_CONFIG_NO_WCHAR
+#endif // __DJGPP__
+////////////////////////////////////////////////////////////////////////////////
+// Embarcadero C++Build
+#if defined(__BORLANDC__)
+ #define CATCH_INTERNAL_CONFIG_POLYFILL_ISNAN
#endif
////////////////////////////////////////////////////////////////////////////////
-// C++ language feature support
-
-// catch all support for C++11
-#if defined(CATCH_CPP11_OR_GREATER)
-
-# if !defined(CATCH_INTERNAL_CONFIG_CPP11_NULLPTR)
-# define CATCH_INTERNAL_CONFIG_CPP11_NULLPTR
-# endif
-# ifndef CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT
-# define CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT
-# endif
-
-# ifndef CATCH_INTERNAL_CONFIG_CPP11_GENERATED_METHODS
-# define CATCH_INTERNAL_CONFIG_CPP11_GENERATED_METHODS
-# endif
-
-# ifndef CATCH_INTERNAL_CONFIG_CPP11_IS_ENUM
-# define CATCH_INTERNAL_CONFIG_CPP11_IS_ENUM
-# endif
-
-# ifndef CATCH_INTERNAL_CONFIG_CPP11_TUPLE
-# define CATCH_INTERNAL_CONFIG_CPP11_TUPLE
-# endif
+// Use of __COUNTER__ is suppressed during code analysis in
+// CLion/AppCode 2017.2.x and former, because __COUNTER__ is not properly
+// handled by it.
+// Otherwise all supported compilers support COUNTER macro,
+// but user still might want to turn it off
+#if ( !defined(__JETBRAINS_IDE__) || __JETBRAINS_IDE__ >= 20170300L )
+ #define CATCH_INTERNAL_CONFIG_COUNTER
+#endif
-# ifndef CATCH_INTERNAL_CONFIG_VARIADIC_MACROS
-# define CATCH_INTERNAL_CONFIG_VARIADIC_MACROS
-# endif
+////////////////////////////////////////////////////////////////////////////////
-# if !defined(CATCH_INTERNAL_CONFIG_CPP11_LONG_LONG)
-# define CATCH_INTERNAL_CONFIG_CPP11_LONG_LONG
-# endif
+// RTX is a special version of Windows that is real time.
+// This means that it is detected as Windows, but does not provide
+// the same set of capabilities as real Windows does.
+#if defined(UNDER_RTSS) || defined(RTX64_BUILD)
+ #define CATCH_INTERNAL_CONFIG_NO_WINDOWS_SEH
+ #define CATCH_INTERNAL_CONFIG_NO_ASYNC
+ #define CATCH_CONFIG_COLOUR_NONE
+#endif
-# if !defined(CATCH_INTERNAL_CONFIG_CPP11_OVERRIDE)
-# define CATCH_INTERNAL_CONFIG_CPP11_OVERRIDE
-# endif
-# if !defined(CATCH_INTERNAL_CONFIG_CPP11_UNIQUE_PTR)
-# define CATCH_INTERNAL_CONFIG_CPP11_UNIQUE_PTR
-# endif
-# if !defined(CATCH_INTERNAL_CONFIG_CPP11_SHUFFLE)
-# define CATCH_INTERNAL_CONFIG_CPP11_SHUFFLE
-# endif
-# if !defined(CATCH_INTERNAL_CONFIG_CPP11_TYPE_TRAITS)
-# define CATCH_INTERNAL_CONFIG_CPP11_TYPE_TRAITS
-# endif
+#if defined(__UCLIBC__)
+#define CATCH_INTERNAL_CONFIG_GLOBAL_NEXTAFTER
+#endif
-#endif // __cplusplus >= 201103L
+// Various stdlib support checks that require __has_include
+#if defined(__has_include)
+ // Check if string_view is available and usable
+ #if __has_include(<string_view>) && defined(CATCH_CPP17_OR_GREATER)
+ # define CATCH_INTERNAL_CONFIG_CPP17_STRING_VIEW
+ #endif
+
+ // Check if optional is available and usable
+ # if __has_include(<optional>) && defined(CATCH_CPP17_OR_GREATER)
+ # define CATCH_INTERNAL_CONFIG_CPP17_OPTIONAL
+ # endif // __has_include(<optional>) && defined(CATCH_CPP17_OR_GREATER)
+
+ // Check if byte is available and usable
+ # if __has_include(<cstddef>) && defined(CATCH_CPP17_OR_GREATER)
+ # define CATCH_INTERNAL_CONFIG_CPP17_BYTE
+ # endif // __has_include(<cstddef>) && defined(CATCH_CPP17_OR_GREATER)
+
+ // Check if variant is available and usable
+ # if __has_include(<variant>) && defined(CATCH_CPP17_OR_GREATER)
+ # if defined(__clang__) && (__clang_major__ < 8)
+ // work around clang bug with libstdc++ https://bugs.llvm.org/show_bug.cgi?id=31852
+ // fix should be in clang 8, workaround in libstdc++ 8.2
+ # include <ciso646>
+ # if defined(__GLIBCXX__) && defined(_GLIBCXX_RELEASE) && (_GLIBCXX_RELEASE < 9)
+ # define CATCH_CONFIG_NO_CPP17_VARIANT
+ # else
+ # define CATCH_INTERNAL_CONFIG_CPP17_VARIANT
+ # endif // defined(__GLIBCXX__) && defined(_GLIBCXX_RELEASE) && (_GLIBCXX_RELEASE < 9)
+ # else
+ # define CATCH_INTERNAL_CONFIG_CPP17_VARIANT
+ # endif // defined(__clang__) && (__clang_major__ < 8)
+ # endif // __has_include(<variant>) && defined(CATCH_CPP17_OR_GREATER)
+#endif // defined(__has_include)
+
+#if defined(CATCH_INTERNAL_CONFIG_COUNTER) && !defined(CATCH_CONFIG_NO_COUNTER) && !defined(CATCH_CONFIG_COUNTER)
+# define CATCH_CONFIG_COUNTER
+#endif
+#if defined(CATCH_INTERNAL_CONFIG_WINDOWS_SEH) && !defined(CATCH_CONFIG_NO_WINDOWS_SEH) && !defined(CATCH_CONFIG_WINDOWS_SEH) && !defined(CATCH_INTERNAL_CONFIG_NO_WINDOWS_SEH)
+# define CATCH_CONFIG_WINDOWS_SEH
+#endif
+// This is set by default, because we assume that unix compilers are posix-signal-compatible by default.
+#if defined(CATCH_INTERNAL_CONFIG_POSIX_SIGNALS) && !defined(CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_POSIX_SIGNALS)
+# define CATCH_CONFIG_POSIX_SIGNALS
+#endif
+// This is set by default, because we assume that compilers with no wchar_t support are just rare exceptions.
+#if !defined(CATCH_INTERNAL_CONFIG_NO_WCHAR) && !defined(CATCH_CONFIG_NO_WCHAR) && !defined(CATCH_CONFIG_WCHAR)
+# define CATCH_CONFIG_WCHAR
+#endif
-// Now set the actual defines based on the above + anything the user has configured
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_NULLPTR) && !defined(CATCH_CONFIG_CPP11_NO_NULLPTR) && !defined(CATCH_CONFIG_CPP11_NULLPTR) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_NULLPTR
+#if !defined(CATCH_INTERNAL_CONFIG_NO_CPP11_TO_STRING) && !defined(CATCH_CONFIG_NO_CPP11_TO_STRING) && !defined(CATCH_CONFIG_CPP11_TO_STRING)
+# define CATCH_CONFIG_CPP11_TO_STRING
#endif
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_NOEXCEPT) && !defined(CATCH_CONFIG_CPP11_NO_NOEXCEPT) && !defined(CATCH_CONFIG_CPP11_NOEXCEPT) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_NOEXCEPT
+
+#if defined(CATCH_INTERNAL_CONFIG_CPP17_OPTIONAL) && !defined(CATCH_CONFIG_NO_CPP17_OPTIONAL) && !defined(CATCH_CONFIG_CPP17_OPTIONAL)
+# define CATCH_CONFIG_CPP17_OPTIONAL
#endif
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_GENERATED_METHODS) && !defined(CATCH_CONFIG_CPP11_NO_GENERATED_METHODS) && !defined(CATCH_CONFIG_CPP11_GENERATED_METHODS) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_GENERATED_METHODS
+
+#if defined(CATCH_INTERNAL_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS) && !defined(CATCH_CONFIG_NO_CPP17_UNCAUGHT_EXCEPTIONS) && !defined(CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS)
+# define CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS
#endif
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_IS_ENUM) && !defined(CATCH_CONFIG_CPP11_NO_IS_ENUM) && !defined(CATCH_CONFIG_CPP11_IS_ENUM) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_IS_ENUM
+
+#if defined(CATCH_INTERNAL_CONFIG_CPP17_STRING_VIEW) && !defined(CATCH_CONFIG_NO_CPP17_STRING_VIEW) && !defined(CATCH_CONFIG_CPP17_STRING_VIEW)
+# define CATCH_CONFIG_CPP17_STRING_VIEW
#endif
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_TUPLE) && !defined(CATCH_CONFIG_CPP11_NO_TUPLE) && !defined(CATCH_CONFIG_CPP11_TUPLE) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_TUPLE
+
+#if defined(CATCH_INTERNAL_CONFIG_CPP17_VARIANT) && !defined(CATCH_CONFIG_NO_CPP17_VARIANT) && !defined(CATCH_CONFIG_CPP17_VARIANT)
+# define CATCH_CONFIG_CPP17_VARIANT
#endif
-#if defined(CATCH_INTERNAL_CONFIG_VARIADIC_MACROS) && !defined(CATCH_CONFIG_NO_VARIADIC_MACROS) && !defined(CATCH_CONFIG_VARIADIC_MACROS)
-# define CATCH_CONFIG_VARIADIC_MACROS
+
+#if defined(CATCH_INTERNAL_CONFIG_CPP17_BYTE) && !defined(CATCH_CONFIG_NO_CPP17_BYTE) && !defined(CATCH_CONFIG_CPP17_BYTE)
+# define CATCH_CONFIG_CPP17_BYTE
#endif
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_LONG_LONG) && !defined(CATCH_CONFIG_CPP11_NO_LONG_LONG) && !defined(CATCH_CONFIG_CPP11_LONG_LONG) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_LONG_LONG
+
+#if defined(CATCH_CONFIG_EXPERIMENTAL_REDIRECT)
+# define CATCH_INTERNAL_CONFIG_NEW_CAPTURE
#endif
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_OVERRIDE) && !defined(CATCH_CONFIG_CPP11_NO_OVERRIDE) && !defined(CATCH_CONFIG_CPP11_OVERRIDE) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_OVERRIDE
+
+#if defined(CATCH_INTERNAL_CONFIG_NEW_CAPTURE) && !defined(CATCH_INTERNAL_CONFIG_NO_NEW_CAPTURE) && !defined(CATCH_CONFIG_NO_NEW_CAPTURE) && !defined(CATCH_CONFIG_NEW_CAPTURE)
+# define CATCH_CONFIG_NEW_CAPTURE
#endif
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_UNIQUE_PTR) && !defined(CATCH_CONFIG_CPP11_NO_UNIQUE_PTR) && !defined(CATCH_CONFIG_CPP11_UNIQUE_PTR) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_UNIQUE_PTR
+
+#if !defined(CATCH_INTERNAL_CONFIG_EXCEPTIONS_ENABLED) && !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS)
+# define CATCH_CONFIG_DISABLE_EXCEPTIONS
#endif
-// Use of __COUNTER__ is suppressed if __JETBRAINS_IDE__ is #defined (meaning we're being parsed by a JetBrains IDE for
-// analytics) because, at time of writing, __COUNTER__ is not properly handled by it.
-// This does not affect compilation
-#if defined(CATCH_INTERNAL_CONFIG_COUNTER) && !defined(CATCH_CONFIG_NO_COUNTER) && !defined(CATCH_CONFIG_COUNTER) && !defined(__JETBRAINS_IDE__)
-# define CATCH_CONFIG_COUNTER
+
+#if defined(CATCH_INTERNAL_CONFIG_POLYFILL_ISNAN) && !defined(CATCH_CONFIG_NO_POLYFILL_ISNAN) && !defined(CATCH_CONFIG_POLYFILL_ISNAN)
+# define CATCH_CONFIG_POLYFILL_ISNAN
#endif
-#if defined(CATCH_INTERNAL_CONFIG_CPP11_SHUFFLE) && !defined(CATCH_CONFIG_CPP11_NO_SHUFFLE) && !defined(CATCH_CONFIG_CPP11_SHUFFLE) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_SHUFFLE
+
+#if defined(CATCH_INTERNAL_CONFIG_USE_ASYNC) && !defined(CATCH_INTERNAL_CONFIG_NO_ASYNC) && !defined(CATCH_CONFIG_NO_USE_ASYNC) && !defined(CATCH_CONFIG_USE_ASYNC)
+# define CATCH_CONFIG_USE_ASYNC
#endif
-# if defined(CATCH_INTERNAL_CONFIG_CPP11_TYPE_TRAITS) && !defined(CATCH_CONFIG_CPP11_NO_TYPE_TRAITS) && !defined(CATCH_CONFIG_CPP11_TYPE_TRAITS) && !defined(CATCH_CONFIG_NO_CPP11)
-# define CATCH_CONFIG_CPP11_TYPE_TRAITS
-# endif
-#if defined(CATCH_INTERNAL_CONFIG_WINDOWS_SEH) && !defined(CATCH_CONFIG_NO_WINDOWS_SEH) && !defined(CATCH_CONFIG_WINDOWS_SEH)
-# define CATCH_CONFIG_WINDOWS_SEH
+
+#if defined(CATCH_INTERNAL_CONFIG_ANDROID_LOGWRITE) && !defined(CATCH_CONFIG_NO_ANDROID_LOGWRITE) && !defined(CATCH_CONFIG_ANDROID_LOGWRITE)
+# define CATCH_CONFIG_ANDROID_LOGWRITE
#endif
-// This is set by default, because we assume that unix compilers are posix-signal-compatible by default.
-#if !defined(CATCH_INTERNAL_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_NO_POSIX_SIGNALS) && !defined(CATCH_CONFIG_POSIX_SIGNALS)
-# define CATCH_CONFIG_POSIX_SIGNALS
+
+#if defined(CATCH_INTERNAL_CONFIG_GLOBAL_NEXTAFTER) && !defined(CATCH_CONFIG_NO_GLOBAL_NEXTAFTER) && !defined(CATCH_CONFIG_GLOBAL_NEXTAFTER)
+# define CATCH_CONFIG_GLOBAL_NEXTAFTER
#endif
#if !defined(CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS)
# define CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS
# define CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS
#endif
-#if !defined(CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS)
-# define CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS
-# define CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS
+#if !defined(CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS)
+# define CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS
+# define CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS
+#endif
+#if !defined(CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS)
+# define CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS
+# define CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS
+#endif
+#if !defined(CATCH_INTERNAL_SUPPRESS_ZERO_VARIADIC_WARNINGS)
+# define CATCH_INTERNAL_SUPPRESS_ZERO_VARIADIC_WARNINGS
+# define CATCH_INTERNAL_UNSUPPRESS_ZERO_VARIADIC_WARNINGS
#endif
-// noexcept support:
-#if defined(CATCH_CONFIG_CPP11_NOEXCEPT) && !defined(CATCH_NOEXCEPT)
-# define CATCH_NOEXCEPT noexcept
-# define CATCH_NOEXCEPT_IS(x) noexcept(x)
-#else
-# define CATCH_NOEXCEPT throw()
-# define CATCH_NOEXCEPT_IS(x)
+#if defined(__APPLE__) && defined(__apple_build_version__) && (__clang_major__ < 10)
+# undef CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS
+# undef CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS
+#elif defined(__clang__) && (__clang_major__ < 5)
+# undef CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS
+# undef CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS
#endif
-// nullptr support
-#ifdef CATCH_CONFIG_CPP11_NULLPTR
-# define CATCH_NULL nullptr
-#else
-# define CATCH_NULL NULL
+#if !defined(CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS)
+# define CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS
+# define CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS
#endif
-// override support
-#ifdef CATCH_CONFIG_CPP11_OVERRIDE
-# define CATCH_OVERRIDE override
+#if defined(CATCH_CONFIG_DISABLE_EXCEPTIONS)
+#define CATCH_TRY if ((true))
+#define CATCH_CATCH_ALL if ((false))
+#define CATCH_CATCH_ANON(type) if ((false))
#else
-# define CATCH_OVERRIDE
+#define CATCH_TRY try
+#define CATCH_CATCH_ALL catch (...)
+#define CATCH_CATCH_ANON(type) catch (type)
#endif
-// unique_ptr support
-#ifdef CATCH_CONFIG_CPP11_UNIQUE_PTR
-# define CATCH_AUTO_PTR( T ) std::unique_ptr<T>
-#else
-# define CATCH_AUTO_PTR( T ) std::auto_ptr<T>
+#if defined(CATCH_INTERNAL_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR) && !defined(CATCH_CONFIG_NO_TRADITIONAL_MSVC_PREPROCESSOR) && !defined(CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR)
+#define CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
#endif
+// end catch_compiler_capabilities.h
#define INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line ) name##line
#define INTERNAL_CATCH_UNIQUE_NAME_LINE( name, line ) INTERNAL_CATCH_UNIQUE_NAME_LINE2( name, line )
#ifdef CATCH_CONFIG_COUNTER
@@ -371,95 +450,48 @@
# define INTERNAL_CATCH_UNIQUE_NAME( name ) INTERNAL_CATCH_UNIQUE_NAME_LINE( name, __LINE__ )
#endif
-#define INTERNAL_CATCH_STRINGIFY2( expr ) #expr
-#define INTERNAL_CATCH_STRINGIFY( expr ) INTERNAL_CATCH_STRINGIFY2( expr )
+#include <iosfwd>
+#include <string>
+#include <cstdint>
-#include <sstream>
-#include <algorithm>
+// We need a dummy global operator<< so we can bring it into Catch namespace later
+struct Catch_global_namespace_dummy {};
+std::ostream& operator<<(std::ostream&, Catch_global_namespace_dummy);
namespace Catch {
- struct IConfig;
-
struct CaseSensitive { enum Choice {
Yes,
No
}; };
class NonCopyable {
-#ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS
NonCopyable( NonCopyable const& ) = delete;
NonCopyable( NonCopyable && ) = delete;
NonCopyable& operator = ( NonCopyable const& ) = delete;
NonCopyable& operator = ( NonCopyable && ) = delete;
-#else
- NonCopyable( NonCopyable const& info );
- NonCopyable& operator = ( NonCopyable const& );
-#endif
protected:
- NonCopyable() {}
+ NonCopyable();
virtual ~NonCopyable();
};
- class SafeBool {
- public:
- typedef void (SafeBool::*type)() const;
-
- static type makeSafe( bool value ) {
- return value ? &SafeBool::trueValue : 0;
- }
- private:
- void trueValue() const {}
- };
-
- template<typename ContainerT>
- inline void deleteAll( ContainerT& container ) {
- typename ContainerT::const_iterator it = container.begin();
- typename ContainerT::const_iterator itEnd = container.end();
- for(; it != itEnd; ++it )
- delete *it;
- }
- template<typename AssociativeContainerT>
- inline void deleteAllValues( AssociativeContainerT& container ) {
- typename AssociativeContainerT::const_iterator it = container.begin();
- typename AssociativeContainerT::const_iterator itEnd = container.end();
- for(; it != itEnd; ++it )
- delete it->second;
- }
-
- bool startsWith( std::string const& s, std::string const& prefix );
- bool startsWith( std::string const& s, char prefix );
- bool endsWith( std::string const& s, std::string const& suffix );
- bool endsWith( std::string const& s, char suffix );
- bool contains( std::string const& s, std::string const& infix );
- void toLowerInPlace( std::string& s );
- std::string toLower( std::string const& s );
- std::string trim( std::string const& str );
- bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis );
-
- struct pluralise {
- pluralise( std::size_t count, std::string const& label );
-
- friend std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser );
+ struct SourceLineInfo {
- std::size_t m_count;
- std::string m_label;
- };
+ SourceLineInfo() = delete;
+ SourceLineInfo( char const* _file, std::size_t _line ) noexcept
+ : file( _file ),
+ line( _line )
+ {}
- struct SourceLineInfo {
+ SourceLineInfo( SourceLineInfo const& other ) = default;
+ SourceLineInfo& operator = ( SourceLineInfo const& ) = default;
+ SourceLineInfo( SourceLineInfo&& ) noexcept = default;
+ SourceLineInfo& operator = ( SourceLineInfo&& ) noexcept = default;
- SourceLineInfo();
- SourceLineInfo( char const* _file, std::size_t _line );
-# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS
- SourceLineInfo(SourceLineInfo const& other) = default;
- SourceLineInfo( SourceLineInfo && ) = default;
- SourceLineInfo& operator = ( SourceLineInfo const& ) = default;
- SourceLineInfo& operator = ( SourceLineInfo && ) = default;
-# endif
- bool empty() const;
- bool operator == ( SourceLineInfo const& other ) const;
- bool operator < ( SourceLineInfo const& other ) const;
+ bool empty() const noexcept { return file[0] == '\0'; }
+ bool operator == ( SourceLineInfo const& other ) const noexcept;
+ bool operator < ( SourceLineInfo const& other ) const noexcept;
char const* file;
std::size_t line;
@@ -467,24 +499,17 @@ namespace Catch {
std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info );
- // This is just here to avoid compiler warnings with macro constants and boolean literals
- inline bool isTrue( bool value ){ return value; }
- inline bool alwaysTrue() { return true; }
- inline bool alwaysFalse() { return false; }
-
- void throwLogicError( std::string const& message, SourceLineInfo const& locationInfo );
-
- void seedRng( IConfig const& config );
- unsigned int rngSeed();
+ // Bring in operator<< from global namespace into Catch namespace
+ // This is necessary because the overload of operator<< above makes
+ // lookup stop at namespace Catch
+ using ::operator<<;
// Use this in variadic streaming macros to allow
// >> +StreamEndStop
// as well as
// >> stuff +StreamEndStop
struct StreamEndStop {
- std::string operator+() {
- return std::string();
- }
+ std::string operator+() const;
};
template<typename T>
T const& operator + ( T const& value, StreamEndStop ) {
@@ -492,364 +517,834 @@ namespace Catch {
}
}
-#define CATCH_INTERNAL_LINEINFO ::Catch::SourceLineInfo( __FILE__, static_cast<std::size_t>( __LINE__ ) )
-#define CATCH_INTERNAL_ERROR( msg ) ::Catch::throwLogicError( msg, CATCH_INTERNAL_LINEINFO );
+#define CATCH_INTERNAL_LINEINFO \
+ ::Catch::SourceLineInfo( __FILE__, static_cast<std::size_t>( __LINE__ ) )
+// end catch_common.h
namespace Catch {
- class NotImplementedException : public std::exception
- {
- public:
- NotImplementedException( SourceLineInfo const& lineInfo );
- NotImplementedException( NotImplementedException const& ) {}
-
- virtual ~NotImplementedException() CATCH_NOEXCEPT {}
-
- virtual const char* what() const CATCH_NOEXCEPT;
-
- private:
- std::string m_what;
- SourceLineInfo m_lineInfo;
+ struct RegistrarForTagAliases {
+ RegistrarForTagAliases( char const* alias, char const* tag, SourceLineInfo const& lineInfo );
};
} // end namespace Catch
-///////////////////////////////////////////////////////////////////////////////
-#define CATCH_NOT_IMPLEMENTED throw Catch::NotImplementedException( CATCH_INTERNAL_LINEINFO )
+#define CATCH_REGISTER_TAG_ALIAS( alias, spec ) \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ namespace{ Catch::RegistrarForTagAliases INTERNAL_CATCH_UNIQUE_NAME( AutoRegisterTagAlias )( alias, spec, CATCH_INTERNAL_LINEINFO ); } \
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS
-// #included from: internal/catch_context.h
-#define TWOBLUECUBES_CATCH_CONTEXT_H_INCLUDED
+// end catch_tag_alias_autoregistrar.h
+// start catch_test_registry.h
-// #included from: catch_interfaces_generators.h
-#define TWOBLUECUBES_CATCH_INTERFACES_GENERATORS_H_INCLUDED
+// start catch_interfaces_testcase.h
-#include <string>
+#include <vector>
namespace Catch {
- struct IGeneratorInfo {
- virtual ~IGeneratorInfo();
- virtual bool moveNext() = 0;
- virtual std::size_t getCurrentIndex() const = 0;
+ class TestSpec;
+
+ struct ITestInvoker {
+ virtual void invoke () const = 0;
+ virtual ~ITestInvoker();
};
- struct IGeneratorsForTest {
- virtual ~IGeneratorsForTest();
+ class TestCase;
+ struct IConfig;
- virtual IGeneratorInfo& getGeneratorInfo( std::string const& fileInfo, std::size_t size ) = 0;
- virtual bool moveNext() = 0;
+ struct ITestCaseRegistry {
+ virtual ~ITestCaseRegistry();
+ virtual std::vector<TestCase> const& getAllTests() const = 0;
+ virtual std::vector<TestCase> const& getAllTestsSorted( IConfig const& config ) const = 0;
};
- IGeneratorsForTest* createGeneratorsForTest();
+ bool isThrowSafe( TestCase const& testCase, IConfig const& config );
+ bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config );
+ std::vector<TestCase> filterTests( std::vector<TestCase> const& testCases, TestSpec const& testSpec, IConfig const& config );
+ std::vector<TestCase> const& getAllTestCasesSorted( IConfig const& config );
-} // end namespace Catch
+}
-// #included from: catch_ptr.hpp
-#define TWOBLUECUBES_CATCH_PTR_HPP_INCLUDED
+// end catch_interfaces_testcase.h
+// start catch_stringref.h
-#ifdef __clang__
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wpadded"
-#endif
+#include <cstddef>
+#include <string>
+#include <iosfwd>
+#include <cassert>
namespace Catch {
- // An intrusive reference counting smart pointer.
- // T must implement addRef() and release() methods
- // typically implementing the IShared interface
- template<typename T>
- class Ptr {
+ /// A non-owning string class (similar to the forthcoming std::string_view)
+ /// Note that, because a StringRef may be a substring of another string,
+ /// it may not be null terminated. c_str() must return a null terminated
+ /// string, however, and so the StringRef will internally take ownership
+ /// (taking a copy), if necessary. In theory this ownership is not externally
+ /// visible - but it does mean (substring) StringRefs should not be shared between
+ /// threads.
+ class StringRef {
public:
- Ptr() : m_p( CATCH_NULL ){}
- Ptr( T* p ) : m_p( p ){
- if( m_p )
- m_p->addRef();
- }
- Ptr( Ptr const& other ) : m_p( other.m_p ){
- if( m_p )
- m_p->addRef();
- }
- ~Ptr(){
- if( m_p )
- m_p->release();
- }
- void reset() {
- if( m_p )
- m_p->release();
- m_p = CATCH_NULL;
+ using size_type = std::size_t;
+ using const_iterator = const char*;
+
+ private:
+ friend struct StringRefTestAccess;
+
+ char const* m_start;
+ size_type m_size;
+
+ char* m_data = nullptr;
+
+ void takeOwnership();
+
+ static constexpr char const* const s_empty = "";
+
+ public: // construction/ assignment
+ StringRef() noexcept
+ : StringRef( s_empty, 0 )
+ {}
+
+ StringRef( StringRef const& other ) noexcept
+ : m_start( other.m_start ),
+ m_size( other.m_size )
+ {}
+
+ StringRef( StringRef&& other ) noexcept
+ : m_start( other.m_start ),
+ m_size( other.m_size ),
+ m_data( other.m_data )
+ {
+ other.m_data = nullptr;
}
- Ptr& operator = ( T* p ){
- Ptr temp( p );
- swap( temp );
- return *this;
+
+ StringRef( char const* rawChars ) noexcept;
+
+ StringRef( char const* rawChars, size_type size ) noexcept
+ : m_start( rawChars ),
+ m_size( size )
+ {}
+
+ StringRef( std::string const& stdString ) noexcept
+ : m_start( stdString.c_str() ),
+ m_size( stdString.size() )
+ {}
+
+ ~StringRef() noexcept {
+ delete[] m_data;
}
- Ptr& operator = ( Ptr const& other ){
- Ptr temp( other );
- swap( temp );
+
+ auto operator = ( StringRef const &other ) noexcept -> StringRef& {
+ delete[] m_data;
+ m_data = nullptr;
+ m_start = other.m_start;
+ m_size = other.m_size;
return *this;
}
- void swap( Ptr& other ) { std::swap( m_p, other.m_p ); }
- T* get() const{ return m_p; }
- T& operator*() const { return *m_p; }
- T* operator->() const { return m_p; }
- bool operator !() const { return m_p == CATCH_NULL; }
- operator SafeBool::type() const { return SafeBool::makeSafe( m_p != CATCH_NULL ); }
- private:
- T* m_p;
- };
+ explicit operator std::string() const {
+ return std::string(m_start, m_size);
+ }
- struct IShared : NonCopyable {
- virtual ~IShared();
- virtual void addRef() const = 0;
- virtual void release() const = 0;
- };
+ void swap( StringRef& other ) noexcept;
- template<typename T = IShared>
- struct SharedImpl : T {
+ public: // operators
+ auto operator == ( StringRef const& other ) const noexcept -> bool;
+ auto operator != ( StringRef const& other ) const noexcept -> bool;
- SharedImpl() : m_rc( 0 ){}
+ auto operator[] ( size_type index ) const noexcept -> char {
+ assert(index < m_size);
+ return m_start[index];
+ }
- virtual void addRef() const {
- ++m_rc;
+ public: // named queries
+ auto empty() const noexcept -> bool {
+ return m_size == 0;
}
- virtual void release() const {
- if( --m_rc == 0 )
- delete this;
+ auto size() const noexcept -> size_type {
+ return m_size;
}
- mutable unsigned int m_rc;
+ auto c_str() const -> char const*;
+
+ public: // substrings and searches
+ auto substr( size_type start, size_type size ) const noexcept -> StringRef;
+
+ // Returns the current start pointer.
+ // Note that the pointer can change when if the StringRef is a substring
+ auto currentData() const noexcept -> char const*;
+
+ public: // iterators
+ const_iterator begin() const { return m_start; }
+ const_iterator end() const { return m_start + m_size; }
+
+ private: // ownership queries - may not be consistent between calls
+ auto isOwned() const noexcept -> bool;
+ auto isSubstring() const noexcept -> bool;
};
-} // end namespace Catch
+ auto operator += ( std::string& lhs, StringRef const& sr ) -> std::string&;
+ auto operator << ( std::ostream& os, StringRef const& sr ) -> std::ostream&;
-#ifdef __clang__
-#pragma clang diagnostic pop
-#endif
+ inline auto operator "" _sr( char const* rawChars, std::size_t size ) noexcept -> StringRef {
+ return StringRef( rawChars, size );
+ }
-namespace Catch {
+} // namespace Catch
- class TestCase;
- class Stream;
- struct IResultCapture;
- struct IRunner;
- struct IGeneratorsForTest;
- struct IConfig;
+inline auto operator "" _catch_sr( char const* rawChars, std::size_t size ) noexcept -> Catch::StringRef {
+ return Catch::StringRef( rawChars, size );
+}
- struct IContext
- {
- virtual ~IContext();
+// end catch_stringref.h
+// start catch_preprocessor.hpp
- virtual IResultCapture* getResultCapture() = 0;
- virtual IRunner* getRunner() = 0;
- virtual size_t getGeneratorIndex( std::string const& fileInfo, size_t totalSize ) = 0;
- virtual bool advanceGeneratorsForCurrentTest() = 0;
- virtual Ptr<IConfig const> getConfig() const = 0;
- };
- struct IMutableContext : IContext
- {
- virtual ~IMutableContext();
- virtual void setResultCapture( IResultCapture* resultCapture ) = 0;
- virtual void setRunner( IRunner* runner ) = 0;
- virtual void setConfig( Ptr<IConfig const> const& config ) = 0;
- };
+#define CATCH_RECURSION_LEVEL0(...) __VA_ARGS__
+#define CATCH_RECURSION_LEVEL1(...) CATCH_RECURSION_LEVEL0(CATCH_RECURSION_LEVEL0(CATCH_RECURSION_LEVEL0(__VA_ARGS__)))
+#define CATCH_RECURSION_LEVEL2(...) CATCH_RECURSION_LEVEL1(CATCH_RECURSION_LEVEL1(CATCH_RECURSION_LEVEL1(__VA_ARGS__)))
+#define CATCH_RECURSION_LEVEL3(...) CATCH_RECURSION_LEVEL2(CATCH_RECURSION_LEVEL2(CATCH_RECURSION_LEVEL2(__VA_ARGS__)))
+#define CATCH_RECURSION_LEVEL4(...) CATCH_RECURSION_LEVEL3(CATCH_RECURSION_LEVEL3(CATCH_RECURSION_LEVEL3(__VA_ARGS__)))
+#define CATCH_RECURSION_LEVEL5(...) CATCH_RECURSION_LEVEL4(CATCH_RECURSION_LEVEL4(CATCH_RECURSION_LEVEL4(__VA_ARGS__)))
- IContext& getCurrentContext();
- IMutableContext& getCurrentMutableContext();
- void cleanUpContext();
- Stream createStream( std::string const& streamName );
+#ifdef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+#define INTERNAL_CATCH_EXPAND_VARGS(...) __VA_ARGS__
+// MSVC needs more evaluations
+#define CATCH_RECURSION_LEVEL6(...) CATCH_RECURSION_LEVEL5(CATCH_RECURSION_LEVEL5(CATCH_RECURSION_LEVEL5(__VA_ARGS__)))
+#define CATCH_RECURSE(...) CATCH_RECURSION_LEVEL6(CATCH_RECURSION_LEVEL6(__VA_ARGS__))
+#else
+#define CATCH_RECURSE(...) CATCH_RECURSION_LEVEL5(__VA_ARGS__)
+#endif
-}
+#define CATCH_REC_END(...)
+#define CATCH_REC_OUT
+
+#define CATCH_EMPTY()
+#define CATCH_DEFER(id) id CATCH_EMPTY()
+
+#define CATCH_REC_GET_END2() 0, CATCH_REC_END
+#define CATCH_REC_GET_END1(...) CATCH_REC_GET_END2
+#define CATCH_REC_GET_END(...) CATCH_REC_GET_END1
+#define CATCH_REC_NEXT0(test, next, ...) next CATCH_REC_OUT
+#define CATCH_REC_NEXT1(test, next) CATCH_DEFER ( CATCH_REC_NEXT0 ) ( test, next, 0)
+#define CATCH_REC_NEXT(test, next) CATCH_REC_NEXT1(CATCH_REC_GET_END test, next)
+
+#define CATCH_REC_LIST0(f, x, peek, ...) , f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1) ) ( f, peek, __VA_ARGS__ )
+#define CATCH_REC_LIST1(f, x, peek, ...) , f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST0) ) ( f, peek, __VA_ARGS__ )
+#define CATCH_REC_LIST2(f, x, peek, ...) f(x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1) ) ( f, peek, __VA_ARGS__ )
+
+#define CATCH_REC_LIST0_UD(f, userdata, x, peek, ...) , f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1_UD) ) ( f, userdata, peek, __VA_ARGS__ )
+#define CATCH_REC_LIST1_UD(f, userdata, x, peek, ...) , f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST0_UD) ) ( f, userdata, peek, __VA_ARGS__ )
+#define CATCH_REC_LIST2_UD(f, userdata, x, peek, ...) f(userdata, x) CATCH_DEFER ( CATCH_REC_NEXT(peek, CATCH_REC_LIST1_UD) ) ( f, userdata, peek, __VA_ARGS__ )
+
+// Applies the function macro `f` to each of the remaining parameters, inserts commas between the results,
+// and passes userdata as the first parameter to each invocation,
+// e.g. CATCH_REC_LIST_UD(f, x, a, b, c) evaluates to f(x, a), f(x, b), f(x, c)
+#define CATCH_REC_LIST_UD(f, userdata, ...) CATCH_RECURSE(CATCH_REC_LIST2_UD(f, userdata, __VA_ARGS__, ()()(), ()()(), ()()(), 0))
+
+#define CATCH_REC_LIST(f, ...) CATCH_RECURSE(CATCH_REC_LIST2(f, __VA_ARGS__, ()()(), ()()(), ()()(), 0))
+
+#define INTERNAL_CATCH_EXPAND1(param) INTERNAL_CATCH_EXPAND2(param)
+#define INTERNAL_CATCH_EXPAND2(...) INTERNAL_CATCH_NO## __VA_ARGS__
+#define INTERNAL_CATCH_DEF(...) INTERNAL_CATCH_DEF __VA_ARGS__
+#define INTERNAL_CATCH_NOINTERNAL_CATCH_DEF
+#define INTERNAL_CATCH_STRINGIZE(...) INTERNAL_CATCH_STRINGIZE2(__VA_ARGS__)
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+#define INTERNAL_CATCH_STRINGIZE2(...) #__VA_ARGS__
+#define INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS(param) INTERNAL_CATCH_STRINGIZE(INTERNAL_CATCH_REMOVE_PARENS(param))
+#else
+// MSVC is adding extra space and needs another indirection to expand INTERNAL_CATCH_NOINTERNAL_CATCH_DEF
+#define INTERNAL_CATCH_STRINGIZE2(...) INTERNAL_CATCH_STRINGIZE3(__VA_ARGS__)
+#define INTERNAL_CATCH_STRINGIZE3(...) #__VA_ARGS__
+#define INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS(param) (INTERNAL_CATCH_STRINGIZE(INTERNAL_CATCH_REMOVE_PARENS(param)) + 1)
+#endif
-// #included from: internal/catch_test_registry.hpp
-#define TWOBLUECUBES_CATCH_TEST_REGISTRY_HPP_INCLUDED
+#define INTERNAL_CATCH_MAKE_NAMESPACE2(...) ns_##__VA_ARGS__
+#define INTERNAL_CATCH_MAKE_NAMESPACE(name) INTERNAL_CATCH_MAKE_NAMESPACE2(name)
-// #included from: catch_interfaces_testcase.h
-#define TWOBLUECUBES_CATCH_INTERFACES_TESTCASE_H_INCLUDED
+#define INTERNAL_CATCH_REMOVE_PARENS(...) INTERNAL_CATCH_EXPAND1(INTERNAL_CATCH_DEF __VA_ARGS__)
-#include <vector>
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+#define INTERNAL_CATCH_MAKE_TYPE_LIST2(...) decltype(get_wrapper<INTERNAL_CATCH_REMOVE_PARENS_GEN(__VA_ARGS__)>())
+#define INTERNAL_CATCH_MAKE_TYPE_LIST(...) INTERNAL_CATCH_MAKE_TYPE_LIST2(INTERNAL_CATCH_REMOVE_PARENS(__VA_ARGS__))
+#else
+#define INTERNAL_CATCH_MAKE_TYPE_LIST2(...) INTERNAL_CATCH_EXPAND_VARGS(decltype(get_wrapper<INTERNAL_CATCH_REMOVE_PARENS_GEN(__VA_ARGS__)>()))
+#define INTERNAL_CATCH_MAKE_TYPE_LIST(...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_MAKE_TYPE_LIST2(INTERNAL_CATCH_REMOVE_PARENS(__VA_ARGS__)))
+#endif
-namespace Catch {
+#define INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(...)\
+ CATCH_REC_LIST(INTERNAL_CATCH_MAKE_TYPE_LIST,__VA_ARGS__)
+
+#define INTERNAL_CATCH_REMOVE_PARENS_1_ARG(_0) INTERNAL_CATCH_REMOVE_PARENS(_0)
+#define INTERNAL_CATCH_REMOVE_PARENS_2_ARG(_0, _1) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_1_ARG(_1)
+#define INTERNAL_CATCH_REMOVE_PARENS_3_ARG(_0, _1, _2) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_2_ARG(_1, _2)
+#define INTERNAL_CATCH_REMOVE_PARENS_4_ARG(_0, _1, _2, _3) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_3_ARG(_1, _2, _3)
+#define INTERNAL_CATCH_REMOVE_PARENS_5_ARG(_0, _1, _2, _3, _4) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_4_ARG(_1, _2, _3, _4)
+#define INTERNAL_CATCH_REMOVE_PARENS_6_ARG(_0, _1, _2, _3, _4, _5) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_5_ARG(_1, _2, _3, _4, _5)
+#define INTERNAL_CATCH_REMOVE_PARENS_7_ARG(_0, _1, _2, _3, _4, _5, _6) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_6_ARG(_1, _2, _4, _5, _6)
+#define INTERNAL_CATCH_REMOVE_PARENS_8_ARG(_0, _1, _2, _3, _4, _5, _6, _7) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_7_ARG(_1, _2, _3, _4, _5, _6, _7)
+#define INTERNAL_CATCH_REMOVE_PARENS_9_ARG(_0, _1, _2, _3, _4, _5, _6, _7, _8) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_8_ARG(_1, _2, _3, _4, _5, _6, _7, _8)
+#define INTERNAL_CATCH_REMOVE_PARENS_10_ARG(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_9_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9)
+#define INTERNAL_CATCH_REMOVE_PARENS_11_ARG(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) INTERNAL_CATCH_REMOVE_PARENS(_0), INTERNAL_CATCH_REMOVE_PARENS_10_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10)
+
+#define INTERNAL_CATCH_VA_NARGS_IMPL(_0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N
+
+#define INTERNAL_CATCH_TYPE_GEN\
+ template<typename...> struct TypeList {};\
+ template<typename...Ts>\
+ constexpr auto get_wrapper() noexcept -> TypeList<Ts...> { return {}; }\
+ template<template<typename...> class...> struct TemplateTypeList{};\
+ template<template<typename...> class...Cs>\
+ constexpr auto get_wrapper() noexcept -> TemplateTypeList<Cs...> { return {}; }\
+ template<typename...>\
+ struct append;\
+ template<typename...>\
+ struct rewrap;\
+ template<template<typename...> class, typename...>\
+ struct create;\
+ template<template<typename...> class, typename>\
+ struct convert;\
+ \
+ template<typename T> \
+ struct append<T> { using type = T; };\
+ template< template<typename...> class L1, typename...E1, template<typename...> class L2, typename...E2, typename...Rest>\
+ struct append<L1<E1...>, L2<E2...>, Rest...> { using type = typename append<L1<E1...,E2...>, Rest...>::type; };\
+ template< template<typename...> class L1, typename...E1, typename...Rest>\
+ struct append<L1<E1...>, TypeList<mpl_::na>, Rest...> { using type = L1<E1...>; };\
+ \
+ template< template<typename...> class Container, template<typename...> class List, typename...elems>\
+ struct rewrap<TemplateTypeList<Container>, List<elems...>> { using type = TypeList<Container<elems...>>; };\
+ template< template<typename...> class Container, template<typename...> class List, class...Elems, typename...Elements>\
+ struct rewrap<TemplateTypeList<Container>, List<Elems...>, Elements...> { using type = typename append<TypeList<Container<Elems...>>, typename rewrap<TemplateTypeList<Container>, Elements...>::type>::type; };\
+ \
+ template<template <typename...> class Final, template< typename...> class...Containers, typename...Types>\
+ struct create<Final, TemplateTypeList<Containers...>, TypeList<Types...>> { using type = typename append<Final<>, typename rewrap<TemplateTypeList<Containers>, Types...>::type...>::type; };\
+ template<template <typename...> class Final, template <typename...> class List, typename...Ts>\
+ struct convert<Final, List<Ts...>> { using type = typename append<Final<>,TypeList<Ts>...>::type; };
+
+#define INTERNAL_CATCH_NTTP_1(signature, ...)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)> struct Nttp{};\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)>\
+ constexpr auto get_wrapper() noexcept -> Nttp<__VA_ARGS__> { return {}; } \
+ template<template<INTERNAL_CATCH_REMOVE_PARENS(signature)> class...> struct NttpTemplateTypeList{};\
+ template<template<INTERNAL_CATCH_REMOVE_PARENS(signature)> class...Cs>\
+ constexpr auto get_wrapper() noexcept -> NttpTemplateTypeList<Cs...> { return {}; } \
+ \
+ template< template<INTERNAL_CATCH_REMOVE_PARENS(signature)> class Container, template<INTERNAL_CATCH_REMOVE_PARENS(signature)> class List, INTERNAL_CATCH_REMOVE_PARENS(signature)>\
+ struct rewrap<NttpTemplateTypeList<Container>, List<__VA_ARGS__>> { using type = TypeList<Container<__VA_ARGS__>>; };\
+ template< template<INTERNAL_CATCH_REMOVE_PARENS(signature)> class Container, template<INTERNAL_CATCH_REMOVE_PARENS(signature)> class List, INTERNAL_CATCH_REMOVE_PARENS(signature), typename...Elements>\
+ struct rewrap<NttpTemplateTypeList<Container>, List<__VA_ARGS__>, Elements...> { using type = typename append<TypeList<Container<__VA_ARGS__>>, typename rewrap<NttpTemplateTypeList<Container>, Elements...>::type>::type; };\
+ template<template <typename...> class Final, template<INTERNAL_CATCH_REMOVE_PARENS(signature)> class...Containers, typename...Types>\
+ struct create<Final, NttpTemplateTypeList<Containers...>, TypeList<Types...>> { using type = typename append<Final<>, typename rewrap<NttpTemplateTypeList<Containers>, Types...>::type...>::type; };
+
+#define INTERNAL_CATCH_DECLARE_SIG_TEST0(TestName)
+#define INTERNAL_CATCH_DECLARE_SIG_TEST1(TestName, signature)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)>\
+ static void TestName()
+#define INTERNAL_CATCH_DECLARE_SIG_TEST_X(TestName, signature, ...)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)>\
+ static void TestName()
+
+#define INTERNAL_CATCH_DEFINE_SIG_TEST0(TestName)
+#define INTERNAL_CATCH_DEFINE_SIG_TEST1(TestName, signature)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)>\
+ static void TestName()
+#define INTERNAL_CATCH_DEFINE_SIG_TEST_X(TestName, signature,...)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)>\
+ static void TestName()
+
+#define INTERNAL_CATCH_NTTP_REGISTER0(TestFunc, signature)\
+ template<typename Type>\
+ void reg_test(TypeList<Type>, Catch::NameAndTags nameAndTags)\
+ {\
+ Catch::AutoReg( Catch::makeTestInvoker(&TestFunc<Type>), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), nameAndTags);\
+ }
+
+#define INTERNAL_CATCH_NTTP_REGISTER(TestFunc, signature, ...)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)>\
+ void reg_test(Nttp<__VA_ARGS__>, Catch::NameAndTags nameAndTags)\
+ {\
+ Catch::AutoReg( Catch::makeTestInvoker(&TestFunc<__VA_ARGS__>), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), nameAndTags);\
+ }
+
+#define INTERNAL_CATCH_NTTP_REGISTER_METHOD0(TestName, signature, ...)\
+ template<typename Type>\
+ void reg_test(TypeList<Type>, Catch::StringRef className, Catch::NameAndTags nameAndTags)\
+ {\
+ Catch::AutoReg( Catch::makeTestInvoker(&TestName<Type>::test), CATCH_INTERNAL_LINEINFO, className, nameAndTags);\
+ }
+
+#define INTERNAL_CATCH_NTTP_REGISTER_METHOD(TestName, signature, ...)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)>\
+ void reg_test(Nttp<__VA_ARGS__>, Catch::StringRef className, Catch::NameAndTags nameAndTags)\
+ {\
+ Catch::AutoReg( Catch::makeTestInvoker(&TestName<__VA_ARGS__>::test), CATCH_INTERNAL_LINEINFO, className, nameAndTags);\
+ }
+
+#define INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD0(TestName, ClassName)
+#define INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD1(TestName, ClassName, signature)\
+ template<typename TestType> \
+ struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName)<TestType> { \
+ void test();\
+ }
+
+#define INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X(TestName, ClassName, signature, ...)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)> \
+ struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName)<__VA_ARGS__> { \
+ void test();\
+ }
+
+#define INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD0(TestName)
+#define INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD1(TestName, signature)\
+ template<typename TestType> \
+ void INTERNAL_CATCH_MAKE_NAMESPACE(TestName)::TestName<TestType>::test()
+#define INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X(TestName, signature, ...)\
+ template<INTERNAL_CATCH_REMOVE_PARENS(signature)> \
+ void INTERNAL_CATCH_MAKE_NAMESPACE(TestName)::TestName<__VA_ARGS__>::test()
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+#define INTERNAL_CATCH_NTTP_0
+#define INTERNAL_CATCH_NTTP_GEN(...) INTERNAL_CATCH_VA_NARGS_IMPL(__VA_ARGS__, INTERNAL_CATCH_NTTP_1(__VA_ARGS__), INTERNAL_CATCH_NTTP_1(__VA_ARGS__), INTERNAL_CATCH_NTTP_1(__VA_ARGS__), INTERNAL_CATCH_NTTP_1(__VA_ARGS__), INTERNAL_CATCH_NTTP_1(__VA_ARGS__), INTERNAL_CATCH_NTTP_1( __VA_ARGS__), INTERNAL_CATCH_NTTP_1( __VA_ARGS__), INTERNAL_CATCH_NTTP_1( __VA_ARGS__), INTERNAL_CATCH_NTTP_1( __VA_ARGS__),INTERNAL_CATCH_NTTP_1( __VA_ARGS__), INTERNAL_CATCH_NTTP_0)
+#define INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD(TestName, ...) INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD1, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD0)(TestName, __VA_ARGS__)
+#define INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD(TestName, ClassName, ...) INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD1, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD0)(TestName, ClassName, __VA_ARGS__)
+#define INTERNAL_CATCH_NTTP_REG_METHOD_GEN(TestName, ...) INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD0, INTERNAL_CATCH_NTTP_REGISTER_METHOD0)(TestName, __VA_ARGS__)
+#define INTERNAL_CATCH_NTTP_REG_GEN(TestFunc, ...) INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER0, INTERNAL_CATCH_NTTP_REGISTER0)(TestFunc, __VA_ARGS__)
+#define INTERNAL_CATCH_DEFINE_SIG_TEST(TestName, ...) INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X,INTERNAL_CATCH_DEFINE_SIG_TEST_X,INTERNAL_CATCH_DEFINE_SIG_TEST1, INTERNAL_CATCH_DEFINE_SIG_TEST0)(TestName, __VA_ARGS__)
+#define INTERNAL_CATCH_DECLARE_SIG_TEST(TestName, ...) INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_DECLARE_SIG_TEST_X,INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X,INTERNAL_CATCH_DECLARE_SIG_TEST_X,INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST1, INTERNAL_CATCH_DECLARE_SIG_TEST0)(TestName, __VA_ARGS__)
+#define INTERNAL_CATCH_REMOVE_PARENS_GEN(...) INTERNAL_CATCH_VA_NARGS_IMPL(__VA_ARGS__, INTERNAL_CATCH_REMOVE_PARENS_11_ARG,INTERNAL_CATCH_REMOVE_PARENS_10_ARG,INTERNAL_CATCH_REMOVE_PARENS_9_ARG,INTERNAL_CATCH_REMOVE_PARENS_8_ARG,INTERNAL_CATCH_REMOVE_PARENS_7_ARG,INTERNAL_CATCH_REMOVE_PARENS_6_ARG,INTERNAL_CATCH_REMOVE_PARENS_5_ARG,INTERNAL_CATCH_REMOVE_PARENS_4_ARG,INTERNAL_CATCH_REMOVE_PARENS_3_ARG,INTERNAL_CATCH_REMOVE_PARENS_2_ARG,INTERNAL_CATCH_REMOVE_PARENS_1_ARG)(__VA_ARGS__)
+#else
+#define INTERNAL_CATCH_NTTP_0(signature)
+#define INTERNAL_CATCH_NTTP_GEN(...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_VA_NARGS_IMPL(__VA_ARGS__, INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_1,INTERNAL_CATCH_NTTP_1, INTERNAL_CATCH_NTTP_0)( __VA_ARGS__))
+#define INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD(TestName, ...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD1, INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD0)(TestName, __VA_ARGS__))
+#define INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD(TestName, ClassName, ...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X,INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD_X, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD1, INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD0)(TestName, ClassName, __VA_ARGS__))
+#define INTERNAL_CATCH_NTTP_REG_METHOD_GEN(TestName, ...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD, INTERNAL_CATCH_NTTP_REGISTER_METHOD0, INTERNAL_CATCH_NTTP_REGISTER_METHOD0)(TestName, __VA_ARGS__))
+#define INTERNAL_CATCH_NTTP_REG_GEN(TestFunc, ...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER, INTERNAL_CATCH_NTTP_REGISTER0, INTERNAL_CATCH_NTTP_REGISTER0)(TestFunc, __VA_ARGS__))
+#define INTERNAL_CATCH_DEFINE_SIG_TEST(TestName, ...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X,INTERNAL_CATCH_DEFINE_SIG_TEST_X,INTERNAL_CATCH_DEFINE_SIG_TEST1, INTERNAL_CATCH_DEFINE_SIG_TEST0)(TestName, __VA_ARGS__))
+#define INTERNAL_CATCH_DECLARE_SIG_TEST(TestName, ...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_VA_NARGS_IMPL( "dummy", __VA_ARGS__, INTERNAL_CATCH_DECLARE_SIG_TEST_X,INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DEFINE_SIG_TEST_X,INTERNAL_CATCH_DECLARE_SIG_TEST_X,INTERNAL_CATCH_DECLARE_SIG_TEST_X, INTERNAL_CATCH_DECLARE_SIG_TEST1, INTERNAL_CATCH_DECLARE_SIG_TEST0)(TestName, __VA_ARGS__))
+#define INTERNAL_CATCH_REMOVE_PARENS_GEN(...) INTERNAL_CATCH_EXPAND_VARGS(INTERNAL_CATCH_VA_NARGS_IMPL(__VA_ARGS__, INTERNAL_CATCH_REMOVE_PARENS_11_ARG,INTERNAL_CATCH_REMOVE_PARENS_10_ARG,INTERNAL_CATCH_REMOVE_PARENS_9_ARG,INTERNAL_CATCH_REMOVE_PARENS_8_ARG,INTERNAL_CATCH_REMOVE_PARENS_7_ARG,INTERNAL_CATCH_REMOVE_PARENS_6_ARG,INTERNAL_CATCH_REMOVE_PARENS_5_ARG,INTERNAL_CATCH_REMOVE_PARENS_4_ARG,INTERNAL_CATCH_REMOVE_PARENS_3_ARG,INTERNAL_CATCH_REMOVE_PARENS_2_ARG,INTERNAL_CATCH_REMOVE_PARENS_1_ARG)(__VA_ARGS__))
+#endif
- class TestSpec;
+// end catch_preprocessor.hpp
+// start catch_meta.hpp
- struct ITestCase : IShared {
- virtual void invoke () const = 0;
- protected:
- virtual ~ITestCase();
- };
- class TestCase;
- struct IConfig;
+#include <type_traits>
- struct ITestCaseRegistry {
- virtual ~ITestCaseRegistry();
- virtual std::vector<TestCase> const& getAllTests() const = 0;
- virtual std::vector<TestCase> const& getAllTestsSorted( IConfig const& config ) const = 0;
- };
+namespace Catch {
+template<typename T>
+struct always_false : std::false_type {};
+
+template <typename> struct true_given : std::true_type {};
+struct is_callable_tester {
+ template <typename Fun, typename... Args>
+ true_given<decltype(std::declval<Fun>()(std::declval<Args>()...))> static test(int);
+ template <typename...>
+ std::false_type static test(...);
+};
- bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config );
- std::vector<TestCase> filterTests( std::vector<TestCase> const& testCases, TestSpec const& testSpec, IConfig const& config );
- std::vector<TestCase> const& getAllTestCasesSorted( IConfig const& config );
+template <typename T>
+struct is_callable;
+
+template <typename Fun, typename... Args>
+struct is_callable<Fun(Args...)> : decltype(is_callable_tester::test<Fun, Args...>(0)) {};
+
+} // namespace Catch
+namespace mpl_{
+ struct na;
}
+// end catch_meta.hpp
namespace Catch {
template<typename C>
-class MethodTestCase : public SharedImpl<ITestCase> {
-
+class TestInvokerAsMethod : public ITestInvoker {
+ void (C::*m_testAsMethod)();
public:
- MethodTestCase( void (C::*method)() ) : m_method( method ) {}
+ TestInvokerAsMethod( void (C::*testAsMethod)() ) noexcept : m_testAsMethod( testAsMethod ) {}
- virtual void invoke() const {
+ void invoke() const override {
C obj;
- (obj.*m_method)();
+ (obj.*m_testAsMethod)();
}
-
-private:
- virtual ~MethodTestCase() {}
-
- void (C::*m_method)();
};
-typedef void(*TestFunction)();
+auto makeTestInvoker( void(*testAsFunction)() ) noexcept -> ITestInvoker*;
-struct NameAndDesc {
- NameAndDesc( const char* _name = "", const char* _description= "" )
- : name( _name ), description( _description )
- {}
+template<typename C>
+auto makeTestInvoker( void (C::*testAsMethod)() ) noexcept -> ITestInvoker* {
+ return new(std::nothrow) TestInvokerAsMethod<C>( testAsMethod );
+}
- const char* name;
- const char* description;
+struct NameAndTags {
+ NameAndTags( StringRef const& name_ = StringRef(), StringRef const& tags_ = StringRef() ) noexcept;
+ StringRef name;
+ StringRef tags;
};
-void registerTestCase
- ( ITestCase* testCase,
- char const* className,
- NameAndDesc const& nameAndDesc,
- SourceLineInfo const& lineInfo );
-
-struct AutoReg {
-
- AutoReg
- ( TestFunction function,
- SourceLineInfo const& lineInfo,
- NameAndDesc const& nameAndDesc );
-
- template<typename C>
- AutoReg
- ( void (C::*method)(),
- char const* className,
- NameAndDesc const& nameAndDesc,
- SourceLineInfo const& lineInfo ) {
+struct AutoReg : NonCopyable {
+ AutoReg( ITestInvoker* invoker, SourceLineInfo const& lineInfo, StringRef const& classOrMethod, NameAndTags const& nameAndTags ) noexcept;
+ ~AutoReg();
+};
- registerTestCase
- ( new MethodTestCase<C>( method ),
- className,
- nameAndDesc,
- lineInfo );
- }
+} // end namespace Catch
- ~AutoReg();
+#if defined(CATCH_CONFIG_DISABLE)
+ #define INTERNAL_CATCH_TESTCASE_NO_REGISTRATION( TestName, ... ) \
+ static void TestName()
+ #define INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION( TestName, ClassName, ... ) \
+ namespace{ \
+ struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName) { \
+ void test(); \
+ }; \
+ } \
+ void TestName::test()
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( TestName, TestFunc, Name, Tags, Signature, ... ) \
+ INTERNAL_CATCH_DEFINE_SIG_TEST(TestFunc, INTERNAL_CATCH_REMOVE_PARENS(Signature))
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( TestNameClass, TestName, ClassName, Name, Tags, Signature, ... ) \
+ namespace{ \
+ namespace INTERNAL_CATCH_MAKE_NAMESPACE(TestName) { \
+ INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD(TestName, ClassName, INTERNAL_CATCH_REMOVE_PARENS(Signature));\
+ } \
+ } \
+ INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD(TestName, INTERNAL_CATCH_REMOVE_PARENS(Signature))
+
+ #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(Name, Tags, ...) \
+ INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename TestType, __VA_ARGS__ )
+ #else
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(Name, Tags, ...) \
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename TestType, __VA_ARGS__ ) )
+ #endif
-private:
- AutoReg( AutoReg const& );
- void operator= ( AutoReg const& );
-};
+ #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG_NO_REGISTRATION(Name, Tags, Signature, ...) \
+ INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ )
+ #else
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG_NO_REGISTRATION(Name, Tags, Signature, ...) \
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ ) )
+ #endif
-void registerTestCaseFunction
- ( TestFunction function,
- SourceLineInfo const& lineInfo,
- NameAndDesc const& nameAndDesc );
+ #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION( ClassName, Name, Tags,... ) \
+ INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ )
+ #else
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION( ClassName, Name, Tags,... ) \
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) )
+ #endif
-} // end namespace Catch
+ #ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG_NO_REGISTRATION( ClassName, Name, Tags, Signature, ... ) \
+ INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ )
+ #else
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG_NO_REGISTRATION( ClassName, Name, Tags, Signature, ... ) \
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) )
+ #endif
+#endif
-#ifdef CATCH_CONFIG_VARIADIC_MACROS
///////////////////////////////////////////////////////////////////////////////
#define INTERNAL_CATCH_TESTCASE2( TestName, ... ) \
static void TestName(); \
- CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
- namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &TestName, CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( __VA_ARGS__ ) ); } \
- CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( &TestName ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ __VA_ARGS__ } ); } /* NOLINT */ \
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
static void TestName()
#define INTERNAL_CATCH_TESTCASE( ... ) \
INTERNAL_CATCH_TESTCASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), __VA_ARGS__ )
///////////////////////////////////////////////////////////////////////////////
#define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, ... ) \
- CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
- namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &QualifiedMethod, "&" #QualifiedMethod, Catch::NameAndDesc( __VA_ARGS__ ), CATCH_INTERNAL_LINEINFO ); } \
- CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( &QualifiedMethod ), CATCH_INTERNAL_LINEINFO, "&" #QualifiedMethod, Catch::NameAndTags{ __VA_ARGS__ } ); } /* NOLINT */ \
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS
///////////////////////////////////////////////////////////////////////////////
#define INTERNAL_CATCH_TEST_CASE_METHOD2( TestName, ClassName, ... )\
- CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
namespace{ \
- struct TestName : ClassName{ \
+ struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName) { \
void test(); \
}; \
- Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( &TestName::test, #ClassName, Catch::NameAndDesc( __VA_ARGS__ ), CATCH_INTERNAL_LINEINFO ); \
+ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( Catch::makeTestInvoker( &TestName::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ __VA_ARGS__ } ); /* NOLINT */ \
} \
- CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
void TestName::test()
#define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, ... ) \
INTERNAL_CATCH_TEST_CASE_METHOD2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), ClassName, __VA_ARGS__ )
///////////////////////////////////////////////////////////////////////////////
#define INTERNAL_CATCH_REGISTER_TESTCASE( Function, ... ) \
- CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
- Catch::AutoReg( Function, CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( __VA_ARGS__ ) ); \
- CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( Catch::makeTestInvoker( Function ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ __VA_ARGS__ } ); /* NOLINT */ \
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS
-#else
///////////////////////////////////////////////////////////////////////////////
- #define INTERNAL_CATCH_TESTCASE2( TestName, Name, Desc ) \
- static void TestName(); \
- CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
- namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &TestName, CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( Name, Desc ) ); }\
- CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \
- static void TestName()
- #define INTERNAL_CATCH_TESTCASE( Name, Desc ) \
- INTERNAL_CATCH_TESTCASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), Name, Desc )
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_2(TestName, TestFunc, Name, Tags, Signature, ... )\
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_ZERO_VARIADIC_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ INTERNAL_CATCH_DECLARE_SIG_TEST(TestFunc, INTERNAL_CATCH_REMOVE_PARENS(Signature));\
+ namespace {\
+ namespace INTERNAL_CATCH_MAKE_NAMESPACE(TestName){\
+ INTERNAL_CATCH_TYPE_GEN\
+ INTERNAL_CATCH_NTTP_GEN(INTERNAL_CATCH_REMOVE_PARENS(Signature))\
+ INTERNAL_CATCH_NTTP_REG_GEN(TestFunc,INTERNAL_CATCH_REMOVE_PARENS(Signature))\
+ template<typename...Types> \
+ struct TestName{\
+ TestName(){\
+ int index = 0; \
+ constexpr char const* tmpl_types[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, __VA_ARGS__)};\
+ using expander = int[];\
+ (void)expander{(reg_test(Types{}, Catch::NameAndTags{ Name " - " + std::string(tmpl_types[index]), Tags } ), index++, 0)... };/* NOLINT */ \
+ }\
+ };\
+ static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){\
+ TestName<INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(__VA_ARGS__)>();\
+ return 0;\
+ }();\
+ }\
+ }\
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_ZERO_VARIADIC_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ INTERNAL_CATCH_DEFINE_SIG_TEST(TestFunc,INTERNAL_CATCH_REMOVE_PARENS(Signature))
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE(Name, Tags, ...) \
+ INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename TestType, __VA_ARGS__ )
+#else
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE(Name, Tags, ...) \
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename TestType, __VA_ARGS__ ) )
+#endif
- ///////////////////////////////////////////////////////////////////////////////
- #define INTERNAL_CATCH_METHOD_AS_TEST_CASE( QualifiedMethod, Name, Desc ) \
- CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
- namespace{ Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar )( &QualifiedMethod, "&" #QualifiedMethod, Catch::NameAndDesc( Name, Desc ), CATCH_INTERNAL_LINEINFO ); } \
- CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG(Name, Tags, Signature, ...) \
+ INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ )
+#else
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG(Name, Tags, Signature, ...) \
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ ) )
+#endif
- ///////////////////////////////////////////////////////////////////////////////
- #define INTERNAL_CATCH_TEST_CASE_METHOD2( TestCaseName, ClassName, TestName, Desc )\
- CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
- namespace{ \
- struct TestCaseName : ClassName{ \
- void test(); \
- }; \
- Catch::AutoReg INTERNAL_CATCH_UNIQUE_NAME( autoRegistrar ) ( &TestCaseName::test, #ClassName, Catch::NameAndDesc( TestName, Desc ), CATCH_INTERNAL_LINEINFO ); \
- } \
- CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS \
- void TestCaseName::test()
- #define INTERNAL_CATCH_TEST_CASE_METHOD( ClassName, TestName, Desc )\
- INTERNAL_CATCH_TEST_CASE_METHOD2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), ClassName, TestName, Desc )
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(TestName, TestFuncName, Name, Tags, Signature, TmplTypes, TypesList) \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_ZERO_VARIADIC_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ template<typename TestType> static void TestFuncName(); \
+ namespace {\
+ namespace INTERNAL_CATCH_MAKE_NAMESPACE(TestName) { \
+ INTERNAL_CATCH_TYPE_GEN \
+ INTERNAL_CATCH_NTTP_GEN(INTERNAL_CATCH_REMOVE_PARENS(Signature)) \
+ template<typename... Types> \
+ struct TestName { \
+ void reg_tests() { \
+ int index = 0; \
+ using expander = int[]; \
+ constexpr char const* tmpl_types[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TmplTypes))};\
+ constexpr char const* types_list[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TypesList))};\
+ constexpr auto num_types = sizeof(types_list) / sizeof(types_list[0]);\
+ (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestFuncName<Types> ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ Name " - " + std::string(tmpl_types[index / num_types]) + "<" + std::string(types_list[index % num_types]) + ">", Tags } ), index++, 0)... };/* NOLINT */\
+ } \
+ }; \
+ static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){ \
+ using TestInit = typename create<TestName, decltype(get_wrapper<INTERNAL_CATCH_REMOVE_PARENS(TmplTypes)>()), TypeList<INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(INTERNAL_CATCH_REMOVE_PARENS(TypesList))>>::type; \
+ TestInit t; \
+ t.reg_tests(); \
+ return 0; \
+ }(); \
+ } \
+ } \
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_ZERO_VARIADIC_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ template<typename TestType> \
+ static void TestFuncName()
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE(Name, Tags, ...)\
+ INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename T,__VA_ARGS__)
+#else
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE(Name, Tags, ...)\
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, typename T, __VA_ARGS__ ) )
+#endif
- ///////////////////////////////////////////////////////////////////////////////
- #define INTERNAL_CATCH_REGISTER_TESTCASE( Function, Name, Desc ) \
- CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS \
- Catch::AutoReg( Function, CATCH_INTERNAL_LINEINFO, Catch::NameAndDesc( Name, Desc ) ); \
- CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG(Name, Tags, Signature, ...)\
+ INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__)
+#else
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG(Name, Tags, Signature, ...)\
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, Signature, __VA_ARGS__ ) )
+#endif
+ #define INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_2(TestName, TestFunc, Name, Tags, TmplList)\
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ template<typename TestType> static void TestFunc(); \
+ namespace {\
+ namespace INTERNAL_CATCH_MAKE_NAMESPACE(TestName){\
+ INTERNAL_CATCH_TYPE_GEN\
+ template<typename... Types> \
+ struct TestName { \
+ void reg_tests() { \
+ int index = 0; \
+ using expander = int[]; \
+ (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestFunc<Types> ), CATCH_INTERNAL_LINEINFO, Catch::StringRef(), Catch::NameAndTags{ Name " - " + std::string(INTERNAL_CATCH_STRINGIZE(TmplList)) + " - " + std::to_string(index), Tags } ), index++, 0)... };/* NOLINT */\
+ } \
+ };\
+ static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){ \
+ using TestInit = typename convert<TestName, TmplList>::type; \
+ TestInit t; \
+ t.reg_tests(); \
+ return 0; \
+ }(); \
+ }}\
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ template<typename TestType> \
+ static void TestFunc()
+
+ #define INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE(Name, Tags, TmplList) \
+ INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), Name, Tags, TmplList )
+
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( TestNameClass, TestName, ClassName, Name, Tags, Signature, ... ) \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_ZERO_VARIADIC_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ namespace {\
+ namespace INTERNAL_CATCH_MAKE_NAMESPACE(TestName){ \
+ INTERNAL_CATCH_TYPE_GEN\
+ INTERNAL_CATCH_NTTP_GEN(INTERNAL_CATCH_REMOVE_PARENS(Signature))\
+ INTERNAL_CATCH_DECLARE_SIG_TEST_METHOD(TestName, ClassName, INTERNAL_CATCH_REMOVE_PARENS(Signature));\
+ INTERNAL_CATCH_NTTP_REG_METHOD_GEN(TestName, INTERNAL_CATCH_REMOVE_PARENS(Signature))\
+ template<typename...Types> \
+ struct TestNameClass{\
+ TestNameClass(){\
+ int index = 0; \
+ constexpr char const* tmpl_types[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, __VA_ARGS__)};\
+ using expander = int[];\
+ (void)expander{(reg_test(Types{}, #ClassName, Catch::NameAndTags{ Name " - " + std::string(tmpl_types[index]), Tags } ), index++, 0)... };/* NOLINT */ \
+ }\
+ };\
+ static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){\
+ TestNameClass<INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(__VA_ARGS__)>();\
+ return 0;\
+ }();\
+ }\
+ }\
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS\
+ CATCH_INTERNAL_UNSUPPRESS_ZERO_VARIADIC_WARNINGS\
+ CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS\
+ INTERNAL_CATCH_DEFINE_SIG_TEST_METHOD(TestName, INTERNAL_CATCH_REMOVE_PARENS(Signature))
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( ClassName, Name, Tags,... ) \
+ INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ )
+#else
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( ClassName, Name, Tags,... ) \
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, typename T, __VA_ARGS__ ) )
#endif
-// #included from: internal/catch_capture.hpp
-#define TWOBLUECUBES_CATCH_CAPTURE_HPP_INCLUDED
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( ClassName, Name, Tags, Signature, ... ) \
+ INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ )
+#else
+ #define INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( ClassName, Name, Tags, Signature, ... ) \
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____C_L_A_S_S____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ) , ClassName, Name, Tags, Signature, __VA_ARGS__ ) )
+#endif
+
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2(TestNameClass, TestName, ClassName, Name, Tags, Signature, TmplTypes, TypesList)\
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_ZERO_VARIADIC_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ template<typename TestType> \
+ struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName <TestType>) { \
+ void test();\
+ };\
+ namespace {\
+ namespace INTERNAL_CATCH_MAKE_NAMESPACE(TestNameClass) {\
+ INTERNAL_CATCH_TYPE_GEN \
+ INTERNAL_CATCH_NTTP_GEN(INTERNAL_CATCH_REMOVE_PARENS(Signature))\
+ template<typename...Types>\
+ struct TestNameClass{\
+ void reg_tests(){\
+ int index = 0;\
+ using expander = int[];\
+ constexpr char const* tmpl_types[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TmplTypes))};\
+ constexpr char const* types_list[] = {CATCH_REC_LIST(INTERNAL_CATCH_STRINGIZE_WITHOUT_PARENS, INTERNAL_CATCH_REMOVE_PARENS(TypesList))};\
+ constexpr auto num_types = sizeof(types_list) / sizeof(types_list[0]);\
+ (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestName<Types>::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ Name " - " + std::string(tmpl_types[index / num_types]) + "<" + std::string(types_list[index % num_types]) + ">", Tags } ), index++, 0)... };/* NOLINT */ \
+ }\
+ };\
+ static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){\
+ using TestInit = typename create<TestNameClass, decltype(get_wrapper<INTERNAL_CATCH_REMOVE_PARENS(TmplTypes)>()), TypeList<INTERNAL_CATCH_MAKE_TYPE_LISTS_FROM_TYPES(INTERNAL_CATCH_REMOVE_PARENS(TypesList))>>::type;\
+ TestInit t;\
+ t.reg_tests();\
+ return 0;\
+ }(); \
+ }\
+ }\
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_ZERO_VARIADIC_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ template<typename TestType> \
+ void TestName<TestType>::test()
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( ClassName, Name, Tags, ... )\
+ INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, typename T, __VA_ARGS__ )
+#else
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( ClassName, Name, Tags, ... )\
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, typename T,__VA_ARGS__ ) )
+#endif
-// #included from: catch_result_builder.h
-#define TWOBLUECUBES_CATCH_RESULT_BUILDER_H_INCLUDED
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( ClassName, Name, Tags, Signature, ... )\
+ INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, Signature, __VA_ARGS__ )
+#else
+ #define INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( ClassName, Name, Tags, Signature, ... )\
+ INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, Signature,__VA_ARGS__ ) )
+#endif
-// #included from: catch_result_type.h
-#define TWOBLUECUBES_CATCH_RESULT_TYPE_H_INCLUDED
+ #define INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD_2( TestNameClass, TestName, ClassName, Name, Tags, TmplList) \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_SUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ template<typename TestType> \
+ struct TestName : INTERNAL_CATCH_REMOVE_PARENS(ClassName <TestType>) { \
+ void test();\
+ };\
+ namespace {\
+ namespace INTERNAL_CATCH_MAKE_NAMESPACE(TestName){ \
+ INTERNAL_CATCH_TYPE_GEN\
+ template<typename...Types>\
+ struct TestNameClass{\
+ void reg_tests(){\
+ int index = 0;\
+ using expander = int[];\
+ (void)expander{(Catch::AutoReg( Catch::makeTestInvoker( &TestName<Types>::test ), CATCH_INTERNAL_LINEINFO, #ClassName, Catch::NameAndTags{ Name " - " + std::string(INTERNAL_CATCH_STRINGIZE(TmplList)) + " - " + std::to_string(index), Tags } ), index++, 0)... };/* NOLINT */ \
+ }\
+ };\
+ static int INTERNAL_CATCH_UNIQUE_NAME( globalRegistrar ) = [](){\
+ using TestInit = typename convert<TestNameClass, TmplList>::type;\
+ TestInit t;\
+ t.reg_tests();\
+ return 0;\
+ }(); \
+ }}\
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
+ CATCH_INTERNAL_UNSUPPRESS_UNUSED_TEMPLATE_WARNINGS \
+ template<typename TestType> \
+ void TestName<TestType>::test()
+
+#define INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD(ClassName, Name, Tags, TmplList) \
+ INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD_2( INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____ ), INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_M_P_L_A_T_E____T_E_S_T____F_U_N_C____ ), ClassName, Name, Tags, TmplList )
+
+// end catch_test_registry.h
+// start catch_capture.hpp
+
+// start catch_assertionhandler.h
+
+// start catch_assertioninfo.h
+
+// start catch_result_type.h
namespace Catch {
@@ -874,12 +1369,8 @@ namespace Catch {
}; };
- inline bool isOk( ResultWas::OfType resultType ) {
- return ( resultType & ResultWas::FailureBit ) == 0;
- }
- inline bool isJustInfo( int flags ) {
- return flags == ResultWas::Info;
- }
+ bool isOk( ResultWas::OfType resultType );
+ bool isJustInfo( int flags );
// ResultDisposition::Flags enum
struct ResultDisposition { enum Flags {
@@ -890,601 +1381,123 @@ namespace Catch {
SuppressFail = 0x08 // Failures are reported but do not fail the test
}; };
- inline ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs ) {
- return static_cast<ResultDisposition::Flags>( static_cast<int>( lhs ) | static_cast<int>( rhs ) );
- }
+ ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs );
- inline bool shouldContinueOnFailure( int flags ) { return ( flags & ResultDisposition::ContinueOnFailure ) != 0; }
- inline bool isFalseTest( int flags ) { return ( flags & ResultDisposition::FalseTest ) != 0; }
- inline bool shouldSuppressFailure( int flags ) { return ( flags & ResultDisposition::SuppressFail ) != 0; }
+ bool shouldContinueOnFailure( int flags );
+ inline bool isFalseTest( int flags ) { return ( flags & ResultDisposition::FalseTest ) != 0; }
+ bool shouldSuppressFailure( int flags );
} // end namespace Catch
-// #included from: catch_assertionresult.h
-#define TWOBLUECUBES_CATCH_ASSERTIONRESULT_H_INCLUDED
-
-#include <string>
-
+// end catch_result_type.h
namespace Catch {
- struct STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison;
-
- struct DecomposedExpression
- {
- virtual ~DecomposedExpression() {}
- virtual bool isBinaryExpression() const {
- return false;
- }
- virtual void reconstructExpression( std::string& dest ) const = 0;
-
- // Only simple binary comparisons can be decomposed.
- // If more complex check is required then wrap sub-expressions in parentheses.
- template<typename T> STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator + ( T const& );
- template<typename T> STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator - ( T const& );
- template<typename T> STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator * ( T const& );
- template<typename T> STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator / ( T const& );
- template<typename T> STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator % ( T const& );
- template<typename T> STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator && ( T const& );
- template<typename T> STATIC_ASSERT_Expression_Too_Complex_Please_Rewrite_As_Binary_Comparison& operator || ( T const& );
-
- private:
- DecomposedExpression& operator = (DecomposedExpression const&);
- };
-
struct AssertionInfo
{
- AssertionInfo() {}
- AssertionInfo( std::string const& _macroName,
- SourceLineInfo const& _lineInfo,
- std::string const& _capturedExpression,
- ResultDisposition::Flags _resultDisposition );
-
- std::string macroName;
+ StringRef macroName;
SourceLineInfo lineInfo;
- std::string capturedExpression;
+ StringRef capturedExpression;
ResultDisposition::Flags resultDisposition;
- };
-
- struct AssertionResultData
- {
- AssertionResultData() : decomposedExpression( CATCH_NULL )
- , resultType( ResultWas::Unknown )
- , negated( false )
- , parenthesized( false ) {}
-
- void negate( bool parenthesize ) {
- negated = !negated;
- parenthesized = parenthesize;
- if( resultType == ResultWas::Ok )
- resultType = ResultWas::ExpressionFailed;
- else if( resultType == ResultWas::ExpressionFailed )
- resultType = ResultWas::Ok;
- }
-
- std::string const& reconstructExpression() const {
- if( decomposedExpression != CATCH_NULL ) {
- decomposedExpression->reconstructExpression( reconstructedExpression );
- if( parenthesized ) {
- reconstructedExpression.insert( 0, 1, '(' );
- reconstructedExpression.append( 1, ')' );
- }
- if( negated ) {
- reconstructedExpression.insert( 0, 1, '!' );
- }
- decomposedExpression = CATCH_NULL;
- }
- return reconstructedExpression;
- }
-
- mutable DecomposedExpression const* decomposedExpression;
- mutable std::string reconstructedExpression;
- std::string message;
- ResultWas::OfType resultType;
- bool negated;
- bool parenthesized;
- };
-
- class AssertionResult {
- public:
- AssertionResult();
- AssertionResult( AssertionInfo const& info, AssertionResultData const& data );
- ~AssertionResult();
-# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS
- AssertionResult( AssertionResult const& ) = default;
- AssertionResult( AssertionResult && ) = default;
- AssertionResult& operator = ( AssertionResult const& ) = default;
- AssertionResult& operator = ( AssertionResult && ) = default;
-# endif
- bool isOk() const;
- bool succeeded() const;
- ResultWas::OfType getResultType() const;
- bool hasExpression() const;
- bool hasMessage() const;
- std::string getExpression() const;
- std::string getExpressionInMacro() const;
- bool hasExpandedExpression() const;
- std::string getExpandedExpression() const;
- std::string getMessage() const;
- SourceLineInfo getSourceInfo() const;
- std::string getTestMacroName() const;
- void discardDecomposedExpression() const;
- void expandDecomposedExpression() const;
-
- protected:
- AssertionInfo m_info;
- AssertionResultData m_resultData;
+ // We want to delete this constructor but a compiler bug in 4.8 means
+ // the struct is then treated as non-aggregate
+ //AssertionInfo() = delete;
};
} // end namespace Catch
-// #included from: catch_matchers.hpp
-#define TWOBLUECUBES_CATCH_MATCHERS_HPP_INCLUDED
-
-namespace Catch {
-namespace Matchers {
- namespace Impl {
-
- template<typename ArgT> struct MatchAllOf;
- template<typename ArgT> struct MatchAnyOf;
- template<typename ArgT> struct MatchNotOf;
-
- class MatcherUntypedBase {
- public:
- std::string toString() const {
- if( m_cachedToString.empty() )
- m_cachedToString = describe();
- return m_cachedToString;
- }
-
- protected:
- virtual ~MatcherUntypedBase();
- virtual std::string describe() const = 0;
- mutable std::string m_cachedToString;
- private:
- MatcherUntypedBase& operator = ( MatcherUntypedBase const& );
- };
-
- template<typename ObjectT>
- struct MatcherMethod {
- virtual bool match( ObjectT const& arg ) const = 0;
- };
- template<typename PtrT>
- struct MatcherMethod<PtrT*> {
- virtual bool match( PtrT* arg ) const = 0;
- };
-
- template<typename ObjectT, typename ComparatorT = ObjectT>
- struct MatcherBase : MatcherUntypedBase, MatcherMethod<ObjectT> {
-
- MatchAllOf<ComparatorT> operator && ( MatcherBase const& other ) const;
- MatchAnyOf<ComparatorT> operator || ( MatcherBase const& other ) const;
- MatchNotOf<ComparatorT> operator ! () const;
- };
-
- template<typename ArgT>
- struct MatchAllOf : MatcherBase<ArgT> {
- virtual bool match( ArgT const& arg ) const CATCH_OVERRIDE {
- for( std::size_t i = 0; i < m_matchers.size(); ++i ) {
- if (!m_matchers[i]->match(arg))
- return false;
- }
- return true;
- }
- virtual std::string describe() const CATCH_OVERRIDE {
- std::string description;
- description.reserve( 4 + m_matchers.size()*32 );
- description += "( ";
- for( std::size_t i = 0; i < m_matchers.size(); ++i ) {
- if( i != 0 )
- description += " and ";
- description += m_matchers[i]->toString();
- }
- description += " )";
- return description;
- }
-
- MatchAllOf<ArgT>& operator && ( MatcherBase<ArgT> const& other ) {
- m_matchers.push_back( &other );
- return *this;
- }
-
- std::vector<MatcherBase<ArgT> const*> m_matchers;
- };
- template<typename ArgT>
- struct MatchAnyOf : MatcherBase<ArgT> {
-
- virtual bool match( ArgT const& arg ) const CATCH_OVERRIDE {
- for( std::size_t i = 0; i < m_matchers.size(); ++i ) {
- if (m_matchers[i]->match(arg))
- return true;
- }
- return false;
- }
- virtual std::string describe() const CATCH_OVERRIDE {
- std::string description;
- description.reserve( 4 + m_matchers.size()*32 );
- description += "( ";
- for( std::size_t i = 0; i < m_matchers.size(); ++i ) {
- if( i != 0 )
- description += " or ";
- description += m_matchers[i]->toString();
- }
- description += " )";
- return description;
- }
-
- MatchAnyOf<ArgT>& operator || ( MatcherBase<ArgT> const& other ) {
- m_matchers.push_back( &other );
- return *this;
- }
-
- std::vector<MatcherBase<ArgT> const*> m_matchers;
- };
-
- template<typename ArgT>
- struct MatchNotOf : MatcherBase<ArgT> {
+// end catch_assertioninfo.h
+// start catch_decomposer.h
- MatchNotOf( MatcherBase<ArgT> const& underlyingMatcher ) : m_underlyingMatcher( underlyingMatcher ) {}
+// start catch_tostring.h
- virtual bool match( ArgT const& arg ) const CATCH_OVERRIDE {
- return !m_underlyingMatcher.match( arg );
- }
-
- virtual std::string describe() const CATCH_OVERRIDE {
- return "not " + m_underlyingMatcher.toString();
- }
- MatcherBase<ArgT> const& m_underlyingMatcher;
- };
-
- template<typename ObjectT, typename ComparatorT>
- MatchAllOf<ComparatorT> MatcherBase<ObjectT, ComparatorT>::operator && ( MatcherBase const& other ) const {
- return MatchAllOf<ComparatorT>() && *this && other;
- }
- template<typename ObjectT, typename ComparatorT>
- MatchAnyOf<ComparatorT> MatcherBase<ObjectT, ComparatorT>::operator || ( MatcherBase const& other ) const {
- return MatchAnyOf<ComparatorT>() || *this || other;
- }
- template<typename ObjectT, typename ComparatorT>
- MatchNotOf<ComparatorT> MatcherBase<ObjectT, ComparatorT>::operator ! () const {
- return MatchNotOf<ComparatorT>( *this );
- }
-
- } // namespace Impl
-
- // The following functions create the actual matcher objects.
- // This allows the types to be inferred
- // - deprecated: prefer ||, && and !
- template<typename T>
- inline Impl::MatchNotOf<T> Not( Impl::MatcherBase<T> const& underlyingMatcher ) {
- return Impl::MatchNotOf<T>( underlyingMatcher );
- }
- template<typename T>
- inline Impl::MatchAllOf<T> AllOf( Impl::MatcherBase<T> const& m1, Impl::MatcherBase<T> const& m2 ) {
- return Impl::MatchAllOf<T>() && m1 && m2;
- }
- template<typename T>
- inline Impl::MatchAllOf<T> AllOf( Impl::MatcherBase<T> const& m1, Impl::MatcherBase<T> const& m2, Impl::MatcherBase<T> const& m3 ) {
- return Impl::MatchAllOf<T>() && m1 && m2 && m3;
- }
- template<typename T>
- inline Impl::MatchAnyOf<T> AnyOf( Impl::MatcherBase<T> const& m1, Impl::MatcherBase<T> const& m2 ) {
- return Impl::MatchAnyOf<T>() || m1 || m2;
- }
- template<typename T>
- inline Impl::MatchAnyOf<T> AnyOf( Impl::MatcherBase<T> const& m1, Impl::MatcherBase<T> const& m2, Impl::MatcherBase<T> const& m3 ) {
- return Impl::MatchAnyOf<T>() || m1 || m2 || m3;
- }
-
-} // namespace Matchers
-
-using namespace Matchers;
-using Matchers::Impl::MatcherBase;
+#include <vector>
+#include <cstddef>
+#include <type_traits>
+#include <string>
+// start catch_stream.h
-} // namespace Catch
+#include <iosfwd>
+#include <cstddef>
+#include <ostream>
namespace Catch {
- struct TestFailureException{};
+ std::ostream& cout();
+ std::ostream& cerr();
+ std::ostream& clog();
- template<typename T> class ExpressionLhs;
+ class StringRef;
- struct CopyableStream {
- CopyableStream() {}
- CopyableStream( CopyableStream const& other ) {
- oss << other.oss.str();
- }
- CopyableStream& operator=( CopyableStream const& other ) {
- oss.str(std::string());
- oss << other.oss.str();
- return *this;
- }
- std::ostringstream oss;
+ struct IStream {
+ virtual ~IStream();
+ virtual std::ostream& stream() const = 0;
};
- class ResultBuilder : public DecomposedExpression {
+ auto makeStream( StringRef const &filename ) -> IStream const*;
+
+ class ReusableStringStream {
+ std::size_t m_index;
+ std::ostream* m_oss;
public:
- ResultBuilder( char const* macroName,
- SourceLineInfo const& lineInfo,
- char const* capturedExpression,
- ResultDisposition::Flags resultDisposition,
- char const* secondArg = "" );
- ~ResultBuilder();
+ ReusableStringStream();
+ ~ReusableStringStream();
- template<typename T>
- ExpressionLhs<T const&> operator <= ( T const& operand );
- ExpressionLhs<bool> operator <= ( bool value );
+ auto str() const -> std::string;
template<typename T>
- ResultBuilder& operator << ( T const& value ) {
- m_stream.oss << value;
+ auto operator << ( T const& value ) -> ReusableStringStream& {
+ *m_oss << value;
return *this;
}
-
- ResultBuilder& setResultType( ResultWas::OfType result );
- ResultBuilder& setResultType( bool result );
-
- void endExpression( DecomposedExpression const& expr );
-
- virtual void reconstructExpression( std::string& dest ) const CATCH_OVERRIDE;
-
- AssertionResult build() const;
- AssertionResult build( DecomposedExpression const& expr ) const;
-
- void useActiveException( ResultDisposition::Flags resultDisposition = ResultDisposition::Normal );
- void captureResult( ResultWas::OfType resultType );
- void captureExpression();
- void captureExpectedException( std::string const& expectedMessage );
- void captureExpectedException( Matchers::Impl::MatcherBase<std::string> const& matcher );
- void handleResult( AssertionResult const& result );
- void react();
- bool shouldDebugBreak() const;
- bool allowThrows() const;
-
- template<typename ArgT, typename MatcherT>
- void captureMatch( ArgT const& arg, MatcherT const& matcher, char const* matcherString );
-
- void setExceptionGuard();
- void unsetExceptionGuard();
-
- private:
- AssertionInfo m_assertionInfo;
- AssertionResultData m_data;
- CopyableStream m_stream;
-
- bool m_shouldDebugBreak;
- bool m_shouldThrow;
- bool m_guardException;
+ auto get() -> std::ostream& { return *m_oss; }
};
+}
-} // namespace Catch
-
-// Include after due to circular dependency:
-// #included from: catch_expression_lhs.hpp
-#define TWOBLUECUBES_CATCH_EXPRESSION_LHS_HPP_INCLUDED
-
-// #included from: catch_evaluate.hpp
-#define TWOBLUECUBES_CATCH_EVALUATE_HPP_INCLUDED
-
-#ifdef _MSC_VER
-#pragma warning(push)
-#pragma warning(disable:4389) // '==' : signed/unsigned mismatch
-#pragma warning(disable:4312) // Converting int to T* using reinterpret_cast (issue on x64 platform)
-#endif
+// end catch_stream.h
+// start catch_interfaces_enum_values_registry.h
-#include <cstddef>
+#include <vector>
namespace Catch {
-namespace Internal {
- enum Operator {
- IsEqualTo,
- IsNotEqualTo,
- IsLessThan,
- IsGreaterThan,
- IsLessThanOrEqualTo,
- IsGreaterThanOrEqualTo
- };
+ namespace Detail {
+ struct EnumInfo {
+ StringRef m_name;
+ std::vector<std::pair<int, StringRef>> m_values;
- template<Operator Op> struct OperatorTraits { static const char* getName(){ return "*error*"; } };
- template<> struct OperatorTraits<IsEqualTo> { static const char* getName(){ return "=="; } };
- template<> struct OperatorTraits<IsNotEqualTo> { static const char* getName(){ return "!="; } };
- template<> struct OperatorTraits<IsLessThan> { static const char* getName(){ return "<"; } };
- template<> struct OperatorTraits<IsGreaterThan> { static const char* getName(){ return ">"; } };
- template<> struct OperatorTraits<IsLessThanOrEqualTo> { static const char* getName(){ return "<="; } };
- template<> struct OperatorTraits<IsGreaterThanOrEqualTo>{ static const char* getName(){ return ">="; } };
+ ~EnumInfo();
- template<typename T>
- inline T& opCast(T const& t) { return const_cast<T&>(t); }
+ StringRef lookup( int value ) const;
+ };
+ } // namespace Detail
-// nullptr_t support based on pull request #154 from Konstantin Baumann
-#ifdef CATCH_CONFIG_CPP11_NULLPTR
- inline std::nullptr_t opCast(std::nullptr_t) { return nullptr; }
-#endif // CATCH_CONFIG_CPP11_NULLPTR
+ struct IMutableEnumValuesRegistry {
+ virtual ~IMutableEnumValuesRegistry();
- // So the compare overloads can be operator agnostic we convey the operator as a template
- // enum, which is used to specialise an Evaluator for doing the comparison.
- template<typename T1, typename T2, Operator Op>
- class Evaluator{};
+ virtual Detail::EnumInfo const& registerEnum( StringRef enumName, StringRef allEnums, std::vector<int> const& values ) = 0;
- template<typename T1, typename T2>
- struct Evaluator<T1, T2, IsEqualTo> {
- static bool evaluate( T1 const& lhs, T2 const& rhs) {
- return bool( opCast( lhs ) == opCast( rhs ) );
- }
- };
- template<typename T1, typename T2>
- struct Evaluator<T1, T2, IsNotEqualTo> {
- static bool evaluate( T1 const& lhs, T2 const& rhs ) {
- return bool( opCast( lhs ) != opCast( rhs ) );
+ template<typename E>
+ Detail::EnumInfo const& registerEnum( StringRef enumName, StringRef allEnums, std::initializer_list<E> values ) {
+ static_assert(sizeof(int) >= sizeof(E), "Cannot serialize enum to int");
+ std::vector<int> intValues;
+ intValues.reserve( values.size() );
+ for( auto enumValue : values )
+ intValues.push_back( static_cast<int>( enumValue ) );
+ return registerEnum( enumName, allEnums, intValues );
}
};
- template<typename T1, typename T2>
- struct Evaluator<T1, T2, IsLessThan> {
- static bool evaluate( T1 const& lhs, T2 const& rhs ) {
- return bool( opCast( lhs ) < opCast( rhs ) );
- }
- };
- template<typename T1, typename T2>
- struct Evaluator<T1, T2, IsGreaterThan> {
- static bool evaluate( T1 const& lhs, T2 const& rhs ) {
- return bool( opCast( lhs ) > opCast( rhs ) );
- }
- };
- template<typename T1, typename T2>
- struct Evaluator<T1, T2, IsGreaterThanOrEqualTo> {
- static bool evaluate( T1 const& lhs, T2 const& rhs ) {
- return bool( opCast( lhs ) >= opCast( rhs ) );
- }
- };
- template<typename T1, typename T2>
- struct Evaluator<T1, T2, IsLessThanOrEqualTo> {
- static bool evaluate( T1 const& lhs, T2 const& rhs ) {
- return bool( opCast( lhs ) <= opCast( rhs ) );
- }
- };
-
- template<Operator Op, typename T1, typename T2>
- bool applyEvaluator( T1 const& lhs, T2 const& rhs ) {
- return Evaluator<T1, T2, Op>::evaluate( lhs, rhs );
- }
- // This level of indirection allows us to specialise for integer types
- // to avoid signed/ unsigned warnings
+} // Catch
- // "base" overload
- template<Operator Op, typename T1, typename T2>
- bool compare( T1 const& lhs, T2 const& rhs ) {
- return Evaluator<T1, T2, Op>::evaluate( lhs, rhs );
- }
-
- // unsigned X to int
- template<Operator Op> bool compare( unsigned int lhs, int rhs ) {
- return applyEvaluator<Op>( lhs, static_cast<unsigned int>( rhs ) );
- }
- template<Operator Op> bool compare( unsigned long lhs, int rhs ) {
- return applyEvaluator<Op>( lhs, static_cast<unsigned int>( rhs ) );
- }
- template<Operator Op> bool compare( unsigned char lhs, int rhs ) {
- return applyEvaluator<Op>( lhs, static_cast<unsigned int>( rhs ) );
- }
-
- // unsigned X to long
- template<Operator Op> bool compare( unsigned int lhs, long rhs ) {
- return applyEvaluator<Op>( lhs, static_cast<unsigned long>( rhs ) );
- }
- template<Operator Op> bool compare( unsigned long lhs, long rhs ) {
- return applyEvaluator<Op>( lhs, static_cast<unsigned long>( rhs ) );
- }
- template<Operator Op> bool compare( unsigned char lhs, long rhs ) {
- return applyEvaluator<Op>( lhs, static_cast<unsigned long>( rhs ) );
- }
-
- // int to unsigned X
- template<Operator Op> bool compare( int lhs, unsigned int rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned int>( lhs ), rhs );
- }
- template<Operator Op> bool compare( int lhs, unsigned long rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned int>( lhs ), rhs );
- }
- template<Operator Op> bool compare( int lhs, unsigned char rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned int>( lhs ), rhs );
- }
-
- // long to unsigned X
- template<Operator Op> bool compare( long lhs, unsigned int rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned long>( lhs ), rhs );
- }
- template<Operator Op> bool compare( long lhs, unsigned long rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned long>( lhs ), rhs );
- }
- template<Operator Op> bool compare( long lhs, unsigned char rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned long>( lhs ), rhs );
- }
-
- // pointer to long (when comparing against NULL)
- template<Operator Op, typename T> bool compare( long lhs, T* rhs ) {
- return Evaluator<T*, T*, Op>::evaluate( reinterpret_cast<T*>( lhs ), rhs );
- }
- template<Operator Op, typename T> bool compare( T* lhs, long rhs ) {
- return Evaluator<T*, T*, Op>::evaluate( lhs, reinterpret_cast<T*>( rhs ) );
- }
-
- // pointer to int (when comparing against NULL)
- template<Operator Op, typename T> bool compare( int lhs, T* rhs ) {
- return Evaluator<T*, T*, Op>::evaluate( reinterpret_cast<T*>( lhs ), rhs );
- }
- template<Operator Op, typename T> bool compare( T* lhs, int rhs ) {
- return Evaluator<T*, T*, Op>::evaluate( lhs, reinterpret_cast<T*>( rhs ) );
- }
+// end catch_interfaces_enum_values_registry.h
-#ifdef CATCH_CONFIG_CPP11_LONG_LONG
- // long long to unsigned X
- template<Operator Op> bool compare( long long lhs, unsigned int rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned long>( lhs ), rhs );
- }
- template<Operator Op> bool compare( long long lhs, unsigned long rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned long>( lhs ), rhs );
- }
- template<Operator Op> bool compare( long long lhs, unsigned long long rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned long>( lhs ), rhs );
- }
- template<Operator Op> bool compare( long long lhs, unsigned char rhs ) {
- return applyEvaluator<Op>( static_cast<unsigned long>( lhs ), rhs );
- }
-
- // unsigned long long to X
- template<Operator Op> bool compare( unsigned long long lhs, int rhs ) {
- return applyEvaluator<Op>( static_cast<long>( lhs ), rhs );
- }
- template<Operator Op> bool compare( unsigned long long lhs, long rhs ) {
- return applyEvaluator<Op>( static_cast<long>( lhs ), rhs );
- }
- template<Operator Op> bool compare( unsigned long long lhs, long long rhs ) {
- return applyEvaluator<Op>( static_cast<long>( lhs ), rhs );
- }
- template<Operator Op> bool compare( unsigned long long lhs, char rhs ) {
- return applyEvaluator<Op>( static_cast<long>( lhs ), rhs );
- }
-
- // pointer to long long (when comparing against NULL)
- template<Operator Op, typename T> bool compare( long long lhs, T* rhs ) {
- return Evaluator<T*, T*, Op>::evaluate( reinterpret_cast<T*>( lhs ), rhs );
- }
- template<Operator Op, typename T> bool compare( T* lhs, long long rhs ) {
- return Evaluator<T*, T*, Op>::evaluate( lhs, reinterpret_cast<T*>( rhs ) );
- }
-#endif // CATCH_CONFIG_CPP11_LONG_LONG
-
-#ifdef CATCH_CONFIG_CPP11_NULLPTR
- // pointer to nullptr_t (when comparing against nullptr)
- template<Operator Op, typename T> bool compare( std::nullptr_t, T* rhs ) {
- return Evaluator<T*, T*, Op>::evaluate( nullptr, rhs );
- }
- template<Operator Op, typename T> bool compare( T* lhs, std::nullptr_t ) {
- return Evaluator<T*, T*, Op>::evaluate( lhs, nullptr );
- }
-#endif // CATCH_CONFIG_CPP11_NULLPTR
-
-} // end of namespace Internal
-} // end of namespace Catch
-
-#ifdef _MSC_VER
-#pragma warning(pop)
+#ifdef CATCH_CONFIG_CPP17_STRING_VIEW
+#include <string_view>
#endif
-// #included from: catch_tostring.h
-#define TWOBLUECUBES_CATCH_TOSTRING_H_INCLUDED
-
-#include <sstream>
-#include <iomanip>
-#include <limits>
-#include <vector>
-#include <cstddef>
-
#ifdef __OBJC__
-// #included from: catch_objc_arc.hpp
-#define TWOBLUECUBES_CATCH_OBJC_ARC_HPP_INCLUDED
+// start catch_objc_arc.hpp
#import <Foundation/Foundation.h>
@@ -1526,853 +1539,1299 @@ inline id performOptionalSelector( id obj, SEL sel ) {
#define CATCH_ARC_STRONG __strong
#endif
+// end catch_objc_arc.hpp
#endif
-#ifdef CATCH_CONFIG_CPP11_TUPLE
-#include <tuple>
-#endif
-
-#ifdef CATCH_CONFIG_CPP11_IS_ENUM
-#include <type_traits>
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable:4180) // We attempt to stream a function (address) by const&, which MSVC complains about but is harmless
#endif
namespace Catch {
+ namespace Detail {
-// Why we're here.
-template<typename T>
-std::string toString( T const& value );
+ extern const std::string unprintableString;
-// Built in overloads
+ std::string rawMemoryToString( const void *object, std::size_t size );
-std::string toString( std::string const& value );
-std::string toString( std::wstring const& value );
-std::string toString( const char* const value );
-std::string toString( char* const value );
-std::string toString( const wchar_t* const value );
-std::string toString( wchar_t* const value );
-std::string toString( int value );
-std::string toString( unsigned long value );
-std::string toString( unsigned int value );
-std::string toString( const double value );
-std::string toString( const float value );
-std::string toString( bool value );
-std::string toString( char value );
-std::string toString( signed char value );
-std::string toString( unsigned char value );
+ template<typename T>
+ std::string rawMemoryToString( const T& object ) {
+ return rawMemoryToString( &object, sizeof(object) );
+ }
-#ifdef CATCH_CONFIG_CPP11_LONG_LONG
-std::string toString( long long value );
-std::string toString( unsigned long long value );
+ template<typename T>
+ class IsStreamInsertable {
+ template<typename Stream, typename U>
+ static auto test(int)
+ -> decltype(std::declval<Stream&>() << std::declval<U>(), std::true_type());
+
+ template<typename, typename>
+ static auto test(...)->std::false_type;
+
+ public:
+ static const bool value = decltype(test<std::ostream, const T&>(0))::value;
+ };
+
+ template<typename E>
+ std::string convertUnknownEnumToString( E e );
+
+ template<typename T>
+ typename std::enable_if<
+ !std::is_enum<T>::value && !std::is_base_of<std::exception, T>::value,
+ std::string>::type convertUnstreamable( T const& ) {
+ return Detail::unprintableString;
+ }
+ template<typename T>
+ typename std::enable_if<
+ !std::is_enum<T>::value && std::is_base_of<std::exception, T>::value,
+ std::string>::type convertUnstreamable(T const& ex) {
+ return ex.what();
+ }
+
+ template<typename T>
+ typename std::enable_if<
+ std::is_enum<T>::value
+ , std::string>::type convertUnstreamable( T const& value ) {
+ return convertUnknownEnumToString( value );
+ }
+
+#if defined(_MANAGED)
+ //! Convert a CLR string to a utf8 std::string
+ template<typename T>
+ std::string clrReferenceToString( T^ ref ) {
+ if (ref == nullptr)
+ return std::string("null");
+ auto bytes = System::Text::Encoding::UTF8->GetBytes(ref->ToString());
+ cli::pin_ptr<System::Byte> p = &bytes[0];
+ return std::string(reinterpret_cast<char const *>(p), bytes->Length);
+ }
#endif
-#ifdef CATCH_CONFIG_CPP11_NULLPTR
-std::string toString( std::nullptr_t );
+ } // namespace Detail
+
+ // If we decide for C++14, change these to enable_if_ts
+ template <typename T, typename = void>
+ struct StringMaker {
+ template <typename Fake = T>
+ static
+ typename std::enable_if<::Catch::Detail::IsStreamInsertable<Fake>::value, std::string>::type
+ convert(const Fake& value) {
+ ReusableStringStream rss;
+ // NB: call using the function-like syntax to avoid ambiguity with
+ // user-defined templated operator<< under clang.
+ rss.operator<<(value);
+ return rss.str();
+ }
+
+ template <typename Fake = T>
+ static
+ typename std::enable_if<!::Catch::Detail::IsStreamInsertable<Fake>::value, std::string>::type
+ convert( const Fake& value ) {
+#if !defined(CATCH_CONFIG_FALLBACK_STRINGIFIER)
+ return Detail::convertUnstreamable(value);
+#else
+ return CATCH_CONFIG_FALLBACK_STRINGIFIER(value);
#endif
+ }
+ };
-#ifdef __OBJC__
- std::string toString( NSString const * const& nsstring );
- std::string toString( NSString * CATCH_ARC_STRONG const& nsstring );
- std::string toString( NSObject* const& nsObject );
+ namespace Detail {
+
+ // This function dispatches all stringification requests inside of Catch.
+ // Should be preferably called fully qualified, like ::Catch::Detail::stringify
+ template <typename T>
+ std::string stringify(const T& e) {
+ return ::Catch::StringMaker<typename std::remove_cv<typename std::remove_reference<T>::type>::type>::convert(e);
+ }
+
+ template<typename E>
+ std::string convertUnknownEnumToString( E e ) {
+ return ::Catch::Detail::stringify(static_cast<typename std::underlying_type<E>::type>(e));
+ }
+
+#if defined(_MANAGED)
+ template <typename T>
+ std::string stringify( T^ e ) {
+ return ::Catch::StringMaker<T^>::convert(e);
+ }
#endif
-namespace Detail {
+ } // namespace Detail
- extern const std::string unprintableString;
+ // Some predefined specializations
- #if !defined(CATCH_CONFIG_CPP11_STREAM_INSERTABLE_CHECK)
- struct BorgType {
- template<typename T> BorgType( T const& );
+ template<>
+ struct StringMaker<std::string> {
+ static std::string convert(const std::string& str);
};
- struct TrueType { char sizer[1]; };
- struct FalseType { char sizer[2]; };
-
- TrueType& testStreamable( std::ostream& );
- FalseType testStreamable( FalseType );
+#ifdef CATCH_CONFIG_CPP17_STRING_VIEW
+ template<>
+ struct StringMaker<std::string_view> {
+ static std::string convert(std::string_view str);
+ };
+#endif
- FalseType operator<<( std::ostream const&, BorgType const& );
+ template<>
+ struct StringMaker<char const *> {
+ static std::string convert(char const * str);
+ };
+ template<>
+ struct StringMaker<char *> {
+ static std::string convert(char * str);
+ };
- template<typename T>
- struct IsStreamInsertable {
- static std::ostream &s;
- static T const&t;
- enum { value = sizeof( testStreamable(s << t) ) == sizeof( TrueType ) };
+#ifdef CATCH_CONFIG_WCHAR
+ template<>
+ struct StringMaker<std::wstring> {
+ static std::string convert(const std::wstring& wstr);
};
-#else
- template<typename T>
- class IsStreamInsertable {
- template<typename SS, typename TT>
- static auto test(int)
- -> decltype( std::declval<SS&>() << std::declval<TT>(), std::true_type() );
- template<typename, typename>
- static auto test(...) -> std::false_type;
+# ifdef CATCH_CONFIG_CPP17_STRING_VIEW
+ template<>
+ struct StringMaker<std::wstring_view> {
+ static std::string convert(std::wstring_view str);
+ };
+# endif
- public:
- static const bool value = decltype(test<std::ostream,const T&>(0))::value;
+ template<>
+ struct StringMaker<wchar_t const *> {
+ static std::string convert(wchar_t const * str);
+ };
+ template<>
+ struct StringMaker<wchar_t *> {
+ static std::string convert(wchar_t * str);
};
#endif
-#if defined(CATCH_CONFIG_CPP11_IS_ENUM)
- template<typename T,
- bool IsEnum = std::is_enum<T>::value
- >
- struct EnumStringMaker
- {
- static std::string convert( T const& ) { return unprintableString; }
+ // TBD: Should we use `strnlen` to ensure that we don't go out of the buffer,
+ // while keeping string semantics?
+ template<int SZ>
+ struct StringMaker<char[SZ]> {
+ static std::string convert(char const* str) {
+ return ::Catch::Detail::stringify(std::string{ str });
+ }
};
-
- template<typename T>
- struct EnumStringMaker<T,true>
- {
- static std::string convert( T const& v )
- {
- return ::Catch::toString(
- static_cast<typename std::underlying_type<T>::type>(v)
- );
+ template<int SZ>
+ struct StringMaker<signed char[SZ]> {
+ static std::string convert(signed char const* str) {
+ return ::Catch::Detail::stringify(std::string{ reinterpret_cast<char const *>(str) });
}
};
-#endif
- template<bool C>
- struct StringMakerBase {
-#if defined(CATCH_CONFIG_CPP11_IS_ENUM)
- template<typename T>
- static std::string convert( T const& v )
- {
- return EnumStringMaker<T>::convert( v );
+ template<int SZ>
+ struct StringMaker<unsigned char[SZ]> {
+ static std::string convert(unsigned char const* str) {
+ return ::Catch::Detail::stringify(std::string{ reinterpret_cast<char const *>(str) });
}
-#else
- template<typename T>
- static std::string convert( T const& ) { return unprintableString; }
-#endif
};
+#if defined(CATCH_CONFIG_CPP17_BYTE)
template<>
- struct StringMakerBase<true> {
- template<typename T>
- static std::string convert( T const& _value ) {
- std::ostringstream oss;
- oss << _value;
- return oss.str();
- }
+ struct StringMaker<std::byte> {
+ static std::string convert(std::byte value);
+ };
+#endif // defined(CATCH_CONFIG_CPP17_BYTE)
+ template<>
+ struct StringMaker<int> {
+ static std::string convert(int value);
+ };
+ template<>
+ struct StringMaker<long> {
+ static std::string convert(long value);
+ };
+ template<>
+ struct StringMaker<long long> {
+ static std::string convert(long long value);
+ };
+ template<>
+ struct StringMaker<unsigned int> {
+ static std::string convert(unsigned int value);
+ };
+ template<>
+ struct StringMaker<unsigned long> {
+ static std::string convert(unsigned long value);
+ };
+ template<>
+ struct StringMaker<unsigned long long> {
+ static std::string convert(unsigned long long value);
};
- std::string rawMemoryToString( const void *object, std::size_t size );
+ template<>
+ struct StringMaker<bool> {
+ static std::string convert(bool b);
+ };
- template<typename T>
- inline std::string rawMemoryToString( const T& object ) {
- return rawMemoryToString( &object, sizeof(object) );
- }
+ template<>
+ struct StringMaker<char> {
+ static std::string convert(char c);
+ };
+ template<>
+ struct StringMaker<signed char> {
+ static std::string convert(signed char c);
+ };
+ template<>
+ struct StringMaker<unsigned char> {
+ static std::string convert(unsigned char c);
+ };
-} // end namespace Detail
+ template<>
+ struct StringMaker<std::nullptr_t> {
+ static std::string convert(std::nullptr_t);
+ };
-template<typename T>
-struct StringMaker :
- Detail::StringMakerBase<Detail::IsStreamInsertable<T>::value> {};
+ template<>
+ struct StringMaker<float> {
+ static std::string convert(float value);
+ static int precision;
+ };
-template<typename T>
-struct StringMaker<T*> {
- template<typename U>
- static std::string convert( U* p ) {
- if( !p )
- return "NULL";
- else
- return Detail::rawMemoryToString( p );
- }
-};
+ template<>
+ struct StringMaker<double> {
+ static std::string convert(double value);
+ static int precision;
+ };
-template<typename R, typename C>
-struct StringMaker<R C::*> {
- static std::string convert( R C::* p ) {
- if( !p )
- return "NULL";
- else
- return Detail::rawMemoryToString( p );
+ template <typename T>
+ struct StringMaker<T*> {
+ template <typename U>
+ static std::string convert(U* p) {
+ if (p) {
+ return ::Catch::Detail::rawMemoryToString(p);
+ } else {
+ return "nullptr";
+ }
+ }
+ };
+
+ template <typename R, typename C>
+ struct StringMaker<R C::*> {
+ static std::string convert(R C::* p) {
+ if (p) {
+ return ::Catch::Detail::rawMemoryToString(p);
+ } else {
+ return "nullptr";
+ }
+ }
+ };
+
+#if defined(_MANAGED)
+ template <typename T>
+ struct StringMaker<T^> {
+ static std::string convert( T^ ref ) {
+ return ::Catch::Detail::clrReferenceToString(ref);
+ }
+ };
+#endif
+
+ namespace Detail {
+ template<typename InputIterator>
+ std::string rangeToString(InputIterator first, InputIterator last) {
+ ReusableStringStream rss;
+ rss << "{ ";
+ if (first != last) {
+ rss << ::Catch::Detail::stringify(*first);
+ for (++first; first != last; ++first)
+ rss << ", " << ::Catch::Detail::stringify(*first);
+ }
+ rss << " }";
+ return rss.str();
+ }
}
-};
-namespace Detail {
- template<typename InputIterator>
- std::string rangeToString( InputIterator first, InputIterator last );
-}
-
-//template<typename T, typename Allocator>
-//struct StringMaker<std::vector<T, Allocator> > {
-// static std::string convert( std::vector<T,Allocator> const& v ) {
-// return Detail::rangeToString( v.begin(), v.end() );
-// }
-//};
-
-template<typename T, typename Allocator>
-std::string toString( std::vector<T,Allocator> const& v ) {
- return Detail::rangeToString( v.begin(), v.end() );
-}
-
-#ifdef CATCH_CONFIG_CPP11_TUPLE
-
-// toString for tuples
-namespace TupleDetail {
- template<
- typename Tuple,
- std::size_t N = 0,
- bool = (N < std::tuple_size<Tuple>::value)
- >
- struct ElementPrinter {
- static void print( const Tuple& tuple, std::ostream& os )
- {
- os << ( N ? ", " : " " )
- << Catch::toString(std::get<N>(tuple));
- ElementPrinter<Tuple,N+1>::print(tuple,os);
- }
- };
+#ifdef __OBJC__
+ template<>
+ struct StringMaker<NSString*> {
+ static std::string convert(NSString * nsstring) {
+ if (!nsstring)
+ return "nil";
+ return std::string("@") + [nsstring UTF8String];
+ }
+ };
+ template<>
+ struct StringMaker<NSObject*> {
+ static std::string convert(NSObject* nsObject) {
+ return ::Catch::Detail::stringify([nsObject description]);
+ }
- template<
- typename Tuple,
- std::size_t N
- >
- struct ElementPrinter<Tuple,N,false> {
- static void print( const Tuple&, std::ostream& ) {}
- };
+ };
+ namespace Detail {
+ inline std::string stringify( NSString* nsstring ) {
+ return StringMaker<NSString*>::convert( nsstring );
+ }
-}
+ } // namespace Detail
+#endif // __OBJC__
-template<typename ...Types>
-struct StringMaker<std::tuple<Types...>> {
+} // namespace Catch
- static std::string convert( const std::tuple<Types...>& tuple )
- {
- std::ostringstream os;
- os << '{';
- TupleDetail::ElementPrinter<std::tuple<Types...>>::print( tuple, os );
- os << " }";
- return os.str();
- }
-};
-#endif // CATCH_CONFIG_CPP11_TUPLE
+//////////////////////////////////////////////////////
+// Separate std-lib types stringification, so it can be selectively enabled
+// This means that we do not bring in
-namespace Detail {
- template<typename T>
- std::string makeString( T const& value ) {
- return StringMaker<T>::convert( value );
- }
-} // end namespace Detail
+#if defined(CATCH_CONFIG_ENABLE_ALL_STRINGMAKERS)
+# define CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER
+# define CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER
+# define CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER
+# define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER
+# define CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER
+#endif
-/// \brief converts any type to a string
-///
-/// The default template forwards on to ostringstream - except when an
-/// ostringstream overload does not exist - in which case it attempts to detect
-/// that and writes {?}.
-/// Overload (not specialise) this template for custom typs that you don't want
-/// to provide an ostream overload for.
-template<typename T>
-std::string toString( T const& value ) {
- return StringMaker<T>::convert( value );
+// Separate std::pair specialization
+#if defined(CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER)
+#include <utility>
+namespace Catch {
+ template<typename T1, typename T2>
+ struct StringMaker<std::pair<T1, T2> > {
+ static std::string convert(const std::pair<T1, T2>& pair) {
+ ReusableStringStream rss;
+ rss << "{ "
+ << ::Catch::Detail::stringify(pair.first)
+ << ", "
+ << ::Catch::Detail::stringify(pair.second)
+ << " }";
+ return rss.str();
+ }
+ };
+}
+#endif // CATCH_CONFIG_ENABLE_PAIR_STRINGMAKER
+
+#if defined(CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER) && defined(CATCH_CONFIG_CPP17_OPTIONAL)
+#include <optional>
+namespace Catch {
+ template<typename T>
+ struct StringMaker<std::optional<T> > {
+ static std::string convert(const std::optional<T>& optional) {
+ ReusableStringStream rss;
+ if (optional.has_value()) {
+ rss << ::Catch::Detail::stringify(*optional);
+ } else {
+ rss << "{ }";
+ }
+ return rss.str();
+ }
+ };
}
+#endif // CATCH_CONFIG_ENABLE_OPTIONAL_STRINGMAKER
+// Separate std::tuple specialization
+#if defined(CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER)
+#include <tuple>
+namespace Catch {
namespace Detail {
- template<typename InputIterator>
- std::string rangeToString( InputIterator first, InputIterator last ) {
- std::ostringstream oss;
- oss << "{ ";
- if( first != last ) {
- oss << Catch::toString( *first );
- for( ++first ; first != last ; ++first )
- oss << ", " << Catch::toString( *first );
- }
- oss << " }";
- return oss.str();
+ template<
+ typename Tuple,
+ std::size_t N = 0,
+ bool = (N < std::tuple_size<Tuple>::value)
+ >
+ struct TupleElementPrinter {
+ static void print(const Tuple& tuple, std::ostream& os) {
+ os << (N ? ", " : " ")
+ << ::Catch::Detail::stringify(std::get<N>(tuple));
+ TupleElementPrinter<Tuple, N + 1>::print(tuple, os);
+ }
+ };
+
+ template<
+ typename Tuple,
+ std::size_t N
+ >
+ struct TupleElementPrinter<Tuple, N, false> {
+ static void print(const Tuple&, std::ostream&) {}
+ };
+
}
-}
-} // end namespace Catch
+ template<typename ...Types>
+ struct StringMaker<std::tuple<Types...>> {
+ static std::string convert(const std::tuple<Types...>& tuple) {
+ ReusableStringStream rss;
+ rss << '{';
+ Detail::TupleElementPrinter<std::tuple<Types...>>::print(tuple, rss.get());
+ rss << " }";
+ return rss.str();
+ }
+ };
+}
+#endif // CATCH_CONFIG_ENABLE_TUPLE_STRINGMAKER
+#if defined(CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER) && defined(CATCH_CONFIG_CPP17_VARIANT)
+#include <variant>
namespace Catch {
+ template<>
+ struct StringMaker<std::monostate> {
+ static std::string convert(const std::monostate&) {
+ return "{ }";
+ }
+ };
-template<typename LhsT, Internal::Operator Op, typename RhsT>
-class BinaryExpression;
+ template<typename... Elements>
+ struct StringMaker<std::variant<Elements...>> {
+ static std::string convert(const std::variant<Elements...>& variant) {
+ if (variant.valueless_by_exception()) {
+ return "{valueless variant}";
+ } else {
+ return std::visit(
+ [](const auto& value) {
+ return ::Catch::Detail::stringify(value);
+ },
+ variant
+ );
+ }
+ }
+ };
+}
+#endif // CATCH_CONFIG_ENABLE_VARIANT_STRINGMAKER
-template<typename ArgT, typename MatcherT>
-class MatchExpression;
+namespace Catch {
+ struct not_this_one {}; // Tag type for detecting which begin/ end are being selected
-// Wraps the LHS of an expression and overloads comparison operators
-// for also capturing those and RHS (if any)
-template<typename T>
-class ExpressionLhs : public DecomposedExpression {
-public:
- ExpressionLhs( ResultBuilder& rb, T lhs ) : m_rb( rb ), m_lhs( lhs ), m_truthy(false) {}
+ // Import begin/ end from std here so they are considered alongside the fallback (...) overloads in this namespace
+ using std::begin;
+ using std::end;
- ExpressionLhs& operator = ( const ExpressionLhs& );
+ not_this_one begin( ... );
+ not_this_one end( ... );
- template<typename RhsT>
- BinaryExpression<T, Internal::IsEqualTo, RhsT const&>
- operator == ( RhsT const& rhs ) {
- return captureExpression<Internal::IsEqualTo>( rhs );
- }
+ template <typename T>
+ struct is_range {
+ static const bool value =
+ !std::is_same<decltype(begin(std::declval<T>())), not_this_one>::value &&
+ !std::is_same<decltype(end(std::declval<T>())), not_this_one>::value;
+ };
- template<typename RhsT>
- BinaryExpression<T, Internal::IsNotEqualTo, RhsT const&>
- operator != ( RhsT const& rhs ) {
- return captureExpression<Internal::IsNotEqualTo>( rhs );
- }
+#if defined(_MANAGED) // Managed types are never ranges
+ template <typename T>
+ struct is_range<T^> {
+ static const bool value = false;
+ };
+#endif
- template<typename RhsT>
- BinaryExpression<T, Internal::IsLessThan, RhsT const&>
- operator < ( RhsT const& rhs ) {
- return captureExpression<Internal::IsLessThan>( rhs );
+ template<typename Range>
+ std::string rangeToString( Range const& range ) {
+ return ::Catch::Detail::rangeToString( begin( range ), end( range ) );
}
- template<typename RhsT>
- BinaryExpression<T, Internal::IsGreaterThan, RhsT const&>
- operator > ( RhsT const& rhs ) {
- return captureExpression<Internal::IsGreaterThan>( rhs );
+ // Handle vector<bool> specially
+ template<typename Allocator>
+ std::string rangeToString( std::vector<bool, Allocator> const& v ) {
+ ReusableStringStream rss;
+ rss << "{ ";
+ bool first = true;
+ for( bool b : v ) {
+ if( first )
+ first = false;
+ else
+ rss << ", ";
+ rss << ::Catch::Detail::stringify( b );
+ }
+ rss << " }";
+ return rss.str();
}
- template<typename RhsT>
- BinaryExpression<T, Internal::IsLessThanOrEqualTo, RhsT const&>
- operator <= ( RhsT const& rhs ) {
- return captureExpression<Internal::IsLessThanOrEqualTo>( rhs );
- }
+ template<typename R>
+ struct StringMaker<R, typename std::enable_if<is_range<R>::value && !::Catch::Detail::IsStreamInsertable<R>::value>::type> {
+ static std::string convert( R const& range ) {
+ return rangeToString( range );
+ }
+ };
- template<typename RhsT>
- BinaryExpression<T, Internal::IsGreaterThanOrEqualTo, RhsT const&>
- operator >= ( RhsT const& rhs ) {
- return captureExpression<Internal::IsGreaterThanOrEqualTo>( rhs );
- }
+ template <typename T, int SZ>
+ struct StringMaker<T[SZ]> {
+ static std::string convert(T const(&arr)[SZ]) {
+ return rangeToString(arr);
+ }
+ };
- BinaryExpression<T, Internal::IsEqualTo, bool> operator == ( bool rhs ) {
- return captureExpression<Internal::IsEqualTo>( rhs );
- }
+} // namespace Catch
- BinaryExpression<T, Internal::IsNotEqualTo, bool> operator != ( bool rhs ) {
- return captureExpression<Internal::IsNotEqualTo>( rhs );
- }
+// Separate std::chrono::duration specialization
+#if defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER)
+#include <ctime>
+#include <ratio>
+#include <chrono>
- void endExpression() {
- m_truthy = m_lhs ? true : false;
- m_rb
- .setResultType( m_truthy )
- .endExpression( *this );
- }
+namespace Catch {
- virtual void reconstructExpression( std::string& dest ) const CATCH_OVERRIDE {
- dest = Catch::toString( m_truthy );
- }
+template <class Ratio>
+struct ratio_string {
+ static std::string symbol();
+};
-private:
- template<Internal::Operator Op, typename RhsT>
- BinaryExpression<T, Op, RhsT&> captureExpression( RhsT& rhs ) const {
- return BinaryExpression<T, Op, RhsT&>( m_rb, m_lhs, rhs );
- }
+template <class Ratio>
+std::string ratio_string<Ratio>::symbol() {
+ Catch::ReusableStringStream rss;
+ rss << '[' << Ratio::num << '/'
+ << Ratio::den << ']';
+ return rss.str();
+}
+template <>
+struct ratio_string<std::atto> {
+ static std::string symbol();
+};
+template <>
+struct ratio_string<std::femto> {
+ static std::string symbol();
+};
+template <>
+struct ratio_string<std::pico> {
+ static std::string symbol();
+};
+template <>
+struct ratio_string<std::nano> {
+ static std::string symbol();
+};
+template <>
+struct ratio_string<std::micro> {
+ static std::string symbol();
+};
+template <>
+struct ratio_string<std::milli> {
+ static std::string symbol();
+};
- template<Internal::Operator Op>
- BinaryExpression<T, Op, bool> captureExpression( bool rhs ) const {
- return BinaryExpression<T, Op, bool>( m_rb, m_lhs, rhs );
- }
+ ////////////
+ // std::chrono::duration specializations
+ template<typename Value, typename Ratio>
+ struct StringMaker<std::chrono::duration<Value, Ratio>> {
+ static std::string convert(std::chrono::duration<Value, Ratio> const& duration) {
+ ReusableStringStream rss;
+ rss << duration.count() << ' ' << ratio_string<Ratio>::symbol() << 's';
+ return rss.str();
+ }
+ };
+ template<typename Value>
+ struct StringMaker<std::chrono::duration<Value, std::ratio<1>>> {
+ static std::string convert(std::chrono::duration<Value, std::ratio<1>> const& duration) {
+ ReusableStringStream rss;
+ rss << duration.count() << " s";
+ return rss.str();
+ }
+ };
+ template<typename Value>
+ struct StringMaker<std::chrono::duration<Value, std::ratio<60>>> {
+ static std::string convert(std::chrono::duration<Value, std::ratio<60>> const& duration) {
+ ReusableStringStream rss;
+ rss << duration.count() << " m";
+ return rss.str();
+ }
+ };
+ template<typename Value>
+ struct StringMaker<std::chrono::duration<Value, std::ratio<3600>>> {
+ static std::string convert(std::chrono::duration<Value, std::ratio<3600>> const& duration) {
+ ReusableStringStream rss;
+ rss << duration.count() << " h";
+ return rss.str();
+ }
+ };
-private:
- ResultBuilder& m_rb;
- T m_lhs;
- bool m_truthy;
-};
+ ////////////
+ // std::chrono::time_point specialization
+ // Generic time_point cannot be specialized, only std::chrono::time_point<system_clock>
+ template<typename Clock, typename Duration>
+ struct StringMaker<std::chrono::time_point<Clock, Duration>> {
+ static std::string convert(std::chrono::time_point<Clock, Duration> const& time_point) {
+ return ::Catch::Detail::stringify(time_point.time_since_epoch()) + " since epoch";
+ }
+ };
+ // std::chrono::time_point<system_clock> specialization
+ template<typename Duration>
+ struct StringMaker<std::chrono::time_point<std::chrono::system_clock, Duration>> {
+ static std::string convert(std::chrono::time_point<std::chrono::system_clock, Duration> const& time_point) {
+ auto converted = std::chrono::system_clock::to_time_t(time_point);
-template<typename LhsT, Internal::Operator Op, typename RhsT>
-class BinaryExpression : public DecomposedExpression {
-public:
- BinaryExpression( ResultBuilder& rb, LhsT lhs, RhsT rhs )
- : m_rb( rb ), m_lhs( lhs ), m_rhs( rhs ) {}
+#ifdef _MSC_VER
+ std::tm timeInfo = {};
+ gmtime_s(&timeInfo, &converted);
+#else
+ std::tm* timeInfo = std::gmtime(&converted);
+#endif
- BinaryExpression& operator = ( BinaryExpression& );
+ auto const timeStampSize = sizeof("2017-01-16T17:06:45Z");
+ char timeStamp[timeStampSize];
+ const char * const fmt = "%Y-%m-%dT%H:%M:%SZ";
- void endExpression() const {
- m_rb
- .setResultType( Internal::compare<Op>( m_lhs, m_rhs ) )
- .endExpression( *this );
- }
+#ifdef _MSC_VER
+ std::strftime(timeStamp, timeStampSize, fmt, &timeInfo);
+#else
+ std::strftime(timeStamp, timeStampSize, fmt, timeInfo);
+#endif
+ return std::string(timeStamp);
+ }
+ };
+}
+#endif // CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER
+
+#define INTERNAL_CATCH_REGISTER_ENUM( enumName, ... ) \
+namespace Catch { \
+ template<> struct StringMaker<enumName> { \
+ static std::string convert( enumName value ) { \
+ static const auto& enumInfo = ::Catch::getMutableRegistryHub().getMutableEnumValuesRegistry().registerEnum( #enumName, #__VA_ARGS__, { __VA_ARGS__ } ); \
+ return static_cast<std::string>(enumInfo.lookup( static_cast<int>( value ) )); \
+ } \
+ }; \
+}
- virtual bool isBinaryExpression() const CATCH_OVERRIDE {
- return true;
- }
+#define CATCH_REGISTER_ENUM( enumName, ... ) INTERNAL_CATCH_REGISTER_ENUM( enumName, __VA_ARGS__ )
- virtual void reconstructExpression( std::string& dest ) const CATCH_OVERRIDE {
- std::string lhs = Catch::toString( m_lhs );
- std::string rhs = Catch::toString( m_rhs );
- char delim = lhs.size() + rhs.size() < 40 &&
- lhs.find('\n') == std::string::npos &&
- rhs.find('\n') == std::string::npos ? ' ' : '\n';
- dest.reserve( 7 + lhs.size() + rhs.size() );
- // 2 for spaces around operator
- // 2 for operator
- // 2 for parentheses (conditionally added later)
- // 1 for negation (conditionally added later)
- dest = lhs;
- dest += delim;
- dest += Internal::OperatorTraits<Op>::getName();
- dest += delim;
- dest += rhs;
- }
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
-private:
- ResultBuilder& m_rb;
- LhsT m_lhs;
- RhsT m_rhs;
-};
+// end catch_tostring.h
+#include <iosfwd>
-template<typename ArgT, typename MatcherT>
-class MatchExpression : public DecomposedExpression {
-public:
- MatchExpression( ArgT arg, MatcherT matcher, char const* matcherString )
- : m_arg( arg ), m_matcher( matcher ), m_matcherString( matcherString ) {}
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable:4389) // '==' : signed/unsigned mismatch
+#pragma warning(disable:4018) // more "signed/unsigned mismatch"
+#pragma warning(disable:4312) // Converting int to T* using reinterpret_cast (issue on x64 platform)
+#pragma warning(disable:4180) // qualifier applied to function type has no meaning
+#pragma warning(disable:4800) // Forcing result to true or false
+#endif
- virtual bool isBinaryExpression() const CATCH_OVERRIDE {
- return true;
- }
+namespace Catch {
- virtual void reconstructExpression( std::string& dest ) const CATCH_OVERRIDE {
- std::string matcherAsString = m_matcher.toString();
- dest = Catch::toString( m_arg );
- dest += ' ';
- if( matcherAsString == Detail::unprintableString )
- dest += m_matcherString;
- else
- dest += matcherAsString;
- }
+ struct ITransientExpression {
+ auto isBinaryExpression() const -> bool { return m_isBinaryExpression; }
+ auto getResult() const -> bool { return m_result; }
+ virtual void streamReconstructedExpression( std::ostream &os ) const = 0;
-private:
- ArgT m_arg;
- MatcherT m_matcher;
- char const* m_matcherString;
-};
+ ITransientExpression( bool isBinaryExpression, bool result )
+ : m_isBinaryExpression( isBinaryExpression ),
+ m_result( result )
+ {}
-} // end namespace Catch
+ // We don't actually need a virtual destructor, but many static analysers
+ // complain if it's not here :-(
+ virtual ~ITransientExpression();
+ bool m_isBinaryExpression;
+ bool m_result;
-namespace Catch {
+ };
- template<typename T>
- inline ExpressionLhs<T const&> ResultBuilder::operator <= ( T const& operand ) {
- return ExpressionLhs<T const&>( *this, operand );
- }
+ void formatReconstructedExpression( std::ostream &os, std::string const& lhs, StringRef op, std::string const& rhs );
- inline ExpressionLhs<bool> ResultBuilder::operator <= ( bool value ) {
- return ExpressionLhs<bool>( *this, value );
- }
+ template<typename LhsT, typename RhsT>
+ class BinaryExpr : public ITransientExpression {
+ LhsT m_lhs;
+ StringRef m_op;
+ RhsT m_rhs;
- template<typename ArgT, typename MatcherT>
- inline void ResultBuilder::captureMatch( ArgT const& arg, MatcherT const& matcher,
- char const* matcherString ) {
- MatchExpression<ArgT const&, MatcherT const&> expr( arg, matcher, matcherString );
- setResultType( matcher.match( arg ) );
- endExpression( expr );
- }
+ void streamReconstructedExpression( std::ostream &os ) const override {
+ formatReconstructedExpression
+ ( os, Catch::Detail::stringify( m_lhs ), m_op, Catch::Detail::stringify( m_rhs ) );
+ }
-} // namespace Catch
+ public:
+ BinaryExpr( bool comparisonResult, LhsT lhs, StringRef op, RhsT rhs )
+ : ITransientExpression{ true, comparisonResult },
+ m_lhs( lhs ),
+ m_op( op ),
+ m_rhs( rhs )
+ {}
-// #included from: catch_message.h
-#define TWOBLUECUBES_CATCH_MESSAGE_H_INCLUDED
+ template<typename T>
+ auto operator && ( T ) const -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<T>::value,
+ "chained comparisons are not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
+ }
-#include <string>
+ template<typename T>
+ auto operator || ( T ) const -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<T>::value,
+ "chained comparisons are not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
+ }
-namespace Catch {
+ template<typename T>
+ auto operator == ( T ) const -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<T>::value,
+ "chained comparisons are not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
+ }
- struct MessageInfo {
- MessageInfo( std::string const& _macroName,
- SourceLineInfo const& _lineInfo,
- ResultWas::OfType _type );
+ template<typename T>
+ auto operator != ( T ) const -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<T>::value,
+ "chained comparisons are not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
+ }
- std::string macroName;
- SourceLineInfo lineInfo;
- ResultWas::OfType type;
- std::string message;
- unsigned int sequence;
+ template<typename T>
+ auto operator > ( T ) const -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<T>::value,
+ "chained comparisons are not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
+ }
- bool operator == ( MessageInfo const& other ) const {
- return sequence == other.sequence;
+ template<typename T>
+ auto operator < ( T ) const -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<T>::value,
+ "chained comparisons are not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
}
- bool operator < ( MessageInfo const& other ) const {
- return sequence < other.sequence;
+
+ template<typename T>
+ auto operator >= ( T ) const -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<T>::value,
+ "chained comparisons are not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
+ }
+
+ template<typename T>
+ auto operator <= ( T ) const -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<T>::value,
+ "chained comparisons are not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
}
- private:
- static unsigned int globalCount;
};
- struct MessageBuilder {
- MessageBuilder( std::string const& macroName,
- SourceLineInfo const& lineInfo,
- ResultWas::OfType type )
- : m_info( macroName, lineInfo, type )
- {}
+ template<typename LhsT>
+ class UnaryExpr : public ITransientExpression {
+ LhsT m_lhs;
- template<typename T>
- MessageBuilder& operator << ( T const& value ) {
- m_stream << value;
- return *this;
+ void streamReconstructedExpression( std::ostream &os ) const override {
+ os << Catch::Detail::stringify( m_lhs );
}
- MessageInfo m_info;
- std::ostringstream m_stream;
+ public:
+ explicit UnaryExpr( LhsT lhs )
+ : ITransientExpression{ false, static_cast<bool>(lhs) },
+ m_lhs( lhs )
+ {}
};
- class ScopedMessage {
+ // Specialised comparison functions to handle equality comparisons between ints and pointers (NULL deduces as an int)
+ template<typename LhsT, typename RhsT>
+ auto compareEqual( LhsT const& lhs, RhsT const& rhs ) -> bool { return static_cast<bool>(lhs == rhs); }
+ template<typename T>
+ auto compareEqual( T* const& lhs, int rhs ) -> bool { return lhs == reinterpret_cast<void const*>( rhs ); }
+ template<typename T>
+ auto compareEqual( T* const& lhs, long rhs ) -> bool { return lhs == reinterpret_cast<void const*>( rhs ); }
+ template<typename T>
+ auto compareEqual( int lhs, T* const& rhs ) -> bool { return reinterpret_cast<void const*>( lhs ) == rhs; }
+ template<typename T>
+ auto compareEqual( long lhs, T* const& rhs ) -> bool { return reinterpret_cast<void const*>( lhs ) == rhs; }
+
+ template<typename LhsT, typename RhsT>
+ auto compareNotEqual( LhsT const& lhs, RhsT&& rhs ) -> bool { return static_cast<bool>(lhs != rhs); }
+ template<typename T>
+ auto compareNotEqual( T* const& lhs, int rhs ) -> bool { return lhs != reinterpret_cast<void const*>( rhs ); }
+ template<typename T>
+ auto compareNotEqual( T* const& lhs, long rhs ) -> bool { return lhs != reinterpret_cast<void const*>( rhs ); }
+ template<typename T>
+ auto compareNotEqual( int lhs, T* const& rhs ) -> bool { return reinterpret_cast<void const*>( lhs ) != rhs; }
+ template<typename T>
+ auto compareNotEqual( long lhs, T* const& rhs ) -> bool { return reinterpret_cast<void const*>( lhs ) != rhs; }
+
+ template<typename LhsT>
+ class ExprLhs {
+ LhsT m_lhs;
public:
- ScopedMessage( MessageBuilder const& builder );
- ScopedMessage( ScopedMessage const& other );
- ~ScopedMessage();
+ explicit ExprLhs( LhsT lhs ) : m_lhs( lhs ) {}
- MessageInfo m_info;
+ template<typename RhsT>
+ auto operator == ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const {
+ return { compareEqual( m_lhs, rhs ), m_lhs, "==", rhs };
+ }
+ auto operator == ( bool rhs ) -> BinaryExpr<LhsT, bool> const {
+ return { m_lhs == rhs, m_lhs, "==", rhs };
+ }
+
+ template<typename RhsT>
+ auto operator != ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const {
+ return { compareNotEqual( m_lhs, rhs ), m_lhs, "!=", rhs };
+ }
+ auto operator != ( bool rhs ) -> BinaryExpr<LhsT, bool> const {
+ return { m_lhs != rhs, m_lhs, "!=", rhs };
+ }
+
+ template<typename RhsT>
+ auto operator > ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const {
+ return { static_cast<bool>(m_lhs > rhs), m_lhs, ">", rhs };
+ }
+ template<typename RhsT>
+ auto operator < ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const {
+ return { static_cast<bool>(m_lhs < rhs), m_lhs, "<", rhs };
+ }
+ template<typename RhsT>
+ auto operator >= ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const {
+ return { static_cast<bool>(m_lhs >= rhs), m_lhs, ">=", rhs };
+ }
+ template<typename RhsT>
+ auto operator <= ( RhsT const& rhs ) -> BinaryExpr<LhsT, RhsT const&> const {
+ return { static_cast<bool>(m_lhs <= rhs), m_lhs, "<=", rhs };
+ }
+
+ template<typename RhsT>
+ auto operator && ( RhsT const& ) -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<RhsT>::value,
+ "operator&& is not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
+ }
+
+ template<typename RhsT>
+ auto operator || ( RhsT const& ) -> BinaryExpr<LhsT, RhsT const&> const {
+ static_assert(always_false<RhsT>::value,
+ "operator|| is not supported inside assertions, "
+ "wrap the expression inside parentheses, or decompose it");
+ }
+
+ auto makeUnaryExpr() const -> UnaryExpr<LhsT> {
+ return UnaryExpr<LhsT>{ m_lhs };
+ }
+ };
+
+ void handleExpression( ITransientExpression const& expr );
+
+ template<typename T>
+ void handleExpression( ExprLhs<T> const& expr ) {
+ handleExpression( expr.makeUnaryExpr() );
+ }
+
+ struct Decomposer {
+ template<typename T>
+ auto operator <= ( T const& lhs ) -> ExprLhs<T const&> {
+ return ExprLhs<T const&>{ lhs };
+ }
+
+ auto operator <=( bool value ) -> ExprLhs<bool> {
+ return ExprLhs<bool>{ value };
+ }
};
} // end namespace Catch
-// #included from: catch_interfaces_capture.h
-#define TWOBLUECUBES_CATCH_INTERFACES_CAPTURE_H_INCLUDED
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+// end catch_decomposer.h
+// start catch_interfaces_capture.h
#include <string>
+#include <chrono>
namespace Catch {
- class TestCase;
class AssertionResult;
struct AssertionInfo;
struct SectionInfo;
struct SectionEndInfo;
struct MessageInfo;
- class ScopedMessageBuilder;
+ struct MessageBuilder;
struct Counts;
+ struct AssertionReaction;
+ struct SourceLineInfo;
+
+ struct ITransientExpression;
+ struct IGeneratorTracker;
+
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ struct BenchmarkInfo;
+ template <typename Duration = std::chrono::duration<double, std::nano>>
+ struct BenchmarkStats;
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
struct IResultCapture {
virtual ~IResultCapture();
- virtual void assertionEnded( AssertionResult const& result ) = 0;
virtual bool sectionStarted( SectionInfo const& sectionInfo,
Counts& assertions ) = 0;
virtual void sectionEnded( SectionEndInfo const& endInfo ) = 0;
virtual void sectionEndedEarly( SectionEndInfo const& endInfo ) = 0;
+
+ virtual auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& = 0;
+
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ virtual void benchmarkPreparing( std::string const& name ) = 0;
+ virtual void benchmarkStarting( BenchmarkInfo const& info ) = 0;
+ virtual void benchmarkEnded( BenchmarkStats<> const& stats ) = 0;
+ virtual void benchmarkFailed( std::string const& error ) = 0;
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
+
virtual void pushScopedMessage( MessageInfo const& message ) = 0;
virtual void popScopedMessage( MessageInfo const& message ) = 0;
+ virtual void emplaceUnscopedMessage( MessageBuilder const& builder ) = 0;
+
+ virtual void handleFatalErrorCondition( StringRef message ) = 0;
+
+ virtual void handleExpr
+ ( AssertionInfo const& info,
+ ITransientExpression const& expr,
+ AssertionReaction& reaction ) = 0;
+ virtual void handleMessage
+ ( AssertionInfo const& info,
+ ResultWas::OfType resultType,
+ StringRef const& message,
+ AssertionReaction& reaction ) = 0;
+ virtual void handleUnexpectedExceptionNotThrown
+ ( AssertionInfo const& info,
+ AssertionReaction& reaction ) = 0;
+ virtual void handleUnexpectedInflightException
+ ( AssertionInfo const& info,
+ std::string const& message,
+ AssertionReaction& reaction ) = 0;
+ virtual void handleIncomplete
+ ( AssertionInfo const& info ) = 0;
+ virtual void handleNonExpr
+ ( AssertionInfo const &info,
+ ResultWas::OfType resultType,
+ AssertionReaction &reaction ) = 0;
+
+ virtual bool lastAssertionPassed() = 0;
+ virtual void assertionPassed() = 0;
+
+ // Deprecated, do not use:
virtual std::string getCurrentTestName() const = 0;
virtual const AssertionResult* getLastResult() const = 0;
-
virtual void exceptionEarlyReported() = 0;
-
- virtual void handleFatalErrorCondition( std::string const& message ) = 0;
};
IResultCapture& getResultCapture();
}
-// #included from: catch_debugger.h
-#define TWOBLUECUBES_CATCH_DEBUGGER_H_INCLUDED
+// end catch_interfaces_capture.h
+namespace Catch {
-// #included from: catch_platform.h
-#define TWOBLUECUBES_CATCH_PLATFORM_H_INCLUDED
+ struct TestFailureException{};
+ struct AssertionResultData;
+ struct IResultCapture;
+ class RunContext;
-#if defined(__MAC_OS_X_VERSION_MIN_REQUIRED)
-# define CATCH_PLATFORM_MAC
-#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED)
-# define CATCH_PLATFORM_IPHONE
-#elif defined(linux) || defined(__linux) || defined(__linux__)
-# define CATCH_PLATFORM_LINUX
-#elif defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER)
-# define CATCH_PLATFORM_WINDOWS
-# if !defined(NOMINMAX) && !defined(CATCH_CONFIG_NO_NOMINMAX)
-# define CATCH_DEFINES_NOMINMAX
-# endif
-# if !defined(WIN32_LEAN_AND_MEAN) && !defined(CATCH_CONFIG_NO_WIN32_LEAN_AND_MEAN)
-# define CATCH_DEFINES_WIN32_LEAN_AND_MEAN
-# endif
-#endif
+ class LazyExpression {
+ friend class AssertionHandler;
+ friend struct AssertionStats;
+ friend class RunContext;
-#include <string>
+ ITransientExpression const* m_transientExpression = nullptr;
+ bool m_isNegated;
+ public:
+ LazyExpression( bool isNegated );
+ LazyExpression( LazyExpression const& other );
+ LazyExpression& operator = ( LazyExpression const& ) = delete;
-namespace Catch{
+ explicit operator bool() const;
- bool isDebuggerActive();
- void writeToDebugConsole( std::string const& text );
-}
+ friend auto operator << ( std::ostream& os, LazyExpression const& lazyExpr ) -> std::ostream&;
+ };
-#ifdef CATCH_PLATFORM_MAC
+ struct AssertionReaction {
+ bool shouldDebugBreak = false;
+ bool shouldThrow = false;
+ };
- // The following code snippet based on:
- // http://cocoawithlove.com/2008/03/break-into-debugger.html
- #if defined(__ppc64__) || defined(__ppc__)
- #define CATCH_TRAP() \
- __asm__("li r0, 20\nsc\nnop\nli r0, 37\nli r4, 2\nsc\nnop\n" \
- : : : "memory","r0","r3","r4" )
- #else
- #define CATCH_TRAP() __asm__("int $3\n" : : )
- #endif
+ class AssertionHandler {
+ AssertionInfo m_assertionInfo;
+ AssertionReaction m_reaction;
+ bool m_completed = false;
+ IResultCapture& m_resultCapture;
-#elif defined(CATCH_PLATFORM_LINUX)
- // If we can use inline assembler, do it because this allows us to break
- // directly at the location of the failing check instead of breaking inside
- // raise() called from it, i.e. one stack frame below.
- #if defined(__GNUC__) && (defined(__i386) || defined(__x86_64))
- #define CATCH_TRAP() asm volatile ("int $3")
- #else // Fall back to the generic way.
- #include <signal.h>
+ public:
+ AssertionHandler
+ ( StringRef const& macroName,
+ SourceLineInfo const& lineInfo,
+ StringRef capturedExpression,
+ ResultDisposition::Flags resultDisposition );
+ ~AssertionHandler() {
+ if ( !m_completed ) {
+ m_resultCapture.handleIncomplete( m_assertionInfo );
+ }
+ }
- #define CATCH_TRAP() raise(SIGTRAP)
- #endif
-#elif defined(_MSC_VER)
- #define CATCH_TRAP() __debugbreak()
-#elif defined(__MINGW32__)
- extern "C" __declspec(dllimport) void __stdcall DebugBreak();
- #define CATCH_TRAP() DebugBreak()
-#endif
+ template<typename T>
+ void handleExpr( ExprLhs<T> const& expr ) {
+ handleExpr( expr.makeUnaryExpr() );
+ }
+ void handleExpr( ITransientExpression const& expr );
-#ifdef CATCH_TRAP
- #define CATCH_BREAK_INTO_DEBUGGER() if( Catch::isDebuggerActive() ) { CATCH_TRAP(); }
-#else
- #define CATCH_BREAK_INTO_DEBUGGER() Catch::alwaysTrue();
-#endif
+ void handleMessage(ResultWas::OfType resultType, StringRef const& message);
+
+ void handleExceptionThrownAsExpected();
+ void handleUnexpectedExceptionNotThrown();
+ void handleExceptionNotThrownAsExpected();
+ void handleThrowingCallSkipped();
+ void handleUnexpectedInflightException();
-// #included from: catch_interfaces_runner.h
-#define TWOBLUECUBES_CATCH_INTERFACES_RUNNER_H_INCLUDED
+ void complete();
+ void setCompleted();
+
+ // query
+ auto allowThrows() const -> bool;
+ };
+
+ void handleExceptionMatchExpr( AssertionHandler& handler, std::string const& str, StringRef const& matcherString );
+
+} // namespace Catch
+
+// end catch_assertionhandler.h
+// start catch_message.h
+
+#include <string>
+#include <vector>
namespace Catch {
- class TestCase;
- struct IRunner {
- virtual ~IRunner();
- virtual bool aborting() const = 0;
+ struct MessageInfo {
+ MessageInfo( StringRef const& _macroName,
+ SourceLineInfo const& _lineInfo,
+ ResultWas::OfType _type );
+
+ StringRef macroName;
+ std::string message;
+ SourceLineInfo lineInfo;
+ ResultWas::OfType type;
+ unsigned int sequence;
+
+ bool operator == ( MessageInfo const& other ) const;
+ bool operator < ( MessageInfo const& other ) const;
+ private:
+ static unsigned int globalCount;
};
-}
-#if defined(CATCH_CONFIG_FAST_COMPILE)
-///////////////////////////////////////////////////////////////////////////////
-// We can speedup compilation significantly by breaking into debugger lower in
-// the callstack, because then we don't have to expand CATCH_BREAK_INTO_DEBUGGER
-// macro in each assertion
-#define INTERNAL_CATCH_REACT( resultBuilder ) \
- resultBuilder.react();
+ struct MessageStream {
+
+ template<typename T>
+ MessageStream& operator << ( T const& value ) {
+ m_stream << value;
+ return *this;
+ }
+
+ ReusableStringStream m_stream;
+ };
+
+ struct MessageBuilder : MessageStream {
+ MessageBuilder( StringRef const& macroName,
+ SourceLineInfo const& lineInfo,
+ ResultWas::OfType type );
+
+ template<typename T>
+ MessageBuilder& operator << ( T const& value ) {
+ m_stream << value;
+ return *this;
+ }
+
+ MessageInfo m_info;
+ };
+
+ class ScopedMessage {
+ public:
+ explicit ScopedMessage( MessageBuilder const& builder );
+ ScopedMessage( ScopedMessage& duplicate ) = delete;
+ ScopedMessage( ScopedMessage&& old );
+ ~ScopedMessage();
+
+ MessageInfo m_info;
+ bool m_moved;
+ };
+
+ class Capturer {
+ std::vector<MessageInfo> m_messages;
+ IResultCapture& m_resultCapture = getResultCapture();
+ size_t m_captured = 0;
+ public:
+ Capturer( StringRef macroName, SourceLineInfo const& lineInfo, ResultWas::OfType resultType, StringRef names );
+ ~Capturer();
+
+ void captureValue( size_t index, std::string const& value );
+
+ template<typename T>
+ void captureValues( size_t index, T const& value ) {
+ captureValue( index, Catch::Detail::stringify( value ) );
+ }
+
+ template<typename T, typename... Ts>
+ void captureValues( size_t index, T const& value, Ts const&... values ) {
+ captureValue( index, Catch::Detail::stringify(value) );
+ captureValues( index+1, values... );
+ }
+ };
+
+} // end namespace Catch
+
+// end catch_message.h
+#if !defined(CATCH_CONFIG_DISABLE)
+
+#if !defined(CATCH_CONFIG_DISABLE_STRINGIFICATION)
+ #define CATCH_INTERNAL_STRINGIFY(...) #__VA_ARGS__
+#else
+ #define CATCH_INTERNAL_STRINGIFY(...) "Disabled by CATCH_CONFIG_DISABLE_STRINGIFICATION"
+#endif
+
+#if defined(CATCH_CONFIG_FAST_COMPILE) || defined(CATCH_CONFIG_DISABLE_EXCEPTIONS)
///////////////////////////////////////////////////////////////////////////////
// Another way to speed-up compilation is to omit local try-catch for REQUIRE*
// macros.
-// This can potentially cause false negative, if the test code catches
-// the exception before it propagates back up to the runner.
-#define INTERNAL_CATCH_TEST_NO_TRY( macroName, resultDisposition, expr ) \
- do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \
- __catchResult.setExceptionGuard(); \
- CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \
- ( __catchResult <= expr ).endExpression(); \
- CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \
- __catchResult.unsetExceptionGuard(); \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::isTrue( false && static_cast<bool>( !!(expr) ) ) ) // expr here is never evaluated at runtime but it forces the compiler to give it a look
-// The double negation silences MSVC's C4800 warning, the static_cast forces short-circuit evaluation if the type has overloaded &&.
-
-#define INTERNAL_CHECK_THAT_NO_TRY( macroName, matcher, resultDisposition, arg ) \
- do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #arg ", " #matcher, resultDisposition ); \
- __catchResult.setExceptionGuard(); \
- __catchResult.captureMatch( arg, matcher, #matcher ); \
- __catchResult.unsetExceptionGuard(); \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::alwaysFalse() )
+#define INTERNAL_CATCH_TRY
+#define INTERNAL_CATCH_CATCH( capturer )
+
+#else // CATCH_CONFIG_FAST_COMPILE
+
+#define INTERNAL_CATCH_TRY try
+#define INTERNAL_CATCH_CATCH( handler ) catch(...) { handler.handleUnexpectedInflightException(); }
-#else
-///////////////////////////////////////////////////////////////////////////////
-// In the event of a failure works out if the debugger needs to be invoked
-// and/or an exception thrown and takes appropriate action.
-// This needs to be done as a macro so the debugger will stop in the user
-// source code rather than in Catch library code
-#define INTERNAL_CATCH_REACT( resultBuilder ) \
- if( resultBuilder.shouldDebugBreak() ) CATCH_BREAK_INTO_DEBUGGER(); \
- resultBuilder.react();
#endif
+#define INTERNAL_CATCH_REACT( handler ) handler.complete();
+
///////////////////////////////////////////////////////////////////////////////
-#define INTERNAL_CATCH_TEST( macroName, resultDisposition, expr ) \
+#define INTERNAL_CATCH_TEST( macroName, resultDisposition, ... ) \
do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \
- try { \
+ Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition ); \
+ INTERNAL_CATCH_TRY { \
CATCH_INTERNAL_SUPPRESS_PARENTHESES_WARNINGS \
- ( __catchResult <= expr ).endExpression(); \
+ catchAssertionHandler.handleExpr( Catch::Decomposer() <= __VA_ARGS__ ); \
CATCH_INTERNAL_UNSUPPRESS_PARENTHESES_WARNINGS \
- } \
- catch( ... ) { \
- __catchResult.useActiveException( resultDisposition ); \
- } \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::isTrue( false && static_cast<bool>( !!(expr) ) ) ) // expr here is never evaluated at runtime but it forces the compiler to give it a look
+ } INTERNAL_CATCH_CATCH( catchAssertionHandler ) \
+ INTERNAL_CATCH_REACT( catchAssertionHandler ) \
+ } while( (void)0, (false) && static_cast<bool>( !!(__VA_ARGS__) ) ) // the expression here is never evaluated at runtime but it forces the compiler to give it a look
// The double negation silences MSVC's C4800 warning, the static_cast forces short-circuit evaluation if the type has overloaded &&.
///////////////////////////////////////////////////////////////////////////////
-#define INTERNAL_CATCH_IF( macroName, resultDisposition, expr ) \
- INTERNAL_CATCH_TEST( macroName, resultDisposition, expr ); \
- if( Catch::getResultCapture().getLastResult()->succeeded() )
+#define INTERNAL_CATCH_IF( macroName, resultDisposition, ... ) \
+ INTERNAL_CATCH_TEST( macroName, resultDisposition, __VA_ARGS__ ); \
+ if( Catch::getResultCapture().lastAssertionPassed() )
///////////////////////////////////////////////////////////////////////////////
-#define INTERNAL_CATCH_ELSE( macroName, resultDisposition, expr ) \
- INTERNAL_CATCH_TEST( macroName, resultDisposition, expr ); \
- if( !Catch::getResultCapture().getLastResult()->succeeded() )
+#define INTERNAL_CATCH_ELSE( macroName, resultDisposition, ... ) \
+ INTERNAL_CATCH_TEST( macroName, resultDisposition, __VA_ARGS__ ); \
+ if( !Catch::getResultCapture().lastAssertionPassed() )
///////////////////////////////////////////////////////////////////////////////
-#define INTERNAL_CATCH_NO_THROW( macroName, resultDisposition, expr ) \
+#define INTERNAL_CATCH_NO_THROW( macroName, resultDisposition, ... ) \
do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition ); \
+ Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition ); \
try { \
- static_cast<void>(expr); \
- __catchResult.captureResult( Catch::ResultWas::Ok ); \
+ static_cast<void>(__VA_ARGS__); \
+ catchAssertionHandler.handleExceptionNotThrownAsExpected(); \
} \
catch( ... ) { \
- __catchResult.useActiveException( resultDisposition ); \
+ catchAssertionHandler.handleUnexpectedInflightException(); \
} \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::alwaysFalse() )
+ INTERNAL_CATCH_REACT( catchAssertionHandler ) \
+ } while( false )
///////////////////////////////////////////////////////////////////////////////
-#define INTERNAL_CATCH_THROWS( macroName, resultDisposition, matcher, expr ) \
+#define INTERNAL_CATCH_THROWS( macroName, resultDisposition, ... ) \
do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr, resultDisposition, #matcher ); \
- if( __catchResult.allowThrows() ) \
+ Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__), resultDisposition); \
+ if( catchAssertionHandler.allowThrows() ) \
try { \
- static_cast<void>(expr); \
- __catchResult.captureResult( Catch::ResultWas::DidntThrowException ); \
+ static_cast<void>(__VA_ARGS__); \
+ catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \
} \
catch( ... ) { \
- __catchResult.captureExpectedException( matcher ); \
+ catchAssertionHandler.handleExceptionThrownAsExpected(); \
} \
else \
- __catchResult.captureResult( Catch::ResultWas::Ok ); \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::alwaysFalse() )
+ catchAssertionHandler.handleThrowingCallSkipped(); \
+ INTERNAL_CATCH_REACT( catchAssertionHandler ) \
+ } while( false )
///////////////////////////////////////////////////////////////////////////////
#define INTERNAL_CATCH_THROWS_AS( macroName, exceptionType, resultDisposition, expr ) \
do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #expr ", " #exceptionType, resultDisposition ); \
- if( __catchResult.allowThrows() ) \
+ Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(expr) ", " CATCH_INTERNAL_STRINGIFY(exceptionType), resultDisposition ); \
+ if( catchAssertionHandler.allowThrows() ) \
try { \
static_cast<void>(expr); \
- __catchResult.captureResult( Catch::ResultWas::DidntThrowException ); \
+ catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \
} \
- catch( exceptionType ) { \
- __catchResult.captureResult( Catch::ResultWas::Ok ); \
+ catch( exceptionType const& ) { \
+ catchAssertionHandler.handleExceptionThrownAsExpected(); \
} \
catch( ... ) { \
- __catchResult.useActiveException( resultDisposition ); \
+ catchAssertionHandler.handleUnexpectedInflightException(); \
} \
else \
- __catchResult.captureResult( Catch::ResultWas::Ok ); \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::alwaysFalse() )
+ catchAssertionHandler.handleThrowingCallSkipped(); \
+ INTERNAL_CATCH_REACT( catchAssertionHandler ) \
+ } while( false )
///////////////////////////////////////////////////////////////////////////////
-#ifdef CATCH_CONFIG_VARIADIC_MACROS
- #define INTERNAL_CATCH_MSG( macroName, messageType, resultDisposition, ... ) \
- do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, "", resultDisposition ); \
- __catchResult << __VA_ARGS__ + ::Catch::StreamEndStop(); \
- __catchResult.captureResult( messageType ); \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::alwaysFalse() )
-#else
- #define INTERNAL_CATCH_MSG( macroName, messageType, resultDisposition, log ) \
- do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, "", resultDisposition ); \
- __catchResult << log + ::Catch::StreamEndStop(); \
- __catchResult.captureResult( messageType ); \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::alwaysFalse() )
-#endif
+#define INTERNAL_CATCH_MSG( macroName, messageType, resultDisposition, ... ) \
+ do { \
+ Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::StringRef(), resultDisposition ); \
+ catchAssertionHandler.handleMessage( messageType, ( Catch::MessageStream() << __VA_ARGS__ + ::Catch::StreamEndStop() ).m_stream.str() ); \
+ INTERNAL_CATCH_REACT( catchAssertionHandler ) \
+ } while( false )
+
+///////////////////////////////////////////////////////////////////////////////
+#define INTERNAL_CATCH_CAPTURE( varName, macroName, ... ) \
+ auto varName = Catch::Capturer( macroName, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info, #__VA_ARGS__ ); \
+ varName.captureValues( 0, __VA_ARGS__ )
///////////////////////////////////////////////////////////////////////////////
#define INTERNAL_CATCH_INFO( macroName, log ) \
- Catch::ScopedMessage INTERNAL_CATCH_UNIQUE_NAME( scopedMessage ) = Catch::MessageBuilder( macroName, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log;
+ Catch::ScopedMessage INTERNAL_CATCH_UNIQUE_NAME( scopedMessage )( Catch::MessageBuilder( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log );
///////////////////////////////////////////////////////////////////////////////
-#define INTERNAL_CHECK_THAT( macroName, matcher, resultDisposition, arg ) \
+#define INTERNAL_CATCH_UNSCOPED_INFO( macroName, log ) \
+ Catch::getResultCapture().emplaceUnscopedMessage( Catch::MessageBuilder( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, Catch::ResultWas::Info ) << log )
+
+///////////////////////////////////////////////////////////////////////////////
+// Although this is matcher-based, it can be used with just a string
+#define INTERNAL_CATCH_THROWS_STR_MATCHES( macroName, resultDisposition, matcher, ... ) \
do { \
- Catch::ResultBuilder __catchResult( macroName, CATCH_INTERNAL_LINEINFO, #arg ", " #matcher, resultDisposition ); \
- try { \
- __catchResult.captureMatch( arg, matcher, #matcher ); \
- } catch( ... ) { \
- __catchResult.useActiveException( resultDisposition | Catch::ResultDisposition::ContinueOnFailure ); \
- } \
- INTERNAL_CATCH_REACT( __catchResult ) \
- } while( Catch::alwaysFalse() )
+ Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \
+ if( catchAssertionHandler.allowThrows() ) \
+ try { \
+ static_cast<void>(__VA_ARGS__); \
+ catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \
+ } \
+ catch( ... ) { \
+ Catch::handleExceptionMatchExpr( catchAssertionHandler, matcher, #matcher##_catch_sr ); \
+ } \
+ else \
+ catchAssertionHandler.handleThrowingCallSkipped(); \
+ INTERNAL_CATCH_REACT( catchAssertionHandler ) \
+ } while( false )
-// #included from: internal/catch_section.h
-#define TWOBLUECUBES_CATCH_SECTION_H_INCLUDED
+#endif // CATCH_CONFIG_DISABLE
-// #included from: catch_section_info.h
-#define TWOBLUECUBES_CATCH_SECTION_INFO_H_INCLUDED
+// end catch_capture.hpp
+// start catch_section.h
-// #included from: catch_totals.hpp
-#define TWOBLUECUBES_CATCH_TOTALS_HPP_INCLUDED
+// start catch_section_info.h
+
+// start catch_totals.h
#include <cstddef>
namespace Catch {
struct Counts {
- Counts() : passed( 0 ), failed( 0 ), failedButOk( 0 ) {}
-
- Counts operator - ( Counts const& other ) const {
- Counts diff;
- diff.passed = passed - other.passed;
- diff.failed = failed - other.failed;
- diff.failedButOk = failedButOk - other.failedButOk;
- return diff;
- }
- Counts& operator += ( Counts const& other ) {
- passed += other.passed;
- failed += other.failed;
- failedButOk += other.failedButOk;
- return *this;
- }
+ Counts operator - ( Counts const& other ) const;
+ Counts& operator += ( Counts const& other );
- std::size_t total() const {
- return passed + failed + failedButOk;
- }
- bool allPassed() const {
- return failed == 0 && failedButOk == 0;
- }
- bool allOk() const {
- return failed == 0;
- }
+ std::size_t total() const;
+ bool allPassed() const;
+ bool allOk() const;
- std::size_t passed;
- std::size_t failed;
- std::size_t failedButOk;
+ std::size_t passed = 0;
+ std::size_t failed = 0;
+ std::size_t failedButOk = 0;
};
struct Totals {
- Totals operator - ( Totals const& other ) const {
- Totals diff;
- diff.assertions = assertions - other.assertions;
- diff.testCases = testCases - other.testCases;
- return diff;
- }
+ Totals operator - ( Totals const& other ) const;
+ Totals& operator += ( Totals const& other );
- Totals delta( Totals const& prevTotals ) const {
- Totals diff = *this - prevTotals;
- if( diff.assertions.failed > 0 )
- ++diff.testCases.failed;
- else if( diff.assertions.failedButOk > 0 )
- ++diff.testCases.failedButOk;
- else
- ++diff.testCases.passed;
- return diff;
- }
-
- Totals& operator += ( Totals const& other ) {
- assertions += other.assertions;
- testCases += other.testCases;
- return *this;
- }
+ Totals delta( Totals const& prevTotals ) const;
+ int error = 0;
Counts assertions;
Counts testCases;
};
}
+// end catch_totals.h
#include <string>
namespace Catch {
@@ -2380,19 +2839,20 @@ namespace Catch {
struct SectionInfo {
SectionInfo
( SourceLineInfo const& _lineInfo,
+ std::string const& _name );
+
+ // Deprecated
+ SectionInfo
+ ( SourceLineInfo const& _lineInfo,
std::string const& _name,
- std::string const& _description = std::string() );
+ std::string const& ) : SectionInfo( _lineInfo, _name ) {}
std::string name;
- std::string description;
+ std::string description; // !Deprecated: this will always be empty
SourceLineInfo lineInfo;
};
struct SectionEndInfo {
- SectionEndInfo( SectionInfo const& _sectionInfo, Counts const& _prevAssertions, double _durationInSeconds )
- : sectionInfo( _sectionInfo ), prevAssertions( _prevAssertions ), durationInSeconds( _durationInSeconds )
- {}
-
SectionInfo sectionInfo;
Counts prevAssertions;
double durationInSeconds;
@@ -2400,36 +2860,29 @@ namespace Catch {
} // end namespace Catch
-// #included from: catch_timer.h
-#define TWOBLUECUBES_CATCH_TIMER_H_INCLUDED
+// end catch_section_info.h
+// start catch_timer.h
-#ifdef _MSC_VER
+#include <cstdint>
namespace Catch {
- typedef unsigned long long UInt64;
-}
-#else
-#include <stdint.h>
-namespace Catch {
- typedef uint64_t UInt64;
-}
-#endif
-namespace Catch {
+ auto getCurrentNanosecondsSinceEpoch() -> uint64_t;
+ auto getEstimatedClockResolution() -> uint64_t;
+
class Timer {
+ uint64_t m_nanoseconds = 0;
public:
- Timer() : m_ticks( 0 ) {}
void start();
- unsigned int getElapsedMicroseconds() const;
- unsigned int getElapsedMilliseconds() const;
- double getElapsedSeconds() const;
-
- private:
- UInt64 m_ticks;
+ auto getElapsedNanoseconds() const -> uint64_t;
+ auto getElapsedMicroseconds() const -> uint64_t;
+ auto getElapsedMilliseconds() const -> unsigned int;
+ auto getElapsedSeconds() const -> double;
};
} // namespace Catch
+// end catch_timer.h
#include <string>
namespace Catch {
@@ -2440,7 +2893,7 @@ namespace Catch {
~Section();
// This indicates whether the section should be executed or not
- operator bool() const;
+ explicit operator bool() const;
private:
SectionInfo m_info;
@@ -2453,203 +2906,23 @@ namespace Catch {
} // end namespace Catch
-#ifdef CATCH_CONFIG_VARIADIC_MACROS
- #define INTERNAL_CATCH_SECTION( ... ) \
- if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, __VA_ARGS__ ) )
-#else
- #define INTERNAL_CATCH_SECTION( name, desc ) \
- if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, name, desc ) )
-#endif
-
-// #included from: internal/catch_generators.hpp
-#define TWOBLUECUBES_CATCH_GENERATORS_HPP_INCLUDED
-
-#include <vector>
-#include <string>
-#include <stdlib.h>
-
-namespace Catch {
-
-template<typename T>
-struct IGenerator {
- virtual ~IGenerator() {}
- virtual T getValue( std::size_t index ) const = 0;
- virtual std::size_t size () const = 0;
-};
+#define INTERNAL_CATCH_SECTION( ... ) \
+ CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \
+ if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, __VA_ARGS__ ) ) \
+ CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS
-template<typename T>
-class BetweenGenerator : public IGenerator<T> {
-public:
- BetweenGenerator( T from, T to ) : m_from( from ), m_to( to ){}
+#define INTERNAL_CATCH_DYNAMIC_SECTION( ... ) \
+ CATCH_INTERNAL_SUPPRESS_UNUSED_WARNINGS \
+ if( Catch::Section const& INTERNAL_CATCH_UNIQUE_NAME( catch_internal_Section ) = Catch::SectionInfo( CATCH_INTERNAL_LINEINFO, (Catch::ReusableStringStream() << __VA_ARGS__).str() ) ) \
+ CATCH_INTERNAL_UNSUPPRESS_UNUSED_WARNINGS
- virtual T getValue( std::size_t index ) const {
- return m_from+static_cast<int>( index );
- }
+// end catch_section.h
+// start catch_interfaces_exception.h
- virtual std::size_t size() const {
- return static_cast<std::size_t>( 1+m_to-m_from );
- }
-
-private:
-
- T m_from;
- T m_to;
-};
-
-template<typename T>
-class ValuesGenerator : public IGenerator<T> {
-public:
- ValuesGenerator(){}
-
- void add( T value ) {
- m_values.push_back( value );
- }
-
- virtual T getValue( std::size_t index ) const {
- return m_values[index];
- }
-
- virtual std::size_t size() const {
- return m_values.size();
- }
-
-private:
- std::vector<T> m_values;
-};
-
-template<typename T>
-class CompositeGenerator {
-public:
- CompositeGenerator() : m_totalSize( 0 ) {}
-
- // *** Move semantics, similar to auto_ptr ***
- CompositeGenerator( CompositeGenerator& other )
- : m_fileInfo( other.m_fileInfo ),
- m_totalSize( 0 )
- {
- move( other );
- }
-
- CompositeGenerator& setFileInfo( const char* fileInfo ) {
- m_fileInfo = fileInfo;
- return *this;
- }
-
- ~CompositeGenerator() {
- deleteAll( m_composed );
- }
-
- operator T () const {
- size_t overallIndex = getCurrentContext().getGeneratorIndex( m_fileInfo, m_totalSize );
-
- typename std::vector<const IGenerator<T>*>::const_iterator it = m_composed.begin();
- typename std::vector<const IGenerator<T>*>::const_iterator itEnd = m_composed.end();
- for( size_t index = 0; it != itEnd; ++it )
- {
- const IGenerator<T>* generator = *it;
- if( overallIndex >= index && overallIndex < index + generator->size() )
- {
- return generator->getValue( overallIndex-index );
- }
- index += generator->size();
- }
- CATCH_INTERNAL_ERROR( "Indexed past end of generated range" );
- return T(); // Suppress spurious "not all control paths return a value" warning in Visual Studio - if you know how to fix this please do so
- }
-
- void add( const IGenerator<T>* generator ) {
- m_totalSize += generator->size();
- m_composed.push_back( generator );
- }
-
- CompositeGenerator& then( CompositeGenerator& other ) {
- move( other );
- return *this;
- }
-
- CompositeGenerator& then( T value ) {
- ValuesGenerator<T>* valuesGen = new ValuesGenerator<T>();
- valuesGen->add( value );
- add( valuesGen );
- return *this;
- }
-
-private:
-
- void move( CompositeGenerator& other ) {
- m_composed.insert( m_composed.end(), other.m_composed.begin(), other.m_composed.end() );
- m_totalSize += other.m_totalSize;
- other.m_composed.clear();
- }
-
- std::vector<const IGenerator<T>*> m_composed;
- std::string m_fileInfo;
- size_t m_totalSize;
-};
-
-namespace Generators
-{
- template<typename T>
- CompositeGenerator<T> between( T from, T to ) {
- CompositeGenerator<T> generators;
- generators.add( new BetweenGenerator<T>( from, to ) );
- return generators;
- }
-
- template<typename T>
- CompositeGenerator<T> values( T val1, T val2 ) {
- CompositeGenerator<T> generators;
- ValuesGenerator<T>* valuesGen = new ValuesGenerator<T>();
- valuesGen->add( val1 );
- valuesGen->add( val2 );
- generators.add( valuesGen );
- return generators;
- }
-
- template<typename T>
- CompositeGenerator<T> values( T val1, T val2, T val3 ){
- CompositeGenerator<T> generators;
- ValuesGenerator<T>* valuesGen = new ValuesGenerator<T>();
- valuesGen->add( val1 );
- valuesGen->add( val2 );
- valuesGen->add( val3 );
- generators.add( valuesGen );
- return generators;
- }
-
- template<typename T>
- CompositeGenerator<T> values( T val1, T val2, T val3, T val4 ) {
- CompositeGenerator<T> generators;
- ValuesGenerator<T>* valuesGen = new ValuesGenerator<T>();
- valuesGen->add( val1 );
- valuesGen->add( val2 );
- valuesGen->add( val3 );
- valuesGen->add( val4 );
- generators.add( valuesGen );
- return generators;
- }
-
-} // end namespace Generators
-
-using namespace Generators;
-
-} // end namespace Catch
-
-#define INTERNAL_CATCH_LINESTR2( line ) #line
-#define INTERNAL_CATCH_LINESTR( line ) INTERNAL_CATCH_LINESTR2( line )
-
-#define INTERNAL_CATCH_GENERATE( expr ) expr.setFileInfo( __FILE__ "(" INTERNAL_CATCH_LINESTR( __LINE__ ) ")" )
-
-// #included from: internal/catch_interfaces_exception.h
-#define TWOBLUECUBES_CATCH_INTERFACES_EXCEPTION_H_INCLUDED
-
-#include <string>
-#include <vector>
-
-// #included from: catch_interfaces_registry_hub.h
-#define TWOBLUECUBES_CATCH_INTERFACES_REGISTRY_HUB_H_INCLUDED
+// start catch_interfaces_registry_hub.h
#include <string>
+#include <memory>
namespace Catch {
@@ -2660,6 +2933,11 @@ namespace Catch {
struct IReporterRegistry;
struct IReporterFactory;
struct ITagAliasRegistry;
+ struct IMutableEnumValuesRegistry;
+
+ class StartupExceptionRegistry;
+
+ using IReporterFactoryPtr = std::shared_ptr<IReporterFactory>;
struct IRegistryHub {
virtual ~IRegistryHub();
@@ -2667,32 +2945,44 @@ namespace Catch {
virtual IReporterRegistry const& getReporterRegistry() const = 0;
virtual ITestCaseRegistry const& getTestCaseRegistry() const = 0;
virtual ITagAliasRegistry const& getTagAliasRegistry() const = 0;
+ virtual IExceptionTranslatorRegistry const& getExceptionTranslatorRegistry() const = 0;
- virtual IExceptionTranslatorRegistry& getExceptionTranslatorRegistry() = 0;
+ virtual StartupExceptionRegistry const& getStartupExceptionRegistry() const = 0;
};
struct IMutableRegistryHub {
virtual ~IMutableRegistryHub();
- virtual void registerReporter( std::string const& name, Ptr<IReporterFactory> const& factory ) = 0;
- virtual void registerListener( Ptr<IReporterFactory> const& factory ) = 0;
+ virtual void registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) = 0;
+ virtual void registerListener( IReporterFactoryPtr const& factory ) = 0;
virtual void registerTest( TestCase const& testInfo ) = 0;
virtual void registerTranslator( const IExceptionTranslator* translator ) = 0;
virtual void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) = 0;
+ virtual void registerStartupException() noexcept = 0;
+ virtual IMutableEnumValuesRegistry& getMutableEnumValuesRegistry() = 0;
};
- IRegistryHub& getRegistryHub();
+ IRegistryHub const& getRegistryHub();
IMutableRegistryHub& getMutableRegistryHub();
void cleanUp();
std::string translateActiveException();
}
-namespace Catch {
+// end catch_interfaces_registry_hub.h
+#if defined(CATCH_CONFIG_DISABLE)
+ #define INTERNAL_CATCH_TRANSLATE_EXCEPTION_NO_REG( translatorName, signature) \
+ static std::string translatorName( signature )
+#endif
- typedef std::string(*exceptionTranslateFunction)();
+#include <exception>
+#include <string>
+#include <vector>
+
+namespace Catch {
+ using exceptionTranslateFunction = std::string(*)();
struct IExceptionTranslator;
- typedef std::vector<const IExceptionTranslator*> ExceptionTranslators;
+ using ExceptionTranslators = std::vector<std::unique_ptr<IExceptionTranslator const>>;
struct IExceptionTranslator {
virtual ~IExceptionTranslator();
@@ -2714,10 +3004,10 @@ namespace Catch {
: m_translateFunction( translateFunction )
{}
- virtual std::string translate( ExceptionTranslators::const_iterator it, ExceptionTranslators::const_iterator itEnd ) const CATCH_OVERRIDE {
+ std::string translate( ExceptionTranslators::const_iterator it, ExceptionTranslators::const_iterator itEnd ) const override {
try {
if( it == itEnd )
- throw;
+ std::rethrow_exception(std::current_exception());
else
return (*it)->translate( it+1, itEnd );
}
@@ -2742,68 +3032,55 @@ namespace Catch {
///////////////////////////////////////////////////////////////////////////////
#define INTERNAL_CATCH_TRANSLATE_EXCEPTION2( translatorName, signature ) \
static std::string translatorName( signature ); \
- namespace{ Catch::ExceptionTranslatorRegistrar INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionRegistrar )( &translatorName ); }\
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ namespace{ Catch::ExceptionTranslatorRegistrar INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionRegistrar )( &translatorName ); } \
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS \
static std::string translatorName( signature )
#define INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION2( INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ), signature )
-// #included from: internal/catch_approx.hpp
-#define TWOBLUECUBES_CATCH_APPROX_HPP_INCLUDED
+// end catch_interfaces_exception.h
+// start catch_approx.h
-#include <cmath>
-#include <limits>
-
-#if defined(CATCH_CONFIG_CPP11_TYPE_TRAITS)
#include <type_traits>
-#endif
namespace Catch {
namespace Detail {
class Approx {
- public:
- explicit Approx ( double value )
- : m_epsilon( std::numeric_limits<float>::epsilon()*100 ),
- m_margin( 0.0 ),
- m_scale( 1.0 ),
- m_value( value )
- {}
+ private:
+ bool equalityComparisonImpl(double other) const;
+ // Validates the new margin (margin >= 0)
+ // out-of-line to avoid including stdexcept in the header
+ void setMargin(double margin);
+ // Validates the new epsilon (0 < epsilon < 1)
+ // out-of-line to avoid including stdexcept in the header
+ void setEpsilon(double epsilon);
- Approx( Approx const& other )
- : m_epsilon( other.m_epsilon ),
- m_margin( other.m_margin ),
- m_scale( other.m_scale ),
- m_value( other.m_value )
- {}
+ public:
+ explicit Approx ( double value );
- static Approx custom() {
- return Approx( 0 );
- }
+ static Approx custom();
-#if defined(CATCH_CONFIG_CPP11_TYPE_TRAITS)
+ Approx operator-() const;
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- Approx operator()( T value ) {
+ Approx operator()( T const& value ) {
Approx approx( static_cast<double>(value) );
- approx.epsilon( m_epsilon );
- approx.margin( m_margin );
- approx.scale( m_scale );
+ approx.m_epsilon = m_epsilon;
+ approx.m_margin = m_margin;
+ approx.m_scale = m_scale;
return approx;
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- explicit Approx( T value ): Approx(static_cast<double>(value))
+ explicit Approx( T const& value ): Approx(static_cast<double>(value))
{}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
friend bool operator == ( const T& lhs, Approx const& rhs ) {
- // Thanks to Richard Harris for his help refining this formula
- auto lhs_v = double(lhs);
- bool relativeOK = std::fabs(lhs_v - rhs.m_value) < rhs.m_epsilon * (rhs.m_scale + (std::max)(std::fabs(lhs_v), std::fabs(rhs.m_value)));
- if (relativeOK) {
- return true;
- }
- return std::fabs(lhs_v - rhs.m_value) < rhs.m_margin;
+ auto lhs_v = static_cast<double>(lhs);
+ return rhs.equalityComparisonImpl(lhs_v);
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
@@ -2812,139 +3089,416 @@ namespace Detail {
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- friend bool operator != ( T lhs, Approx const& rhs ) {
+ friend bool operator != ( T const& lhs, Approx const& rhs ) {
return !operator==( lhs, rhs );
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- friend bool operator != ( Approx const& lhs, T rhs ) {
+ friend bool operator != ( Approx const& lhs, T const& rhs ) {
return !operator==( rhs, lhs );
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- friend bool operator <= ( T lhs, Approx const& rhs ) {
- return double(lhs) < rhs.m_value || lhs == rhs;
+ friend bool operator <= ( T const& lhs, Approx const& rhs ) {
+ return static_cast<double>(lhs) < rhs.m_value || lhs == rhs;
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- friend bool operator <= ( Approx const& lhs, T rhs ) {
- return lhs.m_value < double(rhs) || lhs == rhs;
+ friend bool operator <= ( Approx const& lhs, T const& rhs ) {
+ return lhs.m_value < static_cast<double>(rhs) || lhs == rhs;
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- friend bool operator >= ( T lhs, Approx const& rhs ) {
- return double(lhs) > rhs.m_value || lhs == rhs;
+ friend bool operator >= ( T const& lhs, Approx const& rhs ) {
+ return static_cast<double>(lhs) > rhs.m_value || lhs == rhs;
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- friend bool operator >= ( Approx const& lhs, T rhs ) {
- return lhs.m_value > double(rhs) || lhs == rhs;
+ friend bool operator >= ( Approx const& lhs, T const& rhs ) {
+ return lhs.m_value > static_cast<double>(rhs) || lhs == rhs;
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- Approx& epsilon( T newEpsilon ) {
- m_epsilon = double(newEpsilon);
+ Approx& epsilon( T const& newEpsilon ) {
+ double epsilonAsDouble = static_cast<double>(newEpsilon);
+ setEpsilon(epsilonAsDouble);
return *this;
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- Approx& margin( T newMargin ) {
- m_margin = double(newMargin);
+ Approx& margin( T const& newMargin ) {
+ double marginAsDouble = static_cast<double>(newMargin);
+ setMargin(marginAsDouble);
return *this;
}
template <typename T, typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
- Approx& scale( T newScale ) {
- m_scale = double(newScale);
+ Approx& scale( T const& newScale ) {
+ m_scale = static_cast<double>(newScale);
return *this;
}
-#else
+ std::string toString() const;
- Approx operator()( double value ) {
- Approx approx( value );
- approx.epsilon( m_epsilon );
- approx.margin( m_margin );
- approx.scale( m_scale );
- return approx;
- }
+ private:
+ double m_epsilon;
+ double m_margin;
+ double m_scale;
+ double m_value;
+ };
+} // end namespace Detail
+
+namespace literals {
+ Detail::Approx operator "" _a(long double val);
+ Detail::Approx operator "" _a(unsigned long long val);
+} // end namespace literals
+
+template<>
+struct StringMaker<Catch::Detail::Approx> {
+ static std::string convert(Catch::Detail::Approx const& value);
+};
+
+} // end namespace Catch
+
+// end catch_approx.h
+// start catch_string_manip.h
+
+#include <string>
+#include <iosfwd>
+#include <vector>
+
+namespace Catch {
+
+ bool startsWith( std::string const& s, std::string const& prefix );
+ bool startsWith( std::string const& s, char prefix );
+ bool endsWith( std::string const& s, std::string const& suffix );
+ bool endsWith( std::string const& s, char suffix );
+ bool contains( std::string const& s, std::string const& infix );
+ void toLowerInPlace( std::string& s );
+ std::string toLower( std::string const& s );
+ //! Returns a new string without whitespace at the start/end
+ std::string trim( std::string const& str );
+ //! Returns a substring of the original ref without whitespace. Beware lifetimes!
+ StringRef trim(StringRef ref);
+
+ // !!! Be aware, returns refs into original string - make sure original string outlives them
+ std::vector<StringRef> splitStringRef( StringRef str, char delimiter );
+ bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis );
+
+ struct pluralise {
+ pluralise( std::size_t count, std::string const& label );
+
+ friend std::ostream& operator << ( std::ostream& os, pluralise const& pluraliser );
+
+ std::size_t m_count;
+ std::string m_label;
+ };
+}
+
+// end catch_string_manip.h
+#ifndef CATCH_CONFIG_DISABLE_MATCHERS
+// start catch_capture_matchers.h
+
+// start catch_matchers.h
+
+#include <string>
+#include <vector>
+
+namespace Catch {
+namespace Matchers {
+ namespace Impl {
+
+ template<typename ArgT> struct MatchAllOf;
+ template<typename ArgT> struct MatchAnyOf;
+ template<typename ArgT> struct MatchNotOf;
+
+ class MatcherUntypedBase {
+ public:
+ MatcherUntypedBase() = default;
+ MatcherUntypedBase ( MatcherUntypedBase const& ) = default;
+ MatcherUntypedBase& operator = ( MatcherUntypedBase const& ) = delete;
+ std::string toString() const;
- friend bool operator == ( double lhs, Approx const& rhs ) {
- // Thanks to Richard Harris for his help refining this formula
- bool relativeOK = std::fabs( lhs - rhs.m_value ) < rhs.m_epsilon * (rhs.m_scale + (std::max)( std::fabs(lhs), std::fabs(rhs.m_value) ) );
- if (relativeOK) {
+ protected:
+ virtual ~MatcherUntypedBase();
+ virtual std::string describe() const = 0;
+ mutable std::string m_cachedToString;
+ };
+
+#ifdef __clang__
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wnon-virtual-dtor"
+#endif
+
+ template<typename ObjectT>
+ struct MatcherMethod {
+ virtual bool match( ObjectT const& arg ) const = 0;
+ };
+
+#if defined(__OBJC__)
+ // Hack to fix Catch GH issue #1661. Could use id for generic Object support.
+ // use of const for Object pointers is very uncommon and under ARC it causes some kind of signature mismatch that breaks compilation
+ template<>
+ struct MatcherMethod<NSString*> {
+ virtual bool match( NSString* arg ) const = 0;
+ };
+#endif
+
+#ifdef __clang__
+# pragma clang diagnostic pop
+#endif
+
+ template<typename T>
+ struct MatcherBase : MatcherUntypedBase, MatcherMethod<T> {
+
+ MatchAllOf<T> operator && ( MatcherBase const& other ) const;
+ MatchAnyOf<T> operator || ( MatcherBase const& other ) const;
+ MatchNotOf<T> operator ! () const;
+ };
+
+ template<typename ArgT>
+ struct MatchAllOf : MatcherBase<ArgT> {
+ bool match( ArgT const& arg ) const override {
+ for( auto matcher : m_matchers ) {
+ if (!matcher->match(arg))
+ return false;
+ }
return true;
}
- return std::fabs(lhs - rhs.m_value) < rhs.m_margin;
- }
+ std::string describe() const override {
+ std::string description;
+ description.reserve( 4 + m_matchers.size()*32 );
+ description += "( ";
+ bool first = true;
+ for( auto matcher : m_matchers ) {
+ if( first )
+ first = false;
+ else
+ description += " and ";
+ description += matcher->toString();
+ }
+ description += " )";
+ return description;
+ }
- friend bool operator == ( Approx const& lhs, double rhs ) {
- return operator==( rhs, lhs );
- }
+ MatchAllOf<ArgT>& operator && ( MatcherBase<ArgT> const& other ) {
+ m_matchers.push_back( &other );
+ return *this;
+ }
- friend bool operator != ( double lhs, Approx const& rhs ) {
- return !operator==( lhs, rhs );
- }
+ std::vector<MatcherBase<ArgT> const*> m_matchers;
+ };
+ template<typename ArgT>
+ struct MatchAnyOf : MatcherBase<ArgT> {
- friend bool operator != ( Approx const& lhs, double rhs ) {
- return !operator==( rhs, lhs );
- }
+ bool match( ArgT const& arg ) const override {
+ for( auto matcher : m_matchers ) {
+ if (matcher->match(arg))
+ return true;
+ }
+ return false;
+ }
+ std::string describe() const override {
+ std::string description;
+ description.reserve( 4 + m_matchers.size()*32 );
+ description += "( ";
+ bool first = true;
+ for( auto matcher : m_matchers ) {
+ if( first )
+ first = false;
+ else
+ description += " or ";
+ description += matcher->toString();
+ }
+ description += " )";
+ return description;
+ }
- friend bool operator <= ( double lhs, Approx const& rhs ) {
- return lhs < rhs.m_value || lhs == rhs;
- }
+ MatchAnyOf<ArgT>& operator || ( MatcherBase<ArgT> const& other ) {
+ m_matchers.push_back( &other );
+ return *this;
+ }
- friend bool operator <= ( Approx const& lhs, double rhs ) {
- return lhs.m_value < rhs || lhs == rhs;
- }
+ std::vector<MatcherBase<ArgT> const*> m_matchers;
+ };
- friend bool operator >= ( double lhs, Approx const& rhs ) {
- return lhs > rhs.m_value || lhs == rhs;
- }
+ template<typename ArgT>
+ struct MatchNotOf : MatcherBase<ArgT> {
- friend bool operator >= ( Approx const& lhs, double rhs ) {
- return lhs.m_value > rhs || lhs == rhs;
- }
+ MatchNotOf( MatcherBase<ArgT> const& underlyingMatcher ) : m_underlyingMatcher( underlyingMatcher ) {}
- Approx& epsilon( double newEpsilon ) {
- m_epsilon = newEpsilon;
- return *this;
- }
+ bool match( ArgT const& arg ) const override {
+ return !m_underlyingMatcher.match( arg );
+ }
- Approx& margin( double newMargin ) {
- m_margin = newMargin;
- return *this;
- }
+ std::string describe() const override {
+ return "not " + m_underlyingMatcher.toString();
+ }
+ MatcherBase<ArgT> const& m_underlyingMatcher;
+ };
- Approx& scale( double newScale ) {
- m_scale = newScale;
- return *this;
+ template<typename T>
+ MatchAllOf<T> MatcherBase<T>::operator && ( MatcherBase const& other ) const {
+ return MatchAllOf<T>() && *this && other;
}
-#endif
-
- std::string toString() const {
- std::ostringstream oss;
- oss << "Approx( " << Catch::toString( m_value ) << " )";
- return oss.str();
+ template<typename T>
+ MatchAnyOf<T> MatcherBase<T>::operator || ( MatcherBase const& other ) const {
+ return MatchAnyOf<T>() || *this || other;
+ }
+ template<typename T>
+ MatchNotOf<T> MatcherBase<T>::operator ! () const {
+ return MatchNotOf<T>( *this );
}
- private:
- double m_epsilon;
- double m_margin;
- double m_scale;
- double m_value;
- };
-}
+ } // namespace Impl
-template<>
-inline std::string toString<Detail::Approx>( Detail::Approx const& value ) {
- return value.toString();
+} // namespace Matchers
+
+using namespace Matchers;
+using Matchers::Impl::MatcherBase;
+
+} // namespace Catch
+
+// end catch_matchers.h
+// start catch_matchers_exception.hpp
+
+namespace Catch {
+namespace Matchers {
+namespace Exception {
+
+class ExceptionMessageMatcher : public MatcherBase<std::exception> {
+ std::string m_message;
+public:
+
+ ExceptionMessageMatcher(std::string const& message):
+ m_message(message)
+ {}
+
+ bool match(std::exception const& ex) const override;
+
+ std::string describe() const override;
+};
+
+} // namespace Exception
+
+Exception::ExceptionMessageMatcher Message(std::string const& message);
+
+} // namespace Matchers
+} // namespace Catch
+
+// end catch_matchers_exception.hpp
+// start catch_matchers_floating.h
+
+namespace Catch {
+namespace Matchers {
+
+ namespace Floating {
+
+ enum class FloatingPointKind : uint8_t;
+
+ struct WithinAbsMatcher : MatcherBase<double> {
+ WithinAbsMatcher(double target, double margin);
+ bool match(double const& matchee) const override;
+ std::string describe() const override;
+ private:
+ double m_target;
+ double m_margin;
+ };
+
+ struct WithinUlpsMatcher : MatcherBase<double> {
+ WithinUlpsMatcher(double target, uint64_t ulps, FloatingPointKind baseType);
+ bool match(double const& matchee) const override;
+ std::string describe() const override;
+ private:
+ double m_target;
+ uint64_t m_ulps;
+ FloatingPointKind m_type;
+ };
+
+ // Given IEEE-754 format for floats and doubles, we can assume
+ // that float -> double promotion is lossless. Given this, we can
+ // assume that if we do the standard relative comparison of
+ // |lhs - rhs| <= epsilon * max(fabs(lhs), fabs(rhs)), then we get
+ // the same result if we do this for floats, as if we do this for
+ // doubles that were promoted from floats.
+ struct WithinRelMatcher : MatcherBase<double> {
+ WithinRelMatcher(double target, double epsilon);
+ bool match(double const& matchee) const override;
+ std::string describe() const override;
+ private:
+ double m_target;
+ double m_epsilon;
+ };
+
+ } // namespace Floating
+
+ // The following functions create the actual matcher objects.
+ // This allows the types to be inferred
+ Floating::WithinUlpsMatcher WithinULP(double target, uint64_t maxUlpDiff);
+ Floating::WithinUlpsMatcher WithinULP(float target, uint64_t maxUlpDiff);
+ Floating::WithinAbsMatcher WithinAbs(double target, double margin);
+ Floating::WithinRelMatcher WithinRel(double target, double eps);
+ // defaults epsilon to 100*numeric_limits<double>::epsilon()
+ Floating::WithinRelMatcher WithinRel(double target);
+ Floating::WithinRelMatcher WithinRel(float target, float eps);
+ // defaults epsilon to 100*numeric_limits<float>::epsilon()
+ Floating::WithinRelMatcher WithinRel(float target);
+
+} // namespace Matchers
+} // namespace Catch
+
+// end catch_matchers_floating.h
+// start catch_matchers_generic.hpp
+
+#include <functional>
+#include <string>
+
+namespace Catch {
+namespace Matchers {
+namespace Generic {
+
+namespace Detail {
+ std::string finalizeDescription(const std::string& desc);
}
-} // end namespace Catch
+template <typename T>
+class PredicateMatcher : public MatcherBase<T> {
+ std::function<bool(T const&)> m_predicate;
+ std::string m_description;
+public:
+
+ PredicateMatcher(std::function<bool(T const&)> const& elem, std::string const& descr)
+ :m_predicate(std::move(elem)),
+ m_description(Detail::finalizeDescription(descr))
+ {}
-// #included from: internal/catch_matchers_string.h
-#define TWOBLUECUBES_CATCH_MATCHERS_STRING_H_INCLUDED
+ bool match( T const& item ) const override {
+ return m_predicate(item);
+ }
+
+ std::string describe() const override {
+ return m_description;
+ }
+};
+
+} // namespace Generic
+
+ // The following functions create the actual matcher objects.
+ // The user has to explicitly specify type to the function, because
+ // inferring std::function<bool(T const&)> is hard (but possible) and
+ // requires a lot of TMP.
+ template<typename T>
+ Generic::PredicateMatcher<T> Predicate(std::function<bool(T const&)> const& predicate, std::string const& description = "") {
+ return Generic::PredicateMatcher<T>(predicate, description);
+ }
+
+} // namespace Matchers
+} // namespace Catch
+
+// end catch_matchers_generic.hpp
+// start catch_matchers_string.h
+
+#include <string>
namespace Catch {
namespace Matchers {
@@ -2963,7 +3517,7 @@ namespace Matchers {
struct StringMatcherBase : MatcherBase<std::string> {
StringMatcherBase( std::string const& operation, CasedString const& comparator );
- virtual std::string describe() const CATCH_OVERRIDE;
+ std::string describe() const override;
CasedString m_comparator;
std::string m_operation;
@@ -2971,19 +3525,29 @@ namespace Matchers {
struct EqualsMatcher : StringMatcherBase {
EqualsMatcher( CasedString const& comparator );
- virtual bool match( std::string const& source ) const CATCH_OVERRIDE;
+ bool match( std::string const& source ) const override;
};
struct ContainsMatcher : StringMatcherBase {
ContainsMatcher( CasedString const& comparator );
- virtual bool match( std::string const& source ) const CATCH_OVERRIDE;
+ bool match( std::string const& source ) const override;
};
struct StartsWithMatcher : StringMatcherBase {
StartsWithMatcher( CasedString const& comparator );
- virtual bool match( std::string const& source ) const CATCH_OVERRIDE;
+ bool match( std::string const& source ) const override;
};
struct EndsWithMatcher : StringMatcherBase {
EndsWithMatcher( CasedString const& comparator );
- virtual bool match( std::string const& source ) const CATCH_OVERRIDE;
+ bool match( std::string const& source ) const override;
+ };
+
+ struct RegexMatcher : MatcherBase<std::string> {
+ RegexMatcher( std::string regex, CaseSensitive::Choice caseSensitivity );
+ bool match( std::string const& matchee ) const override;
+ std::string describe() const override;
+
+ private:
+ std::string m_regex;
+ CaseSensitive::Choice m_caseSensitivity;
};
} // namespace StdString
@@ -2995,76 +3559,147 @@ namespace Matchers {
StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes );
StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes );
StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes );
+ StdString::RegexMatcher Matches( std::string const& regex, CaseSensitive::Choice caseSensitivity = CaseSensitive::Yes );
} // namespace Matchers
} // namespace Catch
-// #included from: internal/catch_matchers_vector.h
-#define TWOBLUECUBES_CATCH_MATCHERS_VECTOR_H_INCLUDED
+// end catch_matchers_string.h
+// start catch_matchers_vector.h
+
+#include <algorithm>
namespace Catch {
namespace Matchers {
namespace Vector {
-
template<typename T>
- struct ContainsElementMatcher : MatcherBase<std::vector<T>, T> {
+ struct ContainsElementMatcher : MatcherBase<std::vector<T>> {
ContainsElementMatcher(T const &comparator) : m_comparator( comparator) {}
- bool match(std::vector<T> const &v) const CATCH_OVERRIDE {
- return std::find(v.begin(), v.end(), m_comparator) != v.end();
+ bool match(std::vector<T> const &v) const override {
+ for (auto const& el : v) {
+ if (el == m_comparator) {
+ return true;
+ }
+ }
+ return false;
}
- virtual std::string describe() const CATCH_OVERRIDE {
- return "Contains: " + Catch::toString( m_comparator );
+ std::string describe() const override {
+ return "Contains: " + ::Catch::Detail::stringify( m_comparator );
}
T const& m_comparator;
};
template<typename T>
- struct ContainsMatcher : MatcherBase<std::vector<T>, std::vector<T> > {
+ struct ContainsMatcher : MatcherBase<std::vector<T>> {
ContainsMatcher(std::vector<T> const &comparator) : m_comparator( comparator ) {}
- bool match(std::vector<T> const &v) const CATCH_OVERRIDE {
+ bool match(std::vector<T> const &v) const override {
// !TBD: see note in EqualsMatcher
if (m_comparator.size() > v.size())
return false;
- for (size_t i = 0; i < m_comparator.size(); ++i)
- if (std::find(v.begin(), v.end(), m_comparator[i]) == v.end())
+ for (auto const& comparator : m_comparator) {
+ auto present = false;
+ for (const auto& el : v) {
+ if (el == comparator) {
+ present = true;
+ break;
+ }
+ }
+ if (!present) {
return false;
+ }
+ }
return true;
}
- virtual std::string describe() const CATCH_OVERRIDE {
- return "Contains: " + Catch::toString( m_comparator );
+ std::string describe() const override {
+ return "Contains: " + ::Catch::Detail::stringify( m_comparator );
}
std::vector<T> const& m_comparator;
};
template<typename T>
- struct EqualsMatcher : MatcherBase<std::vector<T>, std::vector<T> > {
+ struct EqualsMatcher : MatcherBase<std::vector<T>> {
EqualsMatcher(std::vector<T> const &comparator) : m_comparator( comparator ) {}
- bool match(std::vector<T> const &v) const CATCH_OVERRIDE {
+ bool match(std::vector<T> const &v) const override {
// !TBD: This currently works if all elements can be compared using !=
// - a more general approach would be via a compare template that defaults
// to using !=. but could be specialised for, e.g. std::vector<T> etc
// - then just call that directly
if (m_comparator.size() != v.size())
return false;
- for (size_t i = 0; i < v.size(); ++i)
+ for (std::size_t i = 0; i < v.size(); ++i)
if (m_comparator[i] != v[i])
return false;
return true;
}
- virtual std::string describe() const CATCH_OVERRIDE {
- return "Equals: " + Catch::toString( m_comparator );
+ std::string describe() const override {
+ return "Equals: " + ::Catch::Detail::stringify( m_comparator );
+ }
+ std::vector<T> const& m_comparator;
+ };
+
+ template<typename T>
+ struct ApproxMatcher : MatcherBase<std::vector<T>> {
+
+ ApproxMatcher(std::vector<T> const& comparator) : m_comparator( comparator ) {}
+
+ bool match(std::vector<T> const &v) const override {
+ if (m_comparator.size() != v.size())
+ return false;
+ for (std::size_t i = 0; i < v.size(); ++i)
+ if (m_comparator[i] != approx(v[i]))
+ return false;
+ return true;
+ }
+ std::string describe() const override {
+ return "is approx: " + ::Catch::Detail::stringify( m_comparator );
}
+ template <typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
+ ApproxMatcher& epsilon( T const& newEpsilon ) {
+ approx.epsilon(newEpsilon);
+ return *this;
+ }
+ template <typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
+ ApproxMatcher& margin( T const& newMargin ) {
+ approx.margin(newMargin);
+ return *this;
+ }
+ template <typename = typename std::enable_if<std::is_constructible<double, T>::value>::type>
+ ApproxMatcher& scale( T const& newScale ) {
+ approx.scale(newScale);
+ return *this;
+ }
+
std::vector<T> const& m_comparator;
+ mutable Catch::Detail::Approx approx = Catch::Detail::Approx::custom();
+ };
+
+ template<typename T>
+ struct UnorderedEqualsMatcher : MatcherBase<std::vector<T>> {
+ UnorderedEqualsMatcher(std::vector<T> const& target) : m_target(target) {}
+ bool match(std::vector<T> const& vec) const override {
+ // Note: This is a reimplementation of std::is_permutation,
+ // because I don't want to include <algorithm> inside the common path
+ if (m_target.size() != vec.size()) {
+ return false;
+ }
+ return std::is_permutation(m_target.begin(), m_target.end(), vec.begin());
+ }
+
+ std::string describe() const override {
+ return "UnorderedEquals: " + ::Catch::Detail::stringify(m_target);
+ }
+ private:
+ std::vector<T> const& m_target;
};
} // namespace Vector
@@ -3087,35 +3722,658 @@ namespace Matchers {
return Vector::EqualsMatcher<T>( comparator );
}
+ template<typename T>
+ Vector::ApproxMatcher<T> Approx( std::vector<T> const& comparator ) {
+ return Vector::ApproxMatcher<T>( comparator );
+ }
+
+ template<typename T>
+ Vector::UnorderedEqualsMatcher<T> UnorderedEquals(std::vector<T> const& target) {
+ return Vector::UnorderedEqualsMatcher<T>(target);
+ }
+
} // namespace Matchers
} // namespace Catch
-// #included from: internal/catch_interfaces_tag_alias_registry.h
-#define TWOBLUECUBES_CATCH_INTERFACES_TAG_ALIAS_REGISTRY_H_INCLUDED
+// end catch_matchers_vector.h
+namespace Catch {
-// #included from: catch_tag_alias.h
-#define TWOBLUECUBES_CATCH_TAG_ALIAS_H_INCLUDED
+ template<typename ArgT, typename MatcherT>
+ class MatchExpr : public ITransientExpression {
+ ArgT const& m_arg;
+ MatcherT m_matcher;
+ StringRef m_matcherString;
+ public:
+ MatchExpr( ArgT const& arg, MatcherT const& matcher, StringRef const& matcherString )
+ : ITransientExpression{ true, matcher.match( arg ) },
+ m_arg( arg ),
+ m_matcher( matcher ),
+ m_matcherString( matcherString )
+ {}
-#include <string>
+ void streamReconstructedExpression( std::ostream &os ) const override {
+ auto matcherAsString = m_matcher.toString();
+ os << Catch::Detail::stringify( m_arg ) << ' ';
+ if( matcherAsString == Detail::unprintableString )
+ os << m_matcherString;
+ else
+ os << matcherAsString;
+ }
+ };
+
+ using StringMatcher = Matchers::Impl::MatcherBase<std::string>;
+
+ void handleExceptionMatchExpr( AssertionHandler& handler, StringMatcher const& matcher, StringRef const& matcherString );
+
+ template<typename ArgT, typename MatcherT>
+ auto makeMatchExpr( ArgT const& arg, MatcherT const& matcher, StringRef const& matcherString ) -> MatchExpr<ArgT, MatcherT> {
+ return MatchExpr<ArgT, MatcherT>( arg, matcher, matcherString );
+ }
+
+} // namespace Catch
+
+///////////////////////////////////////////////////////////////////////////////
+#define INTERNAL_CHECK_THAT( macroName, matcher, resultDisposition, arg ) \
+ do { \
+ Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(arg) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \
+ INTERNAL_CATCH_TRY { \
+ catchAssertionHandler.handleExpr( Catch::makeMatchExpr( arg, matcher, #matcher##_catch_sr ) ); \
+ } INTERNAL_CATCH_CATCH( catchAssertionHandler ) \
+ INTERNAL_CATCH_REACT( catchAssertionHandler ) \
+ } while( false )
+
+///////////////////////////////////////////////////////////////////////////////
+#define INTERNAL_CATCH_THROWS_MATCHES( macroName, exceptionType, resultDisposition, matcher, ... ) \
+ do { \
+ Catch::AssertionHandler catchAssertionHandler( macroName##_catch_sr, CATCH_INTERNAL_LINEINFO, CATCH_INTERNAL_STRINGIFY(__VA_ARGS__) ", " CATCH_INTERNAL_STRINGIFY(exceptionType) ", " CATCH_INTERNAL_STRINGIFY(matcher), resultDisposition ); \
+ if( catchAssertionHandler.allowThrows() ) \
+ try { \
+ static_cast<void>(__VA_ARGS__ ); \
+ catchAssertionHandler.handleUnexpectedExceptionNotThrown(); \
+ } \
+ catch( exceptionType const& ex ) { \
+ catchAssertionHandler.handleExpr( Catch::makeMatchExpr( ex, matcher, #matcher##_catch_sr ) ); \
+ } \
+ catch( ... ) { \
+ catchAssertionHandler.handleUnexpectedInflightException(); \
+ } \
+ else \
+ catchAssertionHandler.handleThrowingCallSkipped(); \
+ INTERNAL_CATCH_REACT( catchAssertionHandler ) \
+ } while( false )
+
+// end catch_capture_matchers.h
+#endif
+// start catch_generators.hpp
+
+// start catch_interfaces_generatortracker.h
+
+
+#include <memory>
namespace Catch {
- struct TagAlias {
- TagAlias( std::string const& _tag, SourceLineInfo _lineInfo ) : tag( _tag ), lineInfo( _lineInfo ) {}
+ namespace Generators {
+ class GeneratorUntypedBase {
+ public:
+ GeneratorUntypedBase() = default;
+ virtual ~GeneratorUntypedBase();
+ // Attempts to move the generator to the next element
+ //
+ // Returns true iff the move succeeded (and a valid element
+ // can be retrieved).
+ virtual bool next() = 0;
+ };
+ using GeneratorBasePtr = std::unique_ptr<GeneratorUntypedBase>;
- std::string tag;
- SourceLineInfo lineInfo;
+ } // namespace Generators
+
+ struct IGeneratorTracker {
+ virtual ~IGeneratorTracker();
+ virtual auto hasGenerator() const -> bool = 0;
+ virtual auto getGenerator() const -> Generators::GeneratorBasePtr const& = 0;
+ virtual void setGenerator( Generators::GeneratorBasePtr&& generator ) = 0;
};
- struct RegistrarForTagAliases {
- RegistrarForTagAliases( char const* alias, char const* tag, SourceLineInfo const& lineInfo );
+} // namespace Catch
+
+// end catch_interfaces_generatortracker.h
+// start catch_enforce.h
+
+#include <exception>
+
+namespace Catch {
+#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS)
+ template <typename Ex>
+ [[noreturn]]
+ void throw_exception(Ex const& e) {
+ throw e;
+ }
+#else // ^^ Exceptions are enabled // Exceptions are disabled vv
+ [[noreturn]]
+ void throw_exception(std::exception const& e);
+#endif
+
+ [[noreturn]]
+ void throw_logic_error(std::string const& msg);
+ [[noreturn]]
+ void throw_domain_error(std::string const& msg);
+ [[noreturn]]
+ void throw_runtime_error(std::string const& msg);
+
+} // namespace Catch;
+
+#define CATCH_MAKE_MSG(...) \
+ (Catch::ReusableStringStream() << __VA_ARGS__).str()
+
+#define CATCH_INTERNAL_ERROR(...) \
+ Catch::throw_logic_error(CATCH_MAKE_MSG( CATCH_INTERNAL_LINEINFO << ": Internal Catch2 error: " << __VA_ARGS__))
+
+#define CATCH_ERROR(...) \
+ Catch::throw_domain_error(CATCH_MAKE_MSG( __VA_ARGS__ ))
+
+#define CATCH_RUNTIME_ERROR(...) \
+ Catch::throw_runtime_error(CATCH_MAKE_MSG( __VA_ARGS__ ))
+
+#define CATCH_ENFORCE( condition, ... ) \
+ do{ if( !(condition) ) CATCH_ERROR( __VA_ARGS__ ); } while(false)
+
+// end catch_enforce.h
+#include <memory>
+#include <vector>
+#include <cassert>
+
+#include <utility>
+#include <exception>
+
+namespace Catch {
+
+class GeneratorException : public std::exception {
+ const char* const m_msg = "";
+
+public:
+ GeneratorException(const char* msg):
+ m_msg(msg)
+ {}
+
+ const char* what() const noexcept override final;
+};
+
+namespace Generators {
+
+ // !TBD move this into its own location?
+ namespace pf{
+ template<typename T, typename... Args>
+ std::unique_ptr<T> make_unique( Args&&... args ) {
+ return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
+ }
+ }
+
+ template<typename T>
+ struct IGenerator : GeneratorUntypedBase {
+ virtual ~IGenerator() = default;
+
+ // Returns the current element of the generator
+ //
+ // \Precondition The generator is either freshly constructed,
+ // or the last call to `next()` returned true
+ virtual T const& get() const = 0;
+ using type = T;
};
-} // end namespace Catch
+ template<typename T>
+ class SingleValueGenerator final : public IGenerator<T> {
+ T m_value;
+ public:
+ SingleValueGenerator(T const& value) : m_value( value ) {}
+ SingleValueGenerator(T&& value) : m_value(std::move(value)) {}
-#define CATCH_REGISTER_TAG_ALIAS( alias, spec ) namespace{ Catch::RegistrarForTagAliases INTERNAL_CATCH_UNIQUE_NAME( AutoRegisterTagAlias )( alias, spec, CATCH_INTERNAL_LINEINFO ); }
-// #included from: catch_option.hpp
-#define TWOBLUECUBES_CATCH_OPTION_HPP_INCLUDED
+ T const& get() const override {
+ return m_value;
+ }
+ bool next() override {
+ return false;
+ }
+ };
+
+ template<typename T>
+ class FixedValuesGenerator final : public IGenerator<T> {
+ static_assert(!std::is_same<T, bool>::value,
+ "FixedValuesGenerator does not support bools because of std::vector<bool>"
+ "specialization, use SingleValue Generator instead.");
+ std::vector<T> m_values;
+ size_t m_idx = 0;
+ public:
+ FixedValuesGenerator( std::initializer_list<T> values ) : m_values( values ) {}
+
+ T const& get() const override {
+ return m_values[m_idx];
+ }
+ bool next() override {
+ ++m_idx;
+ return m_idx < m_values.size();
+ }
+ };
+
+ template <typename T>
+ class GeneratorWrapper final {
+ std::unique_ptr<IGenerator<T>> m_generator;
+ public:
+ GeneratorWrapper(std::unique_ptr<IGenerator<T>> generator):
+ m_generator(std::move(generator))
+ {}
+ T const& get() const {
+ return m_generator->get();
+ }
+ bool next() {
+ return m_generator->next();
+ }
+ };
+
+ template <typename T>
+ GeneratorWrapper<T> value(T&& value) {
+ return GeneratorWrapper<T>(pf::make_unique<SingleValueGenerator<T>>(std::forward<T>(value)));
+ }
+ template <typename T>
+ GeneratorWrapper<T> values(std::initializer_list<T> values) {
+ return GeneratorWrapper<T>(pf::make_unique<FixedValuesGenerator<T>>(values));
+ }
+
+ template<typename T>
+ class Generators : public IGenerator<T> {
+ std::vector<GeneratorWrapper<T>> m_generators;
+ size_t m_current = 0;
+
+ void populate(GeneratorWrapper<T>&& generator) {
+ m_generators.emplace_back(std::move(generator));
+ }
+ void populate(T&& val) {
+ m_generators.emplace_back(value(std::move(val)));
+ }
+ template<typename U>
+ void populate(U&& val) {
+ populate(T(std::move(val)));
+ }
+ template<typename U, typename... Gs>
+ void populate(U&& valueOrGenerator, Gs... moreGenerators) {
+ populate(std::forward<U>(valueOrGenerator));
+ populate(std::forward<Gs>(moreGenerators)...);
+ }
+
+ public:
+ template <typename... Gs>
+ Generators(Gs... moreGenerators) {
+ m_generators.reserve(sizeof...(Gs));
+ populate(std::forward<Gs>(moreGenerators)...);
+ }
+
+ T const& get() const override {
+ return m_generators[m_current].get();
+ }
+
+ bool next() override {
+ if (m_current >= m_generators.size()) {
+ return false;
+ }
+ const bool current_status = m_generators[m_current].next();
+ if (!current_status) {
+ ++m_current;
+ }
+ return m_current < m_generators.size();
+ }
+ };
+
+ template<typename... Ts>
+ GeneratorWrapper<std::tuple<Ts...>> table( std::initializer_list<std::tuple<typename std::decay<Ts>::type...>> tuples ) {
+ return values<std::tuple<Ts...>>( tuples );
+ }
+
+ // Tag type to signal that a generator sequence should convert arguments to a specific type
+ template <typename T>
+ struct as {};
+
+ template<typename T, typename... Gs>
+ auto makeGenerators( GeneratorWrapper<T>&& generator, Gs... moreGenerators ) -> Generators<T> {
+ return Generators<T>(std::move(generator), std::forward<Gs>(moreGenerators)...);
+ }
+ template<typename T>
+ auto makeGenerators( GeneratorWrapper<T>&& generator ) -> Generators<T> {
+ return Generators<T>(std::move(generator));
+ }
+ template<typename T, typename... Gs>
+ auto makeGenerators( T&& val, Gs... moreGenerators ) -> Generators<T> {
+ return makeGenerators( value( std::forward<T>( val ) ), std::forward<Gs>( moreGenerators )... );
+ }
+ template<typename T, typename U, typename... Gs>
+ auto makeGenerators( as<T>, U&& val, Gs... moreGenerators ) -> Generators<T> {
+ return makeGenerators( value( T( std::forward<U>( val ) ) ), std::forward<Gs>( moreGenerators )... );
+ }
+
+ auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker&;
+
+ template<typename L>
+ // Note: The type after -> is weird, because VS2015 cannot parse
+ // the expression used in the typedef inside, when it is in
+ // return type. Yeah.
+ auto generate( SourceLineInfo const& lineInfo, L const& generatorExpression ) -> decltype(std::declval<decltype(generatorExpression())>().get()) {
+ using UnderlyingType = typename decltype(generatorExpression())::type;
+
+ IGeneratorTracker& tracker = acquireGeneratorTracker( lineInfo );
+ if (!tracker.hasGenerator()) {
+ tracker.setGenerator(pf::make_unique<Generators<UnderlyingType>>(generatorExpression()));
+ }
+
+ auto const& generator = static_cast<IGenerator<UnderlyingType> const&>( *tracker.getGenerator() );
+ return generator.get();
+ }
+
+} // namespace Generators
+} // namespace Catch
+
+#define GENERATE( ... ) \
+ Catch::Generators::generate( CATCH_INTERNAL_LINEINFO, [ ]{ using namespace Catch::Generators; return makeGenerators( __VA_ARGS__ ); } )
+#define GENERATE_COPY( ... ) \
+ Catch::Generators::generate( CATCH_INTERNAL_LINEINFO, [=]{ using namespace Catch::Generators; return makeGenerators( __VA_ARGS__ ); } )
+#define GENERATE_REF( ... ) \
+ Catch::Generators::generate( CATCH_INTERNAL_LINEINFO, [&]{ using namespace Catch::Generators; return makeGenerators( __VA_ARGS__ ); } )
+
+// end catch_generators.hpp
+// start catch_generators_generic.hpp
+
+namespace Catch {
+namespace Generators {
+
+ template <typename T>
+ class TakeGenerator : public IGenerator<T> {
+ GeneratorWrapper<T> m_generator;
+ size_t m_returned = 0;
+ size_t m_target;
+ public:
+ TakeGenerator(size_t target, GeneratorWrapper<T>&& generator):
+ m_generator(std::move(generator)),
+ m_target(target)
+ {
+ assert(target != 0 && "Empty generators are not allowed");
+ }
+ T const& get() const override {
+ return m_generator.get();
+ }
+ bool next() override {
+ ++m_returned;
+ if (m_returned >= m_target) {
+ return false;
+ }
+
+ const auto success = m_generator.next();
+ // If the underlying generator does not contain enough values
+ // then we cut short as well
+ if (!success) {
+ m_returned = m_target;
+ }
+ return success;
+ }
+ };
+
+ template <typename T>
+ GeneratorWrapper<T> take(size_t target, GeneratorWrapper<T>&& generator) {
+ return GeneratorWrapper<T>(pf::make_unique<TakeGenerator<T>>(target, std::move(generator)));
+ }
+
+ template <typename T, typename Predicate>
+ class FilterGenerator : public IGenerator<T> {
+ GeneratorWrapper<T> m_generator;
+ Predicate m_predicate;
+ public:
+ template <typename P = Predicate>
+ FilterGenerator(P&& pred, GeneratorWrapper<T>&& generator):
+ m_generator(std::move(generator)),
+ m_predicate(std::forward<P>(pred))
+ {
+ if (!m_predicate(m_generator.get())) {
+ // It might happen that there are no values that pass the
+ // filter. In that case we throw an exception.
+ auto has_initial_value = next();
+ if (!has_initial_value) {
+ Catch::throw_exception(GeneratorException("No valid value found in filtered generator"));
+ }
+ }
+ }
+
+ T const& get() const override {
+ return m_generator.get();
+ }
+
+ bool next() override {
+ bool success = m_generator.next();
+ if (!success) {
+ return false;
+ }
+ while (!m_predicate(m_generator.get()) && (success = m_generator.next()) == true);
+ return success;
+ }
+ };
+
+ template <typename T, typename Predicate>
+ GeneratorWrapper<T> filter(Predicate&& pred, GeneratorWrapper<T>&& generator) {
+ return GeneratorWrapper<T>(std::unique_ptr<IGenerator<T>>(pf::make_unique<FilterGenerator<T, Predicate>>(std::forward<Predicate>(pred), std::move(generator))));
+ }
+
+ template <typename T>
+ class RepeatGenerator : public IGenerator<T> {
+ static_assert(!std::is_same<T, bool>::value,
+ "RepeatGenerator currently does not support bools"
+ "because of std::vector<bool> specialization");
+ GeneratorWrapper<T> m_generator;
+ mutable std::vector<T> m_returned;
+ size_t m_target_repeats;
+ size_t m_current_repeat = 0;
+ size_t m_repeat_index = 0;
+ public:
+ RepeatGenerator(size_t repeats, GeneratorWrapper<T>&& generator):
+ m_generator(std::move(generator)),
+ m_target_repeats(repeats)
+ {
+ assert(m_target_repeats > 0 && "Repeat generator must repeat at least once");
+ }
+
+ T const& get() const override {
+ if (m_current_repeat == 0) {
+ m_returned.push_back(m_generator.get());
+ return m_returned.back();
+ }
+ return m_returned[m_repeat_index];
+ }
+
+ bool next() override {
+ // There are 2 basic cases:
+ // 1) We are still reading the generator
+ // 2) We are reading our own cache
+
+ // In the first case, we need to poke the underlying generator.
+ // If it happily moves, we are left in that state, otherwise it is time to start reading from our cache
+ if (m_current_repeat == 0) {
+ const auto success = m_generator.next();
+ if (!success) {
+ ++m_current_repeat;
+ }
+ return m_current_repeat < m_target_repeats;
+ }
+
+ // In the second case, we need to move indices forward and check that we haven't run up against the end
+ ++m_repeat_index;
+ if (m_repeat_index == m_returned.size()) {
+ m_repeat_index = 0;
+ ++m_current_repeat;
+ }
+ return m_current_repeat < m_target_repeats;
+ }
+ };
+
+ template <typename T>
+ GeneratorWrapper<T> repeat(size_t repeats, GeneratorWrapper<T>&& generator) {
+ return GeneratorWrapper<T>(pf::make_unique<RepeatGenerator<T>>(repeats, std::move(generator)));
+ }
+
+ template <typename T, typename U, typename Func>
+ class MapGenerator : public IGenerator<T> {
+ // TBD: provide static assert for mapping function, for friendly error message
+ GeneratorWrapper<U> m_generator;
+ Func m_function;
+ // To avoid returning dangling reference, we have to save the values
+ T m_cache;
+ public:
+ template <typename F2 = Func>
+ MapGenerator(F2&& function, GeneratorWrapper<U>&& generator) :
+ m_generator(std::move(generator)),
+ m_function(std::forward<F2>(function)),
+ m_cache(m_function(m_generator.get()))
+ {}
+
+ T const& get() const override {
+ return m_cache;
+ }
+ bool next() override {
+ const auto success = m_generator.next();
+ if (success) {
+ m_cache = m_function(m_generator.get());
+ }
+ return success;
+ }
+ };
+
+#if defined(__cpp_lib_is_invocable) && __cpp_lib_is_invocable >= 201703
+ // std::result_of is deprecated in C++17 and removed in C++20. Hence, it is
+ // replaced with std::invoke_result here. Also *_t format is preferred over
+ // typename *::type format.
+ template <typename Func, typename U>
+ using MapFunctionReturnType = std::remove_reference_t<std::remove_cv_t<std::invoke_result_t<Func, U>>>;
+#else
+ template <typename Func, typename U>
+ using MapFunctionReturnType = typename std::remove_reference<typename std::remove_cv<typename std::result_of<Func(U)>::type>::type>::type;
+#endif
+
+ template <typename Func, typename U, typename T = MapFunctionReturnType<Func, U>>
+ GeneratorWrapper<T> map(Func&& function, GeneratorWrapper<U>&& generator) {
+ return GeneratorWrapper<T>(
+ pf::make_unique<MapGenerator<T, U, Func>>(std::forward<Func>(function), std::move(generator))
+ );
+ }
+
+ template <typename T, typename U, typename Func>
+ GeneratorWrapper<T> map(Func&& function, GeneratorWrapper<U>&& generator) {
+ return GeneratorWrapper<T>(
+ pf::make_unique<MapGenerator<T, U, Func>>(std::forward<Func>(function), std::move(generator))
+ );
+ }
+
+ template <typename T>
+ class ChunkGenerator final : public IGenerator<std::vector<T>> {
+ std::vector<T> m_chunk;
+ size_t m_chunk_size;
+ GeneratorWrapper<T> m_generator;
+ bool m_used_up = false;
+ public:
+ ChunkGenerator(size_t size, GeneratorWrapper<T> generator) :
+ m_chunk_size(size), m_generator(std::move(generator))
+ {
+ m_chunk.reserve(m_chunk_size);
+ if (m_chunk_size != 0) {
+ m_chunk.push_back(m_generator.get());
+ for (size_t i = 1; i < m_chunk_size; ++i) {
+ if (!m_generator.next()) {
+ Catch::throw_exception(GeneratorException("Not enough values to initialize the first chunk"));
+ }
+ m_chunk.push_back(m_generator.get());
+ }
+ }
+ }
+ std::vector<T> const& get() const override {
+ return m_chunk;
+ }
+ bool next() override {
+ m_chunk.clear();
+ for (size_t idx = 0; idx < m_chunk_size; ++idx) {
+ if (!m_generator.next()) {
+ return false;
+ }
+ m_chunk.push_back(m_generator.get());
+ }
+ return true;
+ }
+ };
+
+ template <typename T>
+ GeneratorWrapper<std::vector<T>> chunk(size_t size, GeneratorWrapper<T>&& generator) {
+ return GeneratorWrapper<std::vector<T>>(
+ pf::make_unique<ChunkGenerator<T>>(size, std::move(generator))
+ );
+ }
+
+} // namespace Generators
+} // namespace Catch
+
+// end catch_generators_generic.hpp
+// start catch_generators_specific.hpp
+
+// start catch_context.h
+
+#include <memory>
+
+namespace Catch {
+
+ struct IResultCapture;
+ struct IRunner;
+ struct IConfig;
+ struct IMutableContext;
+
+ using IConfigPtr = std::shared_ptr<IConfig const>;
+
+ struct IContext
+ {
+ virtual ~IContext();
+
+ virtual IResultCapture* getResultCapture() = 0;
+ virtual IRunner* getRunner() = 0;
+ virtual IConfigPtr const& getConfig() const = 0;
+ };
+
+ struct IMutableContext : IContext
+ {
+ virtual ~IMutableContext();
+ virtual void setResultCapture( IResultCapture* resultCapture ) = 0;
+ virtual void setRunner( IRunner* runner ) = 0;
+ virtual void setConfig( IConfigPtr const& config ) = 0;
+
+ private:
+ static IMutableContext *currentContext;
+ friend IMutableContext& getCurrentMutableContext();
+ friend void cleanUpContext();
+ static void createContext();
+ };
+
+ inline IMutableContext& getCurrentMutableContext()
+ {
+ if( !IMutableContext::currentContext )
+ IMutableContext::createContext();
+ // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.UndefReturn)
+ return *IMutableContext::currentContext;
+ }
+
+ inline IContext& getCurrentContext()
+ {
+ return getCurrentMutableContext();
+ }
+
+ void cleanUpContext();
+
+ class SimplePcg32;
+ SimplePcg32& rng();
+}
+
+// end catch_context.h
+// start catch_interfaces_config.h
+
+// start catch_option.hpp
namespace Catch {
@@ -3123,12 +4381,12 @@ namespace Catch {
template<typename T>
class Option {
public:
- Option() : nullableValue( CATCH_NULL ) {}
+ Option() : nullableValue( nullptr ) {}
Option( T const& _value )
: nullableValue( new( storage ) T( _value ) )
{}
Option( Option const& _other )
- : nullableValue( _other ? new( storage ) T( *_other ) : CATCH_NULL )
+ : nullableValue( _other ? new( storage ) T( *_other ) : nullptr )
{}
~Option() {
@@ -3152,7 +4410,7 @@ namespace Catch {
void reset() {
if( nullableValue )
nullableValue->~T();
- nullableValue = CATCH_NULL;
+ nullableValue = nullptr;
}
T& operator*() { return *nullableValue; }
@@ -3164,50 +4422,313 @@ namespace Catch {
return nullableValue ? *nullableValue : defaultValue;
}
- bool some() const { return nullableValue != CATCH_NULL; }
- bool none() const { return nullableValue == CATCH_NULL; }
+ bool some() const { return nullableValue != nullptr; }
+ bool none() const { return nullableValue == nullptr; }
- bool operator !() const { return nullableValue == CATCH_NULL; }
- operator SafeBool::type() const {
- return SafeBool::makeSafe( some() );
+ bool operator !() const { return nullableValue == nullptr; }
+ explicit operator bool() const {
+ return some();
}
private:
T *nullableValue;
- union {
- char storage[sizeof(T)];
-
- // These are here to force alignment for the storage
- long double dummy1;
- void (*dummy2)();
- long double dummy3;
-#ifdef CATCH_CONFIG_CPP11_LONG_LONG
- long long dummy4;
-#endif
- };
+ alignas(alignof(T)) char storage[sizeof(T)];
};
} // end namespace Catch
+// end catch_option.hpp
+#include <iosfwd>
+#include <string>
+#include <vector>
+#include <memory>
+
namespace Catch {
- struct ITagAliasRegistry {
- virtual ~ITagAliasRegistry();
- virtual Option<TagAlias> find( std::string const& alias ) const = 0;
- virtual std::string expandAliases( std::string const& unexpandedTestSpec ) const = 0;
+ enum class Verbosity {
+ Quiet = 0,
+ Normal,
+ High
+ };
- static ITagAliasRegistry const& get();
+ struct WarnAbout { enum What {
+ Nothing = 0x00,
+ NoAssertions = 0x01,
+ NoTests = 0x02
+ }; };
+
+ struct ShowDurations { enum OrNot {
+ DefaultForReporter,
+ Always,
+ Never
+ }; };
+ struct RunTests { enum InWhatOrder {
+ InDeclarationOrder,
+ InLexicographicalOrder,
+ InRandomOrder
+ }; };
+ struct UseColour { enum YesOrNo {
+ Auto,
+ Yes,
+ No
+ }; };
+ struct WaitForKeypress { enum When {
+ Never,
+ BeforeStart = 1,
+ BeforeExit = 2,
+ BeforeStartAndExit = BeforeStart | BeforeExit
+ }; };
+
+ class TestSpec;
+
+ struct IConfig : NonCopyable {
+
+ virtual ~IConfig();
+
+ virtual bool allowThrows() const = 0;
+ virtual std::ostream& stream() const = 0;
+ virtual std::string name() const = 0;
+ virtual bool includeSuccessfulResults() const = 0;
+ virtual bool shouldDebugBreak() const = 0;
+ virtual bool warnAboutMissingAssertions() const = 0;
+ virtual bool warnAboutNoTests() const = 0;
+ virtual int abortAfter() const = 0;
+ virtual bool showInvisibles() const = 0;
+ virtual ShowDurations::OrNot showDurations() const = 0;
+ virtual TestSpec const& testSpec() const = 0;
+ virtual bool hasTestFilters() const = 0;
+ virtual std::vector<std::string> const& getTestsOrTags() const = 0;
+ virtual RunTests::InWhatOrder runOrder() const = 0;
+ virtual unsigned int rngSeed() const = 0;
+ virtual UseColour::YesOrNo useColour() const = 0;
+ virtual std::vector<std::string> const& getSectionsToRun() const = 0;
+ virtual Verbosity verbosity() const = 0;
+
+ virtual bool benchmarkNoAnalysis() const = 0;
+ virtual int benchmarkSamples() const = 0;
+ virtual double benchmarkConfidenceInterval() const = 0;
+ virtual unsigned int benchmarkResamples() const = 0;
+ };
+
+ using IConfigPtr = std::shared_ptr<IConfig const>;
+}
+
+// end catch_interfaces_config.h
+// start catch_random_number_generator.h
+
+#include <cstdint>
+
+namespace Catch {
+
+ // This is a simple implementation of C++11 Uniform Random Number
+ // Generator. It does not provide all operators, because Catch2
+ // does not use it, but it should behave as expected inside stdlib's
+ // distributions.
+ // The implementation is based on the PCG family (http://pcg-random.org)
+ class SimplePcg32 {
+ using state_type = std::uint64_t;
+ public:
+ using result_type = std::uint32_t;
+ static constexpr result_type (min)() {
+ return 0;
+ }
+ static constexpr result_type (max)() {
+ return static_cast<result_type>(-1);
+ }
+
+ // Provide some default initial state for the default constructor
+ SimplePcg32():SimplePcg32(0xed743cc4U) {}
+
+ explicit SimplePcg32(result_type seed_);
+
+ void seed(result_type seed_);
+ void discard(uint64_t skip);
+
+ result_type operator()();
+
+ private:
+ friend bool operator==(SimplePcg32 const& lhs, SimplePcg32 const& rhs);
+ friend bool operator!=(SimplePcg32 const& lhs, SimplePcg32 const& rhs);
+
+ // In theory we also need operator<< and operator>>
+ // In practice we do not use them, so we will skip them for now
+
+ std::uint64_t m_state;
+ // This part of the state determines which "stream" of the numbers
+ // is chosen -- we take it as a constant for Catch2, so we only
+ // need to deal with seeding the main state.
+ // Picked by reading 8 bytes from `/dev/random` :-)
+ static const std::uint64_t s_inc = (0x13ed0cc53f939476ULL << 1ULL) | 1ULL;
};
} // end namespace Catch
+// end catch_random_number_generator.h
+#include <random>
+
+namespace Catch {
+namespace Generators {
+
+template <typename Float>
+class RandomFloatingGenerator final : public IGenerator<Float> {
+ Catch::SimplePcg32& m_rng;
+ std::uniform_real_distribution<Float> m_dist;
+ Float m_current_number;
+public:
+
+ RandomFloatingGenerator(Float a, Float b):
+ m_rng(rng()),
+ m_dist(a, b) {
+ static_cast<void>(next());
+ }
+
+ Float const& get() const override {
+ return m_current_number;
+ }
+ bool next() override {
+ m_current_number = m_dist(m_rng);
+ return true;
+ }
+};
+
+template <typename Integer>
+class RandomIntegerGenerator final : public IGenerator<Integer> {
+ Catch::SimplePcg32& m_rng;
+ std::uniform_int_distribution<Integer> m_dist;
+ Integer m_current_number;
+public:
+
+ RandomIntegerGenerator(Integer a, Integer b):
+ m_rng(rng()),
+ m_dist(a, b) {
+ static_cast<void>(next());
+ }
+
+ Integer const& get() const override {
+ return m_current_number;
+ }
+ bool next() override {
+ m_current_number = m_dist(m_rng);
+ return true;
+ }
+};
+
+// TODO: Ideally this would be also constrained against the various char types,
+// but I don't expect users to run into that in practice.
+template <typename T>
+typename std::enable_if<std::is_integral<T>::value && !std::is_same<T, bool>::value,
+GeneratorWrapper<T>>::type
+random(T a, T b) {
+ return GeneratorWrapper<T>(
+ pf::make_unique<RandomIntegerGenerator<T>>(a, b)
+ );
+}
+
+template <typename T>
+typename std::enable_if<std::is_floating_point<T>::value,
+GeneratorWrapper<T>>::type
+random(T a, T b) {
+ return GeneratorWrapper<T>(
+ pf::make_unique<RandomFloatingGenerator<T>>(a, b)
+ );
+}
+
+template <typename T>
+class RangeGenerator final : public IGenerator<T> {
+ T m_current;
+ T m_end;
+ T m_step;
+ bool m_positive;
+
+public:
+ RangeGenerator(T const& start, T const& end, T const& step):
+ m_current(start),
+ m_end(end),
+ m_step(step),
+ m_positive(m_step > T(0))
+ {
+ assert(m_current != m_end && "Range start and end cannot be equal");
+ assert(m_step != T(0) && "Step size cannot be zero");
+ assert(((m_positive && m_current <= m_end) || (!m_positive && m_current >= m_end)) && "Step moves away from end");
+ }
+
+ RangeGenerator(T const& start, T const& end):
+ RangeGenerator(start, end, (start < end) ? T(1) : T(-1))
+ {}
+
+ T const& get() const override {
+ return m_current;
+ }
+
+ bool next() override {
+ m_current += m_step;
+ return (m_positive) ? (m_current < m_end) : (m_current > m_end);
+ }
+};
+
+template <typename T>
+GeneratorWrapper<T> range(T const& start, T const& end, T const& step) {
+ static_assert(std::is_integral<T>::value && !std::is_same<T, bool>::value, "Type must be an integer");
+ return GeneratorWrapper<T>(pf::make_unique<RangeGenerator<T>>(start, end, step));
+}
+
+template <typename T>
+GeneratorWrapper<T> range(T const& start, T const& end) {
+ static_assert(std::is_integral<T>::value && !std::is_same<T, bool>::value, "Type must be an integer");
+ return GeneratorWrapper<T>(pf::make_unique<RangeGenerator<T>>(start, end));
+}
+
+template <typename T>
+class IteratorGenerator final : public IGenerator<T> {
+ static_assert(!std::is_same<T, bool>::value,
+ "IteratorGenerator currently does not support bools"
+ "because of std::vector<bool> specialization");
+
+ std::vector<T> m_elems;
+ size_t m_current = 0;
+public:
+ template <typename InputIterator, typename InputSentinel>
+ IteratorGenerator(InputIterator first, InputSentinel last):m_elems(first, last) {
+ if (m_elems.empty()) {
+ Catch::throw_exception(GeneratorException("IteratorGenerator received no valid values"));
+ }
+ }
+
+ T const& get() const override {
+ return m_elems[m_current];
+ }
+
+ bool next() override {
+ ++m_current;
+ return m_current != m_elems.size();
+ }
+};
+
+template <typename InputIterator,
+ typename InputSentinel,
+ typename ResultType = typename std::iterator_traits<InputIterator>::value_type>
+GeneratorWrapper<ResultType> from_range(InputIterator from, InputSentinel to) {
+ return GeneratorWrapper<ResultType>(pf::make_unique<IteratorGenerator<ResultType>>(from, to));
+}
+
+template <typename Container,
+ typename ResultType = typename Container::value_type>
+GeneratorWrapper<ResultType> from_range(Container const& cnt) {
+ return GeneratorWrapper<ResultType>(pf::make_unique<IteratorGenerator<ResultType>>(cnt.begin(), cnt.end()));
+}
+
+} // namespace Generators
+} // namespace Catch
+
+// end catch_generators_specific.hpp
+
// These files are included here so the single_include script doesn't put them
// in the conditionally compiled sections
-// #included from: internal/catch_test_case_info.h
-#define TWOBLUECUBES_CATCH_TEST_CASE_INFO_H_INCLUDED
+// start catch_test_case_info.h
#include <string>
-#include <set>
+#include <vector>
+#include <memory>
#ifdef __clang__
#pragma clang diagnostic push
@@ -3216,7 +4737,7 @@ namespace Catch {
namespace Catch {
- struct ITestCase;
+ struct ITestInvoker;
struct TestCaseInfo {
enum SpecialProperties{
@@ -3225,30 +4746,30 @@ namespace Catch {
ShouldFail = 1 << 2,
MayFail = 1 << 3,
Throws = 1 << 4,
- NonPortable = 1 << 5
+ NonPortable = 1 << 5,
+ Benchmark = 1 << 6
};
TestCaseInfo( std::string const& _name,
std::string const& _className,
std::string const& _description,
- std::set<std::string> const& _tags,
+ std::vector<std::string> const& _tags,
SourceLineInfo const& _lineInfo );
- TestCaseInfo( TestCaseInfo const& other );
-
- friend void setTags( TestCaseInfo& testCaseInfo, std::set<std::string> const& tags );
+ friend void setTags( TestCaseInfo& testCaseInfo, std::vector<std::string> tags );
bool isHidden() const;
bool throws() const;
bool okToFail() const;
bool expectedToFail() const;
+ std::string tagsAsString() const;
+
std::string name;
std::string className;
std::string description;
- std::set<std::string> tags;
- std::set<std::string> lcaseTags;
- std::string tagsAsString;
+ std::vector<std::string> tags;
+ std::vector<std::string> lcaseTags;
SourceLineInfo lineInfo;
SpecialProperties properties;
};
@@ -3256,8 +4777,7 @@ namespace Catch {
class TestCase : public TestCaseInfo {
public:
- TestCase( ITestCase* testCase, TestCaseInfo const& info );
- TestCase( TestCase const& other );
+ TestCase( ITestInvoker* testCase, TestCaseInfo&& info );
TestCase withName( std::string const& _newName ) const;
@@ -3265,19 +4785,16 @@ namespace Catch {
TestCaseInfo const& getTestCaseInfo() const;
- void swap( TestCase& other );
bool operator == ( TestCase const& other ) const;
bool operator < ( TestCase const& other ) const;
- TestCase& operator = ( TestCase const& other );
private:
- Ptr<ITestCase> test;
+ std::shared_ptr<ITestInvoker> test;
};
- TestCase makeTestCase( ITestCase* testCase,
+ TestCase makeTestCase( ITestInvoker* testCase,
std::string const& className,
- std::string const& name,
- std::string const& description,
+ NameAndTags const& nameAndTags,
SourceLineInfo const& lineInfo );
}
@@ -3285,10 +4802,21 @@ namespace Catch {
#pragma clang diagnostic pop
#endif
+// end catch_test_case_info.h
+// start catch_interfaces_runner.h
+
+namespace Catch {
+
+ struct IRunner {
+ virtual ~IRunner();
+ virtual bool aborting() const = 0;
+ };
+}
+
+// end catch_interfaces_runner.h
#ifdef __OBJC__
-// #included from: internal/catch_objc.hpp
-#define TWOBLUECUBES_CATCH_OBJC_HPP_INCLUDED
+// start catch_objc.hpp
#import <objc/runtime.h>
@@ -3312,7 +4840,7 @@ namespace Catch {
namespace Catch {
- class OcMethod : public SharedImpl<ITestCase> {
+ class OcMethod : public ITestInvoker {
public:
OcMethod( Class cls, SEL sel ) : m_cls( cls ), m_sel( sel ) {}
@@ -3348,9 +4876,9 @@ namespace Catch {
}
}
- inline size_t registerTestMethods() {
- size_t noTestMethods = 0;
- int noClasses = objc_getClassList( CATCH_NULL, 0 );
+ inline std::size_t registerTestMethods() {
+ std::size_t noTestMethods = 0;
+ int noClasses = objc_getClassList( nullptr, 0 );
Class* classes = (CATCH_UNSAFE_UNRETAINED Class *)malloc( sizeof(Class) * noClasses);
objc_getClassList( classes, noClasses );
@@ -3369,7 +4897,7 @@ namespace Catch {
std::string desc = Detail::getAnnotation( cls, "Description", testCaseName );
const char* className = class_getName( cls );
- getMutableRegistryHub().registerTest( makeTestCase( new OcMethod( cls, selector ), className, name.c_str(), desc.c_str(), SourceLineInfo() ) );
+ getMutableRegistryHub().registerTest( makeTestCase( new OcMethod( cls, selector ), className, NameAndTags( name.c_str(), desc.c_str() ), SourceLineInfo("",0) ) );
noTestMethods++;
}
}
@@ -3379,6 +4907,8 @@ namespace Catch {
return noTestMethods;
}
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+
namespace Matchers {
namespace Impl {
namespace NSStringMatchers {
@@ -3390,61 +4920,61 @@ namespace Catch {
arcSafeRelease( m_substr );
}
- virtual bool match( NSString* arg ) const CATCH_OVERRIDE {
+ bool match( NSString* str ) const override {
return false;
}
- NSString* m_substr;
+ NSString* CATCH_ARC_STRONG m_substr;
};
struct Equals : StringHolder {
Equals( NSString* substr ) : StringHolder( substr ){}
- virtual bool match( NSString* str ) const CATCH_OVERRIDE {
+ bool match( NSString* str ) const override {
return (str != nil || m_substr == nil ) &&
[str isEqualToString:m_substr];
}
- virtual std::string describe() const CATCH_OVERRIDE {
- return "equals string: " + Catch::toString( m_substr );
+ std::string describe() const override {
+ return "equals string: " + Catch::Detail::stringify( m_substr );
}
};
struct Contains : StringHolder {
Contains( NSString* substr ) : StringHolder( substr ){}
- virtual bool match( NSString* str ) const {
+ bool match( NSString* str ) const override {
return (str != nil || m_substr == nil ) &&
[str rangeOfString:m_substr].location != NSNotFound;
}
- virtual std::string describe() const CATCH_OVERRIDE {
- return "contains string: " + Catch::toString( m_substr );
+ std::string describe() const override {
+ return "contains string: " + Catch::Detail::stringify( m_substr );
}
};
struct StartsWith : StringHolder {
StartsWith( NSString* substr ) : StringHolder( substr ){}
- virtual bool match( NSString* str ) const {
+ bool match( NSString* str ) const override {
return (str != nil || m_substr == nil ) &&
[str rangeOfString:m_substr].location == 0;
}
- virtual std::string describe() const CATCH_OVERRIDE {
- return "starts with: " + Catch::toString( m_substr );
+ std::string describe() const override {
+ return "starts with: " + Catch::Detail::stringify( m_substr );
}
};
struct EndsWith : StringHolder {
EndsWith( NSString* substr ) : StringHolder( substr ){}
- virtual bool match( NSString* str ) const {
+ bool match( NSString* str ) const override {
return (str != nil || m_substr == nil ) &&
[str rangeOfString:m_substr].location == [str length] - [m_substr length];
}
- virtual std::string describe() const CATCH_OVERRIDE {
- return "ends with: " + Catch::toString( m_substr );
+ std::string describe() const override {
+ return "ends with: " + Catch::Detail::stringify( m_substr );
}
};
@@ -3467,86 +4997,53 @@ namespace Catch {
using namespace Matchers;
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+
} // namespace Catch
///////////////////////////////////////////////////////////////////////////////
-#define OC_TEST_CASE( name, desc )\
-+(NSString*) INTERNAL_CATCH_UNIQUE_NAME( Catch_Name_test ) \
-{\
+#define OC_MAKE_UNIQUE_NAME( root, uniqueSuffix ) root##uniqueSuffix
+#define OC_TEST_CASE2( name, desc, uniqueSuffix ) \
++(NSString*) OC_MAKE_UNIQUE_NAME( Catch_Name_test_, uniqueSuffix ) \
+{ \
return @ name; \
-}\
-+(NSString*) INTERNAL_CATCH_UNIQUE_NAME( Catch_Description_test ) \
+} \
++(NSString*) OC_MAKE_UNIQUE_NAME( Catch_Description_test_, uniqueSuffix ) \
{ \
return @ desc; \
} \
--(void) INTERNAL_CATCH_UNIQUE_NAME( Catch_TestCase_test )
+-(void) OC_MAKE_UNIQUE_NAME( Catch_TestCase_test_, uniqueSuffix )
-#endif
+#define OC_TEST_CASE( name, desc ) OC_TEST_CASE2( name, desc, __LINE__ )
-#ifdef CATCH_IMPL
-
-// !TBD: Move the leak detector code into a separate header
-#ifdef CATCH_CONFIG_WINDOWS_CRTDBG
-#include <crtdbg.h>
-class LeakDetector {
-public:
- LeakDetector() {
- int flag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG);
- flag |= _CRTDBG_LEAK_CHECK_DF;
- flag |= _CRTDBG_ALLOC_MEM_DF;
- _CrtSetDbgFlag(flag);
- _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG);
- _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDERR);
- // Change this to leaking allocation's number to break there
- _CrtSetBreakAlloc(-1);
- }
-};
-#else
-class LeakDetector {};
+// end catch_objc.hpp
#endif
-LeakDetector leakDetector;
-
-// #included from: internal/catch_impl.hpp
-#define TWOBLUECUBES_CATCH_IMPL_HPP_INCLUDED
-
-// Collect all the implementation files together here
-// These are the equivalent of what would usually be cpp files
-
-#ifdef __clang__
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wweak-vtables"
-#endif
+// Benchmarking needs the externally-facing parts of reporters to work
+#if defined(CATCH_CONFIG_EXTERNAL_INTERFACES) || defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+// start catch_external_interfaces.h
-// #included from: ../catch_session.hpp
-#define TWOBLUECUBES_CATCH_RUNNER_HPP_INCLUDED
+// start catch_reporter_bases.hpp
-// #included from: internal/catch_commandline.hpp
-#define TWOBLUECUBES_CATCH_COMMANDLINE_HPP_INCLUDED
+// start catch_interfaces_reporter.h
-// #included from: catch_config.hpp
-#define TWOBLUECUBES_CATCH_CONFIG_HPP_INCLUDED
+// start catch_config.hpp
-// #included from: catch_test_spec_parser.hpp
-#define TWOBLUECUBES_CATCH_TEST_SPEC_PARSER_HPP_INCLUDED
+// start catch_test_spec_parser.h
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpadded"
#endif
-// #included from: catch_test_spec.hpp
-#define TWOBLUECUBES_CATCH_TEST_SPEC_HPP_INCLUDED
+// start catch_test_spec.h
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpadded"
#endif
-// #included from: catch_wildcard_pattern.hpp
-#define TWOBLUECUBES_CATCH_WILDCARD_PATTERN_HPP_INCLUDED
-
-#include <stdexcept>
+// start catch_wildcard_pattern.h
namespace Catch
{
@@ -3560,123 +5057,86 @@ namespace Catch
public:
- WildcardPattern( std::string const& pattern, CaseSensitive::Choice caseSensitivity )
- : m_caseSensitivity( caseSensitivity ),
- m_wildcard( NoWildcard ),
- m_pattern( adjustCase( pattern ) )
- {
- if( startsWith( m_pattern, '*' ) ) {
- m_pattern = m_pattern.substr( 1 );
- m_wildcard = WildcardAtStart;
- }
- if( endsWith( m_pattern, '*' ) ) {
- m_pattern = m_pattern.substr( 0, m_pattern.size()-1 );
- m_wildcard = static_cast<WildcardPosition>( m_wildcard | WildcardAtEnd );
- }
- }
- virtual ~WildcardPattern();
- virtual bool matches( std::string const& str ) const {
- switch( m_wildcard ) {
- case NoWildcard:
- return m_pattern == adjustCase( str );
- case WildcardAtStart:
- return endsWith( adjustCase( str ), m_pattern );
- case WildcardAtEnd:
- return startsWith( adjustCase( str ), m_pattern );
- case WildcardAtBothEnds:
- return contains( adjustCase( str ), m_pattern );
- }
+ WildcardPattern( std::string const& pattern, CaseSensitive::Choice caseSensitivity );
+ virtual ~WildcardPattern() = default;
+ virtual bool matches( std::string const& str ) const;
-#ifdef __clang__
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wunreachable-code"
-#endif
- throw std::logic_error( "Unknown enum" );
-#ifdef __clang__
-#pragma clang diagnostic pop
-#endif
- }
private:
- std::string adjustCase( std::string const& str ) const {
- return m_caseSensitivity == CaseSensitive::No ? toLower( str ) : str;
- }
+ std::string normaliseString( std::string const& str ) const;
CaseSensitive::Choice m_caseSensitivity;
- WildcardPosition m_wildcard;
+ WildcardPosition m_wildcard = NoWildcard;
std::string m_pattern;
};
}
+// end catch_wildcard_pattern.h
#include <string>
#include <vector>
+#include <memory>
namespace Catch {
+ struct IConfig;
+
class TestSpec {
- struct Pattern : SharedImpl<> {
+ class Pattern {
+ public:
+ explicit Pattern( std::string const& name );
virtual ~Pattern();
virtual bool matches( TestCaseInfo const& testCase ) const = 0;
+ std::string const& name() const;
+ private:
+ std::string const m_name;
};
+ using PatternPtr = std::shared_ptr<Pattern>;
+
class NamePattern : public Pattern {
public:
- NamePattern( std::string const& name )
- : m_wildcardPattern( toLower( name ), CaseSensitive::No )
- {}
- virtual ~NamePattern();
- virtual bool matches( TestCaseInfo const& testCase ) const {
- return m_wildcardPattern.matches( toLower( testCase.name ) );
- }
+ explicit NamePattern( std::string const& name, std::string const& filterString );
+ bool matches( TestCaseInfo const& testCase ) const override;
private:
WildcardPattern m_wildcardPattern;
};
class TagPattern : public Pattern {
public:
- TagPattern( std::string const& tag ) : m_tag( toLower( tag ) ) {}
- virtual ~TagPattern();
- virtual bool matches( TestCaseInfo const& testCase ) const {
- return testCase.lcaseTags.find( m_tag ) != testCase.lcaseTags.end();
- }
+ explicit TagPattern( std::string const& tag, std::string const& filterString );
+ bool matches( TestCaseInfo const& testCase ) const override;
private:
std::string m_tag;
};
class ExcludedPattern : public Pattern {
public:
- ExcludedPattern( Ptr<Pattern> const& underlyingPattern ) : m_underlyingPattern( underlyingPattern ) {}
- virtual ~ExcludedPattern();
- virtual bool matches( TestCaseInfo const& testCase ) const { return !m_underlyingPattern->matches( testCase ); }
+ explicit ExcludedPattern( PatternPtr const& underlyingPattern );
+ bool matches( TestCaseInfo const& testCase ) const override;
private:
- Ptr<Pattern> m_underlyingPattern;
+ PatternPtr m_underlyingPattern;
};
struct Filter {
- std::vector<Ptr<Pattern> > m_patterns;
+ std::vector<PatternPtr> m_patterns;
- bool matches( TestCaseInfo const& testCase ) const {
- // All patterns in a filter must match for the filter to be a match
- for( std::vector<Ptr<Pattern> >::const_iterator it = m_patterns.begin(), itEnd = m_patterns.end(); it != itEnd; ++it ) {
- if( !(*it)->matches( testCase ) )
- return false;
- }
- return true;
- }
+ bool matches( TestCaseInfo const& testCase ) const;
+ std::string name() const;
};
public:
- bool hasFilters() const {
- return !m_filters.empty();
- }
- bool matches( TestCaseInfo const& testCase ) const {
- // A TestSpec matches if any filter matches
- for( std::vector<Filter>::const_iterator it = m_filters.begin(), itEnd = m_filters.end(); it != itEnd; ++it )
- if( it->matches( testCase ) )
- return true;
- return false;
- }
+ struct FilterMatch {
+ std::string name;
+ std::vector<TestCase const*> tests;
+ };
+ using Matches = std::vector<FilterMatch>;
+ using vectorStrings = std::vector<std::string>;
+
+ bool hasFilters() const;
+ bool matches( TestCaseInfo const& testCase ) const;
+ Matches matchesByFilter( std::vector<TestCase> const& testCases, IConfig const& config ) const;
+ const vectorStrings & getInvalidArgs() const;
private:
std::vector<Filter> m_filters;
-
+ std::vector<std::string> m_invalidArgs;
friend class TestSpecParser;
};
}
@@ -3685,112 +5145,94 @@ namespace Catch {
#pragma clang diagnostic pop
#endif
+// end catch_test_spec.h
+// start catch_interfaces_tag_alias_registry.h
+
+#include <string>
+
+namespace Catch {
+
+ struct TagAlias;
+
+ struct ITagAliasRegistry {
+ virtual ~ITagAliasRegistry();
+ // Nullptr if not present
+ virtual TagAlias const* find( std::string const& alias ) const = 0;
+ virtual std::string expandAliases( std::string const& unexpandedTestSpec ) const = 0;
+
+ static ITagAliasRegistry const& get();
+ };
+
+} // end namespace Catch
+
+// end catch_interfaces_tag_alias_registry.h
namespace Catch {
class TestSpecParser {
enum Mode{ None, Name, QuotedName, Tag, EscapedName };
- Mode m_mode;
- bool m_exclusion;
- std::size_t m_start, m_pos;
+ Mode m_mode = None;
+ Mode lastMode = None;
+ bool m_exclusion = false;
+ std::size_t m_pos = 0;
+ std::size_t m_realPatternPos = 0;
std::string m_arg;
+ std::string m_substring;
+ std::string m_patternName;
std::vector<std::size_t> m_escapeChars;
TestSpec::Filter m_currentFilter;
TestSpec m_testSpec;
- ITagAliasRegistry const* m_tagAliases;
+ ITagAliasRegistry const* m_tagAliases = nullptr;
public:
- TestSpecParser( ITagAliasRegistry const& tagAliases ) : m_tagAliases( &tagAliases ) {}
+ TestSpecParser( ITagAliasRegistry const& tagAliases );
+
+ TestSpecParser& parse( std::string const& arg );
+ TestSpec testSpec();
- TestSpecParser& parse( std::string const& arg ) {
- m_mode = None;
- m_exclusion = false;
- m_start = std::string::npos;
- m_arg = m_tagAliases->expandAliases( arg );
- m_escapeChars.clear();
- for( m_pos = 0; m_pos < m_arg.size(); ++m_pos )
- visitChar( m_arg[m_pos] );
- if( m_mode == Name )
- addPattern<TestSpec::NamePattern>();
- return *this;
- }
- TestSpec testSpec() {
- addFilter();
- return m_testSpec;
- }
private:
- void visitChar( char c ) {
- if( m_mode == None ) {
- switch( c ) {
- case ' ': return;
- case '~': m_exclusion = true; return;
- case '[': return startNewMode( Tag, ++m_pos );
- case '"': return startNewMode( QuotedName, ++m_pos );
- case '\\': return escape();
- default: startNewMode( Name, m_pos ); break;
- }
- }
- if( m_mode == Name ) {
- if( c == ',' ) {
- addPattern<TestSpec::NamePattern>();
- addFilter();
- }
- else if( c == '[' ) {
- if( subString() == "exclude:" )
- m_exclusion = true;
- else
- addPattern<TestSpec::NamePattern>();
- startNewMode( Tag, ++m_pos );
- }
- else if( c == '\\' )
- escape();
- }
- else if( m_mode == EscapedName )
- m_mode = Name;
- else if( m_mode == QuotedName && c == '"' )
- addPattern<TestSpec::NamePattern>();
- else if( m_mode == Tag && c == ']' )
- addPattern<TestSpec::TagPattern>();
- }
- void startNewMode( Mode mode, std::size_t start ) {
- m_mode = mode;
- m_start = start;
- }
- void escape() {
- if( m_mode == None )
- m_start = m_pos;
- m_mode = EscapedName;
- m_escapeChars.push_back( m_pos );
- }
- std::string subString() const { return m_arg.substr( m_start, m_pos - m_start ); }
+ bool visitChar( char c );
+ void startNewMode( Mode mode );
+ bool processNoneChar( char c );
+ void processNameChar( char c );
+ bool processOtherChar( char c );
+ void endMode();
+ void escape();
+ bool isControlChar( char c ) const;
+ void saveLastMode();
+ void revertBackToLastMode();
+ void addFilter();
+ bool separate();
+
template<typename T>
void addPattern() {
- std::string token = subString();
- for( size_t i = 0; i < m_escapeChars.size(); ++i )
- token = token.substr( 0, m_escapeChars[i]-m_start-i ) + token.substr( m_escapeChars[i]-m_start-i+1 );
+ std::string token = m_patternName;
+ for( std::size_t i = 0; i < m_escapeChars.size(); ++i )
+ token = token.substr( 0, m_escapeChars[i] - i ) + token.substr( m_escapeChars[i] -i +1 );
m_escapeChars.clear();
if( startsWith( token, "exclude:" ) ) {
m_exclusion = true;
token = token.substr( 8 );
}
if( !token.empty() ) {
- Ptr<TestSpec::Pattern> pattern = new T( token );
+ TestSpec::PatternPtr pattern = std::make_shared<T>( token, m_substring );
if( m_exclusion )
- pattern = new TestSpec::ExcludedPattern( pattern );
+ pattern = std::make_shared<TestSpec::ExcludedPattern>( pattern );
m_currentFilter.m_patterns.push_back( pattern );
}
+ m_substring.clear();
+ m_patternName.clear();
m_exclusion = false;
m_mode = None;
}
- void addFilter() {
- if( !m_currentFilter.m_patterns.empty() ) {
- m_testSpec.m_filters.push_back( m_currentFilter );
- m_currentFilter = TestSpec::Filter();
- }
+
+ inline void addCharToPattern(char c) {
+ m_substring += c;
+ m_patternName += c;
+ m_realPatternPos++;
}
+
};
- inline TestSpec parseTestSpec( std::string const& arg ) {
- return TestSpecParser( ITagAliasRegistry::get() ).parse( arg ).testSpec();
- }
+ TestSpec parseTestSpec( std::string const& arg );
} // namespace Catch
@@ -3798,2171 +5240,2121 @@ namespace Catch {
#pragma clang diagnostic pop
#endif
-// #included from: catch_interfaces_config.h
-#define TWOBLUECUBES_CATCH_INTERFACES_CONFIG_H_INCLUDED
+// end catch_test_spec_parser.h
+// Libstdc++ doesn't like incomplete classes for unique_ptr
-#include <iosfwd>
-#include <string>
+#include <memory>
#include <vector>
+#include <string>
+
+#ifndef CATCH_CONFIG_CONSOLE_WIDTH
+#define CATCH_CONFIG_CONSOLE_WIDTH 80
+#endif
namespace Catch {
- struct Verbosity { enum Level {
- NoOutput = 0,
- Quiet,
- Normal
- }; };
+ struct IStream;
- struct WarnAbout { enum What {
- Nothing = 0x00,
- NoAssertions = 0x01
- }; };
+ struct ConfigData {
+ bool listTests = false;
+ bool listTags = false;
+ bool listReporters = false;
+ bool listTestNamesOnly = false;
+
+ bool showSuccessfulTests = false;
+ bool shouldDebugBreak = false;
+ bool noThrow = false;
+ bool showHelp = false;
+ bool showInvisibles = false;
+ bool filenamesAsTags = false;
+ bool libIdentify = false;
+
+ int abortAfter = -1;
+ unsigned int rngSeed = 0;
+
+ bool benchmarkNoAnalysis = false;
+ unsigned int benchmarkSamples = 100;
+ double benchmarkConfidenceInterval = 0.95;
+ unsigned int benchmarkResamples = 100000;
+
+ Verbosity verbosity = Verbosity::Normal;
+ WarnAbout::What warnings = WarnAbout::Nothing;
+ ShowDurations::OrNot showDurations = ShowDurations::DefaultForReporter;
+ RunTests::InWhatOrder runOrder = RunTests::InDeclarationOrder;
+ UseColour::YesOrNo useColour = UseColour::Auto;
+ WaitForKeypress::When waitForKeypress = WaitForKeypress::Never;
- struct ShowDurations { enum OrNot {
- DefaultForReporter,
- Always,
- Never
- }; };
- struct RunTests { enum InWhatOrder {
- InDeclarationOrder,
- InLexicographicalOrder,
- InRandomOrder
- }; };
- struct UseColour { enum YesOrNo {
- Auto,
- Yes,
- No
- }; };
+ std::string outputFilename;
+ std::string name;
+ std::string processName;
+#ifndef CATCH_CONFIG_DEFAULT_REPORTER
+#define CATCH_CONFIG_DEFAULT_REPORTER "console"
+#endif
+ std::string reporterName = CATCH_CONFIG_DEFAULT_REPORTER;
+#undef CATCH_CONFIG_DEFAULT_REPORTER
- class TestSpec;
+ std::vector<std::string> testsOrTags;
+ std::vector<std::string> sectionsToRun;
+ };
- struct IConfig : IShared {
+ class Config : public IConfig {
+ public:
- virtual ~IConfig();
+ Config() = default;
+ Config( ConfigData const& data );
+ virtual ~Config() = default;
- virtual bool allowThrows() const = 0;
- virtual std::ostream& stream() const = 0;
- virtual std::string name() const = 0;
- virtual bool includeSuccessfulResults() const = 0;
- virtual bool shouldDebugBreak() const = 0;
- virtual bool warnAboutMissingAssertions() const = 0;
- virtual int abortAfter() const = 0;
- virtual bool showInvisibles() const = 0;
- virtual ShowDurations::OrNot showDurations() const = 0;
- virtual TestSpec const& testSpec() const = 0;
- virtual RunTests::InWhatOrder runOrder() const = 0;
- virtual unsigned int rngSeed() const = 0;
- virtual UseColour::YesOrNo useColour() const = 0;
- virtual std::vector<std::string> const& getSectionsToRun() const = 0;
+ std::string const& getFilename() const;
+
+ bool listTests() const;
+ bool listTestNamesOnly() const;
+ bool listTags() const;
+ bool listReporters() const;
+
+ std::string getProcessName() const;
+ std::string const& getReporterName() const;
+
+ std::vector<std::string> const& getTestsOrTags() const override;
+ std::vector<std::string> const& getSectionsToRun() const override;
+ TestSpec const& testSpec() const override;
+ bool hasTestFilters() const override;
+
+ bool showHelp() const;
+
+ // IConfig interface
+ bool allowThrows() const override;
+ std::ostream& stream() const override;
+ std::string name() const override;
+ bool includeSuccessfulResults() const override;
+ bool warnAboutMissingAssertions() const override;
+ bool warnAboutNoTests() const override;
+ ShowDurations::OrNot showDurations() const override;
+ RunTests::InWhatOrder runOrder() const override;
+ unsigned int rngSeed() const override;
+ UseColour::YesOrNo useColour() const override;
+ bool shouldDebugBreak() const override;
+ int abortAfter() const override;
+ bool showInvisibles() const override;
+ Verbosity verbosity() const override;
+ bool benchmarkNoAnalysis() const override;
+ int benchmarkSamples() const override;
+ double benchmarkConfidenceInterval() const override;
+ unsigned int benchmarkResamples() const override;
+
+ private:
+
+ IStream const* openStream();
+ ConfigData m_data;
+
+ std::unique_ptr<IStream const> m_stream;
+ TestSpec m_testSpec;
+ bool m_hasTestFilters = false;
};
-}
-// #included from: catch_stream.h
-#define TWOBLUECUBES_CATCH_STREAM_H_INCLUDED
+} // end namespace Catch
-// #included from: catch_streambuf.h
-#define TWOBLUECUBES_CATCH_STREAMBUF_H_INCLUDED
+// end catch_config.hpp
+// start catch_assertionresult.h
-#include <streambuf>
+#include <string>
namespace Catch {
- class StreamBufBase : public std::streambuf {
+ struct AssertionResultData
+ {
+ AssertionResultData() = delete;
+
+ AssertionResultData( ResultWas::OfType _resultType, LazyExpression const& _lazyExpression );
+
+ std::string message;
+ mutable std::string reconstructedExpression;
+ LazyExpression lazyExpression;
+ ResultWas::OfType resultType;
+
+ std::string reconstructExpression() const;
+ };
+
+ class AssertionResult {
public:
- virtual ~StreamBufBase() CATCH_NOEXCEPT;
+ AssertionResult() = delete;
+ AssertionResult( AssertionInfo const& info, AssertionResultData const& data );
+
+ bool isOk() const;
+ bool succeeded() const;
+ ResultWas::OfType getResultType() const;
+ bool hasExpression() const;
+ bool hasMessage() const;
+ std::string getExpression() const;
+ std::string getExpressionInMacro() const;
+ bool hasExpandedExpression() const;
+ std::string getExpandedExpression() const;
+ std::string getMessage() const;
+ SourceLineInfo getSourceInfo() const;
+ StringRef getTestMacroName() const;
+
+ //protected:
+ AssertionInfo m_info;
+ AssertionResultData m_resultData;
};
-}
-#include <streambuf>
-#include <ostream>
-#include <fstream>
+} // end namespace Catch
+
+// end catch_assertionresult.h
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+// start catch_estimate.hpp
+
+ // Statistics estimates
+
+
+namespace Catch {
+ namespace Benchmark {
+ template <typename Duration>
+ struct Estimate {
+ Duration point;
+ Duration lower_bound;
+ Duration upper_bound;
+ double confidence_interval;
+
+ template <typename Duration2>
+ operator Estimate<Duration2>() const {
+ return { point, lower_bound, upper_bound, confidence_interval };
+ }
+ };
+ } // namespace Benchmark
+} // namespace Catch
+
+// end catch_estimate.hpp
+// start catch_outlier_classification.hpp
+
+// Outlier information
+
+namespace Catch {
+ namespace Benchmark {
+ struct OutlierClassification {
+ int samples_seen = 0;
+ int low_severe = 0; // more than 3 times IQR below Q1
+ int low_mild = 0; // 1.5 to 3 times IQR below Q1
+ int high_mild = 0; // 1.5 to 3 times IQR above Q3
+ int high_severe = 0; // more than 3 times IQR above Q3
+
+ int total() const {
+ return low_severe + low_mild + high_mild + high_severe;
+ }
+ };
+ } // namespace Benchmark
+} // namespace Catch
+
+// end catch_outlier_classification.hpp
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
+
+#include <string>
+#include <iosfwd>
+#include <map>
+#include <set>
#include <memory>
+#include <algorithm>
namespace Catch {
- std::ostream& cout();
- std::ostream& cerr();
+ struct ReporterConfig {
+ explicit ReporterConfig( IConfigPtr const& _fullConfig );
- struct IStream {
- virtual ~IStream() CATCH_NOEXCEPT;
- virtual std::ostream& stream() const = 0;
+ ReporterConfig( IConfigPtr const& _fullConfig, std::ostream& _stream );
+
+ std::ostream& stream() const;
+ IConfigPtr fullConfig() const;
+
+ private:
+ std::ostream* m_stream;
+ IConfigPtr m_fullConfig;
};
- class FileStream : public IStream {
- mutable std::ofstream m_ofs;
- public:
- FileStream( std::string const& filename );
- virtual ~FileStream() CATCH_NOEXCEPT;
- public: // IStream
- virtual std::ostream& stream() const CATCH_OVERRIDE;
+ struct ReporterPreferences {
+ bool shouldRedirectStdOut = false;
+ bool shouldReportAllAssertions = false;
};
- class CoutStream : public IStream {
- mutable std::ostream m_os;
- public:
- CoutStream();
- virtual ~CoutStream() CATCH_NOEXCEPT;
+ template<typename T>
+ struct LazyStat : Option<T> {
+ LazyStat& operator=( T const& _value ) {
+ Option<T>::operator=( _value );
+ used = false;
+ return *this;
+ }
+ void reset() {
+ Option<T>::reset();
+ used = false;
+ }
+ bool used = false;
+ };
- public: // IStream
- virtual std::ostream& stream() const CATCH_OVERRIDE;
+ struct TestRunInfo {
+ TestRunInfo( std::string const& _name );
+ std::string name;
};
+ struct GroupInfo {
+ GroupInfo( std::string const& _name,
+ std::size_t _groupIndex,
+ std::size_t _groupsCount );
- class DebugOutStream : public IStream {
- CATCH_AUTO_PTR( StreamBufBase ) m_streamBuf;
- mutable std::ostream m_os;
- public:
- DebugOutStream();
- virtual ~DebugOutStream() CATCH_NOEXCEPT;
+ std::string name;
+ std::size_t groupIndex;
+ std::size_t groupsCounts;
+ };
- public: // IStream
- virtual std::ostream& stream() const CATCH_OVERRIDE;
+ struct AssertionStats {
+ AssertionStats( AssertionResult const& _assertionResult,
+ std::vector<MessageInfo> const& _infoMessages,
+ Totals const& _totals );
+
+ AssertionStats( AssertionStats const& ) = default;
+ AssertionStats( AssertionStats && ) = default;
+ AssertionStats& operator = ( AssertionStats const& ) = delete;
+ AssertionStats& operator = ( AssertionStats && ) = delete;
+ virtual ~AssertionStats();
+
+ AssertionResult assertionResult;
+ std::vector<MessageInfo> infoMessages;
+ Totals totals;
};
-}
-#include <memory>
-#include <vector>
-#include <string>
-#include <stdexcept>
+ struct SectionStats {
+ SectionStats( SectionInfo const& _sectionInfo,
+ Counts const& _assertions,
+ double _durationInSeconds,
+ bool _missingAssertions );
+ SectionStats( SectionStats const& ) = default;
+ SectionStats( SectionStats && ) = default;
+ SectionStats& operator = ( SectionStats const& ) = default;
+ SectionStats& operator = ( SectionStats && ) = default;
+ virtual ~SectionStats();
-#ifndef CATCH_CONFIG_CONSOLE_WIDTH
-#define CATCH_CONFIG_CONSOLE_WIDTH 80
-#endif
+ SectionInfo sectionInfo;
+ Counts assertions;
+ double durationInSeconds;
+ bool missingAssertions;
+ };
-namespace Catch {
+ struct TestCaseStats {
+ TestCaseStats( TestCaseInfo const& _testInfo,
+ Totals const& _totals,
+ std::string const& _stdOut,
+ std::string const& _stdErr,
+ bool _aborting );
- struct ConfigData {
+ TestCaseStats( TestCaseStats const& ) = default;
+ TestCaseStats( TestCaseStats && ) = default;
+ TestCaseStats& operator = ( TestCaseStats const& ) = default;
+ TestCaseStats& operator = ( TestCaseStats && ) = default;
+ virtual ~TestCaseStats();
- ConfigData()
- : listTests( false ),
- listTags( false ),
- listReporters( false ),
- listTestNamesOnly( false ),
- showSuccessfulTests( false ),
- shouldDebugBreak( false ),
- noThrow( false ),
- showHelp( false ),
- showInvisibles( false ),
- filenamesAsTags( false ),
- abortAfter( -1 ),
- rngSeed( 0 ),
- verbosity( Verbosity::Normal ),
- warnings( WarnAbout::Nothing ),
- showDurations( ShowDurations::DefaultForReporter ),
- runOrder( RunTests::InDeclarationOrder ),
- useColour( UseColour::Auto )
- {}
+ TestCaseInfo testInfo;
+ Totals totals;
+ std::string stdOut;
+ std::string stdErr;
+ bool aborting;
+ };
- bool listTests;
- bool listTags;
- bool listReporters;
- bool listTestNamesOnly;
+ struct TestGroupStats {
+ TestGroupStats( GroupInfo const& _groupInfo,
+ Totals const& _totals,
+ bool _aborting );
+ TestGroupStats( GroupInfo const& _groupInfo );
- bool showSuccessfulTests;
- bool shouldDebugBreak;
- bool noThrow;
- bool showHelp;
- bool showInvisibles;
- bool filenamesAsTags;
+ TestGroupStats( TestGroupStats const& ) = default;
+ TestGroupStats( TestGroupStats && ) = default;
+ TestGroupStats& operator = ( TestGroupStats const& ) = default;
+ TestGroupStats& operator = ( TestGroupStats && ) = default;
+ virtual ~TestGroupStats();
- int abortAfter;
- unsigned int rngSeed;
+ GroupInfo groupInfo;
+ Totals totals;
+ bool aborting;
+ };
- Verbosity::Level verbosity;
- WarnAbout::What warnings;
- ShowDurations::OrNot showDurations;
- RunTests::InWhatOrder runOrder;
- UseColour::YesOrNo useColour;
+ struct TestRunStats {
+ TestRunStats( TestRunInfo const& _runInfo,
+ Totals const& _totals,
+ bool _aborting );
- std::string outputFilename;
+ TestRunStats( TestRunStats const& ) = default;
+ TestRunStats( TestRunStats && ) = default;
+ TestRunStats& operator = ( TestRunStats const& ) = default;
+ TestRunStats& operator = ( TestRunStats && ) = default;
+ virtual ~TestRunStats();
+
+ TestRunInfo runInfo;
+ Totals totals;
+ bool aborting;
+ };
+
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ struct BenchmarkInfo {
std::string name;
- std::string processName;
+ double estimatedDuration;
+ int iterations;
+ int samples;
+ unsigned int resamples;
+ double clockResolution;
+ double clockCost;
+ };
- std::vector<std::string> reporterNames;
- std::vector<std::string> testsOrTags;
- std::vector<std::string> sectionsToRun;
+ template <class Duration>
+ struct BenchmarkStats {
+ BenchmarkInfo info;
+
+ std::vector<Duration> samples;
+ Benchmark::Estimate<Duration> mean;
+ Benchmark::Estimate<Duration> standardDeviation;
+ Benchmark::OutlierClassification outliers;
+ double outlierVariance;
+
+ template <typename Duration2>
+ operator BenchmarkStats<Duration2>() const {
+ std::vector<Duration2> samples2;
+ samples2.reserve(samples.size());
+ std::transform(samples.begin(), samples.end(), std::back_inserter(samples2), [](Duration d) { return Duration2(d); });
+ return {
+ info,
+ std::move(samples2),
+ mean,
+ standardDeviation,
+ outliers,
+ outlierVariance,
+ };
+ }
};
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
- class Config : public SharedImpl<IConfig> {
- private:
- Config( Config const& other );
- Config& operator = ( Config const& other );
- virtual void dummy();
- public:
+ struct IStreamingReporter {
+ virtual ~IStreamingReporter() = default;
- Config()
- {}
+ // Implementing class must also provide the following static methods:
+ // static std::string getDescription();
+ // static std::set<Verbosity> getSupportedVerbosities()
- Config( ConfigData const& data )
- : m_data( data ),
- m_stream( openStream() )
- {
- if( !data.testsOrTags.empty() ) {
- TestSpecParser parser( ITagAliasRegistry::get() );
- for( std::size_t i = 0; i < data.testsOrTags.size(); ++i )
- parser.parse( data.testsOrTags[i] );
- m_testSpec = parser.testSpec();
- }
- }
+ virtual ReporterPreferences getPreferences() const = 0;
- virtual ~Config() {}
+ virtual void noMatchingTestCases( std::string const& spec ) = 0;
- std::string const& getFilename() const {
- return m_data.outputFilename ;
- }
+ virtual void reportInvalidArguments(std::string const&) {}
- bool listTests() const { return m_data.listTests; }
- bool listTestNamesOnly() const { return m_data.listTestNamesOnly; }
- bool listTags() const { return m_data.listTags; }
- bool listReporters() const { return m_data.listReporters; }
+ virtual void testRunStarting( TestRunInfo const& testRunInfo ) = 0;
+ virtual void testGroupStarting( GroupInfo const& groupInfo ) = 0;
- std::string getProcessName() const { return m_data.processName; }
+ virtual void testCaseStarting( TestCaseInfo const& testInfo ) = 0;
+ virtual void sectionStarting( SectionInfo const& sectionInfo ) = 0;
- std::vector<std::string> const& getReporterNames() const { return m_data.reporterNames; }
- std::vector<std::string> const& getSectionsToRun() const CATCH_OVERRIDE { return m_data.sectionsToRun; }
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ virtual void benchmarkPreparing( std::string const& ) {}
+ virtual void benchmarkStarting( BenchmarkInfo const& ) {}
+ virtual void benchmarkEnded( BenchmarkStats<> const& ) {}
+ virtual void benchmarkFailed( std::string const& ) {}
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
- virtual TestSpec const& testSpec() const CATCH_OVERRIDE { return m_testSpec; }
+ virtual void assertionStarting( AssertionInfo const& assertionInfo ) = 0;
- bool showHelp() const { return m_data.showHelp; }
+ // The return value indicates if the messages buffer should be cleared:
+ virtual bool assertionEnded( AssertionStats const& assertionStats ) = 0;
- // IConfig interface
- virtual bool allowThrows() const CATCH_OVERRIDE { return !m_data.noThrow; }
- virtual std::ostream& stream() const CATCH_OVERRIDE { return m_stream->stream(); }
- virtual std::string name() const CATCH_OVERRIDE { return m_data.name.empty() ? m_data.processName : m_data.name; }
- virtual bool includeSuccessfulResults() const CATCH_OVERRIDE { return m_data.showSuccessfulTests; }
- virtual bool warnAboutMissingAssertions() const CATCH_OVERRIDE { return m_data.warnings & WarnAbout::NoAssertions; }
- virtual ShowDurations::OrNot showDurations() const CATCH_OVERRIDE { return m_data.showDurations; }
- virtual RunTests::InWhatOrder runOrder() const CATCH_OVERRIDE { return m_data.runOrder; }
- virtual unsigned int rngSeed() const CATCH_OVERRIDE { return m_data.rngSeed; }
- virtual UseColour::YesOrNo useColour() const CATCH_OVERRIDE { return m_data.useColour; }
- virtual bool shouldDebugBreak() const CATCH_OVERRIDE { return m_data.shouldDebugBreak; }
- virtual int abortAfter() const CATCH_OVERRIDE { return m_data.abortAfter; }
- virtual bool showInvisibles() const CATCH_OVERRIDE { return m_data.showInvisibles; }
+ virtual void sectionEnded( SectionStats const& sectionStats ) = 0;
+ virtual void testCaseEnded( TestCaseStats const& testCaseStats ) = 0;
+ virtual void testGroupEnded( TestGroupStats const& testGroupStats ) = 0;
+ virtual void testRunEnded( TestRunStats const& testRunStats ) = 0;
- private:
+ virtual void skipTest( TestCaseInfo const& testInfo ) = 0;
- IStream const* openStream() {
- if( m_data.outputFilename.empty() )
- return new CoutStream();
- else if( m_data.outputFilename[0] == '%' ) {
- if( m_data.outputFilename == "%debug" )
- return new DebugOutStream();
- else
- throw std::domain_error( "Unrecognised stream: " + m_data.outputFilename );
- }
- else
- return new FileStream( m_data.outputFilename );
- }
- ConfigData m_data;
+ // Default empty implementation provided
+ virtual void fatalErrorEncountered( StringRef name );
- CATCH_AUTO_PTR( IStream const ) m_stream;
- TestSpec m_testSpec;
+ virtual bool isMulti() const;
+ };
+ using IStreamingReporterPtr = std::unique_ptr<IStreamingReporter>;
+
+ struct IReporterFactory {
+ virtual ~IReporterFactory();
+ virtual IStreamingReporterPtr create( ReporterConfig const& config ) const = 0;
+ virtual std::string getDescription() const = 0;
+ };
+ using IReporterFactoryPtr = std::shared_ptr<IReporterFactory>;
+
+ struct IReporterRegistry {
+ using FactoryMap = std::map<std::string, IReporterFactoryPtr>;
+ using Listeners = std::vector<IReporterFactoryPtr>;
+
+ virtual ~IReporterRegistry();
+ virtual IStreamingReporterPtr create( std::string const& name, IConfigPtr const& config ) const = 0;
+ virtual FactoryMap const& getFactories() const = 0;
+ virtual Listeners const& getListeners() const = 0;
};
} // end namespace Catch
-// #included from: catch_clara.h
-#define TWOBLUECUBES_CATCH_CLARA_H_INCLUDED
+// end catch_interfaces_reporter.h
+#include <algorithm>
+#include <cstring>
+#include <cfloat>
+#include <cstdio>
+#include <cassert>
+#include <memory>
+#include <ostream>
-// Use Catch's value for console width (store Clara's off to the side, if present)
-#ifdef CLARA_CONFIG_CONSOLE_WIDTH
-#define CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH CLARA_CONFIG_CONSOLE_WIDTH
-#undef CLARA_CONFIG_CONSOLE_WIDTH
-#endif
-#define CLARA_CONFIG_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH
+namespace Catch {
+ void prepareExpandedExpression(AssertionResult& result);
-// Declare Clara inside the Catch namespace
-#define STITCH_CLARA_OPEN_NAMESPACE namespace Catch {
-// #included from: ../external/clara.h
+ // Returns double formatted as %.3f (format expected on output)
+ std::string getFormattedDuration( double duration );
-// Version 0.0.2.4
+ std::string serializeFilters( std::vector<std::string> const& container );
-// Only use header guard if we are not using an outer namespace
-#if !defined(TWOBLUECUBES_CLARA_H_INCLUDED) || defined(STITCH_CLARA_OPEN_NAMESPACE)
+ template<typename DerivedT>
+ struct StreamingReporterBase : IStreamingReporter {
-#ifndef STITCH_CLARA_OPEN_NAMESPACE
-#define TWOBLUECUBES_CLARA_H_INCLUDED
-#define STITCH_CLARA_OPEN_NAMESPACE
-#define STITCH_CLARA_CLOSE_NAMESPACE
-#else
-#define STITCH_CLARA_CLOSE_NAMESPACE }
-#endif
+ StreamingReporterBase( ReporterConfig const& _config )
+ : m_config( _config.fullConfig() ),
+ stream( _config.stream() )
+ {
+ m_reporterPrefs.shouldRedirectStdOut = false;
+ if( !DerivedT::getSupportedVerbosities().count( m_config->verbosity() ) )
+ CATCH_ERROR( "Verbosity level not supported by this reporter" );
+ }
-#define STITCH_TBC_TEXT_FORMAT_OPEN_NAMESPACE STITCH_CLARA_OPEN_NAMESPACE
+ ReporterPreferences getPreferences() const override {
+ return m_reporterPrefs;
+ }
-// ----------- #included from tbc_text_format.h -----------
+ static std::set<Verbosity> getSupportedVerbosities() {
+ return { Verbosity::Normal };
+ }
-// Only use header guard if we are not using an outer namespace
-#if !defined(TBC_TEXT_FORMAT_H_INCLUDED) || defined(STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE)
-#ifndef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE
-#define TBC_TEXT_FORMAT_H_INCLUDED
-#endif
+ ~StreamingReporterBase() override = default;
-#include <string>
-#include <vector>
-#include <sstream>
-#include <algorithm>
-#include <cctype>
+ void noMatchingTestCases(std::string const&) override {}
-// Use optional outer namespace
-#ifdef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE
-namespace STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE {
-#endif
+ void reportInvalidArguments(std::string const&) override {}
-namespace Tbc {
+ void testRunStarting(TestRunInfo const& _testRunInfo) override {
+ currentTestRunInfo = _testRunInfo;
+ }
-#ifdef TBC_TEXT_FORMAT_CONSOLE_WIDTH
- const unsigned int consoleWidth = TBC_TEXT_FORMAT_CONSOLE_WIDTH;
-#else
- const unsigned int consoleWidth = 80;
-#endif
+ void testGroupStarting(GroupInfo const& _groupInfo) override {
+ currentGroupInfo = _groupInfo;
+ }
- struct TextAttributes {
- TextAttributes()
- : initialIndent( std::string::npos ),
- indent( 0 ),
- width( consoleWidth-1 ),
- tabChar( '\t' )
- {}
+ void testCaseStarting(TestCaseInfo const& _testInfo) override {
+ currentTestCaseInfo = _testInfo;
+ }
+ void sectionStarting(SectionInfo const& _sectionInfo) override {
+ m_sectionStack.push_back(_sectionInfo);
+ }
- TextAttributes& setInitialIndent( std::size_t _value ) { initialIndent = _value; return *this; }
- TextAttributes& setIndent( std::size_t _value ) { indent = _value; return *this; }
- TextAttributes& setWidth( std::size_t _value ) { width = _value; return *this; }
- TextAttributes& setTabChar( char _value ) { tabChar = _value; return *this; }
+ void sectionEnded(SectionStats const& /* _sectionStats */) override {
+ m_sectionStack.pop_back();
+ }
+ void testCaseEnded(TestCaseStats const& /* _testCaseStats */) override {
+ currentTestCaseInfo.reset();
+ }
+ void testGroupEnded(TestGroupStats const& /* _testGroupStats */) override {
+ currentGroupInfo.reset();
+ }
+ void testRunEnded(TestRunStats const& /* _testRunStats */) override {
+ currentTestCaseInfo.reset();
+ currentGroupInfo.reset();
+ currentTestRunInfo.reset();
+ }
- std::size_t initialIndent; // indent of first line, or npos
- std::size_t indent; // indent of subsequent lines, or all if initialIndent is npos
- std::size_t width; // maximum width of text, including indent. Longer text will wrap
- char tabChar; // If this char is seen the indent is changed to current pos
+ void skipTest(TestCaseInfo const&) override {
+ // Don't do anything with this by default.
+ // It can optionally be overridden in the derived class.
+ }
+
+ IConfigPtr m_config;
+ std::ostream& stream;
+
+ LazyStat<TestRunInfo> currentTestRunInfo;
+ LazyStat<GroupInfo> currentGroupInfo;
+ LazyStat<TestCaseInfo> currentTestCaseInfo;
+
+ std::vector<SectionInfo> m_sectionStack;
+ ReporterPreferences m_reporterPrefs;
};
- class Text {
- public:
- Text( std::string const& _str, TextAttributes const& _attr = TextAttributes() )
- : attr( _attr )
- {
- std::string wrappableChars = " [({.,/|\\-";
- std::size_t indent = _attr.initialIndent != std::string::npos
- ? _attr.initialIndent
- : _attr.indent;
- std::string remainder = _str;
-
- while( !remainder.empty() ) {
- if( lines.size() >= 1000 ) {
- lines.push_back( "... message truncated due to excessive size" );
- return;
- }
- std::size_t tabPos = std::string::npos;
- std::size_t width = (std::min)( remainder.size(), _attr.width - indent );
- std::size_t pos = remainder.find_first_of( '\n' );
- if( pos <= width ) {
- width = pos;
- }
- pos = remainder.find_last_of( _attr.tabChar, width );
- if( pos != std::string::npos ) {
- tabPos = pos;
- if( remainder[width] == '\n' )
- width--;
- remainder = remainder.substr( 0, tabPos ) + remainder.substr( tabPos+1 );
- }
+ template<typename DerivedT>
+ struct CumulativeReporterBase : IStreamingReporter {
+ template<typename T, typename ChildNodeT>
+ struct Node {
+ explicit Node( T const& _value ) : value( _value ) {}
+ virtual ~Node() {}
- if( width == remainder.size() ) {
- spliceLine( indent, remainder, width );
- }
- else if( remainder[width] == '\n' ) {
- spliceLine( indent, remainder, width );
- if( width <= 1 || remainder.size() != 1 )
- remainder = remainder.substr( 1 );
- indent = _attr.indent;
- }
- else {
- pos = remainder.find_last_of( wrappableChars, width );
- if( pos != std::string::npos && pos > 0 ) {
- spliceLine( indent, remainder, pos );
- if( remainder[0] == ' ' )
- remainder = remainder.substr( 1 );
- }
- else {
- spliceLine( indent, remainder, width-1 );
- lines.back() += "-";
- }
- if( lines.size() == 1 )
- indent = _attr.indent;
- if( tabPos != std::string::npos )
- indent += tabPos;
- }
+ using ChildNodes = std::vector<std::shared_ptr<ChildNodeT>>;
+ T value;
+ ChildNodes children;
+ };
+ struct SectionNode {
+ explicit SectionNode(SectionStats const& _stats) : stats(_stats) {}
+ virtual ~SectionNode() = default;
+
+ bool operator == (SectionNode const& other) const {
+ return stats.sectionInfo.lineInfo == other.stats.sectionInfo.lineInfo;
}
- }
+ bool operator == (std::shared_ptr<SectionNode> const& other) const {
+ return operator==(*other);
+ }
+
+ SectionStats stats;
+ using ChildSections = std::vector<std::shared_ptr<SectionNode>>;
+ using Assertions = std::vector<AssertionStats>;
+ ChildSections childSections;
+ Assertions assertions;
+ std::string stdOut;
+ std::string stdErr;
+ };
+
+ struct BySectionInfo {
+ BySectionInfo( SectionInfo const& other ) : m_other( other ) {}
+ BySectionInfo( BySectionInfo const& other ) : m_other( other.m_other ) {}
+ bool operator() (std::shared_ptr<SectionNode> const& node) const {
+ return ((node->stats.sectionInfo.name == m_other.name) &&
+ (node->stats.sectionInfo.lineInfo == m_other.lineInfo));
+ }
+ void operator=(BySectionInfo const&) = delete;
+
+ private:
+ SectionInfo const& m_other;
+ };
+
+ using TestCaseNode = Node<TestCaseStats, SectionNode>;
+ using TestGroupNode = Node<TestGroupStats, TestCaseNode>;
+ using TestRunNode = Node<TestRunStats, TestGroupNode>;
- void spliceLine( std::size_t _indent, std::string& _remainder, std::size_t _pos ) {
- lines.push_back( std::string( _indent, ' ' ) + _remainder.substr( 0, _pos ) );
- _remainder = _remainder.substr( _pos );
+ CumulativeReporterBase( ReporterConfig const& _config )
+ : m_config( _config.fullConfig() ),
+ stream( _config.stream() )
+ {
+ m_reporterPrefs.shouldRedirectStdOut = false;
+ if( !DerivedT::getSupportedVerbosities().count( m_config->verbosity() ) )
+ CATCH_ERROR( "Verbosity level not supported by this reporter" );
}
+ ~CumulativeReporterBase() override = default;
- typedef std::vector<std::string>::const_iterator const_iterator;
+ ReporterPreferences getPreferences() const override {
+ return m_reporterPrefs;
+ }
- const_iterator begin() const { return lines.begin(); }
- const_iterator end() const { return lines.end(); }
- std::string const& last() const { return lines.back(); }
- std::size_t size() const { return lines.size(); }
- std::string const& operator[]( std::size_t _index ) const { return lines[_index]; }
- std::string toString() const {
- std::ostringstream oss;
- oss << *this;
- return oss.str();
+ static std::set<Verbosity> getSupportedVerbosities() {
+ return { Verbosity::Normal };
}
- inline friend std::ostream& operator << ( std::ostream& _stream, Text const& _text ) {
- for( Text::const_iterator it = _text.begin(), itEnd = _text.end();
- it != itEnd; ++it ) {
- if( it != _text.begin() )
- _stream << "\n";
- _stream << *it;
+ void testRunStarting( TestRunInfo const& ) override {}
+ void testGroupStarting( GroupInfo const& ) override {}
+
+ void testCaseStarting( TestCaseInfo const& ) override {}
+
+ void sectionStarting( SectionInfo const& sectionInfo ) override {
+ SectionStats incompleteStats( sectionInfo, Counts(), 0, false );
+ std::shared_ptr<SectionNode> node;
+ if( m_sectionStack.empty() ) {
+ if( !m_rootSection )
+ m_rootSection = std::make_shared<SectionNode>( incompleteStats );
+ node = m_rootSection;
}
- return _stream;
+ else {
+ SectionNode& parentNode = *m_sectionStack.back();
+ auto it =
+ std::find_if( parentNode.childSections.begin(),
+ parentNode.childSections.end(),
+ BySectionInfo( sectionInfo ) );
+ if( it == parentNode.childSections.end() ) {
+ node = std::make_shared<SectionNode>( incompleteStats );
+ parentNode.childSections.push_back( node );
+ }
+ else
+ node = *it;
+ }
+ m_sectionStack.push_back( node );
+ m_deepestSection = std::move(node);
}
- private:
- std::string str;
- TextAttributes attr;
- std::vector<std::string> lines;
+ void assertionStarting(AssertionInfo const&) override {}
+
+ bool assertionEnded(AssertionStats const& assertionStats) override {
+ assert(!m_sectionStack.empty());
+ // AssertionResult holds a pointer to a temporary DecomposedExpression,
+ // which getExpandedExpression() calls to build the expression string.
+ // Our section stack copy of the assertionResult will likely outlive the
+ // temporary, so it must be expanded or discarded now to avoid calling
+ // a destroyed object later.
+ prepareExpandedExpression(const_cast<AssertionResult&>( assertionStats.assertionResult ) );
+ SectionNode& sectionNode = *m_sectionStack.back();
+ sectionNode.assertions.push_back(assertionStats);
+ return true;
+ }
+ void sectionEnded(SectionStats const& sectionStats) override {
+ assert(!m_sectionStack.empty());
+ SectionNode& node = *m_sectionStack.back();
+ node.stats = sectionStats;
+ m_sectionStack.pop_back();
+ }
+ void testCaseEnded(TestCaseStats const& testCaseStats) override {
+ auto node = std::make_shared<TestCaseNode>(testCaseStats);
+ assert(m_sectionStack.size() == 0);
+ node->children.push_back(m_rootSection);
+ m_testCases.push_back(node);
+ m_rootSection.reset();
+
+ assert(m_deepestSection);
+ m_deepestSection->stdOut = testCaseStats.stdOut;
+ m_deepestSection->stdErr = testCaseStats.stdErr;
+ }
+ void testGroupEnded(TestGroupStats const& testGroupStats) override {
+ auto node = std::make_shared<TestGroupNode>(testGroupStats);
+ node->children.swap(m_testCases);
+ m_testGroups.push_back(node);
+ }
+ void testRunEnded(TestRunStats const& testRunStats) override {
+ auto node = std::make_shared<TestRunNode>(testRunStats);
+ node->children.swap(m_testGroups);
+ m_testRuns.push_back(node);
+ testRunEndedCumulative();
+ }
+ virtual void testRunEndedCumulative() = 0;
+
+ void skipTest(TestCaseInfo const&) override {}
+
+ IConfigPtr m_config;
+ std::ostream& stream;
+ std::vector<AssertionStats> m_assertions;
+ std::vector<std::vector<std::shared_ptr<SectionNode>>> m_sections;
+ std::vector<std::shared_ptr<TestCaseNode>> m_testCases;
+ std::vector<std::shared_ptr<TestGroupNode>> m_testGroups;
+
+ std::vector<std::shared_ptr<TestRunNode>> m_testRuns;
+
+ std::shared_ptr<SectionNode> m_rootSection;
+ std::shared_ptr<SectionNode> m_deepestSection;
+ std::vector<std::shared_ptr<SectionNode>> m_sectionStack;
+ ReporterPreferences m_reporterPrefs;
};
-} // end namespace Tbc
+ template<char C>
+ char const* getLineOfChars() {
+ static char line[CATCH_CONFIG_CONSOLE_WIDTH] = {0};
+ if( !*line ) {
+ std::memset( line, C, CATCH_CONFIG_CONSOLE_WIDTH-1 );
+ line[CATCH_CONFIG_CONSOLE_WIDTH-1] = 0;
+ }
+ return line;
+ }
-#ifdef STITCH_TBC_TEXT_FORMAT_OUTER_NAMESPACE
-} // end outer namespace
-#endif
+ struct TestEventListenerBase : StreamingReporterBase<TestEventListenerBase> {
+ TestEventListenerBase( ReporterConfig const& _config );
-#endif // TBC_TEXT_FORMAT_H_INCLUDED
+ static std::set<Verbosity> getSupportedVerbosities();
-// ----------- end of #include from tbc_text_format.h -----------
-// ........... back in clara.h
+ void assertionStarting(AssertionInfo const&) override;
+ bool assertionEnded(AssertionStats const&) override;
+ };
-#undef STITCH_TBC_TEXT_FORMAT_OPEN_NAMESPACE
+} // end namespace Catch
-// ----------- #included from clara_compilers.h -----------
+// end catch_reporter_bases.hpp
+// start catch_console_colour.h
-#ifndef TWOBLUECUBES_CLARA_COMPILERS_H_INCLUDED
-#define TWOBLUECUBES_CLARA_COMPILERS_H_INCLUDED
+namespace Catch {
-// Detect a number of compiler features - mostly C++11/14 conformance - by compiler
-// The following features are defined:
-//
-// CLARA_CONFIG_CPP11_NULLPTR : is nullptr supported?
-// CLARA_CONFIG_CPP11_NOEXCEPT : is noexcept supported?
-// CLARA_CONFIG_CPP11_GENERATED_METHODS : The delete and default keywords for compiler generated methods
-// CLARA_CONFIG_CPP11_OVERRIDE : is override supported?
-// CLARA_CONFIG_CPP11_UNIQUE_PTR : is unique_ptr supported (otherwise use auto_ptr)
+ struct Colour {
+ enum Code {
+ None = 0,
-// CLARA_CONFIG_CPP11_OR_GREATER : Is C++11 supported?
+ White,
+ Red,
+ Green,
+ Blue,
+ Cyan,
+ Yellow,
+ Grey,
-// CLARA_CONFIG_VARIADIC_MACROS : are variadic macros supported?
+ Bright = 0x10,
-// In general each macro has a _NO_<feature name> form
-// (e.g. CLARA_CONFIG_CPP11_NO_NULLPTR) which disables the feature.
-// Many features, at point of detection, define an _INTERNAL_ macro, so they
-// can be combined, en-mass, with the _NO_ forms later.
+ BrightRed = Bright | Red,
+ BrightGreen = Bright | Green,
+ LightGrey = Bright | Grey,
+ BrightWhite = Bright | White,
+ BrightYellow = Bright | Yellow,
-// All the C++11 features can be disabled with CLARA_CONFIG_NO_CPP11
+ // By intention
+ FileName = LightGrey,
+ Warning = BrightYellow,
+ ResultError = BrightRed,
+ ResultSuccess = BrightGreen,
+ ResultExpectedFailure = Warning,
-#ifdef __clang__
+ Error = BrightRed,
+ Success = Green,
-#if __has_feature(cxx_nullptr)
-#define CLARA_INTERNAL_CONFIG_CPP11_NULLPTR
-#endif
+ OriginalExpression = Cyan,
+ ReconstructedExpression = BrightYellow,
-#if __has_feature(cxx_noexcept)
-#define CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT
-#endif
+ SecondaryText = LightGrey,
+ Headers = White
+ };
-#endif // __clang__
+ // Use constructed object for RAII guard
+ Colour( Code _colourCode );
+ Colour( Colour&& other ) noexcept;
+ Colour& operator=( Colour&& other ) noexcept;
+ ~Colour();
-////////////////////////////////////////////////////////////////////////////////
-// GCC
-#ifdef __GNUC__
+ // Use static method for one-shot changes
+ static void use( Code _colourCode );
-#if __GNUC__ == 4 && __GNUC_MINOR__ >= 6 && defined(__GXX_EXPERIMENTAL_CXX0X__)
-#define CLARA_INTERNAL_CONFIG_CPP11_NULLPTR
-#endif
+ private:
+ bool m_moved = false;
+ };
-// - otherwise more recent versions define __cplusplus >= 201103L
-// and will get picked up below
+ std::ostream& operator << ( std::ostream& os, Colour const& );
-#endif // __GNUC__
+} // end namespace Catch
-////////////////////////////////////////////////////////////////////////////////
-// Visual C++
-#ifdef _MSC_VER
+// end catch_console_colour.h
+// start catch_reporter_registrars.hpp
-#if (_MSC_VER >= 1600)
-#define CLARA_INTERNAL_CONFIG_CPP11_NULLPTR
-#define CLARA_INTERNAL_CONFIG_CPP11_UNIQUE_PTR
-#endif
-#if (_MSC_VER >= 1900 ) // (VC++ 13 (VS2015))
-#define CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT
-#define CLARA_INTERNAL_CONFIG_CPP11_GENERATED_METHODS
-#endif
+namespace Catch {
-#endif // _MSC_VER
+ template<typename T>
+ class ReporterRegistrar {
-////////////////////////////////////////////////////////////////////////////////
-// C++ language feature support
+ class ReporterFactory : public IReporterFactory {
-// catch all support for C++11
-#if defined(__cplusplus) && __cplusplus >= 201103L
+ IStreamingReporterPtr create( ReporterConfig const& config ) const override {
+ return std::unique_ptr<T>( new T( config ) );
+ }
-#define CLARA_CPP11_OR_GREATER
+ std::string getDescription() const override {
+ return T::getDescription();
+ }
+ };
-#if !defined(CLARA_INTERNAL_CONFIG_CPP11_NULLPTR)
-#define CLARA_INTERNAL_CONFIG_CPP11_NULLPTR
-#endif
+ public:
-#ifndef CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT
-#define CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT
-#endif
+ explicit ReporterRegistrar( std::string const& name ) {
+ getMutableRegistryHub().registerReporter( name, std::make_shared<ReporterFactory>() );
+ }
+ };
-#ifndef CLARA_INTERNAL_CONFIG_CPP11_GENERATED_METHODS
-#define CLARA_INTERNAL_CONFIG_CPP11_GENERATED_METHODS
-#endif
+ template<typename T>
+ class ListenerRegistrar {
-#if !defined(CLARA_INTERNAL_CONFIG_CPP11_OVERRIDE)
-#define CLARA_INTERNAL_CONFIG_CPP11_OVERRIDE
-#endif
-#if !defined(CLARA_INTERNAL_CONFIG_CPP11_UNIQUE_PTR)
-#define CLARA_INTERNAL_CONFIG_CPP11_UNIQUE_PTR
-#endif
+ class ListenerFactory : public IReporterFactory {
-#endif // __cplusplus >= 201103L
+ IStreamingReporterPtr create( ReporterConfig const& config ) const override {
+ return std::unique_ptr<T>( new T( config ) );
+ }
+ std::string getDescription() const override {
+ return std::string();
+ }
+ };
-// Now set the actual defines based on the above + anything the user has configured
-#if defined(CLARA_INTERNAL_CONFIG_CPP11_NULLPTR) && !defined(CLARA_CONFIG_CPP11_NO_NULLPTR) && !defined(CLARA_CONFIG_CPP11_NULLPTR) && !defined(CLARA_CONFIG_NO_CPP11)
-#define CLARA_CONFIG_CPP11_NULLPTR
-#endif
-#if defined(CLARA_INTERNAL_CONFIG_CPP11_NOEXCEPT) && !defined(CLARA_CONFIG_CPP11_NO_NOEXCEPT) && !defined(CLARA_CONFIG_CPP11_NOEXCEPT) && !defined(CLARA_CONFIG_NO_CPP11)
-#define CLARA_CONFIG_CPP11_NOEXCEPT
-#endif
-#if defined(CLARA_INTERNAL_CONFIG_CPP11_GENERATED_METHODS) && !defined(CLARA_CONFIG_CPP11_NO_GENERATED_METHODS) && !defined(CLARA_CONFIG_CPP11_GENERATED_METHODS) && !defined(CLARA_CONFIG_NO_CPP11)
-#define CLARA_CONFIG_CPP11_GENERATED_METHODS
-#endif
-#if defined(CLARA_INTERNAL_CONFIG_CPP11_OVERRIDE) && !defined(CLARA_CONFIG_NO_OVERRIDE) && !defined(CLARA_CONFIG_CPP11_OVERRIDE) && !defined(CLARA_CONFIG_NO_CPP11)
-#define CLARA_CONFIG_CPP11_OVERRIDE
-#endif
-#if defined(CLARA_INTERNAL_CONFIG_CPP11_UNIQUE_PTR) && !defined(CLARA_CONFIG_NO_UNIQUE_PTR) && !defined(CLARA_CONFIG_CPP11_UNIQUE_PTR) && !defined(CLARA_CONFIG_NO_CPP11)
-#define CLARA_CONFIG_CPP11_UNIQUE_PTR
-#endif
+ public:
-// noexcept support:
-#if defined(CLARA_CONFIG_CPP11_NOEXCEPT) && !defined(CLARA_NOEXCEPT)
-#define CLARA_NOEXCEPT noexcept
-# define CLARA_NOEXCEPT_IS(x) noexcept(x)
-#else
-#define CLARA_NOEXCEPT throw()
-# define CLARA_NOEXCEPT_IS(x)
-#endif
+ ListenerRegistrar() {
+ getMutableRegistryHub().registerListener( std::make_shared<ListenerFactory>() );
+ }
+ };
+}
-// nullptr support
-#ifdef CLARA_CONFIG_CPP11_NULLPTR
-#define CLARA_NULL nullptr
-#else
-#define CLARA_NULL NULL
-#endif
+#if !defined(CATCH_CONFIG_DISABLE)
-// override support
-#ifdef CLARA_CONFIG_CPP11_OVERRIDE
-#define CLARA_OVERRIDE override
-#else
-#define CLARA_OVERRIDE
-#endif
+#define CATCH_REGISTER_REPORTER( name, reporterType ) \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ namespace{ Catch::ReporterRegistrar<reporterType> catch_internal_RegistrarFor##reporterType( name ); } \
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS
-// unique_ptr support
-#ifdef CLARA_CONFIG_CPP11_UNIQUE_PTR
-# define CLARA_AUTO_PTR( T ) std::unique_ptr<T>
-#else
-# define CLARA_AUTO_PTR( T ) std::auto_ptr<T>
-#endif
+#define CATCH_REGISTER_LISTENER( listenerType ) \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS \
+ namespace{ Catch::ListenerRegistrar<listenerType> catch_internal_RegistrarFor##listenerType; } \
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS
+#else // CATCH_CONFIG_DISABLE
-#endif // TWOBLUECUBES_CLARA_COMPILERS_H_INCLUDED
+#define CATCH_REGISTER_REPORTER(name, reporterType)
+#define CATCH_REGISTER_LISTENER(listenerType)
-// ----------- end of #include from clara_compilers.h -----------
-// ........... back in clara.h
+#endif // CATCH_CONFIG_DISABLE
-#include <map>
-#include <stdexcept>
-#include <memory>
+// end catch_reporter_registrars.hpp
+// Allow users to base their work off existing reporters
+// start catch_reporter_compact.h
-#if defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER)
-#define CLARA_PLATFORM_WINDOWS
-#endif
+namespace Catch {
-// Use optional outer namespace
-#ifdef STITCH_CLARA_OPEN_NAMESPACE
-STITCH_CLARA_OPEN_NAMESPACE
-#endif
+ struct CompactReporter : StreamingReporterBase<CompactReporter> {
-namespace Clara {
+ using StreamingReporterBase::StreamingReporterBase;
- struct UnpositionalTag {};
+ ~CompactReporter() override;
- extern UnpositionalTag _;
+ static std::string getDescription();
-#ifdef CLARA_CONFIG_MAIN
- UnpositionalTag _;
-#endif
+ ReporterPreferences getPreferences() const override;
- namespace Detail {
+ void noMatchingTestCases(std::string const& spec) override;
-#ifdef CLARA_CONSOLE_WIDTH
- const unsigned int consoleWidth = CLARA_CONFIG_CONSOLE_WIDTH;
-#else
- const unsigned int consoleWidth = 80;
-#endif
+ void assertionStarting(AssertionInfo const&) override;
- using namespace Tbc;
+ bool assertionEnded(AssertionStats const& _assertionStats) override;
- inline bool startsWith( std::string const& str, std::string const& prefix ) {
- return str.size() >= prefix.size() && str.substr( 0, prefix.size() ) == prefix;
- }
+ void sectionEnded(SectionStats const& _sectionStats) override;
- template<typename T> struct RemoveConstRef{ typedef T type; };
- template<typename T> struct RemoveConstRef<T&>{ typedef T type; };
- template<typename T> struct RemoveConstRef<T const&>{ typedef T type; };
- template<typename T> struct RemoveConstRef<T const>{ typedef T type; };
+ void testRunEnded(TestRunStats const& _testRunStats) override;
- template<typename T> struct IsBool { static const bool value = false; };
- template<> struct IsBool<bool> { static const bool value = true; };
+ };
- template<typename T>
- void convertInto( std::string const& _source, T& _dest ) {
- std::stringstream ss;
- ss << _source;
- ss >> _dest;
- if( ss.fail() )
- throw std::runtime_error( "Unable to convert " + _source + " to destination type" );
- }
- inline void convertInto( std::string const& _source, std::string& _dest ) {
- _dest = _source;
- }
- char toLowerCh(char c) {
- return static_cast<char>( std::tolower( c ) );
- }
- inline void convertInto( std::string const& _source, bool& _dest ) {
- std::string sourceLC = _source;
- std::transform( sourceLC.begin(), sourceLC.end(), sourceLC.begin(), toLowerCh );
- if( sourceLC == "y" || sourceLC == "1" || sourceLC == "true" || sourceLC == "yes" || sourceLC == "on" )
- _dest = true;
- else if( sourceLC == "n" || sourceLC == "0" || sourceLC == "false" || sourceLC == "no" || sourceLC == "off" )
- _dest = false;
- else
- throw std::runtime_error( "Expected a boolean value but did not recognise:\n '" + _source + "'" );
- }
+} // end namespace Catch
- template<typename ConfigT>
- struct IArgFunction {
- virtual ~IArgFunction() {}
-#ifdef CLARA_CONFIG_CPP11_GENERATED_METHODS
- IArgFunction() = default;
- IArgFunction( IArgFunction const& ) = default;
+// end catch_reporter_compact.h
+// start catch_reporter_console.h
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch
+ // Note that 4062 (not all labels are handled
+ // and default is missing) is enabled
#endif
- virtual void set( ConfigT& config, std::string const& value ) const = 0;
- virtual bool takesArg() const = 0;
- virtual IArgFunction* clone() const = 0;
- };
- template<typename ConfigT>
- class BoundArgFunction {
- public:
- BoundArgFunction() : functionObj( CLARA_NULL ) {}
- BoundArgFunction( IArgFunction<ConfigT>* _functionObj ) : functionObj( _functionObj ) {}
- BoundArgFunction( BoundArgFunction const& other ) : functionObj( other.functionObj ? other.functionObj->clone() : CLARA_NULL ) {}
- BoundArgFunction& operator = ( BoundArgFunction const& other ) {
- IArgFunction<ConfigT>* newFunctionObj = other.functionObj ? other.functionObj->clone() : CLARA_NULL;
- delete functionObj;
- functionObj = newFunctionObj;
- return *this;
- }
- ~BoundArgFunction() { delete functionObj; }
+namespace Catch {
+ // Fwd decls
+ struct SummaryColumn;
+ class TablePrinter;
- void set( ConfigT& config, std::string const& value ) const {
- functionObj->set( config, value );
- }
- bool takesArg() const { return functionObj->takesArg(); }
+ struct ConsoleReporter : StreamingReporterBase<ConsoleReporter> {
+ std::unique_ptr<TablePrinter> m_tablePrinter;
- bool isSet() const {
- return functionObj != CLARA_NULL;
- }
- private:
- IArgFunction<ConfigT>* functionObj;
- };
+ ConsoleReporter(ReporterConfig const& config);
+ ~ConsoleReporter() override;
+ static std::string getDescription();
- template<typename C>
- struct NullBinder : IArgFunction<C>{
- virtual void set( C&, std::string const& ) const {}
- virtual bool takesArg() const { return true; }
- virtual IArgFunction<C>* clone() const { return new NullBinder( *this ); }
- };
+ void noMatchingTestCases(std::string const& spec) override;
- template<typename C, typename M>
- struct BoundDataMember : IArgFunction<C>{
- BoundDataMember( M C::* _member ) : member( _member ) {}
- virtual void set( C& p, std::string const& stringValue ) const {
- convertInto( stringValue, p.*member );
- }
- virtual bool takesArg() const { return !IsBool<M>::value; }
- virtual IArgFunction<C>* clone() const { return new BoundDataMember( *this ); }
- M C::* member;
- };
- template<typename C, typename M>
- struct BoundUnaryMethod : IArgFunction<C>{
- BoundUnaryMethod( void (C::*_member)( M ) ) : member( _member ) {}
- virtual void set( C& p, std::string const& stringValue ) const {
- typename RemoveConstRef<M>::type value;
- convertInto( stringValue, value );
- (p.*member)( value );
- }
- virtual bool takesArg() const { return !IsBool<M>::value; }
- virtual IArgFunction<C>* clone() const { return new BoundUnaryMethod( *this ); }
- void (C::*member)( M );
- };
- template<typename C>
- struct BoundNullaryMethod : IArgFunction<C>{
- BoundNullaryMethod( void (C::*_member)() ) : member( _member ) {}
- virtual void set( C& p, std::string const& stringValue ) const {
- bool value;
- convertInto( stringValue, value );
- if( value )
- (p.*member)();
- }
- virtual bool takesArg() const { return false; }
- virtual IArgFunction<C>* clone() const { return new BoundNullaryMethod( *this ); }
- void (C::*member)();
- };
+ void reportInvalidArguments(std::string const&arg) override;
- template<typename C>
- struct BoundUnaryFunction : IArgFunction<C>{
- BoundUnaryFunction( void (*_function)( C& ) ) : function( _function ) {}
- virtual void set( C& obj, std::string const& stringValue ) const {
- bool value;
- convertInto( stringValue, value );
- if( value )
- function( obj );
- }
- virtual bool takesArg() const { return false; }
- virtual IArgFunction<C>* clone() const { return new BoundUnaryFunction( *this ); }
- void (*function)( C& );
- };
+ void assertionStarting(AssertionInfo const&) override;
- template<typename C, typename T>
- struct BoundBinaryFunction : IArgFunction<C>{
- BoundBinaryFunction( void (*_function)( C&, T ) ) : function( _function ) {}
- virtual void set( C& obj, std::string const& stringValue ) const {
- typename RemoveConstRef<T>::type value;
- convertInto( stringValue, value );
- function( obj, value );
- }
- virtual bool takesArg() const { return !IsBool<T>::value; }
- virtual IArgFunction<C>* clone() const { return new BoundBinaryFunction( *this ); }
- void (*function)( C&, T );
- };
+ bool assertionEnded(AssertionStats const& _assertionStats) override;
- } // namespace Detail
+ void sectionStarting(SectionInfo const& _sectionInfo) override;
+ void sectionEnded(SectionStats const& _sectionStats) override;
- inline std::vector<std::string> argsToVector( int argc, char const* const* const argv ) {
- std::vector<std::string> args( static_cast<std::size_t>( argc ) );
- for( std::size_t i = 0; i < static_cast<std::size_t>( argc ); ++i )
- args[i] = argv[i];
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ void benchmarkPreparing(std::string const& name) override;
+ void benchmarkStarting(BenchmarkInfo const& info) override;
+ void benchmarkEnded(BenchmarkStats<> const& stats) override;
+ void benchmarkFailed(std::string const& error) override;
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
- return args;
- }
+ void testCaseEnded(TestCaseStats const& _testCaseStats) override;
+ void testGroupEnded(TestGroupStats const& _testGroupStats) override;
+ void testRunEnded(TestRunStats const& _testRunStats) override;
+ void testRunStarting(TestRunInfo const& _testRunInfo) override;
+ private:
- class Parser {
- enum Mode { None, MaybeShortOpt, SlashOpt, ShortOpt, LongOpt, Positional };
- Mode mode;
- std::size_t from;
- bool inQuotes;
- public:
+ void lazyPrint();
- struct Token {
- enum Type { Positional, ShortOpt, LongOpt };
- Token( Type _type, std::string const& _data ) : type( _type ), data( _data ) {}
- Type type;
- std::string data;
- };
+ void lazyPrintWithoutClosingBenchmarkTable();
+ void lazyPrintRunInfo();
+ void lazyPrintGroupInfo();
+ void printTestCaseAndSectionHeader();
- Parser() : mode( None ), from( 0 ), inQuotes( false ){}
+ void printClosedHeader(std::string const& _name);
+ void printOpenHeader(std::string const& _name);
- void parseIntoTokens( std::vector<std::string> const& args, std::vector<Token>& tokens ) {
- const std::string doubleDash = "--";
- for( std::size_t i = 1; i < args.size() && args[i] != doubleDash; ++i )
- parseIntoTokens( args[i], tokens);
- }
+ // if string has a : in first line will set indent to follow it on
+ // subsequent lines
+ void printHeaderString(std::string const& _string, std::size_t indent = 0);
- void parseIntoTokens( std::string const& arg, std::vector<Token>& tokens ) {
- for( std::size_t i = 0; i < arg.size(); ++i ) {
- char c = arg[i];
- if( c == '"' )
- inQuotes = !inQuotes;
- mode = handleMode( i, c, arg, tokens );
- }
- mode = handleMode( arg.size(), '\0', arg, tokens );
- }
- Mode handleMode( std::size_t i, char c, std::string const& arg, std::vector<Token>& tokens ) {
- switch( mode ) {
- case None: return handleNone( i, c );
- case MaybeShortOpt: return handleMaybeShortOpt( i, c );
- case ShortOpt:
- case LongOpt:
- case SlashOpt: return handleOpt( i, c, arg, tokens );
- case Positional: return handlePositional( i, c, arg, tokens );
- default: throw std::logic_error( "Unknown mode" );
- }
- }
+ void printTotals(Totals const& totals);
+ void printSummaryRow(std::string const& label, std::vector<SummaryColumn> const& cols, std::size_t row);
- Mode handleNone( std::size_t i, char c ) {
- if( inQuotes ) {
- from = i;
- return Positional;
- }
- switch( c ) {
- case '-': return MaybeShortOpt;
-#ifdef CLARA_PLATFORM_WINDOWS
- case '/': from = i+1; return SlashOpt;
+ void printTotalsDivider(Totals const& totals);
+ void printSummaryDivider();
+ void printTestFilters();
+
+ private:
+ bool m_headerPrinted = false;
+ };
+
+} // end namespace Catch
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
#endif
- default: from = i; return Positional;
- }
- }
- Mode handleMaybeShortOpt( std::size_t i, char c ) {
- switch( c ) {
- case '-': from = i+1; return LongOpt;
- default: from = i; return ShortOpt;
- }
- }
- Mode handleOpt( std::size_t i, char c, std::string const& arg, std::vector<Token>& tokens ) {
- if( std::string( ":=\0", 3 ).find( c ) == std::string::npos )
- return mode;
+// end catch_reporter_console.h
+// start catch_reporter_junit.h
- std::string optName = arg.substr( from, i-from );
- if( mode == ShortOpt )
- for( std::size_t j = 0; j < optName.size(); ++j )
- tokens.push_back( Token( Token::ShortOpt, optName.substr( j, 1 ) ) );
- else if( mode == SlashOpt && optName.size() == 1 )
- tokens.push_back( Token( Token::ShortOpt, optName ) );
- else
- tokens.push_back( Token( Token::LongOpt, optName ) );
- return None;
- }
- Mode handlePositional( std::size_t i, char c, std::string const& arg, std::vector<Token>& tokens ) {
- if( inQuotes || std::string( "\0", 1 ).find( c ) == std::string::npos )
- return mode;
+// start catch_xmlwriter.h
- std::string data = arg.substr( from, i-from );
- tokens.push_back( Token( Token::Positional, data ) );
- return None;
- }
- };
+#include <vector>
- template<typename ConfigT>
- struct CommonArgProperties {
- CommonArgProperties() {}
- CommonArgProperties( Detail::BoundArgFunction<ConfigT> const& _boundField ) : boundField( _boundField ) {}
+namespace Catch {
- Detail::BoundArgFunction<ConfigT> boundField;
- std::string description;
- std::string detail;
- std::string placeholder; // Only value if boundField takes an arg
+ class XmlEncode {
+ public:
+ enum ForWhat { ForTextNodes, ForAttributes };
- bool takesArg() const {
- return !placeholder.empty();
- }
- void validate() const {
- if( !boundField.isSet() )
- throw std::logic_error( "option not bound" );
- }
- };
- struct OptionArgProperties {
- std::vector<std::string> shortNames;
- std::string longName;
+ XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes );
- bool hasShortName( std::string const& shortName ) const {
- return std::find( shortNames.begin(), shortNames.end(), shortName ) != shortNames.end();
- }
- bool hasLongName( std::string const& _longName ) const {
- return _longName == longName;
- }
- };
- struct PositionalArgProperties {
- PositionalArgProperties() : position( -1 ) {}
- int position; // -1 means non-positional (floating)
+ void encodeTo( std::ostream& os ) const;
- bool isFixedPositional() const {
- return position != -1;
- }
+ friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode );
+
+ private:
+ std::string m_str;
+ ForWhat m_forWhat;
};
- template<typename ConfigT>
- class CommandLine {
+ class XmlWriter {
+ public:
- struct Arg : CommonArgProperties<ConfigT>, OptionArgProperties, PositionalArgProperties {
- Arg() {}
- Arg( Detail::BoundArgFunction<ConfigT> const& _boundField ) : CommonArgProperties<ConfigT>( _boundField ) {}
+ class ScopedElement {
+ public:
+ ScopedElement( XmlWriter* writer );
- using CommonArgProperties<ConfigT>::placeholder; // !TBD
+ ScopedElement( ScopedElement&& other ) noexcept;
+ ScopedElement& operator=( ScopedElement&& other ) noexcept;
- std::string dbgName() const {
- if( !longName.empty() )
- return "--" + longName;
- if( !shortNames.empty() )
- return "-" + shortNames[0];
- return "positional args";
- }
- std::string commands() const {
- std::ostringstream oss;
- bool first = true;
- std::vector<std::string>::const_iterator it = shortNames.begin(), itEnd = shortNames.end();
- for(; it != itEnd; ++it ) {
- if( first )
- first = false;
- else
- oss << ", ";
- oss << "-" << *it;
- }
- if( !longName.empty() ) {
- if( !first )
- oss << ", ";
- oss << "--" << longName;
- }
- if( !placeholder.empty() )
- oss << " <" << placeholder << ">";
- return oss.str();
+ ~ScopedElement();
+
+ ScopedElement& writeText( std::string const& text, bool indent = true );
+
+ template<typename T>
+ ScopedElement& writeAttribute( std::string const& name, T const& attribute ) {
+ m_writer->writeAttribute( name, attribute );
+ return *this;
}
+
+ private:
+ mutable XmlWriter* m_writer = nullptr;
};
- typedef CLARA_AUTO_PTR( Arg ) ArgAutoPtr;
+ XmlWriter( std::ostream& os = Catch::cout() );
+ ~XmlWriter();
- friend void addOptName( Arg& arg, std::string const& optName )
- {
- if( optName.empty() )
- return;
- if( Detail::startsWith( optName, "--" ) ) {
- if( !arg.longName.empty() )
- throw std::logic_error( "Only one long opt may be specified. '"
- + arg.longName
- + "' already specified, now attempting to add '"
- + optName + "'" );
- arg.longName = optName.substr( 2 );
- }
- else if( Detail::startsWith( optName, "-" ) )
- arg.shortNames.push_back( optName.substr( 1 ) );
- else
- throw std::logic_error( "option must begin with - or --. Option was: '" + optName + "'" );
- }
- friend void setPositionalArg( Arg& arg, int position )
- {
- arg.position = position;
+ XmlWriter( XmlWriter const& ) = delete;
+ XmlWriter& operator=( XmlWriter const& ) = delete;
+
+ XmlWriter& startElement( std::string const& name );
+
+ ScopedElement scopedElement( std::string const& name );
+
+ XmlWriter& endElement();
+
+ XmlWriter& writeAttribute( std::string const& name, std::string const& attribute );
+
+ XmlWriter& writeAttribute( std::string const& name, bool attribute );
+
+ template<typename T>
+ XmlWriter& writeAttribute( std::string const& name, T const& attribute ) {
+ ReusableStringStream rss;
+ rss << attribute;
+ return writeAttribute( name, rss.str() );
}
- class ArgBuilder {
- public:
- ArgBuilder( Arg* arg ) : m_arg( arg ) {}
+ XmlWriter& writeText( std::string const& text, bool indent = true );
- // Bind a non-boolean data member (requires placeholder string)
- template<typename C, typename M>
- void bind( M C::* field, std::string const& placeholder ) {
- m_arg->boundField = new Detail::BoundDataMember<C,M>( field );
- m_arg->placeholder = placeholder;
- }
- // Bind a boolean data member (no placeholder required)
- template<typename C>
- void bind( bool C::* field ) {
- m_arg->boundField = new Detail::BoundDataMember<C,bool>( field );
- }
+ XmlWriter& writeComment( std::string const& text );
- // Bind a method taking a single, non-boolean argument (requires a placeholder string)
- template<typename C, typename M>
- void bind( void (C::* unaryMethod)( M ), std::string const& placeholder ) {
- m_arg->boundField = new Detail::BoundUnaryMethod<C,M>( unaryMethod );
- m_arg->placeholder = placeholder;
- }
+ void writeStylesheetRef( std::string const& url );
- // Bind a method taking a single, boolean argument (no placeholder string required)
- template<typename C>
- void bind( void (C::* unaryMethod)( bool ) ) {
- m_arg->boundField = new Detail::BoundUnaryMethod<C,bool>( unaryMethod );
- }
+ XmlWriter& writeBlankLine();
- // Bind a method that takes no arguments (will be called if opt is present)
- template<typename C>
- void bind( void (C::* nullaryMethod)() ) {
- m_arg->boundField = new Detail::BoundNullaryMethod<C>( nullaryMethod );
- }
+ void ensureTagClosed();
- // Bind a free function taking a single argument - the object to operate on (no placeholder string required)
- template<typename C>
- void bind( void (* unaryFunction)( C& ) ) {
- m_arg->boundField = new Detail::BoundUnaryFunction<C>( unaryFunction );
- }
+ private:
- // Bind a free function taking a single argument - the object to operate on (requires a placeholder string)
- template<typename C, typename T>
- void bind( void (* binaryFunction)( C&, T ), std::string const& placeholder ) {
- m_arg->boundField = new Detail::BoundBinaryFunction<C, T>( binaryFunction );
- m_arg->placeholder = placeholder;
- }
+ void writeDeclaration();
- ArgBuilder& describe( std::string const& description ) {
- m_arg->description = description;
- return *this;
- }
- ArgBuilder& detail( std::string const& detail ) {
- m_arg->detail = detail;
- return *this;
- }
+ void newlineIfNecessary();
- protected:
- Arg* m_arg;
- };
+ bool m_tagIsOpen = false;
+ bool m_needsNewline = false;
+ std::vector<std::string> m_tags;
+ std::string m_indent;
+ std::ostream& m_os;
+ };
- class OptBuilder : public ArgBuilder {
- public:
- OptBuilder( Arg* arg ) : ArgBuilder( arg ) {}
- OptBuilder( OptBuilder& other ) : ArgBuilder( other ) {}
+}
- OptBuilder& operator[]( std::string const& optName ) {
- addOptName( *ArgBuilder::m_arg, optName );
- return *this;
- }
- };
+// end catch_xmlwriter.h
+namespace Catch {
+ class JunitReporter : public CumulativeReporterBase<JunitReporter> {
public:
+ JunitReporter(ReporterConfig const& _config);
- CommandLine()
- : m_boundProcessName( new Detail::NullBinder<ConfigT>() ),
- m_highestSpecifiedArgPosition( 0 ),
- m_throwOnUnrecognisedTokens( false )
- {}
- CommandLine( CommandLine const& other )
- : m_boundProcessName( other.m_boundProcessName ),
- m_options ( other.m_options ),
- m_positionalArgs( other.m_positionalArgs ),
- m_highestSpecifiedArgPosition( other.m_highestSpecifiedArgPosition ),
- m_throwOnUnrecognisedTokens( other.m_throwOnUnrecognisedTokens )
- {
- if( other.m_floatingArg.get() )
- m_floatingArg.reset( new Arg( *other.m_floatingArg ) );
- }
+ ~JunitReporter() override;
- CommandLine& setThrowOnUnrecognisedTokens( bool shouldThrow = true ) {
- m_throwOnUnrecognisedTokens = shouldThrow;
- return *this;
- }
+ static std::string getDescription();
- OptBuilder operator[]( std::string const& optName ) {
- m_options.push_back( Arg() );
- addOptName( m_options.back(), optName );
- OptBuilder builder( &m_options.back() );
- return builder;
- }
+ void noMatchingTestCases(std::string const& /*spec*/) override;
- ArgBuilder operator[]( int position ) {
- m_positionalArgs.insert( std::make_pair( position, Arg() ) );
- if( position > m_highestSpecifiedArgPosition )
- m_highestSpecifiedArgPosition = position;
- setPositionalArg( m_positionalArgs[position], position );
- ArgBuilder builder( &m_positionalArgs[position] );
- return builder;
- }
+ void testRunStarting(TestRunInfo const& runInfo) override;
- // Invoke this with the _ instance
- ArgBuilder operator[]( UnpositionalTag ) {
- if( m_floatingArg.get() )
- throw std::logic_error( "Only one unpositional argument can be added" );
- m_floatingArg.reset( new Arg() );
- ArgBuilder builder( m_floatingArg.get() );
- return builder;
- }
+ void testGroupStarting(GroupInfo const& groupInfo) override;
- template<typename C, typename M>
- void bindProcessName( M C::* field ) {
- m_boundProcessName = new Detail::BoundDataMember<C,M>( field );
- }
- template<typename C, typename M>
- void bindProcessName( void (C::*_unaryMethod)( M ) ) {
- m_boundProcessName = new Detail::BoundUnaryMethod<C,M>( _unaryMethod );
- }
+ void testCaseStarting(TestCaseInfo const& testCaseInfo) override;
+ bool assertionEnded(AssertionStats const& assertionStats) override;
- void optUsage( std::ostream& os, std::size_t indent = 0, std::size_t width = Detail::consoleWidth ) const {
- typename std::vector<Arg>::const_iterator itBegin = m_options.begin(), itEnd = m_options.end(), it;
- std::size_t maxWidth = 0;
- for( it = itBegin; it != itEnd; ++it )
- maxWidth = (std::max)( maxWidth, it->commands().size() );
+ void testCaseEnded(TestCaseStats const& testCaseStats) override;
- for( it = itBegin; it != itEnd; ++it ) {
- Detail::Text usage( it->commands(), Detail::TextAttributes()
- .setWidth( maxWidth+indent )
- .setIndent( indent ) );
- Detail::Text desc( it->description, Detail::TextAttributes()
- .setWidth( width - maxWidth - 3 ) );
+ void testGroupEnded(TestGroupStats const& testGroupStats) override;
- for( std::size_t i = 0; i < (std::max)( usage.size(), desc.size() ); ++i ) {
- std::string usageCol = i < usage.size() ? usage[i] : "";
- os << usageCol;
+ void testRunEndedCumulative() override;
- if( i < desc.size() && !desc[i].empty() )
- os << std::string( indent + 2 + maxWidth - usageCol.size(), ' ' )
- << desc[i];
- os << "\n";
- }
- }
- }
- std::string optUsage() const {
- std::ostringstream oss;
- optUsage( oss );
- return oss.str();
- }
-
- void argSynopsis( std::ostream& os ) const {
- for( int i = 1; i <= m_highestSpecifiedArgPosition; ++i ) {
- if( i > 1 )
- os << " ";
- typename std::map<int, Arg>::const_iterator it = m_positionalArgs.find( i );
- if( it != m_positionalArgs.end() )
- os << "<" << it->second.placeholder << ">";
- else if( m_floatingArg.get() )
- os << "<" << m_floatingArg->placeholder << ">";
- else
- throw std::logic_error( "non consecutive positional arguments with no floating args" );
- }
- // !TBD No indication of mandatory args
- if( m_floatingArg.get() ) {
- if( m_highestSpecifiedArgPosition > 1 )
- os << " ";
- os << "[<" << m_floatingArg->placeholder << "> ...]";
- }
- }
- std::string argSynopsis() const {
- std::ostringstream oss;
- argSynopsis( oss );
- return oss.str();
- }
+ void writeGroup(TestGroupNode const& groupNode, double suiteTime);
- void usage( std::ostream& os, std::string const& procName ) const {
- validate();
- os << "usage:\n " << procName << " ";
- argSynopsis( os );
- if( !m_options.empty() ) {
- os << " [options]\n\nwhere options are: \n";
- optUsage( os, 2 );
- }
- os << "\n";
- }
- std::string usage( std::string const& procName ) const {
- std::ostringstream oss;
- usage( oss, procName );
- return oss.str();
- }
-
- ConfigT parse( std::vector<std::string> const& args ) const {
- ConfigT config;
- parseInto( args, config );
- return config;
- }
-
- std::vector<Parser::Token> parseInto( std::vector<std::string> const& args, ConfigT& config ) const {
- std::string processName = args.empty() ? std::string() : args[0];
- std::size_t lastSlash = processName.find_last_of( "/\\" );
- if( lastSlash != std::string::npos )
- processName = processName.substr( lastSlash+1 );
- m_boundProcessName.set( config, processName );
- std::vector<Parser::Token> tokens;
- Parser parser;
- parser.parseIntoTokens( args, tokens );
- return populate( tokens, config );
- }
-
- std::vector<Parser::Token> populate( std::vector<Parser::Token> const& tokens, ConfigT& config ) const {
- validate();
- std::vector<Parser::Token> unusedTokens = populateOptions( tokens, config );
- unusedTokens = populateFixedArgs( unusedTokens, config );
- unusedTokens = populateFloatingArgs( unusedTokens, config );
- return unusedTokens;
- }
-
- std::vector<Parser::Token> populateOptions( std::vector<Parser::Token> const& tokens, ConfigT& config ) const {
- std::vector<Parser::Token> unusedTokens;
- std::vector<std::string> errors;
- for( std::size_t i = 0; i < tokens.size(); ++i ) {
- Parser::Token const& token = tokens[i];
- typename std::vector<Arg>::const_iterator it = m_options.begin(), itEnd = m_options.end();
- for(; it != itEnd; ++it ) {
- Arg const& arg = *it;
-
- try {
- if( ( token.type == Parser::Token::ShortOpt && arg.hasShortName( token.data ) ) ||
- ( token.type == Parser::Token::LongOpt && arg.hasLongName( token.data ) ) ) {
- if( arg.takesArg() ) {
- if( i == tokens.size()-1 || tokens[i+1].type != Parser::Token::Positional )
- errors.push_back( "Expected argument to option: " + token.data );
- else
- arg.boundField.set( config, tokens[++i].data );
- }
- else {
- arg.boundField.set( config, "true" );
- }
- break;
- }
- }
- catch( std::exception& ex ) {
- errors.push_back( std::string( ex.what() ) + "\n- while parsing: (" + arg.commands() + ")" );
- }
- }
- if( it == itEnd ) {
- if( token.type == Parser::Token::Positional || !m_throwOnUnrecognisedTokens )
- unusedTokens.push_back( token );
- else if( errors.empty() && m_throwOnUnrecognisedTokens )
- errors.push_back( "unrecognised option: " + token.data );
- }
- }
- if( !errors.empty() ) {
- std::ostringstream oss;
- for( std::vector<std::string>::const_iterator it = errors.begin(), itEnd = errors.end();
- it != itEnd;
- ++it ) {
- if( it != errors.begin() )
- oss << "\n";
- oss << *it;
- }
- throw std::runtime_error( oss.str() );
- }
- return unusedTokens;
- }
- std::vector<Parser::Token> populateFixedArgs( std::vector<Parser::Token> const& tokens, ConfigT& config ) const {
- std::vector<Parser::Token> unusedTokens;
- int position = 1;
- for( std::size_t i = 0; i < tokens.size(); ++i ) {
- Parser::Token const& token = tokens[i];
- typename std::map<int, Arg>::const_iterator it = m_positionalArgs.find( position );
- if( it != m_positionalArgs.end() )
- it->second.boundField.set( config, token.data );
- else
- unusedTokens.push_back( token );
- if( token.type == Parser::Token::Positional )
- position++;
- }
- return unusedTokens;
- }
- std::vector<Parser::Token> populateFloatingArgs( std::vector<Parser::Token> const& tokens, ConfigT& config ) const {
- if( !m_floatingArg.get() )
- return tokens;
- std::vector<Parser::Token> unusedTokens;
- for( std::size_t i = 0; i < tokens.size(); ++i ) {
- Parser::Token const& token = tokens[i];
- if( token.type == Parser::Token::Positional )
- m_floatingArg->boundField.set( config, token.data );
- else
- unusedTokens.push_back( token );
- }
- return unusedTokens;
- }
+ void writeTestCase(TestCaseNode const& testCaseNode);
- void validate() const
- {
- if( m_options.empty() && m_positionalArgs.empty() && !m_floatingArg.get() )
- throw std::logic_error( "No options or arguments specified" );
+ void writeSection(std::string const& className,
+ std::string const& rootName,
+ SectionNode const& sectionNode);
- for( typename std::vector<Arg>::const_iterator it = m_options.begin(),
- itEnd = m_options.end();
- it != itEnd; ++it )
- it->validate();
- }
+ void writeAssertions(SectionNode const& sectionNode);
+ void writeAssertion(AssertionStats const& stats);
- private:
- Detail::BoundArgFunction<ConfigT> m_boundProcessName;
- std::vector<Arg> m_options;
- std::map<int, Arg> m_positionalArgs;
- ArgAutoPtr m_floatingArg;
- int m_highestSpecifiedArgPosition;
- bool m_throwOnUnrecognisedTokens;
+ XmlWriter xml;
+ Timer suiteTimer;
+ std::string stdOutForSuite;
+ std::string stdErrForSuite;
+ unsigned int unexpectedExceptions = 0;
+ bool m_okToFail = false;
};
-} // end namespace Clara
+} // end namespace Catch
-STITCH_CLARA_CLOSE_NAMESPACE
-#undef STITCH_CLARA_OPEN_NAMESPACE
-#undef STITCH_CLARA_CLOSE_NAMESPACE
+// end catch_reporter_junit.h
+// start catch_reporter_xml.h
-#endif // TWOBLUECUBES_CLARA_H_INCLUDED
-#undef STITCH_CLARA_OPEN_NAMESPACE
+namespace Catch {
+ class XmlReporter : public StreamingReporterBase<XmlReporter> {
+ public:
+ XmlReporter(ReporterConfig const& _config);
-// Restore Clara's value for console width, if present
-#ifdef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH
-#define CLARA_CONFIG_CONSOLE_WIDTH CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH
-#undef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH
-#endif
+ ~XmlReporter() override;
-#include <fstream>
-#include <ctime>
+ static std::string getDescription();
-namespace Catch {
+ virtual std::string getStylesheetRef() const;
- inline void abortAfterFirst( ConfigData& config ) { config.abortAfter = 1; }
- inline void abortAfterX( ConfigData& config, int x ) {
- if( x < 1 )
- throw std::runtime_error( "Value after -x or --abortAfter must be greater than zero" );
- config.abortAfter = x;
- }
- inline void addTestOrTags( ConfigData& config, std::string const& _testSpec ) { config.testsOrTags.push_back( _testSpec ); }
- inline void addSectionToRun( ConfigData& config, std::string const& sectionName ) { config.sectionsToRun.push_back( sectionName ); }
- inline void addReporterName( ConfigData& config, std::string const& _reporterName ) { config.reporterNames.push_back( _reporterName ); }
+ void writeSourceInfo(SourceLineInfo const& sourceInfo);
- inline void addWarning( ConfigData& config, std::string const& _warning ) {
- if( _warning == "NoAssertions" )
- config.warnings = static_cast<WarnAbout::What>( config.warnings | WarnAbout::NoAssertions );
- else
- throw std::runtime_error( "Unrecognised warning: '" + _warning + '\'' );
- }
- inline void setOrder( ConfigData& config, std::string const& order ) {
- if( startsWith( "declared", order ) )
- config.runOrder = RunTests::InDeclarationOrder;
- else if( startsWith( "lexical", order ) )
- config.runOrder = RunTests::InLexicographicalOrder;
- else if( startsWith( "random", order ) )
- config.runOrder = RunTests::InRandomOrder;
- else
- throw std::runtime_error( "Unrecognised ordering: '" + order + '\'' );
- }
- inline void setRngSeed( ConfigData& config, std::string const& seed ) {
- if( seed == "time" ) {
- config.rngSeed = static_cast<unsigned int>( std::time(0) );
- }
- else {
- std::stringstream ss;
- ss << seed;
- ss >> config.rngSeed;
- if( ss.fail() )
- throw std::runtime_error( "Argument to --rng-seed should be the word 'time' or a number" );
- }
- }
- inline void setVerbosity( ConfigData& config, int level ) {
- // !TBD: accept strings?
- config.verbosity = static_cast<Verbosity::Level>( level );
- }
- inline void setShowDurations( ConfigData& config, bool _showDurations ) {
- config.showDurations = _showDurations
- ? ShowDurations::Always
- : ShowDurations::Never;
- }
- inline void setUseColour( ConfigData& config, std::string const& value ) {
- std::string mode = toLower( value );
-
- if( mode == "yes" )
- config.useColour = UseColour::Yes;
- else if( mode == "no" )
- config.useColour = UseColour::No;
- else if( mode == "auto" )
- config.useColour = UseColour::Auto;
- else
- throw std::runtime_error( "colour mode must be one of: auto, yes or no" );
- }
- inline void forceColour( ConfigData& config ) {
- config.useColour = UseColour::Yes;
- }
- inline void loadTestNamesFromFile( ConfigData& config, std::string const& _filename ) {
- std::ifstream f( _filename.c_str() );
- if( !f.is_open() )
- throw std::domain_error( "Unable to load input file: " + _filename );
+ public: // StreamingReporterBase
- std::string line;
- while( std::getline( f, line ) ) {
- line = trim(line);
- if( !line.empty() && !startsWith( line, '#' ) ) {
- if( !startsWith( line, '"' ) )
- line = '"' + line + '"';
- addTestOrTags( config, line + ',' );
- }
- }
- }
+ void noMatchingTestCases(std::string const& s) override;
- inline Clara::CommandLine<ConfigData> makeCommandLineParser() {
+ void testRunStarting(TestRunInfo const& testInfo) override;
- using namespace Clara;
- CommandLine<ConfigData> cli;
+ void testGroupStarting(GroupInfo const& groupInfo) override;
- cli.bindProcessName( &ConfigData::processName );
+ void testCaseStarting(TestCaseInfo const& testInfo) override;
- cli["-?"]["-h"]["--help"]
- .describe( "display usage information" )
- .bind( &ConfigData::showHelp );
+ void sectionStarting(SectionInfo const& sectionInfo) override;
- cli["-l"]["--list-tests"]
- .describe( "list all/matching test cases" )
- .bind( &ConfigData::listTests );
+ void assertionStarting(AssertionInfo const&) override;
- cli["-t"]["--list-tags"]
- .describe( "list all/matching tags" )
- .bind( &ConfigData::listTags );
+ bool assertionEnded(AssertionStats const& assertionStats) override;
- cli["-s"]["--success"]
- .describe( "include successful tests in output" )
- .bind( &ConfigData::showSuccessfulTests );
+ void sectionEnded(SectionStats const& sectionStats) override;
- cli["-b"]["--break"]
- .describe( "break into debugger on failure" )
- .bind( &ConfigData::shouldDebugBreak );
+ void testCaseEnded(TestCaseStats const& testCaseStats) override;
- cli["-e"]["--nothrow"]
- .describe( "skip exception tests" )
- .bind( &ConfigData::noThrow );
+ void testGroupEnded(TestGroupStats const& testGroupStats) override;
- cli["-i"]["--invisibles"]
- .describe( "show invisibles (tabs, newlines)" )
- .bind( &ConfigData::showInvisibles );
+ void testRunEnded(TestRunStats const& testRunStats) override;
- cli["-o"]["--out"]
- .describe( "output filename" )
- .bind( &ConfigData::outputFilename, "filename" );
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ void benchmarkPreparing(std::string const& name) override;
+ void benchmarkStarting(BenchmarkInfo const&) override;
+ void benchmarkEnded(BenchmarkStats<> const&) override;
+ void benchmarkFailed(std::string const&) override;
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
- cli["-r"]["--reporter"]
-// .placeholder( "name[:filename]" )
- .describe( "reporter to use (defaults to console)" )
- .bind( &addReporterName, "name" );
+ private:
+ Timer m_testCaseTimer;
+ XmlWriter m_xml;
+ int m_sectionDepth = 0;
+ };
- cli["-n"]["--name"]
- .describe( "suite name" )
- .bind( &ConfigData::name, "name" );
+} // end namespace Catch
- cli["-a"]["--abort"]
- .describe( "abort at first failure" )
- .bind( &abortAfterFirst );
+// end catch_reporter_xml.h
- cli["-x"]["--abortx"]
- .describe( "abort after x failures" )
- .bind( &abortAfterX, "no. failures" );
+// end catch_external_interfaces.h
+#endif
- cli["-w"]["--warn"]
- .describe( "enable warnings" )
- .bind( &addWarning, "warning name" );
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+// start catch_benchmark.hpp
-// - needs updating if reinstated
-// cli.into( &setVerbosity )
-// .describe( "level of verbosity (0=no output)" )
-// .shortOpt( "v")
-// .longOpt( "verbosity" )
-// .placeholder( "level" );
+ // Benchmark
- cli[_]
- .describe( "which test or tests to use" )
- .bind( &addTestOrTags, "test name, pattern or tags" );
+// start catch_chronometer.hpp
- cli["-d"]["--durations"]
- .describe( "show test durations" )
- .bind( &setShowDurations, "yes|no" );
+// User-facing chronometer
- cli["-f"]["--input-file"]
- .describe( "load test names to run from a file" )
- .bind( &loadTestNamesFromFile, "filename" );
- cli["-#"]["--filenames-as-tags"]
- .describe( "adds a tag for the filename" )
- .bind( &ConfigData::filenamesAsTags );
+// start catch_clock.hpp
- cli["-c"]["--section"]
- .describe( "specify section to run" )
- .bind( &addSectionToRun, "section name" );
+// Clocks
- // Less common commands which don't have a short form
- cli["--list-test-names-only"]
- .describe( "list all/matching test cases names only" )
- .bind( &ConfigData::listTestNamesOnly );
- cli["--list-reporters"]
- .describe( "list all reporters" )
- .bind( &ConfigData::listReporters );
+#include <chrono>
+#include <ratio>
- cli["--order"]
- .describe( "test case order (defaults to decl)" )
- .bind( &setOrder, "decl|lex|rand" );
+namespace Catch {
+ namespace Benchmark {
+ template <typename Clock>
+ using ClockDuration = typename Clock::duration;
+ template <typename Clock>
+ using FloatDuration = std::chrono::duration<double, typename Clock::period>;
- cli["--rng-seed"]
- .describe( "set a specific seed for random numbers" )
- .bind( &setRngSeed, "'time'|number" );
+ template <typename Clock>
+ using TimePoint = typename Clock::time_point;
- cli["--force-colour"]
- .describe( "force colourised output (deprecated)" )
- .bind( &forceColour );
-
- cli["--use-colour"]
- .describe( "should output be colourised" )
- .bind( &setUseColour, "yes|no" );
+ using default_clock = std::chrono::steady_clock;
- return cli;
- }
+ template <typename Clock>
+ struct now {
+ TimePoint<Clock> operator()() const {
+ return Clock::now();
+ }
+ };
-} // end namespace Catch
+ using fp_seconds = std::chrono::duration<double, std::ratio<1>>;
+ } // namespace Benchmark
+} // namespace Catch
-// #included from: internal/catch_list.hpp
-#define TWOBLUECUBES_CATCH_LIST_HPP_INCLUDED
+// end catch_clock.hpp
+// start catch_optimizer.hpp
-// #included from: catch_text.h
-#define TWOBLUECUBES_CATCH_TEXT_H_INCLUDED
+ // Hinting the optimizer
-#define TBC_TEXT_FORMAT_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH
-#define CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE Catch
-// #included from: ../external/tbc_text_format.h
-// Only use header guard if we are not using an outer namespace
-#ifndef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE
-# ifdef TWOBLUECUBES_TEXT_FORMAT_H_INCLUDED
-# ifndef TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED
-# define TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED
-# endif
-# else
-# define TWOBLUECUBES_TEXT_FORMAT_H_INCLUDED
-# endif
+#if defined(_MSC_VER)
+# include <atomic> // atomic_thread_fence
#endif
-#ifndef TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED
-#include <string>
-#include <vector>
-#include <sstream>
-// Use optional outer namespace
-#ifdef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE
-namespace CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE {
-#endif
+namespace Catch {
+ namespace Benchmark {
+#if defined(__GNUC__) || defined(__clang__)
+ template <typename T>
+ inline void keep_memory(T* p) {
+ asm volatile("" : : "g"(p) : "memory");
+ }
+ inline void keep_memory() {
+ asm volatile("" : : : "memory");
+ }
-namespace Tbc {
+ namespace Detail {
+ inline void optimizer_barrier() { keep_memory(); }
+ } // namespace Detail
+#elif defined(_MSC_VER)
+
+#pragma optimize("", off)
+ template <typename T>
+ inline void keep_memory(T* p) {
+ // thanks @milleniumbug
+ *reinterpret_cast<char volatile*>(p) = *reinterpret_cast<char const volatile*>(p);
+ }
+ // TODO equivalent keep_memory()
+#pragma optimize("", on)
+
+ namespace Detail {
+ inline void optimizer_barrier() {
+ std::atomic_thread_fence(std::memory_order_seq_cst);
+ }
+ } // namespace Detail
-#ifdef TBC_TEXT_FORMAT_CONSOLE_WIDTH
- const unsigned int consoleWidth = TBC_TEXT_FORMAT_CONSOLE_WIDTH;
-#else
- const unsigned int consoleWidth = 80;
#endif
- struct TextAttributes {
- TextAttributes()
- : initialIndent( std::string::npos ),
- indent( 0 ),
- width( consoleWidth-1 )
- {}
+ template <typename T>
+ inline void deoptimize_value(T&& x) {
+ keep_memory(&x);
+ }
- TextAttributes& setInitialIndent( std::size_t _value ) { initialIndent = _value; return *this; }
- TextAttributes& setIndent( std::size_t _value ) { indent = _value; return *this; }
- TextAttributes& setWidth( std::size_t _value ) { width = _value; return *this; }
+ template <typename Fn, typename... Args>
+ inline auto invoke_deoptimized(Fn&& fn, Args&&... args) -> typename std::enable_if<!std::is_same<void, decltype(fn(args...))>::value>::type {
+ deoptimize_value(std::forward<Fn>(fn) (std::forward<Args...>(args...)));
+ }
- std::size_t initialIndent; // indent of first line, or npos
- std::size_t indent; // indent of subsequent lines, or all if initialIndent is npos
- std::size_t width; // maximum width of text, including indent. Longer text will wrap
- };
+ template <typename Fn, typename... Args>
+ inline auto invoke_deoptimized(Fn&& fn, Args&&... args) -> typename std::enable_if<std::is_same<void, decltype(fn(args...))>::value>::type {
+ std::forward<Fn>(fn) (std::forward<Args...>(args...));
+ }
+ } // namespace Benchmark
+} // namespace Catch
- class Text {
- public:
- Text( std::string const& _str, TextAttributes const& _attr = TextAttributes() )
- : attr( _attr )
- {
- const std::string wrappableBeforeChars = "[({<\t";
- const std::string wrappableAfterChars = "])}>-,./|\\";
- const std::string wrappableInsteadOfChars = " \n\r";
- std::string indent = _attr.initialIndent != std::string::npos
- ? std::string( _attr.initialIndent, ' ' )
- : std::string( _attr.indent, ' ' );
-
- typedef std::string::const_iterator iterator;
- iterator it = _str.begin();
- const iterator strEnd = _str.end();
-
- while( it != strEnd ) {
-
- if( lines.size() >= 1000 ) {
- lines.push_back( "... message truncated due to excessive size" );
- return;
- }
+// end catch_optimizer.hpp
+// start catch_complete_invoke.hpp
- std::string suffix;
- std::size_t width = (std::min)( static_cast<size_t>( strEnd-it ), _attr.width-static_cast<size_t>( indent.size() ) );
- iterator itEnd = it+width;
- iterator itNext = _str.end();
-
- iterator itNewLine = std::find( it, itEnd, '\n' );
- if( itNewLine != itEnd )
- itEnd = itNewLine;
-
- if( itEnd != strEnd ) {
- bool foundWrapPoint = false;
- iterator findIt = itEnd;
- do {
- if( wrappableAfterChars.find( *findIt ) != std::string::npos && findIt != itEnd ) {
- itEnd = findIt+1;
- itNext = findIt+1;
- foundWrapPoint = true;
- }
- else if( findIt > it && wrappableBeforeChars.find( *findIt ) != std::string::npos ) {
- itEnd = findIt;
- itNext = findIt;
- foundWrapPoint = true;
- }
- else if( wrappableInsteadOfChars.find( *findIt ) != std::string::npos ) {
- itNext = findIt+1;
- itEnd = findIt;
- foundWrapPoint = true;
- }
- if( findIt == it )
- break;
- else
- --findIt;
- }
- while( !foundWrapPoint );
+// Invoke with a special case for void
- if( !foundWrapPoint ) {
- // No good wrap char, so we'll break mid word and add a hyphen
- --itEnd;
- itNext = itEnd;
- suffix = "-";
- }
- else {
- while( itEnd > it && wrappableInsteadOfChars.find( *(itEnd-1) ) != std::string::npos )
- --itEnd;
- }
+
+#include <type_traits>
+#include <utility>
+
+namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ template <typename T>
+ struct CompleteType { using type = T; };
+ template <>
+ struct CompleteType<void> { struct type {}; };
+
+ template <typename T>
+ using CompleteType_t = typename CompleteType<T>::type;
+
+ template <typename Result>
+ struct CompleteInvoker {
+ template <typename Fun, typename... Args>
+ static Result invoke(Fun&& fun, Args&&... args) {
+ return std::forward<Fun>(fun)(std::forward<Args>(args)...);
}
- lines.push_back( indent + std::string( it, itEnd ) + suffix );
+ };
+ template <>
+ struct CompleteInvoker<void> {
+ template <typename Fun, typename... Args>
+ static CompleteType_t<void> invoke(Fun&& fun, Args&&... args) {
+ std::forward<Fun>(fun)(std::forward<Args>(args)...);
+ return {};
+ }
+ };
+ template <typename Sig>
+ using ResultOf_t = typename std::result_of<Sig>::type;
- if( indent.size() != _attr.indent )
- indent = std::string( _attr.indent, ' ' );
- it = itNext;
+ // invoke and not return void :(
+ template <typename Fun, typename... Args>
+ CompleteType_t<ResultOf_t<Fun(Args...)>> complete_invoke(Fun&& fun, Args&&... args) {
+ return CompleteInvoker<ResultOf_t<Fun(Args...)>>::invoke(std::forward<Fun>(fun), std::forward<Args>(args)...);
}
- }
- typedef std::vector<std::string>::const_iterator const_iterator;
+ const std::string benchmarkErrorMsg = "a benchmark failed to run successfully";
+ } // namespace Detail
- const_iterator begin() const { return lines.begin(); }
- const_iterator end() const { return lines.end(); }
- std::string const& last() const { return lines.back(); }
- std::size_t size() const { return lines.size(); }
- std::string const& operator[]( std::size_t _index ) const { return lines[_index]; }
- std::string toString() const {
- std::ostringstream oss;
- oss << *this;
- return oss.str();
+ template <typename Fun>
+ Detail::CompleteType_t<Detail::ResultOf_t<Fun()>> user_code(Fun&& fun) {
+ CATCH_TRY{
+ return Detail::complete_invoke(std::forward<Fun>(fun));
+ } CATCH_CATCH_ALL{
+ getResultCapture().benchmarkFailed(translateActiveException());
+ CATCH_RUNTIME_ERROR(Detail::benchmarkErrorMsg);
+ }
}
+ } // namespace Benchmark
+} // namespace Catch
- inline friend std::ostream& operator << ( std::ostream& _stream, Text const& _text ) {
- for( Text::const_iterator it = _text.begin(), itEnd = _text.end();
- it != itEnd; ++it ) {
- if( it != _text.begin() )
- _stream << "\n";
- _stream << *it;
+// end catch_complete_invoke.hpp
+namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ struct ChronometerConcept {
+ virtual void start() = 0;
+ virtual void finish() = 0;
+ virtual ~ChronometerConcept() = default;
+ };
+ template <typename Clock>
+ struct ChronometerModel final : public ChronometerConcept {
+ void start() override { started = Clock::now(); }
+ void finish() override { finished = Clock::now(); }
+
+ ClockDuration<Clock> elapsed() const { return finished - started; }
+
+ TimePoint<Clock> started;
+ TimePoint<Clock> finished;
+ };
+ } // namespace Detail
+
+ struct Chronometer {
+ public:
+ template <typename Fun>
+ void measure(Fun&& fun) { measure(std::forward<Fun>(fun), is_callable<Fun(int)>()); }
+
+ int runs() const { return k; }
+
+ Chronometer(Detail::ChronometerConcept& meter, int k)
+ : impl(&meter)
+ , k(k) {}
+
+ private:
+ template <typename Fun>
+ void measure(Fun&& fun, std::false_type) {
+ measure([&fun](int) { return fun(); }, std::true_type());
}
- return _stream;
- }
- private:
- std::string str;
- TextAttributes attr;
- std::vector<std::string> lines;
- };
+ template <typename Fun>
+ void measure(Fun&& fun, std::true_type) {
+ Detail::optimizer_barrier();
+ impl->start();
+ for (int i = 0; i < k; ++i) invoke_deoptimized(fun, i);
+ impl->finish();
+ Detail::optimizer_barrier();
+ }
-} // end namespace Tbc
+ Detail::ChronometerConcept* impl;
+ int k;
+ };
+ } // namespace Benchmark
+} // namespace Catch
-#ifdef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE
-} // end outer namespace
-#endif
+// end catch_chronometer.hpp
+// start catch_environment.hpp
+
+// Environment information
-#endif // TWOBLUECUBES_TEXT_FORMAT_H_ALREADY_INCLUDED
-#undef CLICHE_TBC_TEXT_FORMAT_OUTER_NAMESPACE
namespace Catch {
- using Tbc::Text;
- using Tbc::TextAttributes;
-}
+ namespace Benchmark {
+ template <typename Duration>
+ struct EnvironmentEstimate {
+ Duration mean;
+ OutlierClassification outliers;
+
+ template <typename Duration2>
+ operator EnvironmentEstimate<Duration2>() const {
+ return { mean, outliers };
+ }
+ };
+ template <typename Clock>
+ struct Environment {
+ using clock_type = Clock;
+ EnvironmentEstimate<FloatDuration<Clock>> clock_resolution;
+ EnvironmentEstimate<FloatDuration<Clock>> clock_cost;
+ };
+ } // namespace Benchmark
+} // namespace Catch
+
+// end catch_environment.hpp
+// start catch_execution_plan.hpp
+
+ // Execution plan
+
-// #included from: catch_console_colour.hpp
-#define TWOBLUECUBES_CATCH_CONSOLE_COLOUR_HPP_INCLUDED
+// start catch_benchmark_function.hpp
+
+ // Dumb std::function implementation for consistent call overhead
+
+
+#include <cassert>
+#include <type_traits>
+#include <utility>
+#include <memory>
namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ template <typename T>
+ using Decay = typename std::decay<T>::type;
+ template <typename T, typename U>
+ struct is_related
+ : std::is_same<Decay<T>, Decay<U>> {};
+
+ /// We need to reinvent std::function because every piece of code that might add overhead
+ /// in a measurement context needs to have consistent performance characteristics so that we
+ /// can account for it in the measurement.
+ /// Implementations of std::function with optimizations that aren't always applicable, like
+ /// small buffer optimizations, are not uncommon.
+ /// This is effectively an implementation of std::function without any such optimizations;
+ /// it may be slow, but it is consistently slow.
+ struct BenchmarkFunction {
+ private:
+ struct callable {
+ virtual void call(Chronometer meter) const = 0;
+ virtual callable* clone() const = 0;
+ virtual ~callable() = default;
+ };
+ template <typename Fun>
+ struct model : public callable {
+ model(Fun&& fun) : fun(std::move(fun)) {}
+ model(Fun const& fun) : fun(fun) {}
+
+ model<Fun>* clone() const override { return new model<Fun>(*this); }
+
+ void call(Chronometer meter) const override {
+ call(meter, is_callable<Fun(Chronometer)>());
+ }
+ void call(Chronometer meter, std::true_type) const {
+ fun(meter);
+ }
+ void call(Chronometer meter, std::false_type) const {
+ meter.measure(fun);
+ }
- struct Colour {
- enum Code {
- None = 0,
+ Fun fun;
+ };
- White,
- Red,
- Green,
- Blue,
- Cyan,
- Yellow,
- Grey,
+ struct do_nothing { void operator()() const {} };
- Bright = 0x10,
+ template <typename T>
+ BenchmarkFunction(model<T>* c) : f(c) {}
- BrightRed = Bright | Red,
- BrightGreen = Bright | Green,
- LightGrey = Bright | Grey,
- BrightWhite = Bright | White,
+ public:
+ BenchmarkFunction()
+ : f(new model<do_nothing>{ {} }) {}
- // By intention
- FileName = LightGrey,
- Warning = Yellow,
- ResultError = BrightRed,
- ResultSuccess = BrightGreen,
- ResultExpectedFailure = Warning,
+ template <typename Fun,
+ typename std::enable_if<!is_related<Fun, BenchmarkFunction>::value, int>::type = 0>
+ BenchmarkFunction(Fun&& fun)
+ : f(new model<typename std::decay<Fun>::type>(std::forward<Fun>(fun))) {}
- Error = BrightRed,
- Success = Green,
+ BenchmarkFunction(BenchmarkFunction&& that)
+ : f(std::move(that.f)) {}
- OriginalExpression = Cyan,
- ReconstructedExpression = Yellow,
+ BenchmarkFunction(BenchmarkFunction const& that)
+ : f(that.f->clone()) {}
- SecondaryText = LightGrey,
- Headers = White
- };
+ BenchmarkFunction& operator=(BenchmarkFunction&& that) {
+ f = std::move(that.f);
+ return *this;
+ }
- // Use constructed object for RAII guard
- Colour( Code _colourCode );
- Colour( Colour const& other );
- ~Colour();
+ BenchmarkFunction& operator=(BenchmarkFunction const& that) {
+ f.reset(that.f->clone());
+ return *this;
+ }
- // Use static method for one-shot changes
- static void use( Code _colourCode );
+ void operator()(Chronometer meter) const { f->call(meter); }
- private:
- bool m_moved;
- };
+ private:
+ std::unique_ptr<callable> f;
+ };
+ } // namespace Detail
+ } // namespace Benchmark
+} // namespace Catch
- inline std::ostream& operator << ( std::ostream& os, Colour const& ) { return os; }
+// end catch_benchmark_function.hpp
+// start catch_repeat.hpp
-} // end namespace Catch
+// repeat algorithm
-// #included from: catch_interfaces_reporter.h
-#define TWOBLUECUBES_CATCH_INTERFACES_REPORTER_H_INCLUDED
-#include <string>
-#include <ostream>
-#include <map>
+#include <type_traits>
+#include <utility>
-namespace Catch
-{
- struct ReporterConfig {
- explicit ReporterConfig( Ptr<IConfig const> const& _fullConfig )
- : m_stream( &_fullConfig->stream() ), m_fullConfig( _fullConfig ) {}
+namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ template <typename Fun>
+ struct repeater {
+ void operator()(int k) const {
+ for (int i = 0; i < k; ++i) {
+ fun();
+ }
+ }
+ Fun fun;
+ };
+ template <typename Fun>
+ repeater<typename std::decay<Fun>::type> repeat(Fun&& fun) {
+ return { std::forward<Fun>(fun) };
+ }
+ } // namespace Detail
+ } // namespace Benchmark
+} // namespace Catch
- ReporterConfig( Ptr<IConfig const> const& _fullConfig, std::ostream& _stream )
- : m_stream( &_stream ), m_fullConfig( _fullConfig ) {}
+// end catch_repeat.hpp
+// start catch_run_for_at_least.hpp
- std::ostream& stream() const { return *m_stream; }
- Ptr<IConfig const> fullConfig() const { return m_fullConfig; }
+// Run a function for a minimum amount of time
- private:
- std::ostream* m_stream;
- Ptr<IConfig const> m_fullConfig;
- };
- struct ReporterPreferences {
- ReporterPreferences()
- : shouldRedirectStdOut( false )
- {}
+// start catch_measure.hpp
- bool shouldRedirectStdOut;
- };
+// Measure
- template<typename T>
- struct LazyStat : Option<T> {
- LazyStat() : used( false ) {}
- LazyStat& operator=( T const& _value ) {
- Option<T>::operator=( _value );
- used = false;
- return *this;
- }
- void reset() {
- Option<T>::reset();
- used = false;
- }
- bool used;
- };
- struct TestRunInfo {
- TestRunInfo( std::string const& _name ) : name( _name ) {}
- std::string name;
- };
- struct GroupInfo {
- GroupInfo( std::string const& _name,
- std::size_t _groupIndex,
- std::size_t _groupsCount )
- : name( _name ),
- groupIndex( _groupIndex ),
- groupsCounts( _groupsCount )
- {}
+// start catch_timing.hpp
- std::string name;
- std::size_t groupIndex;
- std::size_t groupsCounts;
- };
+// Timing
- struct AssertionStats {
- AssertionStats( AssertionResult const& _assertionResult,
- std::vector<MessageInfo> const& _infoMessages,
- Totals const& _totals )
- : assertionResult( _assertionResult ),
- infoMessages( _infoMessages ),
- totals( _totals )
- {
- if( assertionResult.hasMessage() ) {
- // Copy message into messages list.
- // !TBD This should have been done earlier, somewhere
- MessageBuilder builder( assertionResult.getTestMacroName(), assertionResult.getSourceInfo(), assertionResult.getResultType() );
- builder << assertionResult.getMessage();
- builder.m_info.message = builder.m_stream.str();
- infoMessages.push_back( builder.m_info );
- }
- }
- virtual ~AssertionStats();
+#include <tuple>
+#include <type_traits>
-# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS
- AssertionStats( AssertionStats const& ) = default;
- AssertionStats( AssertionStats && ) = default;
- AssertionStats& operator = ( AssertionStats const& ) = default;
- AssertionStats& operator = ( AssertionStats && ) = default;
-# endif
+namespace Catch {
+ namespace Benchmark {
+ template <typename Duration, typename Result>
+ struct Timing {
+ Duration elapsed;
+ Result result;
+ int iterations;
+ };
+ template <typename Clock, typename Sig>
+ using TimingOf = Timing<ClockDuration<Clock>, Detail::CompleteType_t<Detail::ResultOf_t<Sig>>>;
+ } // namespace Benchmark
+} // namespace Catch
- AssertionResult assertionResult;
- std::vector<MessageInfo> infoMessages;
- Totals totals;
- };
+// end catch_timing.hpp
+#include <utility>
- struct SectionStats {
- SectionStats( SectionInfo const& _sectionInfo,
- Counts const& _assertions,
- double _durationInSeconds,
- bool _missingAssertions )
- : sectionInfo( _sectionInfo ),
- assertions( _assertions ),
- durationInSeconds( _durationInSeconds ),
- missingAssertions( _missingAssertions )
- {}
- virtual ~SectionStats();
-# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS
- SectionStats( SectionStats const& ) = default;
- SectionStats( SectionStats && ) = default;
- SectionStats& operator = ( SectionStats const& ) = default;
- SectionStats& operator = ( SectionStats && ) = default;
-# endif
+namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ template <typename Clock, typename Fun, typename... Args>
+ TimingOf<Clock, Fun(Args...)> measure(Fun&& fun, Args&&... args) {
+ auto start = Clock::now();
+ auto&& r = Detail::complete_invoke(fun, std::forward<Args>(args)...);
+ auto end = Clock::now();
+ auto delta = end - start;
+ return { delta, std::forward<decltype(r)>(r), 1 };
+ }
+ } // namespace Detail
+ } // namespace Benchmark
+} // namespace Catch
- SectionInfo sectionInfo;
- Counts assertions;
- double durationInSeconds;
- bool missingAssertions;
- };
+// end catch_measure.hpp
+#include <utility>
+#include <type_traits>
- struct TestCaseStats {
- TestCaseStats( TestCaseInfo const& _testInfo,
- Totals const& _totals,
- std::string const& _stdOut,
- std::string const& _stdErr,
- bool _aborting )
- : testInfo( _testInfo ),
- totals( _totals ),
- stdOut( _stdOut ),
- stdErr( _stdErr ),
- aborting( _aborting )
- {}
- virtual ~TestCaseStats();
+namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ template <typename Clock, typename Fun>
+ TimingOf<Clock, Fun(int)> measure_one(Fun&& fun, int iters, std::false_type) {
+ return Detail::measure<Clock>(fun, iters);
+ }
+ template <typename Clock, typename Fun>
+ TimingOf<Clock, Fun(Chronometer)> measure_one(Fun&& fun, int iters, std::true_type) {
+ Detail::ChronometerModel<Clock> meter;
+ auto&& result = Detail::complete_invoke(fun, Chronometer(meter, iters));
-# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS
- TestCaseStats( TestCaseStats const& ) = default;
- TestCaseStats( TestCaseStats && ) = default;
- TestCaseStats& operator = ( TestCaseStats const& ) = default;
- TestCaseStats& operator = ( TestCaseStats && ) = default;
-# endif
+ return { meter.elapsed(), std::move(result), iters };
+ }
- TestCaseInfo testInfo;
- Totals totals;
- std::string stdOut;
- std::string stdErr;
- bool aborting;
- };
+ template <typename Clock, typename Fun>
+ using run_for_at_least_argument_t = typename std::conditional<is_callable<Fun(Chronometer)>::value, Chronometer, int>::type;
- struct TestGroupStats {
- TestGroupStats( GroupInfo const& _groupInfo,
- Totals const& _totals,
- bool _aborting )
- : groupInfo( _groupInfo ),
- totals( _totals ),
- aborting( _aborting )
- {}
- TestGroupStats( GroupInfo const& _groupInfo )
- : groupInfo( _groupInfo ),
- aborting( false )
- {}
- virtual ~TestGroupStats();
+ struct optimized_away_error : std::exception {
+ const char* what() const noexcept override {
+ return "could not measure benchmark, maybe it was optimized away";
+ }
+ };
-# ifdef CATCH_CONFIG_CPP11_GENERATED_METHODS
- TestGroupStats( TestGroupStats const& ) = default;
- TestGroupStats( TestGroupStats && ) = default;
- TestGroupStats& operator = ( TestGroupStats const& ) = default;
- TestGroupStats& operator = ( TestGroupStats && ) = default;
-# endif
+ template <typename Clock, typename Fun>
+ TimingOf<Clock, Fun(run_for_at_least_argument_t<Clock, Fun>)> run_for_at_least(ClockDuration<Clock> how_long, int seed, Fun&& fun) {
+ auto iters = seed;
+ while (iters < (1 << 30)) {
+ auto&& Timing = measure_one<Clock>(fun, iters, is_callable<Fun(Chronometer)>());
- GroupInfo groupInfo;
- Totals totals;
- bool aborting;
- };
+ if (Timing.elapsed >= how_long) {
+ return { Timing.elapsed, std::move(Timing.result), iters };
+ }
+ iters *= 2;
+ }
+ throw optimized_away_error{};
+ }
+ } // namespace Detail
+ } // namespace Benchmark
+} // namespace Catch
- struct TestRunStats {
- TestRunStats( TestRunInfo const& _runInfo,
- Totals const& _totals,
- bool _aborting )
- : runInfo( _runInfo ),
- totals( _totals ),
- aborting( _aborting )
- {}
- virtual ~TestRunStats();
+// end catch_run_for_at_least.hpp
+#include <algorithm>
-# ifndef CATCH_CONFIG_CPP11_GENERATED_METHODS
- TestRunStats( TestRunStats const& _other )
- : runInfo( _other.runInfo ),
- totals( _other.totals ),
- aborting( _other.aborting )
- {}
-# else
- TestRunStats( TestRunStats const& ) = default;
- TestRunStats( TestRunStats && ) = default;
- TestRunStats& operator = ( TestRunStats const& ) = default;
- TestRunStats& operator = ( TestRunStats && ) = default;
-# endif
+namespace Catch {
+ namespace Benchmark {
+ template <typename Duration>
+ struct ExecutionPlan {
+ int iterations_per_sample;
+ Duration estimated_duration;
+ Detail::BenchmarkFunction benchmark;
+ Duration warmup_time;
+ int warmup_iterations;
+
+ template <typename Duration2>
+ operator ExecutionPlan<Duration2>() const {
+ return { iterations_per_sample, estimated_duration, benchmark, warmup_time, warmup_iterations };
+ }
- TestRunInfo runInfo;
- Totals totals;
- bool aborting;
- };
+ template <typename Clock>
+ std::vector<FloatDuration<Clock>> run(const IConfig &cfg, Environment<FloatDuration<Clock>> env) const {
+ // warmup a bit
+ Detail::run_for_at_least<Clock>(std::chrono::duration_cast<ClockDuration<Clock>>(warmup_time), warmup_iterations, Detail::repeat(now<Clock>{}));
+
+ std::vector<FloatDuration<Clock>> times;
+ times.reserve(cfg.benchmarkSamples());
+ std::generate_n(std::back_inserter(times), cfg.benchmarkSamples(), [this, env] {
+ Detail::ChronometerModel<Clock> model;
+ this->benchmark(Chronometer(model, iterations_per_sample));
+ auto sample_time = model.elapsed() - env.clock_cost.mean;
+ if (sample_time < FloatDuration<Clock>::zero()) sample_time = FloatDuration<Clock>::zero();
+ return sample_time / iterations_per_sample;
+ });
+ return times;
+ }
+ };
+ } // namespace Benchmark
+} // namespace Catch
- class MultipleReporters;
+// end catch_execution_plan.hpp
+// start catch_estimate_clock.hpp
- struct IStreamingReporter : IShared {
- virtual ~IStreamingReporter();
+ // Environment measurement
- // Implementing class must also provide the following static method:
- // static std::string getDescription();
- virtual ReporterPreferences getPreferences() const = 0;
+// start catch_stats.hpp
- virtual void noMatchingTestCases( std::string const& spec ) = 0;
+// Statistical analysis tools
- virtual void testRunStarting( TestRunInfo const& testRunInfo ) = 0;
- virtual void testGroupStarting( GroupInfo const& groupInfo ) = 0;
- virtual void testCaseStarting( TestCaseInfo const& testInfo ) = 0;
- virtual void sectionStarting( SectionInfo const& sectionInfo ) = 0;
+#include <algorithm>
+#include <functional>
+#include <vector>
+#include <numeric>
+#include <tuple>
+#include <cmath>
+#include <utility>
+#include <cstddef>
- virtual void assertionStarting( AssertionInfo const& assertionInfo ) = 0;
+namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ using sample = std::vector<double>;
+
+ double weighted_average_quantile(int k, int q, std::vector<double>::iterator first, std::vector<double>::iterator last);
+
+ template <typename Iterator>
+ OutlierClassification classify_outliers(Iterator first, Iterator last) {
+ std::vector<double> copy(first, last);
+
+ auto q1 = weighted_average_quantile(1, 4, copy.begin(), copy.end());
+ auto q3 = weighted_average_quantile(3, 4, copy.begin(), copy.end());
+ auto iqr = q3 - q1;
+ auto los = q1 - (iqr * 3.);
+ auto lom = q1 - (iqr * 1.5);
+ auto him = q3 + (iqr * 1.5);
+ auto his = q3 + (iqr * 3.);
+
+ OutlierClassification o;
+ for (; first != last; ++first) {
+ auto&& t = *first;
+ if (t < los) ++o.low_severe;
+ else if (t < lom) ++o.low_mild;
+ else if (t > his) ++o.high_severe;
+ else if (t > him) ++o.high_mild;
+ ++o.samples_seen;
+ }
+ return o;
+ }
- // The return value indicates if the messages buffer should be cleared:
- virtual bool assertionEnded( AssertionStats const& assertionStats ) = 0;
+ template <typename Iterator>
+ double mean(Iterator first, Iterator last) {
+ auto count = last - first;
+ double sum = std::accumulate(first, last, 0.);
+ return sum / count;
+ }
- virtual void sectionEnded( SectionStats const& sectionStats ) = 0;
- virtual void testCaseEnded( TestCaseStats const& testCaseStats ) = 0;
- virtual void testGroupEnded( TestGroupStats const& testGroupStats ) = 0;
- virtual void testRunEnded( TestRunStats const& testRunStats ) = 0;
+ template <typename URng, typename Iterator, typename Estimator>
+ sample resample(URng& rng, int resamples, Iterator first, Iterator last, Estimator& estimator) {
+ auto n = last - first;
+ std::uniform_int_distribution<decltype(n)> dist(0, n - 1);
+
+ sample out;
+ out.reserve(resamples);
+ std::generate_n(std::back_inserter(out), resamples, [n, first, &estimator, &dist, &rng] {
+ std::vector<double> resampled;
+ resampled.reserve(n);
+ std::generate_n(std::back_inserter(resampled), n, [first, &dist, &rng] { return first[dist(rng)]; });
+ return estimator(resampled.begin(), resampled.end());
+ });
+ std::sort(out.begin(), out.end());
+ return out;
+ }
- virtual void skipTest( TestCaseInfo const& testInfo ) = 0;
+ template <typename Estimator, typename Iterator>
+ sample jackknife(Estimator&& estimator, Iterator first, Iterator last) {
+ auto n = last - first;
+ auto second = std::next(first);
+ sample results;
+ results.reserve(n);
- virtual MultipleReporters* tryAsMulti() { return CATCH_NULL; }
- };
+ for (auto it = first; it != last; ++it) {
+ std::iter_swap(it, first);
+ results.push_back(estimator(second, last));
+ }
- struct IReporterFactory : IShared {
- virtual ~IReporterFactory();
- virtual IStreamingReporter* create( ReporterConfig const& config ) const = 0;
- virtual std::string getDescription() const = 0;
- };
+ return results;
+ }
- struct IReporterRegistry {
- typedef std::map<std::string, Ptr<IReporterFactory> > FactoryMap;
- typedef std::vector<Ptr<IReporterFactory> > Listeners;
+ inline double normal_cdf(double x) {
+ return std::erfc(-x / std::sqrt(2.0)) / 2.0;
+ }
- virtual ~IReporterRegistry();
- virtual IStreamingReporter* create( std::string const& name, Ptr<IConfig const> const& config ) const = 0;
- virtual FactoryMap const& getFactories() const = 0;
- virtual Listeners const& getListeners() const = 0;
- };
+ double erfc_inv(double x);
+
+ double normal_quantile(double p);
+
+ template <typename Iterator, typename Estimator>
+ Estimate<double> bootstrap(double confidence_level, Iterator first, Iterator last, sample const& resample, Estimator&& estimator) {
+ auto n_samples = last - first;
+
+ double point = estimator(first, last);
+ // Degenerate case with a single sample
+ if (n_samples == 1) return { point, point, point, confidence_level };
+
+ sample jack = jackknife(estimator, first, last);
+ double jack_mean = mean(jack.begin(), jack.end());
+ double sum_squares, sum_cubes;
+ std::tie(sum_squares, sum_cubes) = std::accumulate(jack.begin(), jack.end(), std::make_pair(0., 0.), [jack_mean](std::pair<double, double> sqcb, double x) -> std::pair<double, double> {
+ auto d = jack_mean - x;
+ auto d2 = d * d;
+ auto d3 = d2 * d;
+ return { sqcb.first + d2, sqcb.second + d3 };
+ });
+
+ double accel = sum_cubes / (6 * std::pow(sum_squares, 1.5));
+ int n = static_cast<int>(resample.size());
+ double prob_n = std::count_if(resample.begin(), resample.end(), [point](double x) { return x < point; }) / (double)n;
+ // degenerate case with uniform samples
+ if (prob_n == 0) return { point, point, point, confidence_level };
+
+ double bias = normal_quantile(prob_n);
+ double z1 = normal_quantile((1. - confidence_level) / 2.);
+
+ auto cumn = [n](double x) -> int {
+ return std::lround(normal_cdf(x) * n); };
+ auto a = [bias, accel](double b) { return bias + b / (1. - accel * b); };
+ double b1 = bias + z1;
+ double b2 = bias - z1;
+ double a1 = a(b1);
+ double a2 = a(b2);
+ auto lo = std::max(cumn(a1), 0);
+ auto hi = std::min(cumn(a2), n - 1);
+
+ return { point, resample[lo], resample[hi], confidence_level };
+ }
- Ptr<IStreamingReporter> addReporter( Ptr<IStreamingReporter> const& existingReporter, Ptr<IStreamingReporter> const& additionalReporter );
+ double outlier_variance(Estimate<double> mean, Estimate<double> stddev, int n);
-}
+ struct bootstrap_analysis {
+ Estimate<double> mean;
+ Estimate<double> standard_deviation;
+ double outlier_variance;
+ };
-#include <limits>
+ bootstrap_analysis analyse_samples(double confidence_level, int n_resamples, std::vector<double>::iterator first, std::vector<double>::iterator last);
+ } // namespace Detail
+ } // namespace Benchmark
+} // namespace Catch
+
+// end catch_stats.hpp
#include <algorithm>
+#include <iterator>
+#include <tuple>
+#include <vector>
+#include <cmath>
namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ template <typename Clock>
+ std::vector<double> resolution(int k) {
+ std::vector<TimePoint<Clock>> times;
+ times.reserve(k + 1);
+ std::generate_n(std::back_inserter(times), k + 1, now<Clock>{});
+
+ std::vector<double> deltas;
+ deltas.reserve(k);
+ std::transform(std::next(times.begin()), times.end(), times.begin(),
+ std::back_inserter(deltas),
+ [](TimePoint<Clock> a, TimePoint<Clock> b) { return static_cast<double>((a - b).count()); });
+
+ return deltas;
+ }
- inline std::size_t listTests( Config const& config ) {
+ const auto warmup_iterations = 10000;
+ const auto warmup_time = std::chrono::milliseconds(100);
+ const auto minimum_ticks = 1000;
+ const auto warmup_seed = 10000;
+ const auto clock_resolution_estimation_time = std::chrono::milliseconds(500);
+ const auto clock_cost_estimation_time_limit = std::chrono::seconds(1);
+ const auto clock_cost_estimation_tick_limit = 100000;
+ const auto clock_cost_estimation_time = std::chrono::milliseconds(10);
+ const auto clock_cost_estimation_iterations = 10000;
+
+ template <typename Clock>
+ int warmup() {
+ return run_for_at_least<Clock>(std::chrono::duration_cast<ClockDuration<Clock>>(warmup_time), warmup_seed, &resolution<Clock>)
+ .iterations;
+ }
+ template <typename Clock>
+ EnvironmentEstimate<FloatDuration<Clock>> estimate_clock_resolution(int iterations) {
+ auto r = run_for_at_least<Clock>(std::chrono::duration_cast<ClockDuration<Clock>>(clock_resolution_estimation_time), iterations, &resolution<Clock>)
+ .result;
+ return {
+ FloatDuration<Clock>(mean(r.begin(), r.end())),
+ classify_outliers(r.begin(), r.end()),
+ };
+ }
+ template <typename Clock>
+ EnvironmentEstimate<FloatDuration<Clock>> estimate_clock_cost(FloatDuration<Clock> resolution) {
+ auto time_limit = std::min(resolution * clock_cost_estimation_tick_limit, FloatDuration<Clock>(clock_cost_estimation_time_limit));
+ auto time_clock = [](int k) {
+ return Detail::measure<Clock>([k] {
+ for (int i = 0; i < k; ++i) {
+ volatile auto ignored = Clock::now();
+ (void)ignored;
+ }
+ }).elapsed;
+ };
+ time_clock(1);
+ int iters = clock_cost_estimation_iterations;
+ auto&& r = run_for_at_least<Clock>(std::chrono::duration_cast<ClockDuration<Clock>>(clock_cost_estimation_time), iters, time_clock);
+ std::vector<double> times;
+ int nsamples = static_cast<int>(std::ceil(time_limit / r.elapsed));
+ times.reserve(nsamples);
+ std::generate_n(std::back_inserter(times), nsamples, [time_clock, &r] {
+ return static_cast<double>((time_clock(r.iterations) / r.iterations).count());
+ });
+ return {
+ FloatDuration<Clock>(mean(times.begin(), times.end())),
+ classify_outliers(times.begin(), times.end()),
+ };
+ }
- TestSpec testSpec = config.testSpec();
- if( config.testSpec().hasFilters() )
- Catch::cout() << "Matching test cases:\n";
- else {
- Catch::cout() << "All available test cases:\n";
- testSpec = TestSpecParser( ITagAliasRegistry::get() ).parse( "*" ).testSpec();
- }
+ template <typename Clock>
+ Environment<FloatDuration<Clock>> measure_environment() {
+ static Environment<FloatDuration<Clock>>* env = nullptr;
+ if (env) {
+ return *env;
+ }
- std::size_t matchedTests = 0;
- TextAttributes nameAttr, tagsAttr;
- nameAttr.setInitialIndent( 2 ).setIndent( 4 );
- tagsAttr.setIndent( 6 );
+ auto iters = Detail::warmup<Clock>();
+ auto resolution = Detail::estimate_clock_resolution<Clock>(iters);
+ auto cost = Detail::estimate_clock_cost<Clock>(resolution.mean);
- std::vector<TestCase> matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config );
- for( std::vector<TestCase>::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end();
- it != itEnd;
- ++it ) {
- matchedTests++;
- TestCaseInfo const& testCaseInfo = it->getTestCaseInfo();
- Colour::Code colour = testCaseInfo.isHidden()
- ? Colour::SecondaryText
- : Colour::None;
- Colour colourGuard( colour );
+ env = new Environment<FloatDuration<Clock>>{ resolution, cost };
+ return *env;
+ }
+ } // namespace Detail
+ } // namespace Benchmark
+} // namespace Catch
- Catch::cout() << Text( testCaseInfo.name, nameAttr ) << std::endl;
- if( !testCaseInfo.tags.empty() )
- Catch::cout() << Text( testCaseInfo.tagsAsString, tagsAttr ) << std::endl;
- }
+// end catch_estimate_clock.hpp
+// start catch_analyse.hpp
- if( !config.testSpec().hasFilters() )
- Catch::cout() << pluralise( matchedTests, "test case" ) << '\n' << std::endl;
- else
- Catch::cout() << pluralise( matchedTests, "matching test case" ) << '\n' << std::endl;
- return matchedTests;
- }
+ // Run and analyse one benchmark
- inline std::size_t listTestsNamesOnly( Config const& config ) {
- TestSpec testSpec = config.testSpec();
- if( !config.testSpec().hasFilters() )
- testSpec = TestSpecParser( ITagAliasRegistry::get() ).parse( "*" ).testSpec();
- std::size_t matchedTests = 0;
- std::vector<TestCase> matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config );
- for( std::vector<TestCase>::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end();
- it != itEnd;
- ++it ) {
- matchedTests++;
- TestCaseInfo const& testCaseInfo = it->getTestCaseInfo();
- if( startsWith( testCaseInfo.name, '#' ) )
- Catch::cout() << '"' << testCaseInfo.name << '"' << std::endl;
- else
- Catch::cout() << testCaseInfo.name << std::endl;
- }
- return matchedTests;
- }
- struct TagInfo {
- TagInfo() : count ( 0 ) {}
- void add( std::string const& spelling ) {
- ++count;
- spellings.insert( spelling );
- }
- std::string all() const {
- std::string out;
- for( std::set<std::string>::const_iterator it = spellings.begin(), itEnd = spellings.end();
- it != itEnd;
- ++it )
- out += "[" + *it + "]";
- return out;
- }
- std::set<std::string> spellings;
- std::size_t count;
- };
+// start catch_sample_analysis.hpp
- inline std::size_t listTags( Config const& config ) {
- TestSpec testSpec = config.testSpec();
- if( config.testSpec().hasFilters() )
- Catch::cout() << "Tags for matching test cases:\n";
- else {
- Catch::cout() << "All available tags:\n";
- testSpec = TestSpecParser( ITagAliasRegistry::get() ).parse( "*" ).testSpec();
- }
+// Benchmark results
- std::map<std::string, TagInfo> tagCounts;
- std::vector<TestCase> matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config );
- for( std::vector<TestCase>::const_iterator it = matchedTestCases.begin(), itEnd = matchedTestCases.end();
- it != itEnd;
- ++it ) {
- for( std::set<std::string>::const_iterator tagIt = it->getTestCaseInfo().tags.begin(),
- tagItEnd = it->getTestCaseInfo().tags.end();
- tagIt != tagItEnd;
- ++tagIt ) {
- std::string tagName = *tagIt;
- std::string lcaseTagName = toLower( tagName );
- std::map<std::string, TagInfo>::iterator countIt = tagCounts.find( lcaseTagName );
- if( countIt == tagCounts.end() )
- countIt = tagCounts.insert( std::make_pair( lcaseTagName, TagInfo() ) ).first;
- countIt->second.add( tagName );
+#include <algorithm>
+#include <vector>
+#include <string>
+#include <iterator>
+
+namespace Catch {
+ namespace Benchmark {
+ template <typename Duration>
+ struct SampleAnalysis {
+ std::vector<Duration> samples;
+ Estimate<Duration> mean;
+ Estimate<Duration> standard_deviation;
+ OutlierClassification outliers;
+ double outlier_variance;
+
+ template <typename Duration2>
+ operator SampleAnalysis<Duration2>() const {
+ std::vector<Duration2> samples2;
+ samples2.reserve(samples.size());
+ std::transform(samples.begin(), samples.end(), std::back_inserter(samples2), [](Duration d) { return Duration2(d); });
+ return {
+ std::move(samples2),
+ mean,
+ standard_deviation,
+ outliers,
+ outlier_variance,
+ };
}
- }
+ };
+ } // namespace Benchmark
+} // namespace Catch
- for( std::map<std::string, TagInfo>::const_iterator countIt = tagCounts.begin(),
- countItEnd = tagCounts.end();
- countIt != countItEnd;
- ++countIt ) {
- std::ostringstream oss;
- oss << " " << std::setw(2) << countIt->second.count << " ";
- Text wrapper( countIt->second.all(), TextAttributes()
- .setInitialIndent( 0 )
- .setIndent( oss.str().size() )
- .setWidth( CATCH_CONFIG_CONSOLE_WIDTH-10 ) );
- Catch::cout() << oss.str() << wrapper << '\n';
- }
- Catch::cout() << pluralise( tagCounts.size(), "tag" ) << '\n' << std::endl;
- return tagCounts.size();
- }
+// end catch_sample_analysis.hpp
+#include <algorithm>
+#include <iterator>
+#include <vector>
- inline std::size_t listReporters( Config const& /*config*/ ) {
- Catch::cout() << "Available reporters:\n";
- IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories();
- IReporterRegistry::FactoryMap::const_iterator itBegin = factories.begin(), itEnd = factories.end(), it;
- std::size_t maxNameLen = 0;
- for(it = itBegin; it != itEnd; ++it )
- maxNameLen = (std::max)( maxNameLen, it->first.size() );
-
- for(it = itBegin; it != itEnd; ++it ) {
- Text wrapper( it->second->getDescription(), TextAttributes()
- .setInitialIndent( 0 )
- .setIndent( 7+maxNameLen )
- .setWidth( CATCH_CONFIG_CONSOLE_WIDTH - maxNameLen-8 ) );
- Catch::cout() << " "
- << it->first
- << ':'
- << std::string( maxNameLen - it->first.size() + 2, ' ' )
- << wrapper << '\n';
- }
- Catch::cout() << std::endl;
- return factories.size();
- }
+namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+ template <typename Duration, typename Iterator>
+ SampleAnalysis<Duration> analyse(const IConfig &cfg, Environment<Duration>, Iterator first, Iterator last) {
+ if (!cfg.benchmarkNoAnalysis()) {
+ std::vector<double> samples;
+ samples.reserve(last - first);
+ std::transform(first, last, std::back_inserter(samples), [](Duration d) { return d.count(); });
+
+ auto analysis = Catch::Benchmark::Detail::analyse_samples(cfg.benchmarkConfidenceInterval(), cfg.benchmarkResamples(), samples.begin(), samples.end());
+ auto outliers = Catch::Benchmark::Detail::classify_outliers(samples.begin(), samples.end());
+
+ auto wrap_estimate = [](Estimate<double> e) {
+ return Estimate<Duration> {
+ Duration(e.point),
+ Duration(e.lower_bound),
+ Duration(e.upper_bound),
+ e.confidence_interval,
+ };
+ };
+ std::vector<Duration> samples2;
+ samples2.reserve(samples.size());
+ std::transform(samples.begin(), samples.end(), std::back_inserter(samples2), [](double d) { return Duration(d); });
+ return {
+ std::move(samples2),
+ wrap_estimate(analysis.mean),
+ wrap_estimate(analysis.standard_deviation),
+ outliers,
+ analysis.outlier_variance,
+ };
+ } else {
+ std::vector<Duration> samples;
+ samples.reserve(last - first);
+
+ Duration mean = Duration(0);
+ int i = 0;
+ for (auto it = first; it < last; ++it, ++i) {
+ samples.push_back(Duration(*it));
+ mean += Duration(*it);
+ }
+ mean /= i;
+
+ return {
+ std::move(samples),
+ Estimate<Duration>{mean, mean, mean, 0.0},
+ Estimate<Duration>{Duration(0), Duration(0), Duration(0), 0.0},
+ OutlierClassification{},
+ 0.0
+ };
+ }
+ }
+ } // namespace Detail
+ } // namespace Benchmark
+} // namespace Catch
- inline Option<std::size_t> list( Config const& config ) {
- Option<std::size_t> listedCount;
- if( config.listTests() )
- listedCount = listedCount.valueOr(0) + listTests( config );
- if( config.listTestNamesOnly() )
- listedCount = listedCount.valueOr(0) + listTestsNamesOnly( config );
- if( config.listTags() )
- listedCount = listedCount.valueOr(0) + listTags( config );
- if( config.listReporters() )
- listedCount = listedCount.valueOr(0) + listReporters( config );
- return listedCount;
+// end catch_analyse.hpp
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <vector>
+#include <cmath>
+
+namespace Catch {
+ namespace Benchmark {
+ struct Benchmark {
+ Benchmark(std::string &&name)
+ : name(std::move(name)) {}
+
+ template <class FUN>
+ Benchmark(std::string &&name, FUN &&func)
+ : fun(std::move(func)), name(std::move(name)) {}
+
+ template <typename Clock>
+ ExecutionPlan<FloatDuration<Clock>> prepare(const IConfig &cfg, Environment<FloatDuration<Clock>> env) const {
+ auto min_time = env.clock_resolution.mean * Detail::minimum_ticks;
+ auto run_time = std::max(min_time, std::chrono::duration_cast<decltype(min_time)>(Detail::warmup_time));
+ auto&& test = Detail::run_for_at_least<Clock>(std::chrono::duration_cast<ClockDuration<Clock>>(run_time), 1, fun);
+ int new_iters = static_cast<int>(std::ceil(min_time * test.iterations / test.elapsed));
+ return { new_iters, test.elapsed / test.iterations * new_iters * cfg.benchmarkSamples(), fun, std::chrono::duration_cast<FloatDuration<Clock>>(Detail::warmup_time), Detail::warmup_iterations };
+ }
+
+ template <typename Clock = default_clock>
+ void run() {
+ IConfigPtr cfg = getCurrentContext().getConfig();
+
+ auto env = Detail::measure_environment<Clock>();
+
+ getResultCapture().benchmarkPreparing(name);
+ CATCH_TRY{
+ auto plan = user_code([&] {
+ return prepare<Clock>(*cfg, env);
+ });
+
+ BenchmarkInfo info {
+ name,
+ plan.estimated_duration.count(),
+ plan.iterations_per_sample,
+ cfg->benchmarkSamples(),
+ cfg->benchmarkResamples(),
+ env.clock_resolution.mean.count(),
+ env.clock_cost.mean.count()
+ };
+
+ getResultCapture().benchmarkStarting(info);
+
+ auto samples = user_code([&] {
+ return plan.template run<Clock>(*cfg, env);
+ });
+
+ auto analysis = Detail::analyse(*cfg, env, samples.begin(), samples.end());
+ BenchmarkStats<std::chrono::duration<double, std::nano>> stats{ info, analysis.samples, analysis.mean, analysis.standard_deviation, analysis.outliers, analysis.outlier_variance };
+ getResultCapture().benchmarkEnded(stats);
+
+ } CATCH_CATCH_ALL{
+ if (translateActiveException() != Detail::benchmarkErrorMsg) // benchmark errors have been reported, otherwise rethrow.
+ std::rethrow_exception(std::current_exception());
+ }
+ }
+
+ // sets lambda to be used in fun *and* executes benchmark!
+ template <typename Fun,
+ typename std::enable_if<!Detail::is_related<Fun, Benchmark>::value, int>::type = 0>
+ Benchmark & operator=(Fun func) {
+ fun = Detail::BenchmarkFunction(func);
+ run();
+ return *this;
+ }
+
+ explicit operator bool() {
+ return true;
+ }
+
+ private:
+ Detail::BenchmarkFunction fun;
+ std::string name;
+ };
}
+} // namespace Catch
-} // end namespace Catch
+#define INTERNAL_CATCH_GET_1_ARG(arg1, arg2, ...) arg1
+#define INTERNAL_CATCH_GET_2_ARG(arg1, arg2, ...) arg2
-// #included from: internal/catch_run_context.hpp
-#define TWOBLUECUBES_CATCH_RUNNER_IMPL_HPP_INCLUDED
+#define INTERNAL_CATCH_BENCHMARK(BenchmarkName, name, benchmarkIndex)\
+ if( Catch::Benchmark::Benchmark BenchmarkName{name} ) \
+ BenchmarkName = [&](int benchmarkIndex)
-// #included from: catch_test_case_tracker.hpp
-#define TWOBLUECUBES_CATCH_TEST_CASE_TRACKER_HPP_INCLUDED
+#define INTERNAL_CATCH_BENCHMARK_ADVANCED(BenchmarkName, name)\
+ if( Catch::Benchmark::Benchmark BenchmarkName{name} ) \
+ BenchmarkName = [&]
+
+// end catch_benchmark.hpp
+#endif
+
+#endif // ! CATCH_CONFIG_IMPL_ONLY
+
+#ifdef CATCH_IMPL
+// start catch_impl.hpp
+
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wweak-vtables"
+#endif
+
+// Keep these here for external reporters
+// start catch_test_case_tracker.h
-#include <algorithm>
#include <string>
-#include <assert.h>
#include <vector>
-#include <stdexcept>
-
-CATCH_INTERNAL_SUPPRESS_ETD_WARNINGS
+#include <memory>
namespace Catch {
namespace TestCaseTracking {
@@ -5971,13 +7363,14 @@ namespace TestCaseTracking {
std::string name;
SourceLineInfo location;
- NameAndLocation( std::string const& _name, SourceLineInfo const& _location )
- : name( _name ),
- location( _location )
- {}
+ NameAndLocation( std::string const& _name, SourceLineInfo const& _location );
};
- struct ITracker : SharedImpl<> {
+ struct ITracker;
+
+ using ITrackerPtr = std::shared_ptr<ITracker>;
+
+ struct ITracker {
virtual ~ITracker();
// static queries
@@ -5996,16 +7389,16 @@ namespace TestCaseTracking {
virtual void fail() = 0;
virtual void markAsNeedingAnotherRun() = 0;
- virtual void addChild( Ptr<ITracker> const& child ) = 0;
- virtual ITracker* findChild( NameAndLocation const& nameAndLocation ) = 0;
+ virtual void addChild( ITrackerPtr const& child ) = 0;
+ virtual ITrackerPtr findChild( NameAndLocation const& nameAndLocation ) = 0;
virtual void openChild() = 0;
// Debug/ checking
virtual bool isSectionTracker() const = 0;
- virtual bool isIndexTracker() const = 0;
+ virtual bool isGeneratorTracker() const = 0;
};
- class TrackerContext {
+ class TrackerContext {
enum RunState {
NotStarted,
@@ -6013,47 +7406,21 @@ namespace TestCaseTracking {
CompletedCycle
};
- Ptr<ITracker> m_rootTracker;
- ITracker* m_currentTracker;
- RunState m_runState;
+ ITrackerPtr m_rootTracker;
+ ITracker* m_currentTracker = nullptr;
+ RunState m_runState = NotStarted;
public:
- static TrackerContext& instance() {
- static TrackerContext s_instance;
- return s_instance;
- }
-
- TrackerContext()
- : m_currentTracker( CATCH_NULL ),
- m_runState( NotStarted )
- {}
-
ITracker& startRun();
+ void endRun();
- void endRun() {
- m_rootTracker.reset();
- m_currentTracker = CATCH_NULL;
- m_runState = NotStarted;
- }
-
- void startCycle() {
- m_currentTracker = m_rootTracker.get();
- m_runState = Executing;
- }
- void completeCycle() {
- m_runState = CompletedCycle;
- }
+ void startCycle();
+ void completeCycle();
- bool completedCycle() const {
- return m_runState == CompletedCycle;
- }
- ITracker& currentTracker() {
- return *m_currentTracker;
- }
- void setCurrentTracker( ITracker* tracker ) {
- m_currentTracker = tracker;
- }
+ bool completedCycle() const;
+ ITracker& currentTracker();
+ void setCurrentTracker( ITracker* tracker );
};
class TrackerBase : public ITracker {
@@ -6066,274 +7433,425 @@ namespace TestCaseTracking {
CompletedSuccessfully,
Failed
};
- class TrackerHasName {
- NameAndLocation m_nameAndLocation;
- public:
- TrackerHasName( NameAndLocation const& nameAndLocation ) : m_nameAndLocation( nameAndLocation ) {}
- bool operator ()( Ptr<ITracker> const& tracker ) {
- return
- tracker->nameAndLocation().name == m_nameAndLocation.name &&
- tracker->nameAndLocation().location == m_nameAndLocation.location;
- }
- };
- typedef std::vector<Ptr<ITracker> > Children;
+
+ using Children = std::vector<ITrackerPtr>;
NameAndLocation m_nameAndLocation;
TrackerContext& m_ctx;
ITracker* m_parent;
Children m_children;
- CycleState m_runState;
- public:
- TrackerBase( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent )
- : m_nameAndLocation( nameAndLocation ),
- m_ctx( ctx ),
- m_parent( parent ),
- m_runState( NotStarted )
- {}
- virtual ~TrackerBase();
-
- virtual NameAndLocation const& nameAndLocation() const CATCH_OVERRIDE {
- return m_nameAndLocation;
- }
- virtual bool isComplete() const CATCH_OVERRIDE {
- return m_runState == CompletedSuccessfully || m_runState == Failed;
- }
- virtual bool isSuccessfullyCompleted() const CATCH_OVERRIDE {
- return m_runState == CompletedSuccessfully;
- }
- virtual bool isOpen() const CATCH_OVERRIDE {
- return m_runState != NotStarted && !isComplete();
- }
- virtual bool hasChildren() const CATCH_OVERRIDE {
- return !m_children.empty();
- }
-
- virtual void addChild( Ptr<ITracker> const& child ) CATCH_OVERRIDE {
- m_children.push_back( child );
- }
+ CycleState m_runState = NotStarted;
- virtual ITracker* findChild( NameAndLocation const& nameAndLocation ) CATCH_OVERRIDE {
- Children::const_iterator it = std::find_if( m_children.begin(), m_children.end(), TrackerHasName( nameAndLocation ) );
- return( it != m_children.end() )
- ? it->get()
- : CATCH_NULL;
- }
- virtual ITracker& parent() CATCH_OVERRIDE {
- assert( m_parent ); // Should always be non-null except for root
- return *m_parent;
- }
+ public:
+ TrackerBase( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent );
- virtual void openChild() CATCH_OVERRIDE {
- if( m_runState != ExecutingChildren ) {
- m_runState = ExecutingChildren;
- if( m_parent )
- m_parent->openChild();
- }
- }
+ NameAndLocation const& nameAndLocation() const override;
+ bool isComplete() const override;
+ bool isSuccessfullyCompleted() const override;
+ bool isOpen() const override;
+ bool hasChildren() const override;
- virtual bool isSectionTracker() const CATCH_OVERRIDE { return false; }
- virtual bool isIndexTracker() const CATCH_OVERRIDE { return false; }
+ void addChild( ITrackerPtr const& child ) override;
- void open() {
- m_runState = Executing;
- moveToThis();
- if( m_parent )
- m_parent->openChild();
- }
+ ITrackerPtr findChild( NameAndLocation const& nameAndLocation ) override;
+ ITracker& parent() override;
- virtual void close() CATCH_OVERRIDE {
+ void openChild() override;
- // Close any still open children (e.g. generators)
- while( &m_ctx.currentTracker() != this )
- m_ctx.currentTracker().close();
+ bool isSectionTracker() const override;
+ bool isGeneratorTracker() const override;
- switch( m_runState ) {
- case NotStarted:
- case CompletedSuccessfully:
- case Failed:
- throw std::logic_error( "Illogical state" );
+ void open();
- case NeedsAnotherRun:
- break;;
+ void close() override;
+ void fail() override;
+ void markAsNeedingAnotherRun() override;
- case Executing:
- m_runState = CompletedSuccessfully;
- break;
- case ExecutingChildren:
- if( m_children.empty() || m_children.back()->isComplete() )
- m_runState = CompletedSuccessfully;
- break;
-
- default:
- throw std::logic_error( "Unexpected state" );
- }
- moveToParent();
- m_ctx.completeCycle();
- }
- virtual void fail() CATCH_OVERRIDE {
- m_runState = Failed;
- if( m_parent )
- m_parent->markAsNeedingAnotherRun();
- moveToParent();
- m_ctx.completeCycle();
- }
- virtual void markAsNeedingAnotherRun() CATCH_OVERRIDE {
- m_runState = NeedsAnotherRun;
- }
private:
- void moveToParent() {
- assert( m_parent );
- m_ctx.setCurrentTracker( m_parent );
- }
- void moveToThis() {
- m_ctx.setCurrentTracker( this );
- }
+ void moveToParent();
+ void moveToThis();
};
class SectionTracker : public TrackerBase {
std::vector<std::string> m_filters;
+ std::string m_trimmed_name;
public:
- SectionTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent )
- : TrackerBase( nameAndLocation, ctx, parent )
- {
- if( parent ) {
- while( !parent->isSectionTracker() )
- parent = &parent->parent();
+ SectionTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent );
- SectionTracker& parentSection = static_cast<SectionTracker&>( *parent );
- addNextFilters( parentSection.m_filters );
- }
- }
- virtual ~SectionTracker();
+ bool isSectionTracker() const override;
- virtual bool isSectionTracker() const CATCH_OVERRIDE { return true; }
+ bool isComplete() const override;
- static SectionTracker& acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation ) {
- SectionTracker* section = CATCH_NULL;
+ static SectionTracker& acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation );
- ITracker& currentTracker = ctx.currentTracker();
- if( ITracker* childTracker = currentTracker.findChild( nameAndLocation ) ) {
- assert( childTracker );
- assert( childTracker->isSectionTracker() );
- section = static_cast<SectionTracker*>( childTracker );
- }
- else {
- section = new SectionTracker( nameAndLocation, ctx, &currentTracker );
- currentTracker.addChild( section );
- }
- if( !ctx.completedCycle() )
- section->tryOpen();
- return *section;
- }
+ void tryOpen();
- void tryOpen() {
- if( !isComplete() && (m_filters.empty() || m_filters[0].empty() || m_filters[0] == m_nameAndLocation.name ) )
- open();
- }
+ void addInitialFilters( std::vector<std::string> const& filters );
+ void addNextFilters( std::vector<std::string> const& filters );
+ };
- void addInitialFilters( std::vector<std::string> const& filters ) {
- if( !filters.empty() ) {
- m_filters.push_back(""); // Root - should never be consulted
- m_filters.push_back(""); // Test Case - not a section filter
- m_filters.insert( m_filters.end(), filters.begin(), filters.end() );
- }
- }
- void addNextFilters( std::vector<std::string> const& filters ) {
- if( filters.size() > 1 )
- m_filters.insert( m_filters.end(), ++filters.begin(), filters.end() );
- }
+} // namespace TestCaseTracking
+
+using TestCaseTracking::ITracker;
+using TestCaseTracking::TrackerContext;
+using TestCaseTracking::SectionTracker;
+
+} // namespace Catch
+
+// end catch_test_case_tracker.h
+
+// start catch_leak_detector.h
+
+namespace Catch {
+
+ struct LeakDetector {
+ LeakDetector();
+ ~LeakDetector();
};
- class IndexTracker : public TrackerBase {
- int m_size;
- int m_index;
- public:
- IndexTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent, int size )
- : TrackerBase( nameAndLocation, ctx, parent ),
- m_size( size ),
- m_index( -1 )
- {}
- virtual ~IndexTracker();
+}
+// end catch_leak_detector.h
+// Cpp files will be included in the single-header file here
+// start catch_stats.cpp
+
+// Statistical analysis tools
- virtual bool isIndexTracker() const CATCH_OVERRIDE { return true; }
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
- static IndexTracker& acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation, int size ) {
- IndexTracker* tracker = CATCH_NULL;
+#include <cassert>
+#include <random>
- ITracker& currentTracker = ctx.currentTracker();
- if( ITracker* childTracker = currentTracker.findChild( nameAndLocation ) ) {
- assert( childTracker );
- assert( childTracker->isIndexTracker() );
- tracker = static_cast<IndexTracker*>( childTracker );
+#if defined(CATCH_CONFIG_USE_ASYNC)
+#include <future>
+#endif
+
+namespace {
+ double erf_inv(double x) {
+ // Code accompanying the article "Approximating the erfinv function" in GPU Computing Gems, Volume 2
+ double w, p;
+
+ w = -log((1.0 - x) * (1.0 + x));
+
+ if (w < 6.250000) {
+ w = w - 3.125000;
+ p = -3.6444120640178196996e-21;
+ p = -1.685059138182016589e-19 + p * w;
+ p = 1.2858480715256400167e-18 + p * w;
+ p = 1.115787767802518096e-17 + p * w;
+ p = -1.333171662854620906e-16 + p * w;
+ p = 2.0972767875968561637e-17 + p * w;
+ p = 6.6376381343583238325e-15 + p * w;
+ p = -4.0545662729752068639e-14 + p * w;
+ p = -8.1519341976054721522e-14 + p * w;
+ p = 2.6335093153082322977e-12 + p * w;
+ p = -1.2975133253453532498e-11 + p * w;
+ p = -5.4154120542946279317e-11 + p * w;
+ p = 1.051212273321532285e-09 + p * w;
+ p = -4.1126339803469836976e-09 + p * w;
+ p = -2.9070369957882005086e-08 + p * w;
+ p = 4.2347877827932403518e-07 + p * w;
+ p = -1.3654692000834678645e-06 + p * w;
+ p = -1.3882523362786468719e-05 + p * w;
+ p = 0.0001867342080340571352 + p * w;
+ p = -0.00074070253416626697512 + p * w;
+ p = -0.0060336708714301490533 + p * w;
+ p = 0.24015818242558961693 + p * w;
+ p = 1.6536545626831027356 + p * w;
+ } else if (w < 16.000000) {
+ w = sqrt(w) - 3.250000;
+ p = 2.2137376921775787049e-09;
+ p = 9.0756561938885390979e-08 + p * w;
+ p = -2.7517406297064545428e-07 + p * w;
+ p = 1.8239629214389227755e-08 + p * w;
+ p = 1.5027403968909827627e-06 + p * w;
+ p = -4.013867526981545969e-06 + p * w;
+ p = 2.9234449089955446044e-06 + p * w;
+ p = 1.2475304481671778723e-05 + p * w;
+ p = -4.7318229009055733981e-05 + p * w;
+ p = 6.8284851459573175448e-05 + p * w;
+ p = 2.4031110387097893999e-05 + p * w;
+ p = -0.0003550375203628474796 + p * w;
+ p = 0.00095328937973738049703 + p * w;
+ p = -0.0016882755560235047313 + p * w;
+ p = 0.0024914420961078508066 + p * w;
+ p = -0.0037512085075692412107 + p * w;
+ p = 0.005370914553590063617 + p * w;
+ p = 1.0052589676941592334 + p * w;
+ p = 3.0838856104922207635 + p * w;
+ } else {
+ w = sqrt(w) - 5.000000;
+ p = -2.7109920616438573243e-11;
+ p = -2.5556418169965252055e-10 + p * w;
+ p = 1.5076572693500548083e-09 + p * w;
+ p = -3.7894654401267369937e-09 + p * w;
+ p = 7.6157012080783393804e-09 + p * w;
+ p = -1.4960026627149240478e-08 + p * w;
+ p = 2.9147953450901080826e-08 + p * w;
+ p = -6.7711997758452339498e-08 + p * w;
+ p = 2.2900482228026654717e-07 + p * w;
+ p = -9.9298272942317002539e-07 + p * w;
+ p = 4.5260625972231537039e-06 + p * w;
+ p = -1.9681778105531670567e-05 + p * w;
+ p = 7.5995277030017761139e-05 + p * w;
+ p = -0.00021503011930044477347 + p * w;
+ p = -0.00013871931833623122026 + p * w;
+ p = 1.0103004648645343977 + p * w;
+ p = 4.8499064014085844221 + p * w;
+ }
+ return p * x;
+ }
+
+ double standard_deviation(std::vector<double>::iterator first, std::vector<double>::iterator last) {
+ auto m = Catch::Benchmark::Detail::mean(first, last);
+ double variance = std::accumulate(first, last, 0., [m](double a, double b) {
+ double diff = b - m;
+ return a + diff * diff;
+ }) / (last - first);
+ return std::sqrt(variance);
+ }
+
+}
+
+namespace Catch {
+ namespace Benchmark {
+ namespace Detail {
+
+ double weighted_average_quantile(int k, int q, std::vector<double>::iterator first, std::vector<double>::iterator last) {
+ auto count = last - first;
+ double idx = (count - 1) * k / static_cast<double>(q);
+ int j = static_cast<int>(idx);
+ double g = idx - j;
+ std::nth_element(first, first + j, last);
+ auto xj = first[j];
+ if (g == 0) return xj;
+
+ auto xj1 = *std::min_element(first + (j + 1), last);
+ return xj + g * (xj1 - xj);
}
- else {
- tracker = new IndexTracker( nameAndLocation, ctx, &currentTracker, size );
- currentTracker.addChild( tracker );
+
+ double erfc_inv(double x) {
+ return erf_inv(1.0 - x);
}
- if( !ctx.completedCycle() && !tracker->isComplete() ) {
- if( tracker->m_runState != ExecutingChildren && tracker->m_runState != NeedsAnotherRun )
- tracker->moveNext();
- tracker->open();
+ double normal_quantile(double p) {
+ static const double ROOT_TWO = std::sqrt(2.0);
+
+ double result = 0.0;
+ assert(p >= 0 && p <= 1);
+ if (p < 0 || p > 1) {
+ return result;
+ }
+
+ result = -erfc_inv(2.0 * p);
+ // result *= normal distribution standard deviation (1.0) * sqrt(2)
+ result *= /*sd * */ ROOT_TWO;
+ // result += normal disttribution mean (0)
+ return result;
}
- return *tracker;
- }
+ double outlier_variance(Estimate<double> mean, Estimate<double> stddev, int n) {
+ double sb = stddev.point;
+ double mn = mean.point / n;
+ double mg_min = mn / 2.;
+ double sg = std::min(mg_min / 4., sb / std::sqrt(n));
+ double sg2 = sg * sg;
+ double sb2 = sb * sb;
+
+ auto c_max = [n, mn, sb2, sg2](double x) -> double {
+ double k = mn - x;
+ double d = k * k;
+ double nd = n * d;
+ double k0 = -n * nd;
+ double k1 = sb2 - n * sg2 + nd;
+ double det = k1 * k1 - 4 * sg2 * k0;
+ return (int)(-2. * k0 / (k1 + std::sqrt(det)));
+ };
+
+ auto var_out = [n, sb2, sg2](double c) {
+ double nc = n - c;
+ return (nc / n) * (sb2 - nc * sg2);
+ };
+
+ return std::min(var_out(1), var_out(std::min(c_max(0.), c_max(mg_min)))) / sb2;
+ }
- int index() const { return m_index; }
+ bootstrap_analysis analyse_samples(double confidence_level, int n_resamples, std::vector<double>::iterator first, std::vector<double>::iterator last) {
+ CATCH_INTERNAL_SUPPRESS_GLOBALS_WARNINGS
+ static std::random_device entropy;
+ CATCH_INTERNAL_UNSUPPRESS_GLOBALS_WARNINGS
- void moveNext() {
- m_index++;
- m_children.clear();
- }
+ auto n = static_cast<int>(last - first); // seriously, one can't use integral types without hell in C++
- virtual void close() CATCH_OVERRIDE {
- TrackerBase::close();
- if( m_runState == CompletedSuccessfully && m_index < m_size-1 )
- m_runState = Executing;
- }
- };
+ auto mean = &Detail::mean<std::vector<double>::iterator>;
+ auto stddev = &standard_deviation;
- inline ITracker& TrackerContext::startRun() {
- m_rootTracker = new SectionTracker( NameAndLocation( "{root}", CATCH_INTERNAL_LINEINFO ), *this, CATCH_NULL );
- m_currentTracker = CATCH_NULL;
- m_runState = Executing;
- return *m_rootTracker;
- }
+#if defined(CATCH_CONFIG_USE_ASYNC)
+ auto Estimate = [=](double(*f)(std::vector<double>::iterator, std::vector<double>::iterator)) {
+ auto seed = entropy();
+ return std::async(std::launch::async, [=] {
+ std::mt19937 rng(seed);
+ auto resampled = resample(rng, n_resamples, first, last, f);
+ return bootstrap(confidence_level, first, last, resampled, f);
+ });
+ };
-} // namespace TestCaseTracking
+ auto mean_future = Estimate(mean);
+ auto stddev_future = Estimate(stddev);
-using TestCaseTracking::ITracker;
-using TestCaseTracking::TrackerContext;
-using TestCaseTracking::SectionTracker;
-using TestCaseTracking::IndexTracker;
+ auto mean_estimate = mean_future.get();
+ auto stddev_estimate = stddev_future.get();
+#else
+ auto Estimate = [=](double(*f)(std::vector<double>::iterator, std::vector<double>::iterator)) {
+ auto seed = entropy();
+ std::mt19937 rng(seed);
+ auto resampled = resample(rng, n_resamples, first, last, f);
+ return bootstrap(confidence_level, first, last, resampled, f);
+ };
+
+ auto mean_estimate = Estimate(mean);
+ auto stddev_estimate = Estimate(stddev);
+#endif // CATCH_USE_ASYNC
+ double outlier_variance = Detail::outlier_variance(mean_estimate, stddev_estimate, n);
+
+ return { mean_estimate, stddev_estimate, outlier_variance };
+ }
+ } // namespace Detail
+ } // namespace Benchmark
} // namespace Catch
-CATCH_INTERNAL_UNSUPPRESS_ETD_WARNINGS
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
+// end catch_stats.cpp
+// start catch_approx.cpp
-// #included from: catch_fatal_condition.hpp
-#define TWOBLUECUBES_CATCH_FATAL_CONDITION_H_INCLUDED
+#include <cmath>
+#include <limits>
+
+namespace {
+
+// Performs equivalent check of std::fabs(lhs - rhs) <= margin
+// But without the subtraction to allow for INFINITY in comparison
+bool marginComparison(double lhs, double rhs, double margin) {
+ return (lhs + margin >= rhs) && (rhs + margin >= lhs);
+}
+
+}
namespace Catch {
+namespace Detail {
- // Report the error condition
- inline void reportFatal( std::string const& message ) {
- IContext& context = Catch::getCurrentContext();
- IResultCapture* resultCapture = context.getResultCapture();
- resultCapture->handleFatalErrorCondition( message );
+ Approx::Approx ( double value )
+ : m_epsilon( std::numeric_limits<float>::epsilon()*100 ),
+ m_margin( 0.0 ),
+ m_scale( 0.0 ),
+ m_value( value )
+ {}
+
+ Approx Approx::custom() {
+ return Approx( 0 );
}
-} // namespace Catch
+ Approx Approx::operator-() const {
+ auto temp(*this);
+ temp.m_value = -temp.m_value;
+ return temp;
+ }
+
+ std::string Approx::toString() const {
+ ReusableStringStream rss;
+ rss << "Approx( " << ::Catch::Detail::stringify( m_value ) << " )";
+ return rss.str();
+ }
+
+ bool Approx::equalityComparisonImpl(const double other) const {
+ // First try with fixed margin, then compute margin based on epsilon, scale and Approx's value
+ // Thanks to Richard Harris for his help refining the scaled margin value
+ return marginComparison(m_value, other, m_margin)
+ || marginComparison(m_value, other, m_epsilon * (m_scale + std::fabs(std::isinf(m_value)? 0 : m_value)));
+ }
-#if defined ( CATCH_PLATFORM_WINDOWS ) /////////////////////////////////////////
-// #included from: catch_windows_h_proxy.h
+ void Approx::setMargin(double newMargin) {
+ CATCH_ENFORCE(newMargin >= 0,
+ "Invalid Approx::margin: " << newMargin << '.'
+ << " Approx::Margin has to be non-negative.");
+ m_margin = newMargin;
+ }
+
+ void Approx::setEpsilon(double newEpsilon) {
+ CATCH_ENFORCE(newEpsilon >= 0 && newEpsilon <= 1.0,
+ "Invalid Approx::epsilon: " << newEpsilon << '.'
+ << " Approx::epsilon has to be in [0, 1]");
+ m_epsilon = newEpsilon;
+ }
-#define TWOBLUECUBES_CATCH_WINDOWS_H_PROXY_H_INCLUDED
+} // end namespace Detail
+
+namespace literals {
+ Detail::Approx operator "" _a(long double val) {
+ return Detail::Approx(val);
+ }
+ Detail::Approx operator "" _a(unsigned long long val) {
+ return Detail::Approx(val);
+ }
+} // end namespace literals
-#ifdef CATCH_DEFINES_NOMINMAX
+std::string StringMaker<Catch::Detail::Approx>::convert(Catch::Detail::Approx const& value) {
+ return value.toString();
+}
+
+} // end namespace Catch
+// end catch_approx.cpp
+// start catch_assertionhandler.cpp
+
+// start catch_debugger.h
+
+namespace Catch {
+ bool isDebuggerActive();
+}
+
+#ifdef CATCH_PLATFORM_MAC
+
+ #define CATCH_TRAP() __asm__("int $3\n" : : ) /* NOLINT */
+
+#elif defined(CATCH_PLATFORM_LINUX)
+ // If we can use inline assembler, do it because this allows us to break
+ // directly at the location of the failing check instead of breaking inside
+ // raise() called from it, i.e. one stack frame below.
+ #if defined(__GNUC__) && (defined(__i386) || defined(__x86_64))
+ #define CATCH_TRAP() asm volatile ("int $3") /* NOLINT */
+ #else // Fall back to the generic way.
+ #include <signal.h>
+
+ #define CATCH_TRAP() raise(SIGTRAP)
+ #endif
+#elif defined(_MSC_VER)
+ #define CATCH_TRAP() __debugbreak()
+#elif defined(__MINGW32__)
+ extern "C" __declspec(dllimport) void __stdcall DebugBreak();
+ #define CATCH_TRAP() DebugBreak()
+#endif
+
+#ifdef CATCH_TRAP
+ #define CATCH_BREAK_INTO_DEBUGGER() []{ if( Catch::isDebuggerActive() ) { CATCH_TRAP(); } }()
+#else
+ #define CATCH_BREAK_INTO_DEBUGGER() []{}()
+#endif
+
+// end catch_debugger.h
+// start catch_run_context.h
+
+// start catch_fatal_condition.h
+
+// start catch_windows_h_proxy.h
+
+
+#if defined(CATCH_PLATFORM_WINDOWS)
+
+#if !defined(NOMINMAX) && !defined(CATCH_CONFIG_NO_NOMINMAX)
+# define CATCH_DEFINED_NOMINMAX
# define NOMINMAX
#endif
-#ifdef CATCH_DEFINES_WIN32_LEAN_AND_MEAN
+#if !defined(WIN32_LEAN_AND_MEAN) && !defined(CATCH_CONFIG_NO_WIN32_LEAN_AND_MEAN)
+# define CATCH_DEFINED_WIN32_LEAN_AND_MEAN
# define WIN32_LEAN_AND_MEAN
#endif
@@ -6343,1422 +7861,2061 @@ namespace Catch {
#include <windows.h>
#endif
-#ifdef CATCH_DEFINES_NOMINMAX
+#ifdef CATCH_DEFINED_NOMINMAX
# undef NOMINMAX
#endif
-#ifdef CATCH_DEFINES_WIN32_LEAN_AND_MEAN
+#ifdef CATCH_DEFINED_WIN32_LEAN_AND_MEAN
# undef WIN32_LEAN_AND_MEAN
#endif
+#endif // defined(CATCH_PLATFORM_WINDOWS)
-# if !defined ( CATCH_CONFIG_WINDOWS_SEH )
+// end catch_windows_h_proxy.h
+#if defined( CATCH_CONFIG_WINDOWS_SEH )
namespace Catch {
- struct FatalConditionHandler {
- void reset() {}
- };
-}
-
-# else // CATCH_CONFIG_WINDOWS_SEH is defined
-
-namespace Catch {
-
- struct SignalDefs { DWORD id; const char* name; };
- extern SignalDefs signalDefs[];
- // There is no 1-1 mapping between signals and windows exceptions.
- // Windows can easily distinguish between SO and SigSegV,
- // but SigInt, SigTerm, etc are handled differently.
- SignalDefs signalDefs[] = {
- { EXCEPTION_ILLEGAL_INSTRUCTION, "SIGILL - Illegal instruction signal" },
- { EXCEPTION_STACK_OVERFLOW, "SIGSEGV - Stack overflow" },
- { EXCEPTION_ACCESS_VIOLATION, "SIGSEGV - Segmentation violation signal" },
- { EXCEPTION_INT_DIVIDE_BY_ZERO, "Divide by zero error" },
- };
struct FatalConditionHandler {
- static LONG CALLBACK handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo) {
- for (int i = 0; i < sizeof(signalDefs) / sizeof(SignalDefs); ++i) {
- if (ExceptionInfo->ExceptionRecord->ExceptionCode == signalDefs[i].id) {
- reportFatal(signalDefs[i].name);
- }
- }
- // If its not an exception we care about, pass it along.
- // This stops us from eating debugger breaks etc.
- return EXCEPTION_CONTINUE_SEARCH;
- }
-
- FatalConditionHandler() {
- isSet = true;
- // 32k seems enough for Catch to handle stack overflow,
- // but the value was found experimentally, so there is no strong guarantee
- guaranteeSize = 32 * 1024;
- exceptionHandlerHandle = CATCH_NULL;
- // Register as first handler in current chain
- exceptionHandlerHandle = AddVectoredExceptionHandler(1, handleVectoredException);
- // Pass in guarantee size to be filled
- SetThreadStackGuarantee(&guaranteeSize);
- }
-
- static void reset() {
- if (isSet) {
- // Unregister handler and restore the old guarantee
- RemoveVectoredExceptionHandler(exceptionHandlerHandle);
- SetThreadStackGuarantee(&guaranteeSize);
- exceptionHandlerHandle = CATCH_NULL;
- isSet = false;
- }
- }
+ static LONG CALLBACK handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo);
+ FatalConditionHandler();
+ static void reset();
+ ~FatalConditionHandler();
- ~FatalConditionHandler() {
- reset();
- }
private:
static bool isSet;
static ULONG guaranteeSize;
static PVOID exceptionHandlerHandle;
};
- bool FatalConditionHandler::isSet = false;
- ULONG FatalConditionHandler::guaranteeSize = 0;
- PVOID FatalConditionHandler::exceptionHandlerHandle = CATCH_NULL;
-
} // namespace Catch
-# endif // CATCH_CONFIG_WINDOWS_SEH
+#elif defined ( CATCH_CONFIG_POSIX_SIGNALS )
-#else // Not Windows - assumed to be POSIX compatible //////////////////////////
+#include <signal.h>
-# if !defined(CATCH_CONFIG_POSIX_SIGNALS)
+namespace Catch {
+
+ struct FatalConditionHandler {
+
+ static bool isSet;
+ static struct sigaction oldSigActions[];
+ static stack_t oldSigStack;
+ static char altStackMem[];
+
+ static void handleSignal( int sig );
+
+ FatalConditionHandler();
+ ~FatalConditionHandler();
+ static void reset();
+ };
+
+} // namespace Catch
+
+#else
namespace Catch {
struct FatalConditionHandler {
- void reset() {}
+ void reset();
};
}
-# else // CATCH_CONFIG_POSIX_SIGNALS is defined
+#endif
-#include <signal.h>
+// end catch_fatal_condition.h
+#include <string>
namespace Catch {
- struct SignalDefs {
- int id;
- const char* name;
- };
- extern SignalDefs signalDefs[];
- SignalDefs signalDefs[] = {
- { SIGINT, "SIGINT - Terminal interrupt signal" },
- { SIGILL, "SIGILL - Illegal instruction signal" },
- { SIGFPE, "SIGFPE - Floating point error signal" },
- { SIGSEGV, "SIGSEGV - Segmentation violation signal" },
- { SIGTERM, "SIGTERM - Termination request signal" },
- { SIGABRT, "SIGABRT - Abort (abnormal termination) signal" }
- };
+ struct IMutableContext;
- struct FatalConditionHandler {
+ ///////////////////////////////////////////////////////////////////////////
- static bool isSet;
- static struct sigaction oldSigActions [sizeof(signalDefs)/sizeof(SignalDefs)];
- static stack_t oldSigStack;
- static char altStackMem[SIGSTKSZ];
-
- static void handleSignal( int sig ) {
- std::string name = "<unknown signal>";
- for (std::size_t i = 0; i < sizeof(signalDefs) / sizeof(SignalDefs); ++i) {
- SignalDefs &def = signalDefs[i];
- if (sig == def.id) {
- name = def.name;
- break;
- }
- }
- reset();
- reportFatal(name);
- raise( sig );
- }
+ class RunContext : public IResultCapture, public IRunner {
- FatalConditionHandler() {
- isSet = true;
- stack_t sigStack;
- sigStack.ss_sp = altStackMem;
- sigStack.ss_size = SIGSTKSZ;
- sigStack.ss_flags = 0;
- sigaltstack(&sigStack, &oldSigStack);
- struct sigaction sa = { 0 };
+ public:
+ RunContext( RunContext const& ) = delete;
+ RunContext& operator =( RunContext const& ) = delete;
- sa.sa_handler = handleSignal;
- sa.sa_flags = SA_ONSTACK;
- for (std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i) {
- sigaction(signalDefs[i].id, &sa, &oldSigActions[i]);
- }
- }
+ explicit RunContext( IConfigPtr const& _config, IStreamingReporterPtr&& reporter );
- ~FatalConditionHandler() {
- reset();
- }
- static void reset() {
- if( isSet ) {
- // Set signals back to previous values -- hopefully nobody overwrote them in the meantime
- for( std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i ) {
- sigaction(signalDefs[i].id, &oldSigActions[i], CATCH_NULL);
- }
- // Return the old stack
- sigaltstack(&oldSigStack, CATCH_NULL);
- isSet = false;
- }
- }
- };
+ ~RunContext() override;
- bool FatalConditionHandler::isSet = false;
- struct sigaction FatalConditionHandler::oldSigActions[sizeof(signalDefs)/sizeof(SignalDefs)] = {};
- stack_t FatalConditionHandler::oldSigStack = {};
- char FatalConditionHandler::altStackMem[SIGSTKSZ] = {};
+ void testGroupStarting( std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount );
+ void testGroupEnded( std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount );
-} // namespace Catch
+ Totals runTest(TestCase const& testCase);
-# endif // CATCH_CONFIG_POSIX_SIGNALS
+ IConfigPtr config() const;
+ IStreamingReporter& reporter() const;
-#endif // not Windows
+ public: // IResultCapture
-#include <set>
-#include <string>
+ // Assertion handlers
+ void handleExpr
+ ( AssertionInfo const& info,
+ ITransientExpression const& expr,
+ AssertionReaction& reaction ) override;
+ void handleMessage
+ ( AssertionInfo const& info,
+ ResultWas::OfType resultType,
+ StringRef const& message,
+ AssertionReaction& reaction ) override;
+ void handleUnexpectedExceptionNotThrown
+ ( AssertionInfo const& info,
+ AssertionReaction& reaction ) override;
+ void handleUnexpectedInflightException
+ ( AssertionInfo const& info,
+ std::string const& message,
+ AssertionReaction& reaction ) override;
+ void handleIncomplete
+ ( AssertionInfo const& info ) override;
+ void handleNonExpr
+ ( AssertionInfo const &info,
+ ResultWas::OfType resultType,
+ AssertionReaction &reaction ) override;
-namespace Catch {
+ bool sectionStarted( SectionInfo const& sectionInfo, Counts& assertions ) override;
- class StreamRedirect {
+ void sectionEnded( SectionEndInfo const& endInfo ) override;
+ void sectionEndedEarly( SectionEndInfo const& endInfo ) override;
- public:
- StreamRedirect( std::ostream& stream, std::string& targetString )
- : m_stream( stream ),
- m_prevBuf( stream.rdbuf() ),
- m_targetString( targetString )
- {
- stream.rdbuf( m_oss.rdbuf() );
- }
+ auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& override;
- ~StreamRedirect() {
- m_targetString += m_oss.str();
- m_stream.rdbuf( m_prevBuf );
- }
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ void benchmarkPreparing( std::string const& name ) override;
+ void benchmarkStarting( BenchmarkInfo const& info ) override;
+ void benchmarkEnded( BenchmarkStats<> const& stats ) override;
+ void benchmarkFailed( std::string const& error ) override;
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
- private:
- std::ostream& m_stream;
- std::streambuf* m_prevBuf;
- std::ostringstream m_oss;
- std::string& m_targetString;
- };
+ void pushScopedMessage( MessageInfo const& message ) override;
+ void popScopedMessage( MessageInfo const& message ) override;
- ///////////////////////////////////////////////////////////////////////////
+ void emplaceUnscopedMessage( MessageBuilder const& builder ) override;
- class RunContext : public IResultCapture, public IRunner {
+ std::string getCurrentTestName() const override;
+
+ const AssertionResult* getLastResult() const override;
+
+ void exceptionEarlyReported() override;
+
+ void handleFatalErrorCondition( StringRef message ) override;
- RunContext( RunContext const& );
- void operator =( RunContext const& );
+ bool lastAssertionPassed() override;
+
+ void assertionPassed() override;
public:
+ // !TBD We need to do this another way!
+ bool aborting() const final;
- explicit RunContext( Ptr<IConfig const> const& _config, Ptr<IStreamingReporter> const& reporter )
- : m_runInfo( _config->name() ),
- m_context( getCurrentMutableContext() ),
- m_activeTestCase( CATCH_NULL ),
- m_config( _config ),
- m_reporter( reporter ),
- m_shouldReportUnexpected ( true )
- {
- m_context.setRunner( this );
- m_context.setConfig( m_config );
- m_context.setResultCapture( this );
- m_reporter->testRunStarting( m_runInfo );
- }
+ private:
- virtual ~RunContext() {
- m_reporter->testRunEnded( TestRunStats( m_runInfo, m_totals, aborting() ) );
- }
+ void runCurrentTest( std::string& redirectedCout, std::string& redirectedCerr );
+ void invokeActiveTestCase();
- void testGroupStarting( std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount ) {
- m_reporter->testGroupStarting( GroupInfo( testSpec, groupIndex, groupsCount ) );
- }
- void testGroupEnded( std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount ) {
- m_reporter->testGroupEnded( TestGroupStats( GroupInfo( testSpec, groupIndex, groupsCount ), totals, aborting() ) );
- }
+ void resetAssertionInfo();
+ bool testForMissingAssertions( Counts& assertions );
- Totals runTest( TestCase const& testCase ) {
- Totals prevTotals = m_totals;
+ void assertionEnded( AssertionResult const& result );
+ void reportExpr
+ ( AssertionInfo const &info,
+ ResultWas::OfType resultType,
+ ITransientExpression const *expr,
+ bool negated );
- std::string redirectedCout;
- std::string redirectedCerr;
+ void populateReaction( AssertionReaction& reaction );
- TestCaseInfo testInfo = testCase.getTestCaseInfo();
+ private:
- m_reporter->testCaseStarting( testInfo );
+ void handleUnfinishedSections();
- m_activeTestCase = &testCase;
+ TestRunInfo m_runInfo;
+ IMutableContext& m_context;
+ TestCase const* m_activeTestCase = nullptr;
+ ITracker* m_testCaseTracker = nullptr;
+ Option<AssertionResult> m_lastResult;
- do {
- ITracker& rootTracker = m_trackerContext.startRun();
- assert( rootTracker.isSectionTracker() );
- static_cast<SectionTracker&>( rootTracker ).addInitialFilters( m_config->getSectionsToRun() );
- do {
- m_trackerContext.startCycle();
- m_testCaseTracker = &SectionTracker::acquire( m_trackerContext, TestCaseTracking::NameAndLocation( testInfo.name, testInfo.lineInfo ) );
- runCurrentTest( redirectedCout, redirectedCerr );
- }
- while( !m_testCaseTracker->isSuccessfullyCompleted() && !aborting() );
- }
- // !TBD: deprecated - this will be replaced by indexed trackers
- while( getCurrentContext().advanceGeneratorsForCurrentTest() && !aborting() );
+ IConfigPtr m_config;
+ Totals m_totals;
+ IStreamingReporterPtr m_reporter;
+ std::vector<MessageInfo> m_messages;
+ std::vector<ScopedMessage> m_messageScopes; /* Keeps owners of so-called unscoped messages. */
+ AssertionInfo m_lastAssertionInfo;
+ std::vector<SectionEndInfo> m_unfinishedSections;
+ std::vector<ITracker*> m_activeSections;
+ TrackerContext m_trackerContext;
+ bool m_lastAssertionPassed = false;
+ bool m_shouldReportUnexpected = true;
+ bool m_includeSuccessfulResults;
+ };
- Totals deltaTotals = m_totals.delta( prevTotals );
- if( testInfo.expectedToFail() && deltaTotals.testCases.passed > 0 ) {
- deltaTotals.assertions.failed++;
- deltaTotals.testCases.passed--;
- deltaTotals.testCases.failed++;
- }
- m_totals.testCases += deltaTotals.testCases;
- m_reporter->testCaseEnded( TestCaseStats( testInfo,
- deltaTotals,
- redirectedCout,
- redirectedCerr,
- aborting() ) );
+ void seedRng(IConfig const& config);
+ unsigned int rngSeed();
+} // end namespace Catch
- m_activeTestCase = CATCH_NULL;
- m_testCaseTracker = CATCH_NULL;
+// end catch_run_context.h
+namespace Catch {
- return deltaTotals;
+ namespace {
+ auto operator <<( std::ostream& os, ITransientExpression const& expr ) -> std::ostream& {
+ expr.streamReconstructedExpression( os );
+ return os;
}
+ }
- Ptr<IConfig const> config() const {
- return m_config;
- }
+ LazyExpression::LazyExpression( bool isNegated )
+ : m_isNegated( isNegated )
+ {}
- private: // IResultCapture
+ LazyExpression::LazyExpression( LazyExpression const& other ) : m_isNegated( other.m_isNegated ) {}
- virtual void assertionEnded( AssertionResult const& result ) {
- if( result.getResultType() == ResultWas::Ok ) {
- m_totals.assertions.passed++;
- }
- else if( !result.isOk() ) {
- m_totals.assertions.failed++;
- }
+ LazyExpression::operator bool() const {
+ return m_transientExpression != nullptr;
+ }
- // We have no use for the return value (whether messages should be cleared), because messages were made scoped
- // and should be let to clear themselves out.
- static_cast<void>(m_reporter->assertionEnded(AssertionStats(result, m_messages, m_totals)));
+ auto operator << ( std::ostream& os, LazyExpression const& lazyExpr ) -> std::ostream& {
+ if( lazyExpr.m_isNegated )
+ os << "!";
- // Reset working state
- m_lastAssertionInfo = AssertionInfo( std::string(), m_lastAssertionInfo.lineInfo, "{Unknown expression after the reported line}" , m_lastAssertionInfo.resultDisposition );
- m_lastResult = result;
+ if( lazyExpr ) {
+ if( lazyExpr.m_isNegated && lazyExpr.m_transientExpression->isBinaryExpression() )
+ os << "(" << *lazyExpr.m_transientExpression << ")";
+ else
+ os << *lazyExpr.m_transientExpression;
}
+ else {
+ os << "{** error - unchecked empty expression requested **}";
+ }
+ return os;
+ }
- virtual bool sectionStarted (
- SectionInfo const& sectionInfo,
- Counts& assertions
- )
- {
- ITracker& sectionTracker = SectionTracker::acquire( m_trackerContext, TestCaseTracking::NameAndLocation( sectionInfo.name, sectionInfo.lineInfo ) );
- if( !sectionTracker.isOpen() )
- return false;
- m_activeSections.push_back( &sectionTracker );
+ AssertionHandler::AssertionHandler
+ ( StringRef const& macroName,
+ SourceLineInfo const& lineInfo,
+ StringRef capturedExpression,
+ ResultDisposition::Flags resultDisposition )
+ : m_assertionInfo{ macroName, lineInfo, capturedExpression, resultDisposition },
+ m_resultCapture( getResultCapture() )
+ {}
- m_lastAssertionInfo.lineInfo = sectionInfo.lineInfo;
+ void AssertionHandler::handleExpr( ITransientExpression const& expr ) {
+ m_resultCapture.handleExpr( m_assertionInfo, expr, m_reaction );
+ }
+ void AssertionHandler::handleMessage(ResultWas::OfType resultType, StringRef const& message) {
+ m_resultCapture.handleMessage( m_assertionInfo, resultType, message, m_reaction );
+ }
- m_reporter->sectionStarting( sectionInfo );
+ auto AssertionHandler::allowThrows() const -> bool {
+ return getCurrentContext().getConfig()->allowThrows();
+ }
- assertions = m_totals.assertions;
+ void AssertionHandler::complete() {
+ setCompleted();
+ if( m_reaction.shouldDebugBreak ) {
- return true;
+ // If you find your debugger stopping you here then go one level up on the
+ // call-stack for the code that caused it (typically a failed assertion)
+
+ // (To go back to the test and change execution, jump over the throw, next)
+ CATCH_BREAK_INTO_DEBUGGER();
}
- bool testForMissingAssertions( Counts& assertions ) {
- if( assertions.total() != 0 )
- return false;
- if( !m_config->warnAboutMissingAssertions() )
- return false;
- if( m_trackerContext.currentTracker().hasChildren() )
- return false;
- m_totals.assertions.failed++;
- assertions.failed++;
- return true;
+ if (m_reaction.shouldThrow) {
+#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS)
+ throw Catch::TestFailureException();
+#else
+ CATCH_ERROR( "Test failure requires aborting test!" );
+#endif
}
+ }
+ void AssertionHandler::setCompleted() {
+ m_completed = true;
+ }
- virtual void sectionEnded( SectionEndInfo const& endInfo ) {
- Counts assertions = m_totals.assertions - endInfo.prevAssertions;
- bool missingAssertions = testForMissingAssertions( assertions );
-
- if( !m_activeSections.empty() ) {
- m_activeSections.back()->close();
- m_activeSections.pop_back();
- }
+ void AssertionHandler::handleUnexpectedInflightException() {
+ m_resultCapture.handleUnexpectedInflightException( m_assertionInfo, Catch::translateActiveException(), m_reaction );
+ }
- m_reporter->sectionEnded( SectionStats( endInfo.sectionInfo, assertions, endInfo.durationInSeconds, missingAssertions ) );
- m_messages.clear();
- }
+ void AssertionHandler::handleExceptionThrownAsExpected() {
+ m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction);
+ }
+ void AssertionHandler::handleExceptionNotThrownAsExpected() {
+ m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction);
+ }
- virtual void sectionEndedEarly( SectionEndInfo const& endInfo ) {
- if( m_unfinishedSections.empty() )
- m_activeSections.back()->fail();
- else
- m_activeSections.back()->close();
- m_activeSections.pop_back();
+ void AssertionHandler::handleUnexpectedExceptionNotThrown() {
+ m_resultCapture.handleUnexpectedExceptionNotThrown( m_assertionInfo, m_reaction );
+ }
- m_unfinishedSections.push_back( endInfo );
- }
+ void AssertionHandler::handleThrowingCallSkipped() {
+ m_resultCapture.handleNonExpr(m_assertionInfo, ResultWas::Ok, m_reaction);
+ }
- virtual void pushScopedMessage( MessageInfo const& message ) {
- m_messages.push_back( message );
- }
+ // This is the overload that takes a string and infers the Equals matcher from it
+ // The more general overload, that takes any string matcher, is in catch_capture_matchers.cpp
+ void handleExceptionMatchExpr( AssertionHandler& handler, std::string const& str, StringRef const& matcherString ) {
+ handleExceptionMatchExpr( handler, Matchers::Equals( str ), matcherString );
+ }
- virtual void popScopedMessage( MessageInfo const& message ) {
- m_messages.erase( std::remove( m_messages.begin(), m_messages.end(), message ), m_messages.end() );
- }
+} // namespace Catch
+// end catch_assertionhandler.cpp
+// start catch_assertionresult.cpp
- virtual std::string getCurrentTestName() const {
- return m_activeTestCase
- ? m_activeTestCase->getTestCaseInfo().name
- : std::string();
- }
+namespace Catch {
+ AssertionResultData::AssertionResultData(ResultWas::OfType _resultType, LazyExpression const & _lazyExpression):
+ lazyExpression(_lazyExpression),
+ resultType(_resultType) {}
- virtual const AssertionResult* getLastResult() const {
- return &m_lastResult;
- }
+ std::string AssertionResultData::reconstructExpression() const {
- virtual void exceptionEarlyReported() {
- m_shouldReportUnexpected = false;
+ if( reconstructedExpression.empty() ) {
+ if( lazyExpression ) {
+ ReusableStringStream rss;
+ rss << lazyExpression;
+ reconstructedExpression = rss.str();
+ }
}
+ return reconstructedExpression;
+ }
- virtual void handleFatalErrorCondition( std::string const& message ) {
- // Don't rebuild the result -- the stringification itself can cause more fatal errors
- // Instead, fake a result data.
- AssertionResultData tempResult;
- tempResult.resultType = ResultWas::FatalErrorCondition;
- tempResult.message = message;
- AssertionResult result(m_lastAssertionInfo, tempResult);
+ AssertionResult::AssertionResult( AssertionInfo const& info, AssertionResultData const& data )
+ : m_info( info ),
+ m_resultData( data )
+ {}
- getResultCapture().assertionEnded(result);
+ // Result was a success
+ bool AssertionResult::succeeded() const {
+ return Catch::isOk( m_resultData.resultType );
+ }
- handleUnfinishedSections();
+ // Result was a success, or failure is suppressed
+ bool AssertionResult::isOk() const {
+ return Catch::isOk( m_resultData.resultType ) || shouldSuppressFailure( m_info.resultDisposition );
+ }
- // Recreate section for test case (as we will lose the one that was in scope)
- TestCaseInfo const& testCaseInfo = m_activeTestCase->getTestCaseInfo();
- SectionInfo testCaseSection( testCaseInfo.lineInfo, testCaseInfo.name, testCaseInfo.description );
+ ResultWas::OfType AssertionResult::getResultType() const {
+ return m_resultData.resultType;
+ }
- Counts assertions;
- assertions.failed = 1;
- SectionStats testCaseSectionStats( testCaseSection, assertions, 0, false );
- m_reporter->sectionEnded( testCaseSectionStats );
+ bool AssertionResult::hasExpression() const {
+ return !m_info.capturedExpression.empty();
+ }
- TestCaseInfo testInfo = m_activeTestCase->getTestCaseInfo();
+ bool AssertionResult::hasMessage() const {
+ return !m_resultData.message.empty();
+ }
- Totals deltaTotals;
- deltaTotals.testCases.failed = 1;
- m_reporter->testCaseEnded( TestCaseStats( testInfo,
- deltaTotals,
- std::string(),
- std::string(),
- false ) );
- m_totals.testCases.failed++;
- testGroupEnded( std::string(), m_totals, 1, 1 );
- m_reporter->testRunEnded( TestRunStats( m_runInfo, m_totals, false ) );
+ std::string AssertionResult::getExpression() const {
+ // Possibly overallocating by 3 characters should be basically free
+ std::string expr; expr.reserve(m_info.capturedExpression.size() + 3);
+ if (isFalseTest(m_info.resultDisposition)) {
+ expr += "!(";
+ }
+ expr += m_info.capturedExpression;
+ if (isFalseTest(m_info.resultDisposition)) {
+ expr += ')';
}
+ return expr;
+ }
- public:
- // !TBD We need to do this another way!
- bool aborting() const {
- return m_totals.assertions.failed == static_cast<std::size_t>( m_config->abortAfter() );
+ std::string AssertionResult::getExpressionInMacro() const {
+ std::string expr;
+ if( m_info.macroName.empty() )
+ expr = static_cast<std::string>(m_info.capturedExpression);
+ else {
+ expr.reserve( m_info.macroName.size() + m_info.capturedExpression.size() + 4 );
+ expr += m_info.macroName;
+ expr += "( ";
+ expr += m_info.capturedExpression;
+ expr += " )";
}
+ return expr;
+ }
- private:
+ bool AssertionResult::hasExpandedExpression() const {
+ return hasExpression() && getExpandedExpression() != getExpression();
+ }
- void runCurrentTest( std::string& redirectedCout, std::string& redirectedCerr ) {
- TestCaseInfo const& testCaseInfo = m_activeTestCase->getTestCaseInfo();
- SectionInfo testCaseSection( testCaseInfo.lineInfo, testCaseInfo.name, testCaseInfo.description );
- m_reporter->sectionStarting( testCaseSection );
- Counts prevAssertions = m_totals.assertions;
- double duration = 0;
- m_shouldReportUnexpected = true;
- try {
- m_lastAssertionInfo = AssertionInfo( "TEST_CASE", testCaseInfo.lineInfo, std::string(), ResultDisposition::Normal );
+ std::string AssertionResult::getExpandedExpression() const {
+ std::string expr = m_resultData.reconstructExpression();
+ return expr.empty()
+ ? getExpression()
+ : expr;
+ }
- seedRng( *m_config );
+ std::string AssertionResult::getMessage() const {
+ return m_resultData.message;
+ }
+ SourceLineInfo AssertionResult::getSourceInfo() const {
+ return m_info.lineInfo;
+ }
- Timer timer;
- timer.start();
- if( m_reporter->getPreferences().shouldRedirectStdOut ) {
- StreamRedirect coutRedir( Catch::cout(), redirectedCout );
- StreamRedirect cerrRedir( Catch::cerr(), redirectedCerr );
- invokeActiveTestCase();
- }
- else {
- invokeActiveTestCase();
- }
- duration = timer.getElapsedSeconds();
- }
- catch( TestFailureException& ) {
- // This just means the test was aborted due to failure
- }
- catch(...) {
- // Under CATCH_CONFIG_FAST_COMPILE, unexpected exceptions under REQUIRE assertions
- // are reported without translation at the point of origin.
- if (m_shouldReportUnexpected) {
- makeUnexpectedResultBuilder().useActiveException();
- }
- }
- m_testCaseTracker->close();
- handleUnfinishedSections();
- m_messages.clear();
+ StringRef AssertionResult::getTestMacroName() const {
+ return m_info.macroName;
+ }
- Counts assertions = m_totals.assertions - prevAssertions;
- bool missingAssertions = testForMissingAssertions( assertions );
+} // end namespace Catch
+// end catch_assertionresult.cpp
+// start catch_capture_matchers.cpp
- if( testCaseInfo.okToFail() ) {
- std::swap( assertions.failedButOk, assertions.failed );
- m_totals.assertions.failed -= assertions.failedButOk;
- m_totals.assertions.failedButOk += assertions.failedButOk;
- }
+namespace Catch {
- SectionStats testCaseSectionStats( testCaseSection, assertions, duration, missingAssertions );
- m_reporter->sectionEnded( testCaseSectionStats );
- }
+ using StringMatcher = Matchers::Impl::MatcherBase<std::string>;
- void invokeActiveTestCase() {
- FatalConditionHandler fatalConditionHandler; // Handle signals
- m_activeTestCase->invoke();
- fatalConditionHandler.reset();
- }
+ // This is the general overload that takes a any string matcher
+ // There is another overload, in catch_assertionhandler.h/.cpp, that only takes a string and infers
+ // the Equals matcher (so the header does not mention matchers)
+ void handleExceptionMatchExpr( AssertionHandler& handler, StringMatcher const& matcher, StringRef const& matcherString ) {
+ std::string exceptionMessage = Catch::translateActiveException();
+ MatchExpr<std::string, StringMatcher const&> expr( exceptionMessage, matcher, matcherString );
+ handler.handleExpr( expr );
+ }
- private:
+} // namespace Catch
+// end catch_capture_matchers.cpp
+// start catch_commandline.cpp
- ResultBuilder makeUnexpectedResultBuilder() const {
- return ResultBuilder( m_lastAssertionInfo.macroName.c_str(),
- m_lastAssertionInfo.lineInfo,
- m_lastAssertionInfo.capturedExpression.c_str(),
- m_lastAssertionInfo.resultDisposition );
- }
+// start catch_commandline.h
- void handleUnfinishedSections() {
- // If sections ended prematurely due to an exception we stored their
- // infos here so we can tear them down outside the unwind process.
- for( std::vector<SectionEndInfo>::const_reverse_iterator it = m_unfinishedSections.rbegin(),
- itEnd = m_unfinishedSections.rend();
- it != itEnd;
- ++it )
- sectionEnded( *it );
- m_unfinishedSections.clear();
- }
+// start catch_clara.h
- TestRunInfo m_runInfo;
- IMutableContext& m_context;
- TestCase const* m_activeTestCase;
- ITracker* m_testCaseTracker;
- ITracker* m_currentSectionTracker;
- AssertionResult m_lastResult;
+// Use Catch's value for console width (store Clara's off to the side, if present)
+#ifdef CLARA_CONFIG_CONSOLE_WIDTH
+#define CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH
+#undef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH
+#endif
+#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_CONFIG_CONSOLE_WIDTH-1
- Ptr<IConfig const> m_config;
- Totals m_totals;
- Ptr<IStreamingReporter> m_reporter;
- std::vector<MessageInfo> m_messages;
- AssertionInfo m_lastAssertionInfo;
- std::vector<SectionEndInfo> m_unfinishedSections;
- std::vector<ITracker*> m_activeSections;
- TrackerContext m_trackerContext;
- bool m_shouldReportUnexpected;
- };
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wweak-vtables"
+#pragma clang diagnostic ignored "-Wexit-time-destructors"
+#pragma clang diagnostic ignored "-Wshadow"
+#endif
- IResultCapture& getResultCapture() {
- if( IResultCapture* capture = getCurrentContext().getResultCapture() )
- return *capture;
- else
- throw std::logic_error( "No result capture instance" );
- }
+// start clara.hpp
+// Copyright 2017 Two Blue Cubes Ltd. All rights reserved.
+//
+// Distributed under the Boost Software License, Version 1.0. (See accompanying
+// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
+//
+// See https://github.com/philsquared/Clara for more details
-} // end namespace Catch
+// Clara v1.1.5
-// #included from: internal/catch_version.h
-#define TWOBLUECUBES_CATCH_VERSION_H_INCLUDED
-namespace Catch {
+#ifndef CATCH_CLARA_CONFIG_CONSOLE_WIDTH
+#define CATCH_CLARA_CONFIG_CONSOLE_WIDTH 80
+#endif
- // Versioning information
- struct Version {
- Version( unsigned int _majorVersion,
- unsigned int _minorVersion,
- unsigned int _patchNumber,
- char const * const _branchName,
- unsigned int _buildNumber );
+#ifndef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH
+#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_CLARA_CONFIG_CONSOLE_WIDTH
+#endif
- unsigned int const majorVersion;
- unsigned int const minorVersion;
- unsigned int const patchNumber;
+#ifndef CLARA_CONFIG_OPTIONAL_TYPE
+#ifdef __has_include
+#if __has_include(<optional>) && __cplusplus >= 201703L
+#include <optional>
+#define CLARA_CONFIG_OPTIONAL_TYPE std::optional
+#endif
+#endif
+#endif
- // buildNumber is only used if branchName is not null
- char const * const branchName;
- unsigned int const buildNumber;
+// ----------- #included from clara_textflow.hpp -----------
- friend std::ostream& operator << ( std::ostream& os, Version const& version );
+// TextFlowCpp
+//
+// A single-header library for wrapping and laying out basic text, by Phil Nash
+//
+// Distributed under the Boost Software License, Version 1.0. (See accompanying
+// file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
+//
+// This project is hosted at https://github.com/philsquared/textflowcpp
- private:
- void operator=( Version const& );
- };
- inline Version libraryVersion();
-}
+#include <cassert>
+#include <ostream>
+#include <sstream>
+#include <vector>
-#include <fstream>
-#include <stdlib.h>
-#include <limits>
+#ifndef CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH
+#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH 80
+#endif
namespace Catch {
+namespace clara {
+namespace TextFlow {
- Ptr<IStreamingReporter> createReporter( std::string const& reporterName, Ptr<Config> const& config ) {
- Ptr<IStreamingReporter> reporter = getRegistryHub().getReporterRegistry().create( reporterName, config.get() );
- if( !reporter ) {
- std::ostringstream oss;
- oss << "No reporter registered with name: '" << reporterName << "'";
- throw std::domain_error( oss.str() );
- }
- return reporter;
- }
+inline auto isWhitespace(char c) -> bool {
+ static std::string chars = " \t\n\r";
+ return chars.find(c) != std::string::npos;
+}
+inline auto isBreakableBefore(char c) -> bool {
+ static std::string chars = "[({<|";
+ return chars.find(c) != std::string::npos;
+}
+inline auto isBreakableAfter(char c) -> bool {
+ static std::string chars = "])}>.,:;*+-=&/\\";
+ return chars.find(c) != std::string::npos;
+}
- Ptr<IStreamingReporter> makeReporter( Ptr<Config> const& config ) {
- std::vector<std::string> reporters = config->getReporterNames();
- if( reporters.empty() )
- reporters.push_back( "console" );
+class Columns;
- Ptr<IStreamingReporter> reporter;
- for( std::vector<std::string>::const_iterator it = reporters.begin(), itEnd = reporters.end();
- it != itEnd;
- ++it )
- reporter = addReporter( reporter, createReporter( *it, config ) );
- return reporter;
- }
- Ptr<IStreamingReporter> addListeners( Ptr<IConfig const> const& config, Ptr<IStreamingReporter> reporters ) {
- IReporterRegistry::Listeners listeners = getRegistryHub().getReporterRegistry().getListeners();
- for( IReporterRegistry::Listeners::const_iterator it = listeners.begin(), itEnd = listeners.end();
- it != itEnd;
- ++it )
- reporters = addReporter(reporters, (*it)->create( ReporterConfig( config ) ) );
- return reporters;
- }
+class Column {
+ std::vector<std::string> m_strings;
+ size_t m_width = CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH;
+ size_t m_indent = 0;
+ size_t m_initialIndent = std::string::npos;
- Totals runTests( Ptr<Config> const& config ) {
+public:
+ class iterator {
+ friend Column;
+
+ Column const& m_column;
+ size_t m_stringIndex = 0;
+ size_t m_pos = 0;
+
+ size_t m_len = 0;
+ size_t m_end = 0;
+ bool m_suffix = false;
+
+ iterator(Column const& column, size_t stringIndex)
+ : m_column(column),
+ m_stringIndex(stringIndex) {}
+
+ auto line() const -> std::string const& { return m_column.m_strings[m_stringIndex]; }
+
+ auto isBoundary(size_t at) const -> bool {
+ assert(at > 0);
+ assert(at <= line().size());
+
+ return at == line().size() ||
+ (isWhitespace(line()[at]) && !isWhitespace(line()[at - 1])) ||
+ isBreakableBefore(line()[at]) ||
+ isBreakableAfter(line()[at - 1]);
+ }
+
+ void calcLength() {
+ assert(m_stringIndex < m_column.m_strings.size());
+
+ m_suffix = false;
+ auto width = m_column.m_width - indent();
+ m_end = m_pos;
+ if (line()[m_pos] == '\n') {
+ ++m_end;
+ }
+ while (m_end < line().size() && line()[m_end] != '\n')
+ ++m_end;
+
+ if (m_end < m_pos + width) {
+ m_len = m_end - m_pos;
+ } else {
+ size_t len = width;
+ while (len > 0 && !isBoundary(m_pos + len))
+ --len;
+ while (len > 0 && isWhitespace(line()[m_pos + len - 1]))
+ --len;
+
+ if (len > 0) {
+ m_len = len;
+ } else {
+ m_suffix = true;
+ m_len = width - 1;
+ }
+ }
+ }
+
+ auto indent() const -> size_t {
+ auto initial = m_pos == 0 && m_stringIndex == 0 ? m_column.m_initialIndent : std::string::npos;
+ return initial == std::string::npos ? m_column.m_indent : initial;
+ }
+
+ auto addIndentAndSuffix(std::string const &plain) const -> std::string {
+ return std::string(indent(), ' ') + (m_suffix ? plain + "-" : plain);
+ }
+
+ public:
+ using difference_type = std::ptrdiff_t;
+ using value_type = std::string;
+ using pointer = value_type * ;
+ using reference = value_type & ;
+ using iterator_category = std::forward_iterator_tag;
+
+ explicit iterator(Column const& column) : m_column(column) {
+ assert(m_column.m_width > m_column.m_indent);
+ assert(m_column.m_initialIndent == std::string::npos || m_column.m_width > m_column.m_initialIndent);
+ calcLength();
+ if (m_len == 0)
+ m_stringIndex++; // Empty string
+ }
+
+ auto operator *() const -> std::string {
+ assert(m_stringIndex < m_column.m_strings.size());
+ assert(m_pos <= m_end);
+ return addIndentAndSuffix(line().substr(m_pos, m_len));
+ }
+
+ auto operator ++() -> iterator& {
+ m_pos += m_len;
+ if (m_pos < line().size() && line()[m_pos] == '\n')
+ m_pos += 1;
+ else
+ while (m_pos < line().size() && isWhitespace(line()[m_pos]))
+ ++m_pos;
+
+ if (m_pos == line().size()) {
+ m_pos = 0;
+ ++m_stringIndex;
+ }
+ if (m_stringIndex < m_column.m_strings.size())
+ calcLength();
+ return *this;
+ }
+ auto operator ++(int) -> iterator {
+ iterator prev(*this);
+ operator++();
+ return prev;
+ }
+
+ auto operator ==(iterator const& other) const -> bool {
+ return
+ m_pos == other.m_pos &&
+ m_stringIndex == other.m_stringIndex &&
+ &m_column == &other.m_column;
+ }
+ auto operator !=(iterator const& other) const -> bool {
+ return !operator==(other);
+ }
+ };
+ using const_iterator = iterator;
+
+ explicit Column(std::string const& text) { m_strings.push_back(text); }
+
+ auto width(size_t newWidth) -> Column& {
+ assert(newWidth > 0);
+ m_width = newWidth;
+ return *this;
+ }
+ auto indent(size_t newIndent) -> Column& {
+ m_indent = newIndent;
+ return *this;
+ }
+ auto initialIndent(size_t newIndent) -> Column& {
+ m_initialIndent = newIndent;
+ return *this;
+ }
+
+ auto width() const -> size_t { return m_width; }
+ auto begin() const -> iterator { return iterator(*this); }
+ auto end() const -> iterator { return { *this, m_strings.size() }; }
+
+ inline friend std::ostream& operator << (std::ostream& os, Column const& col) {
+ bool first = true;
+ for (auto line : col) {
+ if (first)
+ first = false;
+ else
+ os << "\n";
+ os << line;
+ }
+ return os;
+ }
+
+ auto operator + (Column const& other)->Columns;
+
+ auto toString() const -> std::string {
+ std::ostringstream oss;
+ oss << *this;
+ return oss.str();
+ }
+};
- Ptr<IConfig const> iconfig = config.get();
+class Spacer : public Column {
- Ptr<IStreamingReporter> reporter = makeReporter( config );
- reporter = addListeners( iconfig, reporter );
+public:
+ explicit Spacer(size_t spaceWidth) : Column("") {
+ width(spaceWidth);
+ }
+};
- RunContext context( iconfig, reporter );
+class Columns {
+ std::vector<Column> m_columns;
- Totals totals;
+public:
- context.testGroupStarting( config->name(), 1, 1 );
+ class iterator {
+ friend Columns;
+ struct EndTag {};
+
+ std::vector<Column> const& m_columns;
+ std::vector<Column::iterator> m_iterators;
+ size_t m_activeIterators;
+
+ iterator(Columns const& columns, EndTag)
+ : m_columns(columns.m_columns),
+ m_activeIterators(0) {
+ m_iterators.reserve(m_columns.size());
+
+ for (auto const& col : m_columns)
+ m_iterators.push_back(col.end());
+ }
+
+ public:
+ using difference_type = std::ptrdiff_t;
+ using value_type = std::string;
+ using pointer = value_type * ;
+ using reference = value_type & ;
+ using iterator_category = std::forward_iterator_tag;
+
+ explicit iterator(Columns const& columns)
+ : m_columns(columns.m_columns),
+ m_activeIterators(m_columns.size()) {
+ m_iterators.reserve(m_columns.size());
+
+ for (auto const& col : m_columns)
+ m_iterators.push_back(col.begin());
+ }
+
+ auto operator ==(iterator const& other) const -> bool {
+ return m_iterators == other.m_iterators;
+ }
+ auto operator !=(iterator const& other) const -> bool {
+ return m_iterators != other.m_iterators;
+ }
+ auto operator *() const -> std::string {
+ std::string row, padding;
+
+ for (size_t i = 0; i < m_columns.size(); ++i) {
+ auto width = m_columns[i].width();
+ if (m_iterators[i] != m_columns[i].end()) {
+ std::string col = *m_iterators[i];
+ row += padding + col;
+ if (col.size() < width)
+ padding = std::string(width - col.size(), ' ');
+ else
+ padding = "";
+ } else {
+ padding += std::string(width, ' ');
+ }
+ }
+ return row;
+ }
+ auto operator ++() -> iterator& {
+ for (size_t i = 0; i < m_columns.size(); ++i) {
+ if (m_iterators[i] != m_columns[i].end())
+ ++m_iterators[i];
+ }
+ return *this;
+ }
+ auto operator ++(int) -> iterator {
+ iterator prev(*this);
+ operator++();
+ return prev;
+ }
+ };
+ using const_iterator = iterator;
+
+ auto begin() const -> iterator { return iterator(*this); }
+ auto end() const -> iterator { return { *this, iterator::EndTag() }; }
+
+ auto operator += (Column const& col) -> Columns& {
+ m_columns.push_back(col);
+ return *this;
+ }
+ auto operator + (Column const& col) -> Columns {
+ Columns combined = *this;
+ combined += col;
+ return combined;
+ }
+
+ inline friend std::ostream& operator << (std::ostream& os, Columns const& cols) {
+
+ bool first = true;
+ for (auto line : cols) {
+ if (first)
+ first = false;
+ else
+ os << "\n";
+ os << line;
+ }
+ return os;
+ }
+
+ auto toString() const -> std::string {
+ std::ostringstream oss;
+ oss << *this;
+ return oss.str();
+ }
+};
- TestSpec testSpec = config->testSpec();
- if( !testSpec.hasFilters() )
- testSpec = TestSpecParser( ITagAliasRegistry::get() ).parse( "~[.]" ).testSpec(); // All not hidden tests
+inline auto Column::operator + (Column const& other) -> Columns {
+ Columns cols;
+ cols += *this;
+ cols += other;
+ return cols;
+}
+}
- std::vector<TestCase> const& allTestCases = getAllTestCasesSorted( *iconfig );
- for( std::vector<TestCase>::const_iterator it = allTestCases.begin(), itEnd = allTestCases.end();
- it != itEnd;
- ++it ) {
- if( !context.aborting() && matchTest( *it, testSpec, *iconfig ) )
- totals += context.runTest( *it );
- else
- reporter->skipTest( *it );
- }
+}
+}
- context.testGroupEnded( iconfig->name(), totals, 1, 1 );
- return totals;
- }
+// ----------- end of #include from clara_textflow.hpp -----------
+// ........... back in clara.hpp
- void applyFilenamesAsTags( IConfig const& config ) {
- std::vector<TestCase> const& tests = getAllTestCasesSorted( config );
- for(std::size_t i = 0; i < tests.size(); ++i ) {
- TestCase& test = const_cast<TestCase&>( tests[i] );
- std::set<std::string> tags = test.tags;
+#include <cctype>
+#include <string>
+#include <memory>
+#include <set>
+#include <algorithm>
- std::string filename = test.lineInfo.file;
- std::string::size_type lastSlash = filename.find_last_of( "\\/" );
- if( lastSlash != std::string::npos )
- filename = filename.substr( lastSlash+1 );
+#if !defined(CATCH_PLATFORM_WINDOWS) && ( defined(WIN32) || defined(__WIN32__) || defined(_WIN32) || defined(_MSC_VER) )
+#define CATCH_PLATFORM_WINDOWS
+#endif
- std::string::size_type lastDot = filename.find_last_of( "." );
- if( lastDot != std::string::npos )
- filename = filename.substr( 0, lastDot );
+namespace Catch { namespace clara {
+namespace detail {
- tags.insert( "#" + filename );
- setTags( test, tags );
- }
- }
+ // Traits for extracting arg and return type of lambdas (for single argument lambdas)
+ template<typename L>
+ struct UnaryLambdaTraits : UnaryLambdaTraits<decltype( &L::operator() )> {};
- class Session : NonCopyable {
- static bool alreadyInstantiated;
+ template<typename ClassT, typename ReturnT, typename... Args>
+ struct UnaryLambdaTraits<ReturnT( ClassT::* )( Args... ) const> {
+ static const bool isValid = false;
+ };
- public:
+ template<typename ClassT, typename ReturnT, typename ArgT>
+ struct UnaryLambdaTraits<ReturnT( ClassT::* )( ArgT ) const> {
+ static const bool isValid = true;
+ using ArgType = typename std::remove_const<typename std::remove_reference<ArgT>::type>::type;
+ using ReturnType = ReturnT;
+ };
- struct OnUnusedOptions { enum DoWhat { Ignore, Fail }; };
+ class TokenStream;
- Session()
- : m_cli( makeCommandLineParser() ) {
- if( alreadyInstantiated ) {
- std::string msg = "Only one instance of Catch::Session can ever be used";
- Catch::cerr() << msg << std::endl;
- throw std::logic_error( msg );
- }
- alreadyInstantiated = true;
- }
- ~Session() {
- Catch::cleanUp();
- }
+ // Transport for raw args (copied from main args, or supplied via init list for testing)
+ class Args {
+ friend TokenStream;
+ std::string m_exeName;
+ std::vector<std::string> m_args;
- void showHelp( std::string const& processName ) {
- Catch::cout() << "\nCatch v" << libraryVersion() << "\n";
+ public:
+ Args( int argc, char const* const* argv )
+ : m_exeName(argv[0]),
+ m_args(argv + 1, argv + argc) {}
+
+ Args( std::initializer_list<std::string> args )
+ : m_exeName( *args.begin() ),
+ m_args( args.begin()+1, args.end() )
+ {}
- m_cli.usage( Catch::cout(), processName );
- Catch::cout() << "For more detail usage please see the project docs\n" << std::endl;
+ auto exeName() const -> std::string {
+ return m_exeName;
}
+ };
- int applyCommandLine( int argc, char const* const* const argv, OnUnusedOptions::DoWhat unusedOptionBehaviour = OnUnusedOptions::Fail ) {
- try {
- m_cli.setThrowOnUnrecognisedTokens( unusedOptionBehaviour == OnUnusedOptions::Fail );
- m_unusedTokens = m_cli.parseInto( Clara::argsToVector( argc, argv ), m_configData );
- if( m_configData.showHelp )
- showHelp( m_configData.processName );
- m_config.reset();
- }
- catch( std::exception& ex ) {
- {
- Colour colourGuard( Colour::Red );
- Catch::cerr()
- << "\nError(s) in input:\n"
- << Text( ex.what(), TextAttributes().setIndent(2) )
- << "\n\n";
+ // Wraps a token coming from a token stream. These may not directly correspond to strings as a single string
+ // may encode an option + its argument if the : or = form is used
+ enum class TokenType {
+ Option, Argument
+ };
+ struct Token {
+ TokenType type;
+ std::string token;
+ };
+
+ inline auto isOptPrefix( char c ) -> bool {
+ return c == '-'
+#ifdef CATCH_PLATFORM_WINDOWS
+ || c == '/'
+#endif
+ ;
+ }
+
+ // Abstracts iterators into args as a stream of tokens, with option arguments uniformly handled
+ class TokenStream {
+ using Iterator = std::vector<std::string>::const_iterator;
+ Iterator it;
+ Iterator itEnd;
+ std::vector<Token> m_tokenBuffer;
+
+ void loadBuffer() {
+ m_tokenBuffer.resize( 0 );
+
+ // Skip any empty strings
+ while( it != itEnd && it->empty() )
+ ++it;
+
+ if( it != itEnd ) {
+ auto const &next = *it;
+ if( isOptPrefix( next[0] ) ) {
+ auto delimiterPos = next.find_first_of( " :=" );
+ if( delimiterPos != std::string::npos ) {
+ m_tokenBuffer.push_back( { TokenType::Option, next.substr( 0, delimiterPos ) } );
+ m_tokenBuffer.push_back( { TokenType::Argument, next.substr( delimiterPos + 1 ) } );
+ } else {
+ if( next[1] != '-' && next.size() > 2 ) {
+ std::string opt = "- ";
+ for( size_t i = 1; i < next.size(); ++i ) {
+ opt[1] = next[i];
+ m_tokenBuffer.push_back( { TokenType::Option, opt } );
+ }
+ } else {
+ m_tokenBuffer.push_back( { TokenType::Option, next } );
+ }
+ }
+ } else {
+ m_tokenBuffer.push_back( { TokenType::Argument, next } );
}
- m_cli.usage( Catch::cout(), m_configData.processName );
- return (std::numeric_limits<int>::max)();
}
- return 0;
}
- void useConfigData( ConfigData const& _configData ) {
- m_configData = _configData;
- m_config.reset();
+ public:
+ explicit TokenStream( Args const &args ) : TokenStream( args.m_args.begin(), args.m_args.end() ) {}
+
+ TokenStream( Iterator it, Iterator itEnd ) : it( it ), itEnd( itEnd ) {
+ loadBuffer();
}
- int run( int argc, char const* const* const argv ) {
+ explicit operator bool() const {
+ return !m_tokenBuffer.empty() || it != itEnd;
+ }
- int returnCode = applyCommandLine( argc, argv );
- if( returnCode == 0 )
- returnCode = run();
- return returnCode;
+ auto count() const -> size_t { return m_tokenBuffer.size() + (itEnd - it); }
+
+ auto operator*() const -> Token {
+ assert( !m_tokenBuffer.empty() );
+ return m_tokenBuffer.front();
}
- #if defined(WIN32) && defined(UNICODE)
- int run( int argc, wchar_t const* const* const argv ) {
+ auto operator->() const -> Token const * {
+ assert( !m_tokenBuffer.empty() );
+ return &m_tokenBuffer.front();
+ }
- char **utf8Argv = new char *[ argc ];
+ auto operator++() -> TokenStream & {
+ if( m_tokenBuffer.size() >= 2 ) {
+ m_tokenBuffer.erase( m_tokenBuffer.begin() );
+ } else {
+ if( it != itEnd )
+ ++it;
+ loadBuffer();
+ }
+ return *this;
+ }
+ };
- for ( int i = 0; i < argc; ++i ) {
- int bufSize = WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, NULL, 0, NULL, NULL );
+ class ResultBase {
+ public:
+ enum Type {
+ Ok, LogicError, RuntimeError
+ };
- utf8Argv[ i ] = new char[ bufSize ];
+ protected:
+ ResultBase( Type type ) : m_type( type ) {}
+ virtual ~ResultBase() = default;
- WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, utf8Argv[i], bufSize, NULL, NULL );
- }
+ virtual void enforceOk() const = 0;
- int returnCode = applyCommandLine( argc, utf8Argv );
- if( returnCode == 0 )
- returnCode = run();
+ Type m_type;
+ };
- for ( int i = 0; i < argc; ++i )
- delete [] utf8Argv[ i ];
+ template<typename T>
+ class ResultValueBase : public ResultBase {
+ public:
+ auto value() const -> T const & {
+ enforceOk();
+ return m_value;
+ }
- delete [] utf8Argv;
+ protected:
+ ResultValueBase( Type type ) : ResultBase( type ) {}
- return returnCode;
+ ResultValueBase( ResultValueBase const &other ) : ResultBase( other ) {
+ if( m_type == ResultBase::Ok )
+ new( &m_value ) T( other.m_value );
}
- #endif
- int run() {
- if( m_configData.showHelp )
- return 0;
+ ResultValueBase( Type, T const &value ) : ResultBase( Ok ) {
+ new( &m_value ) T( value );
+ }
- try
- {
- config(); // Force config to be constructed
+ auto operator=( ResultValueBase const &other ) -> ResultValueBase & {
+ if( m_type == ResultBase::Ok )
+ m_value.~T();
+ ResultBase::operator=(other);
+ if( m_type == ResultBase::Ok )
+ new( &m_value ) T( other.m_value );
+ return *this;
+ }
- seedRng( *m_config );
+ ~ResultValueBase() override {
+ if( m_type == Ok )
+ m_value.~T();
+ }
- if( m_configData.filenamesAsTags )
- applyFilenamesAsTags( *m_config );
+ union {
+ T m_value;
+ };
+ };
- // Handle list request
- if( Option<std::size_t> listed = list( config() ) )
- return static_cast<int>( *listed );
+ template<>
+ class ResultValueBase<void> : public ResultBase {
+ protected:
+ using ResultBase::ResultBase;
+ };
- return static_cast<int>( runTests( m_config ).assertions.failed );
- }
- catch( std::exception& ex ) {
- Catch::cerr() << ex.what() << std::endl;
- return (std::numeric_limits<int>::max)();
- }
+ template<typename T = void>
+ class BasicResult : public ResultValueBase<T> {
+ public:
+ template<typename U>
+ explicit BasicResult( BasicResult<U> const &other )
+ : ResultValueBase<T>( other.type() ),
+ m_errorMessage( other.errorMessage() )
+ {
+ assert( type() != ResultBase::Ok );
}
- Clara::CommandLine<ConfigData> const& cli() const {
- return m_cli;
- }
- std::vector<Clara::Parser::Token> const& unusedTokens() const {
- return m_unusedTokens;
- }
- ConfigData& configData() {
- return m_configData;
+ template<typename U>
+ static auto ok( U const &value ) -> BasicResult { return { ResultBase::Ok, value }; }
+ static auto ok() -> BasicResult { return { ResultBase::Ok }; }
+ static auto logicError( std::string const &message ) -> BasicResult { return { ResultBase::LogicError, message }; }
+ static auto runtimeError( std::string const &message ) -> BasicResult { return { ResultBase::RuntimeError, message }; }
+
+ explicit operator bool() const { return m_type == ResultBase::Ok; }
+ auto type() const -> ResultBase::Type { return m_type; }
+ auto errorMessage() const -> std::string { return m_errorMessage; }
+
+ protected:
+ void enforceOk() const override {
+
+ // Errors shouldn't reach this point, but if they do
+ // the actual error message will be in m_errorMessage
+ assert( m_type != ResultBase::LogicError );
+ assert( m_type != ResultBase::RuntimeError );
+ if( m_type != ResultBase::Ok )
+ std::abort();
}
- Config& config() {
- if( !m_config )
- m_config = new Config( m_configData );
- return *m_config;
+
+ std::string m_errorMessage; // Only populated if resultType is an error
+
+ BasicResult( ResultBase::Type type, std::string const &message )
+ : ResultValueBase<T>(type),
+ m_errorMessage(message)
+ {
+ assert( m_type != ResultBase::Ok );
}
- private:
- Clara::CommandLine<ConfigData> m_cli;
- std::vector<Clara::Parser::Token> m_unusedTokens;
- ConfigData m_configData;
- Ptr<Config> m_config;
+
+ using ResultValueBase<T>::ResultValueBase;
+ using ResultBase::m_type;
};
- bool Session::alreadyInstantiated = false;
+ enum class ParseResultType {
+ Matched, NoMatch, ShortCircuitAll, ShortCircuitSame
+ };
-} // end namespace Catch
+ class ParseState {
+ public:
-// #included from: catch_registry_hub.hpp
-#define TWOBLUECUBES_CATCH_REGISTRY_HUB_HPP_INCLUDED
+ ParseState( ParseResultType type, TokenStream const &remainingTokens )
+ : m_type(type),
+ m_remainingTokens( remainingTokens )
+ {}
-// #included from: catch_test_case_registry_impl.hpp
-#define TWOBLUECUBES_CATCH_TEST_CASE_REGISTRY_IMPL_HPP_INCLUDED
+ auto type() const -> ParseResultType { return m_type; }
+ auto remainingTokens() const -> TokenStream { return m_remainingTokens; }
-#include <vector>
-#include <set>
-#include <sstream>
-#include <algorithm>
+ private:
+ ParseResultType m_type;
+ TokenStream m_remainingTokens;
+ };
-namespace Catch {
+ using Result = BasicResult<void>;
+ using ParserResult = BasicResult<ParseResultType>;
+ using InternalParseResult = BasicResult<ParseState>;
- struct RandomNumberGenerator {
- typedef std::ptrdiff_t result_type;
+ struct HelpColumns {
+ std::string left;
+ std::string right;
+ };
- result_type operator()( result_type n ) const { return std::rand() % n; }
+ template<typename T>
+ inline auto convertInto( std::string const &source, T& target ) -> ParserResult {
+ std::stringstream ss;
+ ss << source;
+ ss >> target;
+ if( ss.fail() )
+ return ParserResult::runtimeError( "Unable to convert '" + source + "' to destination type" );
+ else
+ return ParserResult::ok( ParseResultType::Matched );
+ }
+ inline auto convertInto( std::string const &source, std::string& target ) -> ParserResult {
+ target = source;
+ return ParserResult::ok( ParseResultType::Matched );
+ }
+ inline auto convertInto( std::string const &source, bool &target ) -> ParserResult {
+ std::string srcLC = source;
+ std::transform( srcLC.begin(), srcLC.end(), srcLC.begin(), []( char c ) { return static_cast<char>( std::tolower(c) ); } );
+ if (srcLC == "y" || srcLC == "1" || srcLC == "true" || srcLC == "yes" || srcLC == "on")
+ target = true;
+ else if (srcLC == "n" || srcLC == "0" || srcLC == "false" || srcLC == "no" || srcLC == "off")
+ target = false;
+ else
+ return ParserResult::runtimeError( "Expected a boolean value but did not recognise: '" + source + "'" );
+ return ParserResult::ok( ParseResultType::Matched );
+ }
+#ifdef CLARA_CONFIG_OPTIONAL_TYPE
+ template<typename T>
+ inline auto convertInto( std::string const &source, CLARA_CONFIG_OPTIONAL_TYPE<T>& target ) -> ParserResult {
+ T temp;
+ auto result = convertInto( source, temp );
+ if( result )
+ target = std::move(temp);
+ return result;
+ }
+#endif // CLARA_CONFIG_OPTIONAL_TYPE
+
+ struct NonCopyable {
+ NonCopyable() = default;
+ NonCopyable( NonCopyable const & ) = delete;
+ NonCopyable( NonCopyable && ) = delete;
+ NonCopyable &operator=( NonCopyable const & ) = delete;
+ NonCopyable &operator=( NonCopyable && ) = delete;
+ };
-#ifdef CATCH_CONFIG_CPP11_SHUFFLE
- static constexpr result_type min() { return 0; }
- static constexpr result_type max() { return 1000000; }
- result_type operator()() const { return std::rand() % max(); }
-#endif
- template<typename V>
- static void shuffle( V& vector ) {
- RandomNumberGenerator rng;
-#ifdef CATCH_CONFIG_CPP11_SHUFFLE
- std::shuffle( vector.begin(), vector.end(), rng );
-#else
- std::random_shuffle( vector.begin(), vector.end(), rng );
-#endif
- }
+ struct BoundRef : NonCopyable {
+ virtual ~BoundRef() = default;
+ virtual auto isContainer() const -> bool { return false; }
+ virtual auto isFlag() const -> bool { return false; }
+ };
+ struct BoundValueRefBase : BoundRef {
+ virtual auto setValue( std::string const &arg ) -> ParserResult = 0;
+ };
+ struct BoundFlagRefBase : BoundRef {
+ virtual auto setFlag( bool flag ) -> ParserResult = 0;
+ virtual auto isFlag() const -> bool { return true; }
};
- inline std::vector<TestCase> sortTests( IConfig const& config, std::vector<TestCase> const& unsortedTestCases ) {
+ template<typename T>
+ struct BoundValueRef : BoundValueRefBase {
+ T &m_ref;
- std::vector<TestCase> sorted = unsortedTestCases;
+ explicit BoundValueRef( T &ref ) : m_ref( ref ) {}
- switch( config.runOrder() ) {
- case RunTests::InLexicographicalOrder:
- std::sort( sorted.begin(), sorted.end() );
- break;
- case RunTests::InRandomOrder:
- {
- seedRng( config );
- RandomNumberGenerator::shuffle( sorted );
- }
- break;
- case RunTests::InDeclarationOrder:
- // already in declaration order
- break;
+ auto setValue( std::string const &arg ) -> ParserResult override {
+ return convertInto( arg, m_ref );
}
- return sorted;
- }
- bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ) {
- return testSpec.matches( testCase ) && ( config.allowThrows() || !testCase.throws() );
- }
+ };
- void enforceNoDuplicateTestCases( std::vector<TestCase> const& functions ) {
- std::set<TestCase> seenFunctions;
- for( std::vector<TestCase>::const_iterator it = functions.begin(), itEnd = functions.end();
- it != itEnd;
- ++it ) {
- std::pair<std::set<TestCase>::const_iterator, bool> prev = seenFunctions.insert( *it );
- if( !prev.second ) {
- std::ostringstream ss;
+ template<typename T>
+ struct BoundValueRef<std::vector<T>> : BoundValueRefBase {
+ std::vector<T> &m_ref;
- ss << Colour( Colour::Red )
- << "error: TEST_CASE( \"" << it->name << "\" ) already defined.\n"
- << "\tFirst seen at " << prev.first->getTestCaseInfo().lineInfo << '\n'
- << "\tRedefined at " << it->getTestCaseInfo().lineInfo << std::endl;
+ explicit BoundValueRef( std::vector<T> &ref ) : m_ref( ref ) {}
- throw std::runtime_error(ss.str());
- }
+ auto isContainer() const -> bool override { return true; }
+
+ auto setValue( std::string const &arg ) -> ParserResult override {
+ T temp;
+ auto result = convertInto( arg, temp );
+ if( result )
+ m_ref.push_back( temp );
+ return result;
}
- }
+ };
- std::vector<TestCase> filterTests( std::vector<TestCase> const& testCases, TestSpec const& testSpec, IConfig const& config ) {
- std::vector<TestCase> filtered;
- filtered.reserve( testCases.size() );
- for( std::vector<TestCase>::const_iterator it = testCases.begin(), itEnd = testCases.end();
- it != itEnd;
- ++it )
- if( matchTest( *it, testSpec, config ) )
- filtered.push_back( *it );
- return filtered;
- }
- std::vector<TestCase> const& getAllTestCasesSorted( IConfig const& config ) {
- return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config );
- }
+ struct BoundFlagRef : BoundFlagRefBase {
+ bool &m_ref;
- class TestRegistry : public ITestCaseRegistry {
- public:
- TestRegistry()
- : m_currentSortOrder( RunTests::InDeclarationOrder ),
- m_unnamedCount( 0 )
- {}
- virtual ~TestRegistry();
+ explicit BoundFlagRef( bool &ref ) : m_ref( ref ) {}
- virtual void registerTest( TestCase const& testCase ) {
- std::string name = testCase.getTestCaseInfo().name;
- if( name.empty() ) {
- std::ostringstream oss;
- oss << "Anonymous test case " << ++m_unnamedCount;
- return registerTest( testCase.withName( oss.str() ) );
- }
- m_functions.push_back( testCase );
+ auto setFlag( bool flag ) -> ParserResult override {
+ m_ref = flag;
+ return ParserResult::ok( ParseResultType::Matched );
}
+ };
- virtual std::vector<TestCase> const& getAllTests() const {
- return m_functions;
- }
- virtual std::vector<TestCase> const& getAllTestsSorted( IConfig const& config ) const {
- if( m_sortedFunctions.empty() )
- enforceNoDuplicateTestCases( m_functions );
+ template<typename ReturnType>
+ struct LambdaInvoker {
+ static_assert( std::is_same<ReturnType, ParserResult>::value, "Lambda must return void or clara::ParserResult" );
- if( m_currentSortOrder != config.runOrder() || m_sortedFunctions.empty() ) {
- m_sortedFunctions = sortTests( config, m_functions );
- m_currentSortOrder = config.runOrder();
- }
- return m_sortedFunctions;
+ template<typename L, typename ArgType>
+ static auto invoke( L const &lambda, ArgType const &arg ) -> ParserResult {
+ return lambda( arg );
}
+ };
- private:
- std::vector<TestCase> m_functions;
- mutable RunTests::InWhatOrder m_currentSortOrder;
- mutable std::vector<TestCase> m_sortedFunctions;
- size_t m_unnamedCount;
- std::ios_base::Init m_ostreamInit; // Forces cout/ cerr to be initialised
+ template<>
+ struct LambdaInvoker<void> {
+ template<typename L, typename ArgType>
+ static auto invoke( L const &lambda, ArgType const &arg ) -> ParserResult {
+ lambda( arg );
+ return ParserResult::ok( ParseResultType::Matched );
+ }
};
- ///////////////////////////////////////////////////////////////////////////
+ template<typename ArgType, typename L>
+ inline auto invokeLambda( L const &lambda, std::string const &arg ) -> ParserResult {
+ ArgType temp{};
+ auto result = convertInto( arg, temp );
+ return !result
+ ? result
+ : LambdaInvoker<typename UnaryLambdaTraits<L>::ReturnType>::invoke( lambda, temp );
+ }
- class FreeFunctionTestCase : public SharedImpl<ITestCase> {
- public:
+ template<typename L>
+ struct BoundLambda : BoundValueRefBase {
+ L m_lambda;
- FreeFunctionTestCase( TestFunction fun ) : m_fun( fun ) {}
+ static_assert( UnaryLambdaTraits<L>::isValid, "Supplied lambda must take exactly one argument" );
+ explicit BoundLambda( L const &lambda ) : m_lambda( lambda ) {}
- virtual void invoke() const {
- m_fun();
+ auto setValue( std::string const &arg ) -> ParserResult override {
+ return invokeLambda<typename UnaryLambdaTraits<L>::ArgType>( m_lambda, arg );
}
+ };
- private:
- virtual ~FreeFunctionTestCase();
+ template<typename L>
+ struct BoundFlagLambda : BoundFlagRefBase {
+ L m_lambda;
- TestFunction m_fun;
- };
+ static_assert( UnaryLambdaTraits<L>::isValid, "Supplied lambda must take exactly one argument" );
+ static_assert( std::is_same<typename UnaryLambdaTraits<L>::ArgType, bool>::value, "flags must be boolean" );
- inline std::string extractClassName( std::string const& classOrQualifiedMethodName ) {
- std::string className = classOrQualifiedMethodName;
- if( startsWith( className, '&' ) )
- {
- std::size_t lastColons = className.rfind( "::" );
- std::size_t penultimateColons = className.rfind( "::", lastColons-1 );
- if( penultimateColons == std::string::npos )
- penultimateColons = 1;
- className = className.substr( penultimateColons, lastColons-penultimateColons );
+ explicit BoundFlagLambda( L const &lambda ) : m_lambda( lambda ) {}
+
+ auto setFlag( bool flag ) -> ParserResult override {
+ return LambdaInvoker<typename UnaryLambdaTraits<L>::ReturnType>::invoke( m_lambda, flag );
}
- return className;
- }
+ };
- void registerTestCase
- ( ITestCase* testCase,
- char const* classOrQualifiedMethodName,
- NameAndDesc const& nameAndDesc,
- SourceLineInfo const& lineInfo ) {
+ enum class Optionality { Optional, Required };
- getMutableRegistryHub().registerTest
- ( makeTestCase
- ( testCase,
- extractClassName( classOrQualifiedMethodName ),
- nameAndDesc.name,
- nameAndDesc.description,
- lineInfo ) );
- }
- void registerTestCaseFunction
- ( TestFunction function,
- SourceLineInfo const& lineInfo,
- NameAndDesc const& nameAndDesc ) {
- registerTestCase( new FreeFunctionTestCase( function ), "", nameAndDesc, lineInfo );
- }
+ struct Parser;
- ///////////////////////////////////////////////////////////////////////////
+ class ParserBase {
+ public:
+ virtual ~ParserBase() = default;
+ virtual auto validate() const -> Result { return Result::ok(); }
+ virtual auto parse( std::string const& exeName, TokenStream const &tokens) const -> InternalParseResult = 0;
+ virtual auto cardinality() const -> size_t { return 1; }
- AutoReg::AutoReg
- ( TestFunction function,
- SourceLineInfo const& lineInfo,
- NameAndDesc const& nameAndDesc ) {
- registerTestCaseFunction( function, lineInfo, nameAndDesc );
- }
+ auto parse( Args const &args ) const -> InternalParseResult {
+ return parse( args.exeName(), TokenStream( args ) );
+ }
+ };
- AutoReg::~AutoReg() {}
+ template<typename DerivedT>
+ class ComposableParserImpl : public ParserBase {
+ public:
+ template<typename T>
+ auto operator|( T const &other ) const -> Parser;
-} // end namespace Catch
+ template<typename T>
+ auto operator+( T const &other ) const -> Parser;
+ };
-// #included from: catch_reporter_registry.hpp
-#define TWOBLUECUBES_CATCH_REPORTER_REGISTRY_HPP_INCLUDED
+ // Common code and state for Args and Opts
+ template<typename DerivedT>
+ class ParserRefImpl : public ComposableParserImpl<DerivedT> {
+ protected:
+ Optionality m_optionality = Optionality::Optional;
+ std::shared_ptr<BoundRef> m_ref;
+ std::string m_hint;
+ std::string m_description;
-#include <map>
+ explicit ParserRefImpl( std::shared_ptr<BoundRef> const &ref ) : m_ref( ref ) {}
-namespace Catch {
+ public:
+ template<typename T>
+ ParserRefImpl( T &ref, std::string const &hint )
+ : m_ref( std::make_shared<BoundValueRef<T>>( ref ) ),
+ m_hint( hint )
+ {}
- class ReporterRegistry : public IReporterRegistry {
+ template<typename LambdaT>
+ ParserRefImpl( LambdaT const &ref, std::string const &hint )
+ : m_ref( std::make_shared<BoundLambda<LambdaT>>( ref ) ),
+ m_hint(hint)
+ {}
- public:
+ auto operator()( std::string const &description ) -> DerivedT & {
+ m_description = description;
+ return static_cast<DerivedT &>( *this );
+ }
- virtual ~ReporterRegistry() CATCH_OVERRIDE {}
+ auto optional() -> DerivedT & {
+ m_optionality = Optionality::Optional;
+ return static_cast<DerivedT &>( *this );
+ };
+
+ auto required() -> DerivedT & {
+ m_optionality = Optionality::Required;
+ return static_cast<DerivedT &>( *this );
+ };
- virtual IStreamingReporter* create( std::string const& name, Ptr<IConfig const> const& config ) const CATCH_OVERRIDE {
- FactoryMap::const_iterator it = m_factories.find( name );
- if( it == m_factories.end() )
- return CATCH_NULL;
- return it->second->create( ReporterConfig( config ) );
+ auto isOptional() const -> bool {
+ return m_optionality == Optionality::Optional;
}
- void registerReporter( std::string const& name, Ptr<IReporterFactory> const& factory ) {
- m_factories.insert( std::make_pair( name, factory ) );
+ auto cardinality() const -> size_t override {
+ if( m_ref->isContainer() )
+ return 0;
+ else
+ return 1;
}
- void registerListener( Ptr<IReporterFactory> const& factory ) {
- m_listeners.push_back( factory );
+
+ auto hint() const -> std::string { return m_hint; }
+ };
+
+ class ExeName : public ComposableParserImpl<ExeName> {
+ std::shared_ptr<std::string> m_name;
+ std::shared_ptr<BoundValueRefBase> m_ref;
+
+ template<typename LambdaT>
+ static auto makeRef(LambdaT const &lambda) -> std::shared_ptr<BoundValueRefBase> {
+ return std::make_shared<BoundLambda<LambdaT>>( lambda) ;
}
- virtual FactoryMap const& getFactories() const CATCH_OVERRIDE {
- return m_factories;
+ public:
+ ExeName() : m_name( std::make_shared<std::string>( "<executable>" ) ) {}
+
+ explicit ExeName( std::string &ref ) : ExeName() {
+ m_ref = std::make_shared<BoundValueRef<std::string>>( ref );
}
- virtual Listeners const& getListeners() const CATCH_OVERRIDE {
- return m_listeners;
+
+ template<typename LambdaT>
+ explicit ExeName( LambdaT const& lambda ) : ExeName() {
+ m_ref = std::make_shared<BoundLambda<LambdaT>>( lambda );
}
- private:
- FactoryMap m_factories;
- Listeners m_listeners;
+ // The exe name is not parsed out of the normal tokens, but is handled specially
+ auto parse( std::string const&, TokenStream const &tokens ) const -> InternalParseResult override {
+ return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, tokens ) );
+ }
+
+ auto name() const -> std::string { return *m_name; }
+ auto set( std::string const& newName ) -> ParserResult {
+
+ auto lastSlash = newName.find_last_of( "\\/" );
+ auto filename = ( lastSlash == std::string::npos )
+ ? newName
+ : newName.substr( lastSlash+1 );
+
+ *m_name = filename;
+ if( m_ref )
+ return m_ref->setValue( filename );
+ else
+ return ParserResult::ok( ParseResultType::Matched );
+ }
};
-}
-// #included from: catch_exception_translator_registry.hpp
-#define TWOBLUECUBES_CATCH_EXCEPTION_TRANSLATOR_REGISTRY_HPP_INCLUDED
+ class Arg : public ParserRefImpl<Arg> {
+ public:
+ using ParserRefImpl::ParserRefImpl;
-#ifdef __OBJC__
-#import "Foundation/Foundation.h"
+ auto parse( std::string const &, TokenStream const &tokens ) const -> InternalParseResult override {
+ auto validationResult = validate();
+ if( !validationResult )
+ return InternalParseResult( validationResult );
+
+ auto remainingTokens = tokens;
+ auto const &token = *remainingTokens;
+ if( token.type != TokenType::Argument )
+ return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, remainingTokens ) );
+
+ assert( !m_ref->isFlag() );
+ auto valueRef = static_cast<detail::BoundValueRefBase*>( m_ref.get() );
+
+ auto result = valueRef->setValue( remainingTokens->token );
+ if( !result )
+ return InternalParseResult( result );
+ else
+ return InternalParseResult::ok( ParseState( ParseResultType::Matched, ++remainingTokens ) );
+ }
+ };
+
+ inline auto normaliseOpt( std::string const &optName ) -> std::string {
+#ifdef CATCH_PLATFORM_WINDOWS
+ if( optName[0] == '/' )
+ return "-" + optName.substr( 1 );
+ else
#endif
+ return optName;
+ }
-namespace Catch {
+ class Opt : public ParserRefImpl<Opt> {
+ protected:
+ std::vector<std::string> m_optNames;
- class ExceptionTranslatorRegistry : public IExceptionTranslatorRegistry {
public:
- ~ExceptionTranslatorRegistry() {
- deleteAll( m_translators );
+ template<typename LambdaT>
+ explicit Opt( LambdaT const &ref ) : ParserRefImpl( std::make_shared<BoundFlagLambda<LambdaT>>( ref ) ) {}
+
+ explicit Opt( bool &ref ) : ParserRefImpl( std::make_shared<BoundFlagRef>( ref ) ) {}
+
+ template<typename LambdaT>
+ Opt( LambdaT const &ref, std::string const &hint ) : ParserRefImpl( ref, hint ) {}
+
+ template<typename T>
+ Opt( T &ref, std::string const &hint ) : ParserRefImpl( ref, hint ) {}
+
+ auto operator[]( std::string const &optName ) -> Opt & {
+ m_optNames.push_back( optName );
+ return *this;
}
- virtual void registerTranslator( const IExceptionTranslator* translator ) {
- m_translators.push_back( translator );
+ auto getHelpColumns() const -> std::vector<HelpColumns> {
+ std::ostringstream oss;
+ bool first = true;
+ for( auto const &opt : m_optNames ) {
+ if (first)
+ first = false;
+ else
+ oss << ", ";
+ oss << opt;
+ }
+ if( !m_hint.empty() )
+ oss << " <" << m_hint << ">";
+ return { { oss.str(), m_description } };
}
- virtual std::string translateActiveException() const {
- try {
-#ifdef __OBJC__
- // In Objective-C try objective-c exceptions first
- @try {
- return tryTranslators();
- }
- @catch (NSException *exception) {
- return Catch::toString( [exception description] );
+ auto isMatch( std::string const &optToken ) const -> bool {
+ auto normalisedToken = normaliseOpt( optToken );
+ for( auto const &name : m_optNames ) {
+ if( normaliseOpt( name ) == normalisedToken )
+ return true;
+ }
+ return false;
+ }
+
+ using ParserBase::parse;
+
+ auto parse( std::string const&, TokenStream const &tokens ) const -> InternalParseResult override {
+ auto validationResult = validate();
+ if( !validationResult )
+ return InternalParseResult( validationResult );
+
+ auto remainingTokens = tokens;
+ if( remainingTokens && remainingTokens->type == TokenType::Option ) {
+ auto const &token = *remainingTokens;
+ if( isMatch(token.token ) ) {
+ if( m_ref->isFlag() ) {
+ auto flagRef = static_cast<detail::BoundFlagRefBase*>( m_ref.get() );
+ auto result = flagRef->setFlag( true );
+ if( !result )
+ return InternalParseResult( result );
+ if( result.value() == ParseResultType::ShortCircuitAll )
+ return InternalParseResult::ok( ParseState( result.value(), remainingTokens ) );
+ } else {
+ auto valueRef = static_cast<detail::BoundValueRefBase*>( m_ref.get() );
+ ++remainingTokens;
+ if( !remainingTokens )
+ return InternalParseResult::runtimeError( "Expected argument following " + token.token );
+ auto const &argToken = *remainingTokens;
+ if( argToken.type != TokenType::Argument )
+ return InternalParseResult::runtimeError( "Expected argument following " + token.token );
+ auto result = valueRef->setValue( argToken.token );
+ if( !result )
+ return InternalParseResult( result );
+ if( result.value() == ParseResultType::ShortCircuitAll )
+ return InternalParseResult::ok( ParseState( result.value(), remainingTokens ) );
+ }
+ return InternalParseResult::ok( ParseState( ParseResultType::Matched, ++remainingTokens ) );
}
+ }
+ return InternalParseResult::ok( ParseState( ParseResultType::NoMatch, remainingTokens ) );
+ }
+
+ auto validate() const -> Result override {
+ if( m_optNames.empty() )
+ return Result::logicError( "No options supplied to Opt" );
+ for( auto const &name : m_optNames ) {
+ if( name.empty() )
+ return Result::logicError( "Option name cannot be empty" );
+#ifdef CATCH_PLATFORM_WINDOWS
+ if( name[0] != '-' && name[0] != '/' )
+ return Result::logicError( "Option name must begin with '-' or '/'" );
#else
- return tryTranslators();
+ if( name[0] != '-' )
+ return Result::logicError( "Option name must begin with '-'" );
#endif
}
- catch( TestFailureException& ) {
- throw;
- }
- catch( std::exception& ex ) {
- return ex.what();
- }
- catch( std::string& msg ) {
- return msg;
- }
- catch( const char* msg ) {
- return msg;
- }
- catch(...) {
- return "Unknown exception";
- }
+ return ParserRefImpl::validate();
}
+ };
- std::string tryTranslators() const {
- if( m_translators.empty() )
- throw;
- else
- return m_translators[0]->translate( m_translators.begin()+1, m_translators.end() );
+ struct Help : Opt {
+ Help( bool &showHelpFlag )
+ : Opt([&]( bool flag ) {
+ showHelpFlag = flag;
+ return ParserResult::ok( ParseResultType::ShortCircuitAll );
+ })
+ {
+ static_cast<Opt &>( *this )
+ ("display usage information")
+ ["-?"]["-h"]["--help"]
+ .optional();
}
-
- private:
- std::vector<const IExceptionTranslator*> m_translators;
};
-}
-// #included from: catch_tag_alias_registry.h
-#define TWOBLUECUBES_CATCH_TAG_ALIAS_REGISTRY_H_INCLUDED
+ struct Parser : ParserBase {
-#include <map>
+ mutable ExeName m_exeName;
+ std::vector<Opt> m_options;
+ std::vector<Arg> m_args;
-namespace Catch {
-
- class TagAliasRegistry : public ITagAliasRegistry {
- public:
- virtual ~TagAliasRegistry();
- virtual Option<TagAlias> find( std::string const& alias ) const;
- virtual std::string expandAliases( std::string const& unexpandedTestSpec ) const;
- void add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo );
-
- private:
- std::map<std::string, TagAlias> m_registry;
- };
+ auto operator|=( ExeName const &exeName ) -> Parser & {
+ m_exeName = exeName;
+ return *this;
+ }
-} // end namespace Catch
+ auto operator|=( Arg const &arg ) -> Parser & {
+ m_args.push_back(arg);
+ return *this;
+ }
-namespace Catch {
+ auto operator|=( Opt const &opt ) -> Parser & {
+ m_options.push_back(opt);
+ return *this;
+ }
- namespace {
+ auto operator|=( Parser const &other ) -> Parser & {
+ m_options.insert(m_options.end(), other.m_options.begin(), other.m_options.end());
+ m_args.insert(m_args.end(), other.m_args.begin(), other.m_args.end());
+ return *this;
+ }
- class RegistryHub : public IRegistryHub, public IMutableRegistryHub {
+ template<typename T>
+ auto operator|( T const &other ) const -> Parser {
+ return Parser( *this ) |= other;
+ }
- RegistryHub( RegistryHub const& );
- void operator=( RegistryHub const& );
+ // Forward deprecated interface with '+' instead of '|'
+ template<typename T>
+ auto operator+=( T const &other ) -> Parser & { return operator|=( other ); }
+ template<typename T>
+ auto operator+( T const &other ) const -> Parser { return operator|( other ); }
- public: // IRegistryHub
- RegistryHub() {
- }
- virtual IReporterRegistry const& getReporterRegistry() const CATCH_OVERRIDE {
- return m_reporterRegistry;
- }
- virtual ITestCaseRegistry const& getTestCaseRegistry() const CATCH_OVERRIDE {
- return m_testCaseRegistry;
- }
- virtual IExceptionTranslatorRegistry& getExceptionTranslatorRegistry() CATCH_OVERRIDE {
- return m_exceptionTranslatorRegistry;
- }
- virtual ITagAliasRegistry const& getTagAliasRegistry() const CATCH_OVERRIDE {
- return m_tagAliasRegistry;
+ auto getHelpColumns() const -> std::vector<HelpColumns> {
+ std::vector<HelpColumns> cols;
+ for (auto const &o : m_options) {
+ auto childCols = o.getHelpColumns();
+ cols.insert( cols.end(), childCols.begin(), childCols.end() );
}
+ return cols;
+ }
- public: // IMutableRegistryHub
- virtual void registerReporter( std::string const& name, Ptr<IReporterFactory> const& factory ) CATCH_OVERRIDE {
- m_reporterRegistry.registerReporter( name, factory );
+ void writeToStream( std::ostream &os ) const {
+ if (!m_exeName.name().empty()) {
+ os << "usage:\n" << " " << m_exeName.name() << " ";
+ bool required = true, first = true;
+ for( auto const &arg : m_args ) {
+ if (first)
+ first = false;
+ else
+ os << " ";
+ if( arg.isOptional() && required ) {
+ os << "[";
+ required = false;
+ }
+ os << "<" << arg.hint() << ">";
+ if( arg.cardinality() == 0 )
+ os << " ... ";
+ }
+ if( !required )
+ os << "]";
+ if( !m_options.empty() )
+ os << " options";
+ os << "\n\nwhere options are:" << std::endl;
}
- virtual void registerListener( Ptr<IReporterFactory> const& factory ) CATCH_OVERRIDE {
- m_reporterRegistry.registerListener( factory );
+
+ auto rows = getHelpColumns();
+ size_t consoleWidth = CATCH_CLARA_CONFIG_CONSOLE_WIDTH;
+ size_t optWidth = 0;
+ for( auto const &cols : rows )
+ optWidth = (std::max)(optWidth, cols.left.size() + 2);
+
+ optWidth = (std::min)(optWidth, consoleWidth/2);
+
+ for( auto const &cols : rows ) {
+ auto row =
+ TextFlow::Column( cols.left ).width( optWidth ).indent( 2 ) +
+ TextFlow::Spacer(4) +
+ TextFlow::Column( cols.right ).width( consoleWidth - 7 - optWidth );
+ os << row << std::endl;
}
- virtual void registerTest( TestCase const& testInfo ) CATCH_OVERRIDE {
- m_testCaseRegistry.registerTest( testInfo );
+ }
+
+ friend auto operator<<( std::ostream &os, Parser const &parser ) -> std::ostream& {
+ parser.writeToStream( os );
+ return os;
+ }
+
+ auto validate() const -> Result override {
+ for( auto const &opt : m_options ) {
+ auto result = opt.validate();
+ if( !result )
+ return result;
}
- virtual void registerTranslator( const IExceptionTranslator* translator ) CATCH_OVERRIDE {
- m_exceptionTranslatorRegistry.registerTranslator( translator );
+ for( auto const &arg : m_args ) {
+ auto result = arg.validate();
+ if( !result )
+ return result;
}
- virtual void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) CATCH_OVERRIDE {
- m_tagAliasRegistry.add( alias, tag, lineInfo );
+ return Result::ok();
+ }
+
+ using ParserBase::parse;
+
+ auto parse( std::string const& exeName, TokenStream const &tokens ) const -> InternalParseResult override {
+
+ struct ParserInfo {
+ ParserBase const* parser = nullptr;
+ size_t count = 0;
+ };
+ const size_t totalParsers = m_options.size() + m_args.size();
+ assert( totalParsers < 512 );
+ // ParserInfo parseInfos[totalParsers]; // <-- this is what we really want to do
+ ParserInfo parseInfos[512];
+
+ {
+ size_t i = 0;
+ for (auto const &opt : m_options) parseInfos[i++].parser = &opt;
+ for (auto const &arg : m_args) parseInfos[i++].parser = &arg;
}
- private:
- TestRegistry m_testCaseRegistry;
- ReporterRegistry m_reporterRegistry;
- ExceptionTranslatorRegistry m_exceptionTranslatorRegistry;
- TagAliasRegistry m_tagAliasRegistry;
- };
+ m_exeName.set( exeName );
+
+ auto result = InternalParseResult::ok( ParseState( ParseResultType::NoMatch, tokens ) );
+ while( result.value().remainingTokens() ) {
+ bool tokenParsed = false;
+
+ for( size_t i = 0; i < totalParsers; ++i ) {
+ auto& parseInfo = parseInfos[i];
+ if( parseInfo.parser->cardinality() == 0 || parseInfo.count < parseInfo.parser->cardinality() ) {
+ result = parseInfo.parser->parse(exeName, result.value().remainingTokens());
+ if (!result)
+ return result;
+ if (result.value().type() != ParseResultType::NoMatch) {
+ tokenParsed = true;
+ ++parseInfo.count;
+ break;
+ }
+ }
+ }
- // Single, global, instance
- inline RegistryHub*& getTheRegistryHub() {
- static RegistryHub* theRegistryHub = CATCH_NULL;
- if( !theRegistryHub )
- theRegistryHub = new RegistryHub();
- return theRegistryHub;
+ if( result.value().type() == ParseResultType::ShortCircuitAll )
+ return result;
+ if( !tokenParsed )
+ return InternalParseResult::runtimeError( "Unrecognised token: " + result.value().remainingTokens()->token );
+ }
+ // !TBD Check missing required options
+ return result;
}
- }
+ };
- IRegistryHub& getRegistryHub() {
- return *getTheRegistryHub();
- }
- IMutableRegistryHub& getMutableRegistryHub() {
- return *getTheRegistryHub();
- }
- void cleanUp() {
- delete getTheRegistryHub();
- getTheRegistryHub() = CATCH_NULL;
- cleanUpContext();
- }
- std::string translateActiveException() {
- return getRegistryHub().getExceptionTranslatorRegistry().translateActiveException();
+ template<typename DerivedT>
+ template<typename T>
+ auto ComposableParserImpl<DerivedT>::operator|( T const &other ) const -> Parser {
+ return Parser() | static_cast<DerivedT const &>( *this ) | other;
}
+} // namespace detail
-} // end namespace Catch
+// A Combined parser
+using detail::Parser;
-// #included from: catch_notimplemented_exception.hpp
-#define TWOBLUECUBES_CATCH_NOTIMPLEMENTED_EXCEPTION_HPP_INCLUDED
+// A parser for options
+using detail::Opt;
-#include <sstream>
+// A parser for arguments
+using detail::Arg;
-namespace Catch {
+// Wrapper for argc, argv from main()
+using detail::Args;
- NotImplementedException::NotImplementedException( SourceLineInfo const& lineInfo )
- : m_lineInfo( lineInfo ) {
- std::ostringstream oss;
- oss << lineInfo << ": function ";
- oss << "not implemented";
- m_what = oss.str();
- }
+// Specifies the name of the executable
+using detail::ExeName;
- const char* NotImplementedException::what() const CATCH_NOEXCEPT {
- return m_what.c_str();
- }
+// Convenience wrapper for option parser that specifies the help option
+using detail::Help;
-} // end namespace Catch
+// enum of result types from a parse
+using detail::ParseResultType;
-// #included from: catch_context_impl.hpp
-#define TWOBLUECUBES_CATCH_CONTEXT_IMPL_HPP_INCLUDED
+// Result type for parser operation
+using detail::ParserResult;
-// #included from: catch_stream.hpp
-#define TWOBLUECUBES_CATCH_STREAM_HPP_INCLUDED
+}} // namespace Catch::clara
-#include <stdexcept>
-#include <cstdio>
-#include <iostream>
+// end clara.hpp
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif
+// Restore Clara's value for console width, if present
+#ifdef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH
+#define CATCH_CLARA_TEXTFLOW_CONFIG_CONSOLE_WIDTH CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH
+#undef CATCH_TEMP_CLARA_CONFIG_CONSOLE_WIDTH
+#endif
+
+// end catch_clara.h
namespace Catch {
- template<typename WriterF, size_t bufferSize=256>
- class StreamBufImpl : public StreamBufBase {
- char data[bufferSize];
- WriterF m_writer;
+ clara::Parser makeCommandLineParser( ConfigData& config );
- public:
- StreamBufImpl() {
- setp( data, data + sizeof(data) );
- }
+} // end namespace Catch
- ~StreamBufImpl() CATCH_NOEXCEPT {
- sync();
- }
+// end catch_commandline.h
+#include <fstream>
+#include <ctime>
- private:
- int overflow( int c ) {
- sync();
+namespace Catch {
- if( c != EOF ) {
- if( pbase() == epptr() )
- m_writer( std::string( 1, static_cast<char>( c ) ) );
- else
- sputc( static_cast<char>( c ) );
- }
- return 0;
- }
+ clara::Parser makeCommandLineParser( ConfigData& config ) {
- int sync() {
- if( pbase() != pptr() ) {
- m_writer( std::string( pbase(), static_cast<std::string::size_type>( pptr() - pbase() ) ) );
- setp( pbase(), epptr() );
- }
- return 0;
- }
- };
+ using namespace clara;
- ///////////////////////////////////////////////////////////////////////////
+ auto const setWarning = [&]( std::string const& warning ) {
+ auto warningSet = [&]() {
+ if( warning == "NoAssertions" )
+ return WarnAbout::NoAssertions;
- FileStream::FileStream( std::string const& filename ) {
- m_ofs.open( filename.c_str() );
- if( m_ofs.fail() ) {
- std::ostringstream oss;
- oss << "Unable to open file: '" << filename << '\'';
- throw std::domain_error( oss.str() );
- }
- }
+ if ( warning == "NoTests" )
+ return WarnAbout::NoTests;
- std::ostream& FileStream::stream() const {
- return m_ofs;
- }
+ return WarnAbout::Nothing;
+ }();
- struct OutputDebugWriter {
+ if (warningSet == WarnAbout::Nothing)
+ return ParserResult::runtimeError( "Unrecognised warning: '" + warning + "'" );
+ config.warnings = static_cast<WarnAbout::What>( config.warnings | warningSet );
+ return ParserResult::ok( ParseResultType::Matched );
+ };
+ auto const loadTestNamesFromFile = [&]( std::string const& filename ) {
+ std::ifstream f( filename.c_str() );
+ if( !f.is_open() )
+ return ParserResult::runtimeError( "Unable to load input file: '" + filename + "'" );
+
+ std::string line;
+ while( std::getline( f, line ) ) {
+ line = trim(line);
+ if( !line.empty() && !startsWith( line, '#' ) ) {
+ if( !startsWith( line, '"' ) )
+ line = '"' + line + '"';
+ config.testsOrTags.push_back( line );
+ config.testsOrTags.push_back( "," );
- void operator()( std::string const&str ) {
- writeToDebugConsole( str );
- }
- };
+ }
+ }
+ //Remove comma in the end
+ if(!config.testsOrTags.empty())
+ config.testsOrTags.erase( config.testsOrTags.end()-1 );
- DebugOutStream::DebugOutStream()
- : m_streamBuf( new StreamBufImpl<OutputDebugWriter>() ),
- m_os( m_streamBuf.get() )
- {}
+ return ParserResult::ok( ParseResultType::Matched );
+ };
+ auto const setTestOrder = [&]( std::string const& order ) {
+ if( startsWith( "declared", order ) )
+ config.runOrder = RunTests::InDeclarationOrder;
+ else if( startsWith( "lexical", order ) )
+ config.runOrder = RunTests::InLexicographicalOrder;
+ else if( startsWith( "random", order ) )
+ config.runOrder = RunTests::InRandomOrder;
+ else
+ return clara::ParserResult::runtimeError( "Unrecognised ordering: '" + order + "'" );
+ return ParserResult::ok( ParseResultType::Matched );
+ };
+ auto const setRngSeed = [&]( std::string const& seed ) {
+ if( seed != "time" )
+ return clara::detail::convertInto( seed, config.rngSeed );
+ config.rngSeed = static_cast<unsigned int>( std::time(nullptr) );
+ return ParserResult::ok( ParseResultType::Matched );
+ };
+ auto const setColourUsage = [&]( std::string const& useColour ) {
+ auto mode = toLower( useColour );
+
+ if( mode == "yes" )
+ config.useColour = UseColour::Yes;
+ else if( mode == "no" )
+ config.useColour = UseColour::No;
+ else if( mode == "auto" )
+ config.useColour = UseColour::Auto;
+ else
+ return ParserResult::runtimeError( "colour mode must be one of: auto, yes or no. '" + useColour + "' not recognised" );
+ return ParserResult::ok( ParseResultType::Matched );
+ };
+ auto const setWaitForKeypress = [&]( std::string const& keypress ) {
+ auto keypressLc = toLower( keypress );
+ if( keypressLc == "start" )
+ config.waitForKeypress = WaitForKeypress::BeforeStart;
+ else if( keypressLc == "exit" )
+ config.waitForKeypress = WaitForKeypress::BeforeExit;
+ else if( keypressLc == "both" )
+ config.waitForKeypress = WaitForKeypress::BeforeStartAndExit;
+ else
+ return ParserResult::runtimeError( "keypress argument must be one of: start, exit or both. '" + keypress + "' not recognised" );
+ return ParserResult::ok( ParseResultType::Matched );
+ };
+ auto const setVerbosity = [&]( std::string const& verbosity ) {
+ auto lcVerbosity = toLower( verbosity );
+ if( lcVerbosity == "quiet" )
+ config.verbosity = Verbosity::Quiet;
+ else if( lcVerbosity == "normal" )
+ config.verbosity = Verbosity::Normal;
+ else if( lcVerbosity == "high" )
+ config.verbosity = Verbosity::High;
+ else
+ return ParserResult::runtimeError( "Unrecognised verbosity, '" + verbosity + "'" );
+ return ParserResult::ok( ParseResultType::Matched );
+ };
+ auto const setReporter = [&]( std::string const& reporter ) {
+ IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories();
- std::ostream& DebugOutStream::stream() const {
- return m_os;
- }
+ auto lcReporter = toLower( reporter );
+ auto result = factories.find( lcReporter );
- // Store the streambuf from cout up-front because
- // cout may get redirected when running tests
- CoutStream::CoutStream()
- : m_os( Catch::cout().rdbuf() )
- {}
+ if( factories.end() != result )
+ config.reporterName = lcReporter;
+ else
+ return ParserResult::runtimeError( "Unrecognized reporter, '" + reporter + "'. Check available with --list-reporters" );
+ return ParserResult::ok( ParseResultType::Matched );
+ };
- std::ostream& CoutStream::stream() const {
- return m_os;
+ auto cli
+ = ExeName( config.processName )
+ | Help( config.showHelp )
+ | Opt( config.listTests )
+ ["-l"]["--list-tests"]
+ ( "list all/matching test cases" )
+ | Opt( config.listTags )
+ ["-t"]["--list-tags"]
+ ( "list all/matching tags" )
+ | Opt( config.showSuccessfulTests )
+ ["-s"]["--success"]
+ ( "include successful tests in output" )
+ | Opt( config.shouldDebugBreak )
+ ["-b"]["--break"]
+ ( "break into debugger on failure" )
+ | Opt( config.noThrow )
+ ["-e"]["--nothrow"]
+ ( "skip exception tests" )
+ | Opt( config.showInvisibles )
+ ["-i"]["--invisibles"]
+ ( "show invisibles (tabs, newlines)" )
+ | Opt( config.outputFilename, "filename" )
+ ["-o"]["--out"]
+ ( "output filename" )
+ | Opt( setReporter, "name" )
+ ["-r"]["--reporter"]
+ ( "reporter to use (defaults to console)" )
+ | Opt( config.name, "name" )
+ ["-n"]["--name"]
+ ( "suite name" )
+ | Opt( [&]( bool ){ config.abortAfter = 1; } )
+ ["-a"]["--abort"]
+ ( "abort at first failure" )
+ | Opt( [&]( int x ){ config.abortAfter = x; }, "no. failures" )
+ ["-x"]["--abortx"]
+ ( "abort after x failures" )
+ | Opt( setWarning, "warning name" )
+ ["-w"]["--warn"]
+ ( "enable warnings" )
+ | Opt( [&]( bool flag ) { config.showDurations = flag ? ShowDurations::Always : ShowDurations::Never; }, "yes|no" )
+ ["-d"]["--durations"]
+ ( "show test durations" )
+ | Opt( loadTestNamesFromFile, "filename" )
+ ["-f"]["--input-file"]
+ ( "load test names to run from a file" )
+ | Opt( config.filenamesAsTags )
+ ["-#"]["--filenames-as-tags"]
+ ( "adds a tag for the filename" )
+ | Opt( config.sectionsToRun, "section name" )
+ ["-c"]["--section"]
+ ( "specify section to run" )
+ | Opt( setVerbosity, "quiet|normal|high" )
+ ["-v"]["--verbosity"]
+ ( "set output verbosity" )
+ | Opt( config.listTestNamesOnly )
+ ["--list-test-names-only"]
+ ( "list all/matching test cases names only" )
+ | Opt( config.listReporters )
+ ["--list-reporters"]
+ ( "list all reporters" )
+ | Opt( setTestOrder, "decl|lex|rand" )
+ ["--order"]
+ ( "test case order (defaults to decl)" )
+ | Opt( setRngSeed, "'time'|number" )
+ ["--rng-seed"]
+ ( "set a specific seed for random numbers" )
+ | Opt( setColourUsage, "yes|no" )
+ ["--use-colour"]
+ ( "should output be colourised" )
+ | Opt( config.libIdentify )
+ ["--libidentify"]
+ ( "report name and version according to libidentify standard" )
+ | Opt( setWaitForKeypress, "start|exit|both" )
+ ["--wait-for-keypress"]
+ ( "waits for a keypress before exiting" )
+ | Opt( config.benchmarkSamples, "samples" )
+ ["--benchmark-samples"]
+ ( "number of samples to collect (default: 100)" )
+ | Opt( config.benchmarkResamples, "resamples" )
+ ["--benchmark-resamples"]
+ ( "number of resamples for the bootstrap (default: 100000)" )
+ | Opt( config.benchmarkConfidenceInterval, "confidence interval" )
+ ["--benchmark-confidence-interval"]
+ ( "confidence interval for the bootstrap (between 0 and 1, default: 0.95)" )
+ | Opt( config.benchmarkNoAnalysis )
+ ["--benchmark-no-analysis"]
+ ( "perform only measurements; do not perform any analysis" )
+ | Arg( config.testsOrTags, "test name|pattern|tags" )
+ ( "which test or tests to use" );
+
+ return cli;
}
-#ifndef CATCH_CONFIG_NOSTDOUT // If you #define this you must implement these functions
- std::ostream& cout() {
- return std::cout;
+} // end namespace Catch
+// end catch_commandline.cpp
+// start catch_common.cpp
+
+#include <cstring>
+#include <ostream>
+
+namespace Catch {
+
+ bool SourceLineInfo::operator == ( SourceLineInfo const& other ) const noexcept {
+ return line == other.line && (file == other.file || std::strcmp(file, other.file) == 0);
}
- std::ostream& cerr() {
- return std::cerr;
+ bool SourceLineInfo::operator < ( SourceLineInfo const& other ) const noexcept {
+ // We can assume that the same file will usually have the same pointer.
+ // Thus, if the pointers are the same, there is no point in calling the strcmp
+ return line < other.line || ( line == other.line && file != other.file && (std::strcmp(file, other.file) < 0));
}
+
+ std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ) {
+#ifndef __GNUG__
+ os << info.file << '(' << info.line << ')';
+#else
+ os << info.file << ':' << info.line;
#endif
-}
+ return os;
+ }
-namespace Catch {
+ std::string StreamEndStop::operator+() const {
+ return std::string();
+ }
- class Context : public IMutableContext {
+ NonCopyable::NonCopyable() = default;
+ NonCopyable::~NonCopyable() = default;
- Context() : m_config( CATCH_NULL ), m_runner( CATCH_NULL ), m_resultCapture( CATCH_NULL ) {}
- Context( Context const& );
- void operator=( Context const& );
+}
+// end catch_common.cpp
+// start catch_config.cpp
- public:
- virtual ~Context() {
- deleteAllValues( m_generatorsByTestName );
- }
+namespace Catch {
- public: // IContext
- virtual IResultCapture* getResultCapture() {
- return m_resultCapture;
- }
- virtual IRunner* getRunner() {
- return m_runner;
- }
- virtual size_t getGeneratorIndex( std::string const& fileInfo, size_t totalSize ) {
- return getGeneratorsForCurrentTest()
- .getGeneratorInfo( fileInfo, totalSize )
- .getCurrentIndex();
+ Config::Config( ConfigData const& data )
+ : m_data( data ),
+ m_stream( openStream() )
+ {
+ // We need to trim filter specs to avoid trouble with superfluous
+ // whitespace (esp. important for bdd macros, as those are manually
+ // aligned with whitespace).
+
+ for (auto& elem : m_data.testsOrTags) {
+ elem = trim(elem);
}
- virtual bool advanceGeneratorsForCurrentTest() {
- IGeneratorsForTest* generators = findGeneratorsForCurrentTest();
- return generators && generators->moveNext();
+ for (auto& elem : m_data.sectionsToRun) {
+ elem = trim(elem);
}
- virtual Ptr<IConfig const> getConfig() const {
- return m_config;
+ TestSpecParser parser(ITagAliasRegistry::get());
+ if (!m_data.testsOrTags.empty()) {
+ m_hasTestFilters = true;
+ for (auto const& testOrTags : m_data.testsOrTags) {
+ parser.parse(testOrTags);
+ }
}
+ m_testSpec = parser.testSpec();
+ }
- public: // IMutableContext
- virtual void setResultCapture( IResultCapture* resultCapture ) {
- m_resultCapture = resultCapture;
- }
- virtual void setRunner( IRunner* runner ) {
- m_runner = runner;
- }
- virtual void setConfig( Ptr<IConfig const> const& config ) {
- m_config = config;
- }
+ std::string const& Config::getFilename() const {
+ return m_data.outputFilename ;
+ }
- friend IMutableContext& getCurrentMutableContext();
+ bool Config::listTests() const { return m_data.listTests; }
+ bool Config::listTestNamesOnly() const { return m_data.listTestNamesOnly; }
+ bool Config::listTags() const { return m_data.listTags; }
+ bool Config::listReporters() const { return m_data.listReporters; }
- private:
- IGeneratorsForTest* findGeneratorsForCurrentTest() {
- std::string testName = getResultCapture()->getCurrentTestName();
+ std::string Config::getProcessName() const { return m_data.processName; }
+ std::string const& Config::getReporterName() const { return m_data.reporterName; }
- std::map<std::string, IGeneratorsForTest*>::const_iterator it =
- m_generatorsByTestName.find( testName );
- return it != m_generatorsByTestName.end()
- ? it->second
- : CATCH_NULL;
- }
+ std::vector<std::string> const& Config::getTestsOrTags() const { return m_data.testsOrTags; }
+ std::vector<std::string> const& Config::getSectionsToRun() const { return m_data.sectionsToRun; }
- IGeneratorsForTest& getGeneratorsForCurrentTest() {
- IGeneratorsForTest* generators = findGeneratorsForCurrentTest();
- if( !generators ) {
- std::string testName = getResultCapture()->getCurrentTestName();
- generators = createGeneratorsForTest();
- m_generatorsByTestName.insert( std::make_pair( testName, generators ) );
- }
- return *generators;
- }
+ TestSpec const& Config::testSpec() const { return m_testSpec; }
+ bool Config::hasTestFilters() const { return m_hasTestFilters; }
- private:
- Ptr<IConfig const> m_config;
- IRunner* m_runner;
- IResultCapture* m_resultCapture;
- std::map<std::string, IGeneratorsForTest*> m_generatorsByTestName;
- };
+ bool Config::showHelp() const { return m_data.showHelp; }
- namespace {
- Context* currentContext = CATCH_NULL;
- }
- IMutableContext& getCurrentMutableContext() {
- if( !currentContext )
- currentContext = new Context();
- return *currentContext;
- }
- IContext& getCurrentContext() {
- return getCurrentMutableContext();
- }
+ // IConfig interface
+ bool Config::allowThrows() const { return !m_data.noThrow; }
+ std::ostream& Config::stream() const { return m_stream->stream(); }
+ std::string Config::name() const { return m_data.name.empty() ? m_data.processName : m_data.name; }
+ bool Config::includeSuccessfulResults() const { return m_data.showSuccessfulTests; }
+ bool Config::warnAboutMissingAssertions() const { return !!(m_data.warnings & WarnAbout::NoAssertions); }
+ bool Config::warnAboutNoTests() const { return !!(m_data.warnings & WarnAbout::NoTests); }
+ ShowDurations::OrNot Config::showDurations() const { return m_data.showDurations; }
+ RunTests::InWhatOrder Config::runOrder() const { return m_data.runOrder; }
+ unsigned int Config::rngSeed() const { return m_data.rngSeed; }
+ UseColour::YesOrNo Config::useColour() const { return m_data.useColour; }
+ bool Config::shouldDebugBreak() const { return m_data.shouldDebugBreak; }
+ int Config::abortAfter() const { return m_data.abortAfter; }
+ bool Config::showInvisibles() const { return m_data.showInvisibles; }
+ Verbosity Config::verbosity() const { return m_data.verbosity; }
- void cleanUpContext() {
- delete currentContext;
- currentContext = CATCH_NULL;
+ bool Config::benchmarkNoAnalysis() const { return m_data.benchmarkNoAnalysis; }
+ int Config::benchmarkSamples() const { return m_data.benchmarkSamples; }
+ double Config::benchmarkConfidenceInterval() const { return m_data.benchmarkConfidenceInterval; }
+ unsigned int Config::benchmarkResamples() const { return m_data.benchmarkResamples; }
+
+ IStream const* Config::openStream() {
+ return Catch::makeStream(m_data.outputFilename);
}
-}
-// #included from: catch_console_colour_impl.hpp
-#define TWOBLUECUBES_CATCH_CONSOLE_COLOUR_IMPL_HPP_INCLUDED
+} // end namespace Catch
+// end catch_config.cpp
+// start catch_console_colour.cpp
-// #included from: catch_errno_guard.hpp
-#define TWOBLUECUBES_CATCH_ERRNO_GUARD_HPP_INCLUDED
+#if defined(__clang__)
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wexit-time-destructors"
+#endif
-#include <cerrno>
+// start catch_errno_guard.h
namespace Catch {
class ErrnoGuard {
public:
- ErrnoGuard():m_oldErrno(errno){}
- ~ErrnoGuard() { errno = m_oldErrno; }
+ ErrnoGuard();
+ ~ErrnoGuard();
private:
int m_oldErrno;
};
}
+// end catch_errno_guard.h
+#include <sstream>
+
namespace Catch {
namespace {
struct IColourImpl {
- virtual ~IColourImpl() {}
+ virtual ~IColourImpl() = default;
virtual void use( Colour::Code _colourCode ) = 0;
};
@@ -7797,7 +9954,7 @@ namespace {
originalBackgroundAttributes = csbiInfo.wAttributes & ~( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE | FOREGROUND_INTENSITY );
}
- virtual void use( Colour::Code _colourCode ) {
+ void use( Colour::Code _colourCode ) override {
switch( _colourCode ) {
case Colour::None: return setTextAttribute( originalForegroundAttributes );
case Colour::White: return setTextAttribute( FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE );
@@ -7812,8 +9969,12 @@ namespace {
case Colour::BrightRed: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_RED );
case Colour::BrightGreen: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN );
case Colour::BrightWhite: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE );
+ case Colour::BrightYellow: return setTextAttribute( FOREGROUND_INTENSITY | FOREGROUND_RED | FOREGROUND_GREEN );
+
+ case Colour::Bright: CATCH_INTERNAL_ERROR( "not a colour" );
- case Colour::Bright: throw std::logic_error( "not a colour" );
+ default:
+ CATCH_ERROR( "Unknown colour requested" );
}
}
@@ -7829,14 +9990,12 @@ namespace {
IColourImpl* platformColourInstance() {
static Win32ColourImpl s_instance;
- Ptr<IConfig const> config = getCurrentContext().getConfig();
+ IConfigPtr config = getCurrentContext().getConfig();
UseColour::YesOrNo colourMode = config
? config->useColour()
: UseColour::Auto;
if( colourMode == UseColour::Auto )
- colourMode = !isDebuggerActive()
- ? UseColour::Yes
- : UseColour::No;
+ colourMode = UseColour::Yes;
return colourMode == UseColour::Yes
? &s_instance
: NoColourImpl::instance();
@@ -7858,7 +10017,7 @@ namespace {
// https://github.com/philsquared/Catch/pull/131
class PosixColourImpl : public IColourImpl {
public:
- virtual void use( Colour::Code _colourCode ) {
+ void use( Colour::Code _colourCode ) override {
switch( _colourCode ) {
case Colour::None:
case Colour::White: return setColour( "[0m" );
@@ -7873,8 +10032,10 @@ namespace {
case Colour::BrightRed: return setColour( "[1;31m" );
case Colour::BrightGreen: return setColour( "[1;32m" );
case Colour::BrightWhite: return setColour( "[1;37m" );
+ case Colour::BrightYellow: return setColour( "[1;33m" );
- case Colour::Bright: throw std::logic_error( "not a colour" );
+ case Colour::Bright: CATCH_INTERNAL_ERROR( "not a colour" );
+ default: CATCH_INTERNAL_ERROR( "Unknown colour requested" );
}
}
static IColourImpl* instance() {
@@ -7884,18 +10045,31 @@ namespace {
private:
void setColour( const char* _escapeCode ) {
- Catch::cout() << '\033' << _escapeCode;
+ getCurrentContext().getConfig()->stream()
+ << '\033' << _escapeCode;
}
};
+ bool useColourOnPlatform() {
+ return
+#ifdef CATCH_PLATFORM_MAC
+ !isDebuggerActive() &&
+#endif
+#if !(defined(__DJGPP__) && defined(__STRICT_ANSI__))
+ isatty(STDOUT_FILENO)
+#else
+ false
+#endif
+ ;
+ }
IColourImpl* platformColourInstance() {
ErrnoGuard guard;
- Ptr<IConfig const> config = getCurrentContext().getConfig();
+ IConfigPtr config = getCurrentContext().getConfig();
UseColour::YesOrNo colourMode = config
? config->useColour()
: UseColour::Auto;
if( colourMode == UseColour::Auto )
- colourMode = (!isDebuggerActive() && isatty(STDOUT_FILENO) )
+ colourMode = useColourOnPlatform()
? UseColour::Yes
: UseColour::No;
return colourMode == UseColour::Yes
@@ -7918,409 +10092,1534 @@ namespace Catch {
namespace Catch {
- Colour::Colour( Code _colourCode ) : m_moved( false ) { use( _colourCode ); }
- Colour::Colour( Colour const& _other ) : m_moved( false ) { const_cast<Colour&>( _other ).m_moved = true; }
+ Colour::Colour( Code _colourCode ) { use( _colourCode ); }
+ Colour::Colour( Colour&& rhs ) noexcept {
+ m_moved = rhs.m_moved;
+ rhs.m_moved = true;
+ }
+ Colour& Colour::operator=( Colour&& rhs ) noexcept {
+ m_moved = rhs.m_moved;
+ rhs.m_moved = true;
+ return *this;
+ }
+
Colour::~Colour(){ if( !m_moved ) use( None ); }
void Colour::use( Code _colourCode ) {
static IColourImpl* impl = platformColourInstance();
- impl->use( _colourCode );
+ // Strictly speaking, this cannot possibly happen.
+ // However, under some conditions it does happen (see #1626),
+ // and this change is small enough that we can let practicality
+ // triumph over purity in this case.
+ if (impl != NULL) {
+ impl->use( _colourCode );
+ }
+ }
+
+ std::ostream& operator << ( std::ostream& os, Colour const& ) {
+ return os;
}
} // end namespace Catch
-// #included from: catch_generators_impl.hpp
-#define TWOBLUECUBES_CATCH_GENERATORS_IMPL_HPP_INCLUDED
+#if defined(__clang__)
+# pragma clang diagnostic pop
+#endif
-#include <vector>
-#include <string>
-#include <map>
+// end catch_console_colour.cpp
+// start catch_context.cpp
namespace Catch {
- struct GeneratorInfo : IGeneratorInfo {
+ class Context : public IMutableContext, NonCopyable {
- GeneratorInfo( std::size_t size )
- : m_size( size ),
- m_currentIndex( 0 )
- {}
+ public: // IContext
+ IResultCapture* getResultCapture() override {
+ return m_resultCapture;
+ }
+ IRunner* getRunner() override {
+ return m_runner;
+ }
- bool moveNext() {
- if( ++m_currentIndex == m_size ) {
- m_currentIndex = 0;
- return false;
- }
- return true;
+ IConfigPtr const& getConfig() const override {
+ return m_config;
}
- std::size_t getCurrentIndex() const {
- return m_currentIndex;
+ ~Context() override;
+
+ public: // IMutableContext
+ void setResultCapture( IResultCapture* resultCapture ) override {
+ m_resultCapture = resultCapture;
+ }
+ void setRunner( IRunner* runner ) override {
+ m_runner = runner;
+ }
+ void setConfig( IConfigPtr const& config ) override {
+ m_config = config;
}
- std::size_t m_size;
- std::size_t m_currentIndex;
+ friend IMutableContext& getCurrentMutableContext();
+
+ private:
+ IConfigPtr m_config;
+ IRunner* m_runner = nullptr;
+ IResultCapture* m_resultCapture = nullptr;
};
- ///////////////////////////////////////////////////////////////////////////
+ IMutableContext *IMutableContext::currentContext = nullptr;
- class GeneratorsForTest : public IGeneratorsForTest {
+ void IMutableContext::createContext()
+ {
+ currentContext = new Context();
+ }
- public:
- ~GeneratorsForTest() {
- deleteAll( m_generatorsInOrder );
+ void cleanUpContext() {
+ delete IMutableContext::currentContext;
+ IMutableContext::currentContext = nullptr;
+ }
+ IContext::~IContext() = default;
+ IMutableContext::~IMutableContext() = default;
+ Context::~Context() = default;
+
+ SimplePcg32& rng() {
+ static SimplePcg32 s_rng;
+ return s_rng;
+ }
+
+}
+// end catch_context.cpp
+// start catch_debug_console.cpp
+
+// start catch_debug_console.h
+
+#include <string>
+
+namespace Catch {
+ void writeToDebugConsole( std::string const& text );
+}
+
+// end catch_debug_console.h
+#if defined(CATCH_CONFIG_ANDROID_LOGWRITE)
+#include <android/log.h>
+
+ namespace Catch {
+ void writeToDebugConsole( std::string const& text ) {
+ __android_log_write( ANDROID_LOG_DEBUG, "Catch", text.c_str() );
}
+ }
+
+#elif defined(CATCH_PLATFORM_WINDOWS)
- IGeneratorInfo& getGeneratorInfo( std::string const& fileInfo, std::size_t size ) {
- std::map<std::string, IGeneratorInfo*>::const_iterator it = m_generatorsByName.find( fileInfo );
- if( it == m_generatorsByName.end() ) {
- IGeneratorInfo* info = new GeneratorInfo( size );
- m_generatorsByName.insert( std::make_pair( fileInfo, info ) );
- m_generatorsInOrder.push_back( info );
- return *info;
+ namespace Catch {
+ void writeToDebugConsole( std::string const& text ) {
+ ::OutputDebugStringA( text.c_str() );
+ }
+ }
+
+#else
+
+ namespace Catch {
+ void writeToDebugConsole( std::string const& text ) {
+ // !TBD: Need a version for Mac/ XCode and other IDEs
+ Catch::cout() << text;
+ }
+ }
+
+#endif // Platform
+// end catch_debug_console.cpp
+// start catch_debugger.cpp
+
+#ifdef CATCH_PLATFORM_MAC
+
+# include <assert.h>
+# include <stdbool.h>
+# include <sys/types.h>
+# include <unistd.h>
+# include <cstddef>
+# include <ostream>
+
+#ifdef __apple_build_version__
+ // These headers will only compile with AppleClang (XCode)
+ // For other compilers (Clang, GCC, ... ) we need to exclude them
+# include <sys/sysctl.h>
+#endif
+
+ namespace Catch {
+ #ifdef __apple_build_version__
+ // The following function is taken directly from the following technical note:
+ // https://developer.apple.com/library/archive/qa/qa1361/_index.html
+
+ // Returns true if the current process is being debugged (either
+ // running under the debugger or has a debugger attached post facto).
+ bool isDebuggerActive(){
+ int mib[4];
+ struct kinfo_proc info;
+ std::size_t size;
+
+ // Initialize the flags so that, if sysctl fails for some bizarre
+ // reason, we get a predictable result.
+
+ info.kp_proc.p_flag = 0;
+
+ // Initialize mib, which tells sysctl the info we want, in this case
+ // we're looking for information about a specific process ID.
+
+ mib[0] = CTL_KERN;
+ mib[1] = KERN_PROC;
+ mib[2] = KERN_PROC_PID;
+ mib[3] = getpid();
+
+ // Call sysctl.
+
+ size = sizeof(info);
+ if( sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, nullptr, 0) != 0 ) {
+ Catch::cerr() << "\n** Call to sysctl failed - unable to determine if debugger is active **\n" << std::endl;
+ return false;
}
- return *it->second;
+
+ // We're being debugged if the P_TRACED flag is set.
+
+ return ( (info.kp_proc.p_flag & P_TRACED) != 0 );
+ }
+ #else
+ bool isDebuggerActive() {
+ // We need to find another way to determine this for non-appleclang compilers on macOS
+ return false;
}
+ #endif
+ } // namespace Catch
- bool moveNext() {
- std::vector<IGeneratorInfo*>::const_iterator it = m_generatorsInOrder.begin();
- std::vector<IGeneratorInfo*>::const_iterator itEnd = m_generatorsInOrder.end();
- for(; it != itEnd; ++it ) {
- if( (*it)->moveNext() )
- return true;
+#elif defined(CATCH_PLATFORM_LINUX)
+ #include <fstream>
+ #include <string>
+
+ namespace Catch{
+ // The standard POSIX way of detecting a debugger is to attempt to
+ // ptrace() the process, but this needs to be done from a child and not
+ // this process itself to still allow attaching to this process later
+ // if wanted, so is rather heavy. Under Linux we have the PID of the
+ // "debugger" (which doesn't need to be gdb, of course, it could also
+ // be strace, for example) in /proc/$PID/status, so just get it from
+ // there instead.
+ bool isDebuggerActive(){
+ // Libstdc++ has a bug, where std::ifstream sets errno to 0
+ // This way our users can properly assert over errno values
+ ErrnoGuard guard;
+ std::ifstream in("/proc/self/status");
+ for( std::string line; std::getline(in, line); ) {
+ static const int PREFIX_LEN = 11;
+ if( line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0 ) {
+ // We're traced if the PID is not 0 and no other PID starts
+ // with 0 digit, so it's enough to check for just a single
+ // character.
+ return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0';
+ }
}
+
return false;
}
+ } // namespace Catch
+#elif defined(_MSC_VER)
+ extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent();
+ namespace Catch {
+ bool isDebuggerActive() {
+ return IsDebuggerPresent() != 0;
+ }
+ }
+#elif defined(__MINGW32__)
+ extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent();
+ namespace Catch {
+ bool isDebuggerActive() {
+ return IsDebuggerPresent() != 0;
+ }
+ }
+#else
+ namespace Catch {
+ bool isDebuggerActive() { return false; }
+ }
+#endif // Platform
+// end catch_debugger.cpp
+// start catch_decomposer.cpp
- private:
- std::map<std::string, IGeneratorInfo*> m_generatorsByName;
- std::vector<IGeneratorInfo*> m_generatorsInOrder;
- };
+namespace Catch {
- IGeneratorsForTest* createGeneratorsForTest()
- {
- return new GeneratorsForTest();
+ ITransientExpression::~ITransientExpression() = default;
+
+ void formatReconstructedExpression( std::ostream &os, std::string const& lhs, StringRef op, std::string const& rhs ) {
+ if( lhs.size() + rhs.size() < 40 &&
+ lhs.find('\n') == std::string::npos &&
+ rhs.find('\n') == std::string::npos )
+ os << lhs << " " << op << " " << rhs;
+ else
+ os << lhs << "\n" << op << "\n" << rhs;
}
+}
+// end catch_decomposer.cpp
+// start catch_enforce.cpp
-} // end namespace Catch
+#include <stdexcept>
+
+namespace Catch {
+#if defined(CATCH_CONFIG_DISABLE_EXCEPTIONS) && !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS_CUSTOM_HANDLER)
+ [[noreturn]]
+ void throw_exception(std::exception const& e) {
+ Catch::cerr() << "Catch will terminate because it needed to throw an exception.\n"
+ << "The message was: " << e.what() << '\n';
+ std::terminate();
+ }
+#endif
+
+ [[noreturn]]
+ void throw_logic_error(std::string const& msg) {
+ throw_exception(std::logic_error(msg));
+ }
+
+ [[noreturn]]
+ void throw_domain_error(std::string const& msg) {
+ throw_exception(std::domain_error(msg));
+ }
-// #included from: catch_assertionresult.hpp
-#define TWOBLUECUBES_CATCH_ASSERTIONRESULT_HPP_INCLUDED
+ [[noreturn]]
+ void throw_runtime_error(std::string const& msg) {
+ throw_exception(std::runtime_error(msg));
+ }
+
+} // namespace Catch;
+// end catch_enforce.cpp
+// start catch_enum_values_registry.cpp
+// start catch_enum_values_registry.h
+
+#include <vector>
+#include <memory>
namespace Catch {
- AssertionInfo::AssertionInfo( std::string const& _macroName,
- SourceLineInfo const& _lineInfo,
- std::string const& _capturedExpression,
- ResultDisposition::Flags _resultDisposition )
- : macroName( _macroName ),
- lineInfo( _lineInfo ),
- capturedExpression( _capturedExpression ),
- resultDisposition( _resultDisposition )
- {}
+ namespace Detail {
- AssertionResult::AssertionResult() {}
+ std::unique_ptr<EnumInfo> makeEnumInfo( StringRef enumName, StringRef allValueNames, std::vector<int> const& values );
- AssertionResult::AssertionResult( AssertionInfo const& info, AssertionResultData const& data )
- : m_info( info ),
- m_resultData( data )
- {}
+ class EnumValuesRegistry : public IMutableEnumValuesRegistry {
- AssertionResult::~AssertionResult() {}
+ std::vector<std::unique_ptr<EnumInfo>> m_enumInfos;
- // Result was a success
- bool AssertionResult::succeeded() const {
- return Catch::isOk( m_resultData.resultType );
+ EnumInfo const& registerEnum( StringRef enumName, StringRef allEnums, std::vector<int> const& values) override;
+ };
+
+ std::vector<StringRef> parseEnums( StringRef enums );
+
+ } // Detail
+
+} // Catch
+
+// end catch_enum_values_registry.h
+
+#include <map>
+#include <cassert>
+
+namespace Catch {
+
+ IMutableEnumValuesRegistry::~IMutableEnumValuesRegistry() {}
+
+ namespace Detail {
+
+ namespace {
+ // Extracts the actual name part of an enum instance
+ // In other words, it returns the Blue part of Bikeshed::Colour::Blue
+ StringRef extractInstanceName(StringRef enumInstance) {
+ // Find last occurence of ":"
+ size_t name_start = enumInstance.size();
+ while (name_start > 0 && enumInstance[name_start - 1] != ':') {
+ --name_start;
+ }
+ return enumInstance.substr(name_start, enumInstance.size() - name_start);
+ }
+ }
+
+ std::vector<StringRef> parseEnums( StringRef enums ) {
+ auto enumValues = splitStringRef( enums, ',' );
+ std::vector<StringRef> parsed;
+ parsed.reserve( enumValues.size() );
+ for( auto const& enumValue : enumValues ) {
+ parsed.push_back(trim(extractInstanceName(enumValue)));
+ }
+ return parsed;
+ }
+
+ EnumInfo::~EnumInfo() {}
+
+ StringRef EnumInfo::lookup( int value ) const {
+ for( auto const& valueToName : m_values ) {
+ if( valueToName.first == value )
+ return valueToName.second;
+ }
+ return "{** unexpected enum value **}"_sr;
+ }
+
+ std::unique_ptr<EnumInfo> makeEnumInfo( StringRef enumName, StringRef allValueNames, std::vector<int> const& values ) {
+ std::unique_ptr<EnumInfo> enumInfo( new EnumInfo );
+ enumInfo->m_name = enumName;
+ enumInfo->m_values.reserve( values.size() );
+
+ const auto valueNames = Catch::Detail::parseEnums( allValueNames );
+ assert( valueNames.size() == values.size() );
+ std::size_t i = 0;
+ for( auto value : values )
+ enumInfo->m_values.push_back({ value, valueNames[i++] });
+
+ return enumInfo;
+ }
+
+ EnumInfo const& EnumValuesRegistry::registerEnum( StringRef enumName, StringRef allValueNames, std::vector<int> const& values ) {
+ m_enumInfos.push_back(makeEnumInfo(enumName, allValueNames, values));
+ return *m_enumInfos.back();
+ }
+
+ } // Detail
+} // Catch
+
+// end catch_enum_values_registry.cpp
+// start catch_errno_guard.cpp
+
+#include <cerrno>
+
+namespace Catch {
+ ErrnoGuard::ErrnoGuard():m_oldErrno(errno){}
+ ErrnoGuard::~ErrnoGuard() { errno = m_oldErrno; }
+}
+// end catch_errno_guard.cpp
+// start catch_exception_translator_registry.cpp
+
+// start catch_exception_translator_registry.h
+
+#include <vector>
+#include <string>
+#include <memory>
+
+namespace Catch {
+
+ class ExceptionTranslatorRegistry : public IExceptionTranslatorRegistry {
+ public:
+ ~ExceptionTranslatorRegistry();
+ virtual void registerTranslator( const IExceptionTranslator* translator );
+ std::string translateActiveException() const override;
+ std::string tryTranslators() const;
+
+ private:
+ std::vector<std::unique_ptr<IExceptionTranslator const>> m_translators;
+ };
+}
+
+// end catch_exception_translator_registry.h
+#ifdef __OBJC__
+#import "Foundation/Foundation.h"
+#endif
+
+namespace Catch {
+
+ ExceptionTranslatorRegistry::~ExceptionTranslatorRegistry() {
}
- // Result was a success, or failure is suppressed
- bool AssertionResult::isOk() const {
- return Catch::isOk( m_resultData.resultType ) || shouldSuppressFailure( m_info.resultDisposition );
+ void ExceptionTranslatorRegistry::registerTranslator( const IExceptionTranslator* translator ) {
+ m_translators.push_back( std::unique_ptr<const IExceptionTranslator>( translator ) );
}
- ResultWas::OfType AssertionResult::getResultType() const {
- return m_resultData.resultType;
+#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS)
+ std::string ExceptionTranslatorRegistry::translateActiveException() const {
+ try {
+#ifdef __OBJC__
+ // In Objective-C try objective-c exceptions first
+ @try {
+ return tryTranslators();
+ }
+ @catch (NSException *exception) {
+ return Catch::Detail::stringify( [exception description] );
+ }
+#else
+ // Compiling a mixed mode project with MSVC means that CLR
+ // exceptions will be caught in (...) as well. However, these
+ // do not fill-in std::current_exception and thus lead to crash
+ // when attempting rethrow.
+ // /EHa switch also causes structured exceptions to be caught
+ // here, but they fill-in current_exception properly, so
+ // at worst the output should be a little weird, instead of
+ // causing a crash.
+ if (std::current_exception() == nullptr) {
+ return "Non C++ exception. Possibly a CLR exception.";
+ }
+ return tryTranslators();
+#endif
+ }
+ catch( TestFailureException& ) {
+ std::rethrow_exception(std::current_exception());
+ }
+ catch( std::exception& ex ) {
+ return ex.what();
+ }
+ catch( std::string& msg ) {
+ return msg;
+ }
+ catch( const char* msg ) {
+ return msg;
+ }
+ catch(...) {
+ return "Unknown exception";
+ }
}
- bool AssertionResult::hasExpression() const {
- return !m_info.capturedExpression.empty();
+ std::string ExceptionTranslatorRegistry::tryTranslators() const {
+ if (m_translators.empty()) {
+ std::rethrow_exception(std::current_exception());
+ } else {
+ return m_translators[0]->translate(m_translators.begin() + 1, m_translators.end());
+ }
}
- bool AssertionResult::hasMessage() const {
- return !m_resultData.message.empty();
+#else // ^^ Exceptions are enabled // Exceptions are disabled vv
+ std::string ExceptionTranslatorRegistry::translateActiveException() const {
+ CATCH_INTERNAL_ERROR("Attempted to translate active exception under CATCH_CONFIG_DISABLE_EXCEPTIONS!");
}
- std::string AssertionResult::getExpression() const {
- if( isFalseTest( m_info.resultDisposition ) )
- return '!' + m_info.capturedExpression;
- else
- return m_info.capturedExpression;
+ std::string ExceptionTranslatorRegistry::tryTranslators() const {
+ CATCH_INTERNAL_ERROR("Attempted to use exception translators under CATCH_CONFIG_DISABLE_EXCEPTIONS!");
}
- std::string AssertionResult::getExpressionInMacro() const {
- if( m_info.macroName.empty() )
- return m_info.capturedExpression;
- else
- return m_info.macroName + "( " + m_info.capturedExpression + " )";
+#endif
+
+}
+// end catch_exception_translator_registry.cpp
+// start catch_fatal_condition.cpp
+
+#if defined(__GNUC__)
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wmissing-field-initializers"
+#endif
+
+#if defined( CATCH_CONFIG_WINDOWS_SEH ) || defined( CATCH_CONFIG_POSIX_SIGNALS )
+
+namespace {
+ // Report the error condition
+ void reportFatal( char const * const message ) {
+ Catch::getCurrentContext().getResultCapture()->handleFatalErrorCondition( message );
}
+}
- bool AssertionResult::hasExpandedExpression() const {
- return hasExpression() && getExpandedExpression() != getExpression();
+#endif // signals/SEH handling
+
+#if defined( CATCH_CONFIG_WINDOWS_SEH )
+
+namespace Catch {
+ struct SignalDefs { DWORD id; const char* name; };
+
+ // There is no 1-1 mapping between signals and windows exceptions.
+ // Windows can easily distinguish between SO and SigSegV,
+ // but SigInt, SigTerm, etc are handled differently.
+ static SignalDefs signalDefs[] = {
+ { static_cast<DWORD>(EXCEPTION_ILLEGAL_INSTRUCTION), "SIGILL - Illegal instruction signal" },
+ { static_cast<DWORD>(EXCEPTION_STACK_OVERFLOW), "SIGSEGV - Stack overflow" },
+ { static_cast<DWORD>(EXCEPTION_ACCESS_VIOLATION), "SIGSEGV - Segmentation violation signal" },
+ { static_cast<DWORD>(EXCEPTION_INT_DIVIDE_BY_ZERO), "Divide by zero error" },
+ };
+
+ LONG CALLBACK FatalConditionHandler::handleVectoredException(PEXCEPTION_POINTERS ExceptionInfo) {
+ for (auto const& def : signalDefs) {
+ if (ExceptionInfo->ExceptionRecord->ExceptionCode == def.id) {
+ reportFatal(def.name);
+ }
+ }
+ // If its not an exception we care about, pass it along.
+ // This stops us from eating debugger breaks etc.
+ return EXCEPTION_CONTINUE_SEARCH;
}
- std::string AssertionResult::getExpandedExpression() const {
- return m_resultData.reconstructExpression();
+ FatalConditionHandler::FatalConditionHandler() {
+ isSet = true;
+ // 32k seems enough for Catch to handle stack overflow,
+ // but the value was found experimentally, so there is no strong guarantee
+ guaranteeSize = 32 * 1024;
+ exceptionHandlerHandle = nullptr;
+ // Register as first handler in current chain
+ exceptionHandlerHandle = AddVectoredExceptionHandler(1, handleVectoredException);
+ // Pass in guarantee size to be filled
+ SetThreadStackGuarantee(&guaranteeSize);
}
- std::string AssertionResult::getMessage() const {
- return m_resultData.message;
+ void FatalConditionHandler::reset() {
+ if (isSet) {
+ RemoveVectoredExceptionHandler(exceptionHandlerHandle);
+ SetThreadStackGuarantee(&guaranteeSize);
+ exceptionHandlerHandle = nullptr;
+ isSet = false;
+ }
}
- SourceLineInfo AssertionResult::getSourceInfo() const {
- return m_info.lineInfo;
+
+ FatalConditionHandler::~FatalConditionHandler() {
+ reset();
}
- std::string AssertionResult::getTestMacroName() const {
- return m_info.macroName;
+bool FatalConditionHandler::isSet = false;
+ULONG FatalConditionHandler::guaranteeSize = 0;
+PVOID FatalConditionHandler::exceptionHandlerHandle = nullptr;
+
+} // namespace Catch
+
+#elif defined( CATCH_CONFIG_POSIX_SIGNALS )
+
+namespace Catch {
+
+ struct SignalDefs {
+ int id;
+ const char* name;
+ };
+
+ // 32kb for the alternate stack seems to be sufficient. However, this value
+ // is experimentally determined, so that's not guaranteed.
+ static constexpr std::size_t sigStackSize = 32768 >= MINSIGSTKSZ ? 32768 : MINSIGSTKSZ;
+
+ static SignalDefs signalDefs[] = {
+ { SIGINT, "SIGINT - Terminal interrupt signal" },
+ { SIGILL, "SIGILL - Illegal instruction signal" },
+ { SIGFPE, "SIGFPE - Floating point error signal" },
+ { SIGSEGV, "SIGSEGV - Segmentation violation signal" },
+ { SIGTERM, "SIGTERM - Termination request signal" },
+ { SIGABRT, "SIGABRT - Abort (abnormal termination) signal" }
+ };
+
+ void FatalConditionHandler::handleSignal( int sig ) {
+ char const * name = "<unknown signal>";
+ for (auto const& def : signalDefs) {
+ if (sig == def.id) {
+ name = def.name;
+ break;
+ }
+ }
+ reset();
+ reportFatal(name);
+ raise( sig );
}
- void AssertionResult::discardDecomposedExpression() const {
- m_resultData.decomposedExpression = CATCH_NULL;
+ FatalConditionHandler::FatalConditionHandler() {
+ isSet = true;
+ stack_t sigStack;
+ sigStack.ss_sp = altStackMem;
+ sigStack.ss_size = sigStackSize;
+ sigStack.ss_flags = 0;
+ sigaltstack(&sigStack, &oldSigStack);
+ struct sigaction sa = { };
+
+ sa.sa_handler = handleSignal;
+ sa.sa_flags = SA_ONSTACK;
+ for (std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i) {
+ sigaction(signalDefs[i].id, &sa, &oldSigActions[i]);
+ }
}
- void AssertionResult::expandDecomposedExpression() const {
- m_resultData.reconstructExpression();
+ FatalConditionHandler::~FatalConditionHandler() {
+ reset();
}
+ void FatalConditionHandler::reset() {
+ if( isSet ) {
+ // Set signals back to previous values -- hopefully nobody overwrote them in the meantime
+ for( std::size_t i = 0; i < sizeof(signalDefs)/sizeof(SignalDefs); ++i ) {
+ sigaction(signalDefs[i].id, &oldSigActions[i], nullptr);
+ }
+ // Return the old stack
+ sigaltstack(&oldSigStack, nullptr);
+ isSet = false;
+ }
+ }
+
+ bool FatalConditionHandler::isSet = false;
+ struct sigaction FatalConditionHandler::oldSigActions[sizeof(signalDefs)/sizeof(SignalDefs)] = {};
+ stack_t FatalConditionHandler::oldSigStack = {};
+ char FatalConditionHandler::altStackMem[sigStackSize] = {};
+
+} // namespace Catch
+
+#else
+
+namespace Catch {
+ void FatalConditionHandler::reset() {}
+}
+
+#endif // signals/SEH handling
+
+#if defined(__GNUC__)
+# pragma GCC diagnostic pop
+#endif
+// end catch_fatal_condition.cpp
+// start catch_generators.cpp
+
+#include <limits>
+#include <set>
+
+namespace Catch {
+
+IGeneratorTracker::~IGeneratorTracker() {}
+
+const char* GeneratorException::what() const noexcept {
+ return m_msg;
+}
+
+namespace Generators {
+
+ GeneratorUntypedBase::~GeneratorUntypedBase() {}
+
+ auto acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& {
+ return getResultCapture().acquireGeneratorTracker( lineInfo );
+ }
+
+} // namespace Generators
+} // namespace Catch
+// end catch_generators.cpp
+// start catch_interfaces_capture.cpp
+
+namespace Catch {
+ IResultCapture::~IResultCapture() = default;
+}
+// end catch_interfaces_capture.cpp
+// start catch_interfaces_config.cpp
+
+namespace Catch {
+ IConfig::~IConfig() = default;
+}
+// end catch_interfaces_config.cpp
+// start catch_interfaces_exception.cpp
+
+namespace Catch {
+ IExceptionTranslator::~IExceptionTranslator() = default;
+ IExceptionTranslatorRegistry::~IExceptionTranslatorRegistry() = default;
+}
+// end catch_interfaces_exception.cpp
+// start catch_interfaces_registry_hub.cpp
+
+namespace Catch {
+ IRegistryHub::~IRegistryHub() = default;
+ IMutableRegistryHub::~IMutableRegistryHub() = default;
+}
+// end catch_interfaces_registry_hub.cpp
+// start catch_interfaces_reporter.cpp
+
+// start catch_reporter_listening.h
+
+namespace Catch {
+
+ class ListeningReporter : public IStreamingReporter {
+ using Reporters = std::vector<IStreamingReporterPtr>;
+ Reporters m_listeners;
+ IStreamingReporterPtr m_reporter = nullptr;
+ ReporterPreferences m_preferences;
+
+ public:
+ ListeningReporter();
+
+ void addListener( IStreamingReporterPtr&& listener );
+ void addReporter( IStreamingReporterPtr&& reporter );
+
+ public: // IStreamingReporter
+
+ ReporterPreferences getPreferences() const override;
+
+ void noMatchingTestCases( std::string const& spec ) override;
+
+ void reportInvalidArguments(std::string const&arg) override;
+
+ static std::set<Verbosity> getSupportedVerbosities();
+
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ void benchmarkPreparing(std::string const& name) override;
+ void benchmarkStarting( BenchmarkInfo const& benchmarkInfo ) override;
+ void benchmarkEnded( BenchmarkStats<> const& benchmarkStats ) override;
+ void benchmarkFailed(std::string const&) override;
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
+
+ void testRunStarting( TestRunInfo const& testRunInfo ) override;
+ void testGroupStarting( GroupInfo const& groupInfo ) override;
+ void testCaseStarting( TestCaseInfo const& testInfo ) override;
+ void sectionStarting( SectionInfo const& sectionInfo ) override;
+ void assertionStarting( AssertionInfo const& assertionInfo ) override;
+
+ // The return value indicates if the messages buffer should be cleared:
+ bool assertionEnded( AssertionStats const& assertionStats ) override;
+ void sectionEnded( SectionStats const& sectionStats ) override;
+ void testCaseEnded( TestCaseStats const& testCaseStats ) override;
+ void testGroupEnded( TestGroupStats const& testGroupStats ) override;
+ void testRunEnded( TestRunStats const& testRunStats ) override;
+
+ void skipTest( TestCaseInfo const& testInfo ) override;
+ bool isMulti() const override;
+
+ };
+
} // end namespace Catch
-// #included from: catch_test_case_info.hpp
-#define TWOBLUECUBES_CATCH_TEST_CASE_INFO_HPP_INCLUDED
+// end catch_reporter_listening.h
+namespace Catch {
-#include <cctype>
+ ReporterConfig::ReporterConfig( IConfigPtr const& _fullConfig )
+ : m_stream( &_fullConfig->stream() ), m_fullConfig( _fullConfig ) {}
+
+ ReporterConfig::ReporterConfig( IConfigPtr const& _fullConfig, std::ostream& _stream )
+ : m_stream( &_stream ), m_fullConfig( _fullConfig ) {}
+
+ std::ostream& ReporterConfig::stream() const { return *m_stream; }
+ IConfigPtr ReporterConfig::fullConfig() const { return m_fullConfig; }
+
+ TestRunInfo::TestRunInfo( std::string const& _name ) : name( _name ) {}
+
+ GroupInfo::GroupInfo( std::string const& _name,
+ std::size_t _groupIndex,
+ std::size_t _groupsCount )
+ : name( _name ),
+ groupIndex( _groupIndex ),
+ groupsCounts( _groupsCount )
+ {}
+
+ AssertionStats::AssertionStats( AssertionResult const& _assertionResult,
+ std::vector<MessageInfo> const& _infoMessages,
+ Totals const& _totals )
+ : assertionResult( _assertionResult ),
+ infoMessages( _infoMessages ),
+ totals( _totals )
+ {
+ assertionResult.m_resultData.lazyExpression.m_transientExpression = _assertionResult.m_resultData.lazyExpression.m_transientExpression;
+
+ if( assertionResult.hasMessage() ) {
+ // Copy message into messages list.
+ // !TBD This should have been done earlier, somewhere
+ MessageBuilder builder( assertionResult.getTestMacroName(), assertionResult.getSourceInfo(), assertionResult.getResultType() );
+ builder << assertionResult.getMessage();
+ builder.m_info.message = builder.m_stream.str();
+
+ infoMessages.push_back( builder.m_info );
+ }
+ }
+
+ AssertionStats::~AssertionStats() = default;
+
+ SectionStats::SectionStats( SectionInfo const& _sectionInfo,
+ Counts const& _assertions,
+ double _durationInSeconds,
+ bool _missingAssertions )
+ : sectionInfo( _sectionInfo ),
+ assertions( _assertions ),
+ durationInSeconds( _durationInSeconds ),
+ missingAssertions( _missingAssertions )
+ {}
+
+ SectionStats::~SectionStats() = default;
+
+ TestCaseStats::TestCaseStats( TestCaseInfo const& _testInfo,
+ Totals const& _totals,
+ std::string const& _stdOut,
+ std::string const& _stdErr,
+ bool _aborting )
+ : testInfo( _testInfo ),
+ totals( _totals ),
+ stdOut( _stdOut ),
+ stdErr( _stdErr ),
+ aborting( _aborting )
+ {}
+
+ TestCaseStats::~TestCaseStats() = default;
+
+ TestGroupStats::TestGroupStats( GroupInfo const& _groupInfo,
+ Totals const& _totals,
+ bool _aborting )
+ : groupInfo( _groupInfo ),
+ totals( _totals ),
+ aborting( _aborting )
+ {}
+
+ TestGroupStats::TestGroupStats( GroupInfo const& _groupInfo )
+ : groupInfo( _groupInfo ),
+ aborting( false )
+ {}
+
+ TestGroupStats::~TestGroupStats() = default;
+
+ TestRunStats::TestRunStats( TestRunInfo const& _runInfo,
+ Totals const& _totals,
+ bool _aborting )
+ : runInfo( _runInfo ),
+ totals( _totals ),
+ aborting( _aborting )
+ {}
+
+ TestRunStats::~TestRunStats() = default;
+
+ void IStreamingReporter::fatalErrorEncountered( StringRef ) {}
+ bool IStreamingReporter::isMulti() const { return false; }
+
+ IReporterFactory::~IReporterFactory() = default;
+ IReporterRegistry::~IReporterRegistry() = default;
+
+} // end namespace Catch
+// end catch_interfaces_reporter.cpp
+// start catch_interfaces_runner.cpp
+
+namespace Catch {
+ IRunner::~IRunner() = default;
+}
+// end catch_interfaces_runner.cpp
+// start catch_interfaces_testcase.cpp
+
+namespace Catch {
+ ITestInvoker::~ITestInvoker() = default;
+ ITestCaseRegistry::~ITestCaseRegistry() = default;
+}
+// end catch_interfaces_testcase.cpp
+// start catch_leak_detector.cpp
+
+#ifdef CATCH_CONFIG_WINDOWS_CRTDBG
+#include <crtdbg.h>
+
+namespace Catch {
+
+ LeakDetector::LeakDetector() {
+ int flag = _CrtSetDbgFlag(_CRTDBG_REPORT_FLAG);
+ flag |= _CRTDBG_LEAK_CHECK_DF;
+ flag |= _CRTDBG_ALLOC_MEM_DF;
+ _CrtSetDbgFlag(flag);
+ _CrtSetReportMode(_CRT_WARN, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG);
+ _CrtSetReportFile(_CRT_WARN, _CRTDBG_FILE_STDERR);
+ // Change this to leaking allocation's number to break there
+ _CrtSetBreakAlloc(-1);
+ }
+}
+
+#else
+
+ Catch::LeakDetector::LeakDetector() {}
+
+#endif
+
+Catch::LeakDetector::~LeakDetector() {
+ Catch::cleanUp();
+}
+// end catch_leak_detector.cpp
+// start catch_list.cpp
+
+// start catch_list.h
+
+#include <set>
+
+namespace Catch {
+
+ std::size_t listTests( Config const& config );
+
+ std::size_t listTestsNamesOnly( Config const& config );
+
+ struct TagInfo {
+ void add( std::string const& spelling );
+ std::string all() const;
+
+ std::set<std::string> spellings;
+ std::size_t count = 0;
+ };
+
+ std::size_t listTags( Config const& config );
+
+ std::size_t listReporters();
+
+ Option<std::size_t> list( std::shared_ptr<Config> const& config );
+
+} // end namespace Catch
+
+// end catch_list.h
+// start catch_text.h
namespace Catch {
+ using namespace clara::TextFlow;
+}
+
+// end catch_text.h
+#include <limits>
+#include <algorithm>
+#include <iomanip>
+
+namespace Catch {
+
+ std::size_t listTests( Config const& config ) {
+ TestSpec testSpec = config.testSpec();
+ if( config.hasTestFilters() )
+ Catch::cout() << "Matching test cases:\n";
+ else {
+ Catch::cout() << "All available test cases:\n";
+ }
+
+ auto matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config );
+ for( auto const& testCaseInfo : matchedTestCases ) {
+ Colour::Code colour = testCaseInfo.isHidden()
+ ? Colour::SecondaryText
+ : Colour::None;
+ Colour colourGuard( colour );
- inline TestCaseInfo::SpecialProperties parseSpecialTag( std::string const& tag ) {
- if( startsWith( tag, '.' ) ||
- tag == "hide" ||
- tag == "!hide" )
- return TestCaseInfo::IsHidden;
- else if( tag == "!throws" )
- return TestCaseInfo::Throws;
- else if( tag == "!shouldfail" )
- return TestCaseInfo::ShouldFail;
- else if( tag == "!mayfail" )
- return TestCaseInfo::MayFail;
- else if( tag == "!nonportable" )
- return TestCaseInfo::NonPortable;
+ Catch::cout() << Column( testCaseInfo.name ).initialIndent( 2 ).indent( 4 ) << "\n";
+ if( config.verbosity() >= Verbosity::High ) {
+ Catch::cout() << Column( Catch::Detail::stringify( testCaseInfo.lineInfo ) ).indent(4) << std::endl;
+ std::string description = testCaseInfo.description;
+ if( description.empty() )
+ description = "(NO DESCRIPTION)";
+ Catch::cout() << Column( description ).indent(4) << std::endl;
+ }
+ if( !testCaseInfo.tags.empty() )
+ Catch::cout() << Column( testCaseInfo.tagsAsString() ).indent( 6 ) << "\n";
+ }
+
+ if( !config.hasTestFilters() )
+ Catch::cout() << pluralise( matchedTestCases.size(), "test case" ) << '\n' << std::endl;
else
- return TestCaseInfo::None;
+ Catch::cout() << pluralise( matchedTestCases.size(), "matching test case" ) << '\n' << std::endl;
+ return matchedTestCases.size();
}
- inline bool isReservedTag( std::string const& tag ) {
- return parseSpecialTag( tag ) == TestCaseInfo::None && tag.size() > 0 && !std::isalnum( tag[0] );
+
+ std::size_t listTestsNamesOnly( Config const& config ) {
+ TestSpec testSpec = config.testSpec();
+ std::size_t matchedTests = 0;
+ std::vector<TestCase> matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config );
+ for( auto const& testCaseInfo : matchedTestCases ) {
+ matchedTests++;
+ if( startsWith( testCaseInfo.name, '#' ) )
+ Catch::cout() << '"' << testCaseInfo.name << '"';
+ else
+ Catch::cout() << testCaseInfo.name;
+ if ( config.verbosity() >= Verbosity::High )
+ Catch::cout() << "\t@" << testCaseInfo.lineInfo;
+ Catch::cout() << std::endl;
+ }
+ return matchedTests;
+ }
+
+ void TagInfo::add( std::string const& spelling ) {
+ ++count;
+ spellings.insert( spelling );
}
- inline void enforceNotReservedTag( std::string const& tag, SourceLineInfo const& _lineInfo ) {
- if( isReservedTag( tag ) ) {
- std::ostringstream ss;
- ss << Colour(Colour::Red)
- << "Tag name [" << tag << "] not allowed.\n"
- << "Tag names starting with non alpha-numeric characters are reserved\n"
- << Colour(Colour::FileName)
- << _lineInfo << '\n';
- throw std::runtime_error(ss.str());
+
+ std::string TagInfo::all() const {
+ size_t size = 0;
+ for (auto const& spelling : spellings) {
+ // Add 2 for the brackes
+ size += spelling.size() + 2;
+ }
+
+ std::string out; out.reserve(size);
+ for (auto const& spelling : spellings) {
+ out += '[';
+ out += spelling;
+ out += ']';
}
+ return out;
}
- TestCase makeTestCase( ITestCase* _testCase,
- std::string const& _className,
- std::string const& _name,
- std::string const& _descOrTags,
- SourceLineInfo const& _lineInfo )
- {
- bool isHidden( startsWith( _name, "./" ) ); // Legacy support
+ std::size_t listTags( Config const& config ) {
+ TestSpec testSpec = config.testSpec();
+ if( config.hasTestFilters() )
+ Catch::cout() << "Tags for matching test cases:\n";
+ else {
+ Catch::cout() << "All available tags:\n";
+ }
- // Parse out tags
- std::set<std::string> tags;
- std::string desc, tag;
- bool inTag = false;
- for( std::size_t i = 0; i < _descOrTags.size(); ++i ) {
- char c = _descOrTags[i];
- if( !inTag ) {
- if( c == '[' )
- inTag = true;
- else
- desc += c;
- }
- else {
- if( c == ']' ) {
- TestCaseInfo::SpecialProperties prop = parseSpecialTag( tag );
- if( prop == TestCaseInfo::IsHidden )
- isHidden = true;
- else if( prop == TestCaseInfo::None )
- enforceNotReservedTag( tag, _lineInfo );
+ std::map<std::string, TagInfo> tagCounts;
- tags.insert( tag );
- tag.clear();
- inTag = false;
- }
- else
- tag += c;
+ std::vector<TestCase> matchedTestCases = filterTests( getAllTestCasesSorted( config ), testSpec, config );
+ for( auto const& testCase : matchedTestCases ) {
+ for( auto const& tagName : testCase.getTestCaseInfo().tags ) {
+ std::string lcaseTagName = toLower( tagName );
+ auto countIt = tagCounts.find( lcaseTagName );
+ if( countIt == tagCounts.end() )
+ countIt = tagCounts.insert( std::make_pair( lcaseTagName, TagInfo() ) ).first;
+ countIt->second.add( tagName );
}
}
- if( isHidden ) {
- tags.insert( "hide" );
- tags.insert( "." );
+
+ for( auto const& tagCount : tagCounts ) {
+ ReusableStringStream rss;
+ rss << " " << std::setw(2) << tagCount.second.count << " ";
+ auto str = rss.str();
+ auto wrapper = Column( tagCount.second.all() )
+ .initialIndent( 0 )
+ .indent( str.size() )
+ .width( CATCH_CONFIG_CONSOLE_WIDTH-10 );
+ Catch::cout() << str << wrapper << '\n';
}
+ Catch::cout() << pluralise( tagCounts.size(), "tag" ) << '\n' << std::endl;
+ return tagCounts.size();
+ }
- TestCaseInfo info( _name, _className, desc, tags, _lineInfo );
- return TestCase( _testCase, info );
+ std::size_t listReporters() {
+ Catch::cout() << "Available reporters:\n";
+ IReporterRegistry::FactoryMap const& factories = getRegistryHub().getReporterRegistry().getFactories();
+ std::size_t maxNameLen = 0;
+ for( auto const& factoryKvp : factories )
+ maxNameLen = (std::max)( maxNameLen, factoryKvp.first.size() );
+
+ for( auto const& factoryKvp : factories ) {
+ Catch::cout()
+ << Column( factoryKvp.first + ":" )
+ .indent(2)
+ .width( 5+maxNameLen )
+ + Column( factoryKvp.second->getDescription() )
+ .initialIndent(0)
+ .indent(2)
+ .width( CATCH_CONFIG_CONSOLE_WIDTH - maxNameLen-8 )
+ << "\n";
+ }
+ Catch::cout() << std::endl;
+ return factories.size();
}
- void setTags( TestCaseInfo& testCaseInfo, std::set<std::string> const& tags )
- {
- testCaseInfo.tags = tags;
- testCaseInfo.lcaseTags.clear();
+ Option<std::size_t> list( std::shared_ptr<Config> const& config ) {
+ Option<std::size_t> listedCount;
+ getCurrentMutableContext().setConfig( config );
+ if( config->listTests() )
+ listedCount = listedCount.valueOr(0) + listTests( *config );
+ if( config->listTestNamesOnly() )
+ listedCount = listedCount.valueOr(0) + listTestsNamesOnly( *config );
+ if( config->listTags() )
+ listedCount = listedCount.valueOr(0) + listTags( *config );
+ if( config->listReporters() )
+ listedCount = listedCount.valueOr(0) + listReporters();
+ return listedCount;
+ }
- std::ostringstream oss;
- for( std::set<std::string>::const_iterator it = tags.begin(), itEnd = tags.end(); it != itEnd; ++it ) {
- oss << '[' << *it << ']';
- std::string lcaseTag = toLower( *it );
- testCaseInfo.properties = static_cast<TestCaseInfo::SpecialProperties>( testCaseInfo.properties | parseSpecialTag( lcaseTag ) );
- testCaseInfo.lcaseTags.insert( lcaseTag );
+} // end namespace Catch
+// end catch_list.cpp
+// start catch_matchers.cpp
+
+namespace Catch {
+namespace Matchers {
+ namespace Impl {
+
+ std::string MatcherUntypedBase::toString() const {
+ if( m_cachedToString.empty() )
+ m_cachedToString = describe();
+ return m_cachedToString;
}
- testCaseInfo.tagsAsString = oss.str();
+
+ MatcherUntypedBase::~MatcherUntypedBase() = default;
+
+ } // namespace Impl
+} // namespace Matchers
+
+using namespace Matchers;
+using Matchers::Impl::MatcherBase;
+
+} // namespace Catch
+// end catch_matchers.cpp
+// start catch_matchers_exception.cpp
+
+namespace Catch {
+namespace Matchers {
+namespace Exception {
+
+bool ExceptionMessageMatcher::match(std::exception const& ex) const {
+ return ex.what() == m_message;
+}
+
+std::string ExceptionMessageMatcher::describe() const {
+ return "exception message matches \"" + m_message + "\"";
+}
+
+}
+Exception::ExceptionMessageMatcher Message(std::string const& message) {
+ return Exception::ExceptionMessageMatcher(message);
+}
+
+// namespace Exception
+} // namespace Matchers
+} // namespace Catch
+// end catch_matchers_exception.cpp
+// start catch_matchers_floating.cpp
+
+// start catch_polyfills.hpp
+
+namespace Catch {
+ bool isnan(float f);
+ bool isnan(double d);
+}
+
+// end catch_polyfills.hpp
+// start catch_to_string.hpp
+
+#include <string>
+
+namespace Catch {
+ template <typename T>
+ std::string to_string(T const& t) {
+#if defined(CATCH_CONFIG_CPP11_TO_STRING)
+ return std::to_string(t);
+#else
+ ReusableStringStream rss;
+ rss << t;
+ return rss.str();
+#endif
}
+} // end namespace Catch
- TestCaseInfo::TestCaseInfo( std::string const& _name,
- std::string const& _className,
- std::string const& _description,
- std::set<std::string> const& _tags,
- SourceLineInfo const& _lineInfo )
- : name( _name ),
- className( _className ),
- description( _description ),
- lineInfo( _lineInfo ),
- properties( None )
- {
- setTags( *this, _tags );
+// end catch_to_string.hpp
+#include <algorithm>
+#include <cmath>
+#include <cstdlib>
+#include <cstdint>
+#include <cstring>
+#include <sstream>
+#include <type_traits>
+#include <iomanip>
+#include <limits>
+
+namespace Catch {
+namespace {
+
+ int32_t convert(float f) {
+ static_assert(sizeof(float) == sizeof(int32_t), "Important ULP matcher assumption violated");
+ int32_t i;
+ std::memcpy(&i, &f, sizeof(f));
+ return i;
}
- TestCaseInfo::TestCaseInfo( TestCaseInfo const& other )
- : name( other.name ),
- className( other.className ),
- description( other.description ),
- tags( other.tags ),
- lcaseTags( other.lcaseTags ),
- tagsAsString( other.tagsAsString ),
- lineInfo( other.lineInfo ),
- properties( other.properties )
- {}
+ int64_t convert(double d) {
+ static_assert(sizeof(double) == sizeof(int64_t), "Important ULP matcher assumption violated");
+ int64_t i;
+ std::memcpy(&i, &d, sizeof(d));
+ return i;
+ }
- bool TestCaseInfo::isHidden() const {
- return ( properties & IsHidden ) != 0;
+ template <typename FP>
+ bool almostEqualUlps(FP lhs, FP rhs, uint64_t maxUlpDiff) {
+ // Comparison with NaN should always be false.
+ // This way we can rule it out before getting into the ugly details
+ if (Catch::isnan(lhs) || Catch::isnan(rhs)) {
+ return false;
+ }
+
+ auto lc = convert(lhs);
+ auto rc = convert(rhs);
+
+ if ((lc < 0) != (rc < 0)) {
+ // Potentially we can have +0 and -0
+ return lhs == rhs;
+ }
+
+ auto ulpDiff = std::abs(lc - rc);
+ return static_cast<uint64_t>(ulpDiff) <= maxUlpDiff;
}
- bool TestCaseInfo::throws() const {
- return ( properties & Throws ) != 0;
+
+} //end anonymous namespace
+
+#if defined(CATCH_CONFIG_GLOBAL_NEXTAFTER)
+
+#if defined(__clang__)
+#pragma clang diagnostic push
+// The long double overload is currently unused
+#pragma clang diagnostic ignored "-Wunused-function"
+#endif
+
+ float nextafter(float x, float y) {
+ return ::nextafterf(x, y);
}
- bool TestCaseInfo::okToFail() const {
- return ( properties & (ShouldFail | MayFail ) ) != 0;
+
+ double nextafter(double x, double y) {
+ return ::nextafter(x, y);
}
- bool TestCaseInfo::expectedToFail() const {
- return ( properties & (ShouldFail ) ) != 0;
+
+ long double nextafter(long double x, long double y) {
+ return ::nextafterl(x, y);
}
- TestCase::TestCase( ITestCase* testCase, TestCaseInfo const& info ) : TestCaseInfo( info ), test( testCase ) {}
+#if defined(__clang__)
+#pragma clang diagnostic pop
+#endif
- TestCase::TestCase( TestCase const& other )
- : TestCaseInfo( other ),
- test( other.test )
- {}
+#endif // ^^^ CATCH_CONFIG_GLOBAL_NEXTAFTER ^^^
- TestCase TestCase::withName( std::string const& _newName ) const {
- TestCase other( *this );
- other.name = _newName;
- return other;
+namespace {
+
+template <typename FP>
+FP step(FP start, FP direction, uint64_t steps) {
+ for (uint64_t i = 0; i < steps; ++i) {
+#if defined(CATCH_CONFIG_GLOBAL_NEXTAFTER)
+ start = Catch::nextafter(start, direction);
+#else
+ start = std::nextafter(start, direction);
+#endif
}
+ return start;
+}
- void TestCase::swap( TestCase& other ) {
- test.swap( other.test );
- name.swap( other.name );
- className.swap( other.className );
- description.swap( other.description );
- tags.swap( other.tags );
- lcaseTags.swap( other.lcaseTags );
- tagsAsString.swap( other.tagsAsString );
- std::swap( TestCaseInfo::properties, static_cast<TestCaseInfo&>( other ).properties );
- std::swap( lineInfo, other.lineInfo );
+// Performs equivalent check of std::fabs(lhs - rhs) <= margin
+// But without the subtraction to allow for INFINITY in comparison
+bool marginComparison(double lhs, double rhs, double margin) {
+ return (lhs + margin >= rhs) && (rhs + margin >= lhs);
+}
+
+template <typename FloatingPoint>
+void write(std::ostream& out, FloatingPoint num) {
+ out << std::scientific
+ << std::setprecision(std::numeric_limits<FloatingPoint>::max_digits10 - 1)
+ << num;
+}
+
+} // end anonymous namespace
+
+namespace Matchers {
+namespace Floating {
+
+ enum class FloatingPointKind : uint8_t {
+ Float,
+ Double
+ };
+
+ WithinAbsMatcher::WithinAbsMatcher(double target, double margin)
+ :m_target{ target }, m_margin{ margin } {
+ CATCH_ENFORCE(margin >= 0, "Invalid margin: " << margin << '.'
+ << " Margin has to be non-negative.");
}
- void TestCase::invoke() const {
- test->invoke();
+ // Performs equivalent check of std::fabs(lhs - rhs) <= margin
+ // But without the subtraction to allow for INFINITY in comparison
+ bool WithinAbsMatcher::match(double const& matchee) const {
+ return (matchee + m_margin >= m_target) && (m_target + m_margin >= matchee);
}
- bool TestCase::operator == ( TestCase const& other ) const {
- return test.get() == other.test.get() &&
- name == other.name &&
- className == other.className;
+ std::string WithinAbsMatcher::describe() const {
+ return "is within " + ::Catch::Detail::stringify(m_margin) + " of " + ::Catch::Detail::stringify(m_target);
}
- bool TestCase::operator < ( TestCase const& other ) const {
- return name < other.name;
+ WithinUlpsMatcher::WithinUlpsMatcher(double target, uint64_t ulps, FloatingPointKind baseType)
+ :m_target{ target }, m_ulps{ ulps }, m_type{ baseType } {
+ CATCH_ENFORCE(m_type == FloatingPointKind::Double
+ || m_ulps < (std::numeric_limits<uint32_t>::max)(),
+ "Provided ULP is impossibly large for a float comparison.");
}
- TestCase& TestCase::operator = ( TestCase const& other ) {
- TestCase temp( other );
- swap( temp );
- return *this;
+
+#if defined(__clang__)
+#pragma clang diagnostic push
+// Clang <3.5 reports on the default branch in the switch below
+#pragma clang diagnostic ignored "-Wunreachable-code"
+#endif
+
+ bool WithinUlpsMatcher::match(double const& matchee) const {
+ switch (m_type) {
+ case FloatingPointKind::Float:
+ return almostEqualUlps<float>(static_cast<float>(matchee), static_cast<float>(m_target), m_ulps);
+ case FloatingPointKind::Double:
+ return almostEqualUlps<double>(matchee, m_target, m_ulps);
+ default:
+ CATCH_INTERNAL_ERROR( "Unknown FloatingPointKind value" );
+ }
}
- TestCaseInfo const& TestCase::getTestCaseInfo() const
- {
- return *this;
+#if defined(__clang__)
+#pragma clang diagnostic pop
+#endif
+
+ std::string WithinUlpsMatcher::describe() const {
+ std::stringstream ret;
+
+ ret << "is within " << m_ulps << " ULPs of ";
+
+ if (m_type == FloatingPointKind::Float) {
+ write(ret, static_cast<float>(m_target));
+ ret << 'f';
+ } else {
+ write(ret, m_target);
+ }
+
+ ret << " ([";
+ if (m_type == FloatingPointKind::Double) {
+ write(ret, step(m_target, static_cast<double>(-INFINITY), m_ulps));
+ ret << ", ";
+ write(ret, step(m_target, static_cast<double>( INFINITY), m_ulps));
+ } else {
+ write(ret, step(static_cast<float>(m_target), -INFINITY, m_ulps));
+ ret << ", ";
+ write(ret, step(static_cast<float>(m_target), INFINITY, m_ulps));
+ }
+ ret << "])";
+
+ return ret.str();
}
-} // end namespace Catch
+ WithinRelMatcher::WithinRelMatcher(double target, double epsilon):
+ m_target(target),
+ m_epsilon(epsilon){
+ CATCH_ENFORCE(m_epsilon >= 0., "Relative comparison with epsilon < 0 does not make sense.");
+ CATCH_ENFORCE(m_epsilon < 1., "Relative comparison with epsilon >= 1 does not make sense.");
+ }
+
+ bool WithinRelMatcher::match(double const& matchee) const {
+ const auto relMargin = m_epsilon * (std::max)(std::fabs(matchee), std::fabs(m_target));
+ return marginComparison(matchee, m_target,
+ std::isinf(relMargin)? 0 : relMargin);
+ }
-// #included from: catch_version.hpp
-#define TWOBLUECUBES_CATCH_VERSION_HPP_INCLUDED
+ std::string WithinRelMatcher::describe() const {
+ Catch::ReusableStringStream sstr;
+ sstr << "and " << m_target << " are within " << m_epsilon * 100. << "% of each other";
+ return sstr.str();
+ }
+
+}// namespace Floating
+
+Floating::WithinUlpsMatcher WithinULP(double target, uint64_t maxUlpDiff) {
+ return Floating::WithinUlpsMatcher(target, maxUlpDiff, Floating::FloatingPointKind::Double);
+}
+
+Floating::WithinUlpsMatcher WithinULP(float target, uint64_t maxUlpDiff) {
+ return Floating::WithinUlpsMatcher(target, maxUlpDiff, Floating::FloatingPointKind::Float);
+}
+
+Floating::WithinAbsMatcher WithinAbs(double target, double margin) {
+ return Floating::WithinAbsMatcher(target, margin);
+}
+
+Floating::WithinRelMatcher WithinRel(double target, double eps) {
+ return Floating::WithinRelMatcher(target, eps);
+}
+
+Floating::WithinRelMatcher WithinRel(double target) {
+ return Floating::WithinRelMatcher(target, std::numeric_limits<double>::epsilon() * 100);
+}
+
+Floating::WithinRelMatcher WithinRel(float target, float eps) {
+ return Floating::WithinRelMatcher(target, eps);
+}
+
+Floating::WithinRelMatcher WithinRel(float target) {
+ return Floating::WithinRelMatcher(target, std::numeric_limits<float>::epsilon() * 100);
+}
+
+} // namespace Matchers
+} // namespace Catch
+
+// end catch_matchers_floating.cpp
+// start catch_matchers_generic.cpp
+
+std::string Catch::Matchers::Generic::Detail::finalizeDescription(const std::string& desc) {
+ if (desc.empty()) {
+ return "matches undescribed predicate";
+ } else {
+ return "matches predicate: \"" + desc + '"';
+ }
+}
+// end catch_matchers_generic.cpp
+// start catch_matchers_string.cpp
+
+#include <regex>
namespace Catch {
+namespace Matchers {
- Version::Version
- ( unsigned int _majorVersion,
- unsigned int _minorVersion,
- unsigned int _patchNumber,
- char const * const _branchName,
- unsigned int _buildNumber )
- : majorVersion( _majorVersion ),
- minorVersion( _minorVersion ),
- patchNumber( _patchNumber ),
- branchName( _branchName ),
- buildNumber( _buildNumber )
- {}
+ namespace StdString {
- std::ostream& operator << ( std::ostream& os, Version const& version ) {
- os << version.majorVersion << '.'
- << version.minorVersion << '.'
- << version.patchNumber;
- // branchName is never null -> 0th char is \0 if it is empty
- if (version.branchName[0]) {
- os << '-' << version.branchName
- << '.' << version.buildNumber;
+ CasedString::CasedString( std::string const& str, CaseSensitive::Choice caseSensitivity )
+ : m_caseSensitivity( caseSensitivity ),
+ m_str( adjustString( str ) )
+ {}
+ std::string CasedString::adjustString( std::string const& str ) const {
+ return m_caseSensitivity == CaseSensitive::No
+ ? toLower( str )
+ : str;
}
- return os;
+ std::string CasedString::caseSensitivitySuffix() const {
+ return m_caseSensitivity == CaseSensitive::No
+ ? " (case insensitive)"
+ : std::string();
+ }
+
+ StringMatcherBase::StringMatcherBase( std::string const& operation, CasedString const& comparator )
+ : m_comparator( comparator ),
+ m_operation( operation ) {
+ }
+
+ std::string StringMatcherBase::describe() const {
+ std::string description;
+ description.reserve(5 + m_operation.size() + m_comparator.m_str.size() +
+ m_comparator.caseSensitivitySuffix().size());
+ description += m_operation;
+ description += ": \"";
+ description += m_comparator.m_str;
+ description += "\"";
+ description += m_comparator.caseSensitivitySuffix();
+ return description;
+ }
+
+ EqualsMatcher::EqualsMatcher( CasedString const& comparator ) : StringMatcherBase( "equals", comparator ) {}
+
+ bool EqualsMatcher::match( std::string const& source ) const {
+ return m_comparator.adjustString( source ) == m_comparator.m_str;
+ }
+
+ ContainsMatcher::ContainsMatcher( CasedString const& comparator ) : StringMatcherBase( "contains", comparator ) {}
+
+ bool ContainsMatcher::match( std::string const& source ) const {
+ return contains( m_comparator.adjustString( source ), m_comparator.m_str );
+ }
+
+ StartsWithMatcher::StartsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "starts with", comparator ) {}
+
+ bool StartsWithMatcher::match( std::string const& source ) const {
+ return startsWith( m_comparator.adjustString( source ), m_comparator.m_str );
+ }
+
+ EndsWithMatcher::EndsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "ends with", comparator ) {}
+
+ bool EndsWithMatcher::match( std::string const& source ) const {
+ return endsWith( m_comparator.adjustString( source ), m_comparator.m_str );
+ }
+
+ RegexMatcher::RegexMatcher(std::string regex, CaseSensitive::Choice caseSensitivity): m_regex(std::move(regex)), m_caseSensitivity(caseSensitivity) {}
+
+ bool RegexMatcher::match(std::string const& matchee) const {
+ auto flags = std::regex::ECMAScript; // ECMAScript is the default syntax option anyway
+ if (m_caseSensitivity == CaseSensitive::Choice::No) {
+ flags |= std::regex::icase;
+ }
+ auto reg = std::regex(m_regex, flags);
+ return std::regex_match(matchee, reg);
+ }
+
+ std::string RegexMatcher::describe() const {
+ return "matches " + ::Catch::Detail::stringify(m_regex) + ((m_caseSensitivity == CaseSensitive::Choice::Yes)? " case sensitively" : " case insensitively");
+ }
+
+ } // namespace StdString
+
+ StdString::EqualsMatcher Equals( std::string const& str, CaseSensitive::Choice caseSensitivity ) {
+ return StdString::EqualsMatcher( StdString::CasedString( str, caseSensitivity) );
+ }
+ StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity ) {
+ return StdString::ContainsMatcher( StdString::CasedString( str, caseSensitivity) );
+ }
+ StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) {
+ return StdString::EndsWithMatcher( StdString::CasedString( str, caseSensitivity) );
+ }
+ StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) {
+ return StdString::StartsWithMatcher( StdString::CasedString( str, caseSensitivity) );
}
- inline Version libraryVersion() {
- static Version version( 1, 9, 4, "", 0 );
- return version;
+ StdString::RegexMatcher Matches(std::string const& regex, CaseSensitive::Choice caseSensitivity) {
+ return StdString::RegexMatcher(regex, caseSensitivity);
}
-}
+} // namespace Matchers
+} // namespace Catch
+// end catch_matchers_string.cpp
+// start catch_message.cpp
+
+// start catch_uncaught_exceptions.h
+
+namespace Catch {
+ bool uncaught_exceptions();
+} // end namespace Catch
-// #included from: catch_message.hpp
-#define TWOBLUECUBES_CATCH_MESSAGE_HPP_INCLUDED
+// end catch_uncaught_exceptions.h
+#include <cassert>
+#include <stack>
namespace Catch {
- MessageInfo::MessageInfo( std::string const& _macroName,
+ MessageInfo::MessageInfo( StringRef const& _macroName,
SourceLineInfo const& _lineInfo,
ResultWas::OfType _type )
: macroName( _macroName ),
@@ -8329,217 +11628,1912 @@ namespace Catch {
sequence( ++globalCount )
{}
+ bool MessageInfo::operator==( MessageInfo const& other ) const {
+ return sequence == other.sequence;
+ }
+
+ bool MessageInfo::operator<( MessageInfo const& other ) const {
+ return sequence < other.sequence;
+ }
+
// This may need protecting if threading support is added
unsigned int MessageInfo::globalCount = 0;
////////////////////////////////////////////////////////////////////////////
+ Catch::MessageBuilder::MessageBuilder( StringRef const& macroName,
+ SourceLineInfo const& lineInfo,
+ ResultWas::OfType type )
+ :m_info(macroName, lineInfo, type) {}
+
+ ////////////////////////////////////////////////////////////////////////////
+
ScopedMessage::ScopedMessage( MessageBuilder const& builder )
- : m_info( builder.m_info )
+ : m_info( builder.m_info ), m_moved()
{
m_info.message = builder.m_stream.str();
getResultCapture().pushScopedMessage( m_info );
}
- ScopedMessage::ScopedMessage( ScopedMessage const& other )
- : m_info( other.m_info )
- {}
+
+ ScopedMessage::ScopedMessage( ScopedMessage&& old )
+ : m_info( old.m_info ), m_moved()
+ {
+ old.m_moved = true;
+ }
ScopedMessage::~ScopedMessage() {
- if ( !std::uncaught_exception() ){
+ if ( !uncaught_exceptions() && !m_moved ){
getResultCapture().popScopedMessage(m_info);
}
}
+ Capturer::Capturer( StringRef macroName, SourceLineInfo const& lineInfo, ResultWas::OfType resultType, StringRef names ) {
+ auto trimmed = [&] (size_t start, size_t end) {
+ while (names[start] == ',' || isspace(names[start])) {
+ ++start;
+ }
+ while (names[end] == ',' || isspace(names[end])) {
+ --end;
+ }
+ return names.substr(start, end - start + 1);
+ };
+ auto skipq = [&] (size_t start, char quote) {
+ for (auto i = start + 1; i < names.size() ; ++i) {
+ if (names[i] == quote)
+ return i;
+ if (names[i] == '\\')
+ ++i;
+ }
+ CATCH_INTERNAL_ERROR("CAPTURE parsing encountered unmatched quote");
+ };
+
+ size_t start = 0;
+ std::stack<char> openings;
+ for (size_t pos = 0; pos < names.size(); ++pos) {
+ char c = names[pos];
+ switch (c) {
+ case '[':
+ case '{':
+ case '(':
+ // It is basically impossible to disambiguate between
+ // comparison and start of template args in this context
+// case '<':
+ openings.push(c);
+ break;
+ case ']':
+ case '}':
+ case ')':
+// case '>':
+ openings.pop();
+ break;
+ case '"':
+ case '\'':
+ pos = skipq(pos, c);
+ break;
+ case ',':
+ if (start != pos && openings.size() == 0) {
+ m_messages.emplace_back(macroName, lineInfo, resultType);
+ m_messages.back().message = static_cast<std::string>(trimmed(start, pos));
+ m_messages.back().message += " := ";
+ start = pos;
+ }
+ }
+ }
+ assert(openings.size() == 0 && "Mismatched openings");
+ m_messages.emplace_back(macroName, lineInfo, resultType);
+ m_messages.back().message = static_cast<std::string>(trimmed(start, names.size() - 1));
+ m_messages.back().message += " := ";
+ }
+ Capturer::~Capturer() {
+ if ( !uncaught_exceptions() ){
+ assert( m_captured == m_messages.size() );
+ for( size_t i = 0; i < m_captured; ++i )
+ m_resultCapture.popScopedMessage( m_messages[i] );
+ }
+ }
+
+ void Capturer::captureValue( size_t index, std::string const& value ) {
+ assert( index < m_messages.size() );
+ m_messages[index].message += value;
+ m_resultCapture.pushScopedMessage( m_messages[index] );
+ m_captured++;
+ }
+
} // end namespace Catch
+// end catch_message.cpp
+// start catch_output_redirect.cpp
-// #included from: catch_legacy_reporter_adapter.hpp
-#define TWOBLUECUBES_CATCH_LEGACY_REPORTER_ADAPTER_HPP_INCLUDED
+// start catch_output_redirect.h
+#ifndef TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H
+#define TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H
-// #included from: catch_legacy_reporter_adapter.h
-#define TWOBLUECUBES_CATCH_LEGACY_REPORTER_ADAPTER_H_INCLUDED
+#include <cstdio>
+#include <iosfwd>
+#include <string>
-namespace Catch
-{
- // Deprecated
- struct IReporter : IShared {
- virtual ~IReporter();
-
- virtual bool shouldRedirectStdout() const = 0;
-
- virtual void StartTesting() = 0;
- virtual void EndTesting( Totals const& totals ) = 0;
- virtual void StartGroup( std::string const& groupName ) = 0;
- virtual void EndGroup( std::string const& groupName, Totals const& totals ) = 0;
- virtual void StartTestCase( TestCaseInfo const& testInfo ) = 0;
- virtual void EndTestCase( TestCaseInfo const& testInfo, Totals const& totals, std::string const& stdOut, std::string const& stdErr ) = 0;
- virtual void StartSection( std::string const& sectionName, std::string const& description ) = 0;
- virtual void EndSection( std::string const& sectionName, Counts const& assertions ) = 0;
- virtual void NoAssertionsInSection( std::string const& sectionName ) = 0;
- virtual void NoAssertionsInTestCase( std::string const& testName ) = 0;
- virtual void Aborted() = 0;
- virtual void Result( AssertionResult const& result ) = 0;
- };
-
- class LegacyReporterAdapter : public SharedImpl<IStreamingReporter>
+namespace Catch {
+
+ class RedirectedStream {
+ std::ostream& m_originalStream;
+ std::ostream& m_redirectionStream;
+ std::streambuf* m_prevBuf;
+
+ public:
+ RedirectedStream( std::ostream& originalStream, std::ostream& redirectionStream );
+ ~RedirectedStream();
+ };
+
+ class RedirectedStdOut {
+ ReusableStringStream m_rss;
+ RedirectedStream m_cout;
+ public:
+ RedirectedStdOut();
+ auto str() const -> std::string;
+ };
+
+ // StdErr has two constituent streams in C++, std::cerr and std::clog
+ // This means that we need to redirect 2 streams into 1 to keep proper
+ // order of writes
+ class RedirectedStdErr {
+ ReusableStringStream m_rss;
+ RedirectedStream m_cerr;
+ RedirectedStream m_clog;
+ public:
+ RedirectedStdErr();
+ auto str() const -> std::string;
+ };
+
+ class RedirectedStreams {
+ public:
+ RedirectedStreams(RedirectedStreams const&) = delete;
+ RedirectedStreams& operator=(RedirectedStreams const&) = delete;
+ RedirectedStreams(RedirectedStreams&&) = delete;
+ RedirectedStreams& operator=(RedirectedStreams&&) = delete;
+
+ RedirectedStreams(std::string& redirectedCout, std::string& redirectedCerr);
+ ~RedirectedStreams();
+ private:
+ std::string& m_redirectedCout;
+ std::string& m_redirectedCerr;
+ RedirectedStdOut m_redirectedStdOut;
+ RedirectedStdErr m_redirectedStdErr;
+ };
+
+#if defined(CATCH_CONFIG_NEW_CAPTURE)
+
+ // Windows's implementation of std::tmpfile is terrible (it tries
+ // to create a file inside system folder, thus requiring elevated
+ // privileges for the binary), so we have to use tmpnam(_s) and
+ // create the file ourselves there.
+ class TempFile {
+ public:
+ TempFile(TempFile const&) = delete;
+ TempFile& operator=(TempFile const&) = delete;
+ TempFile(TempFile&&) = delete;
+ TempFile& operator=(TempFile&&) = delete;
+
+ TempFile();
+ ~TempFile();
+
+ std::FILE* getFile();
+ std::string getContents();
+
+ private:
+ std::FILE* m_file = nullptr;
+ #if defined(_MSC_VER)
+ char m_buffer[L_tmpnam] = { 0 };
+ #endif
+ };
+
+ class OutputRedirect {
+ public:
+ OutputRedirect(OutputRedirect const&) = delete;
+ OutputRedirect& operator=(OutputRedirect const&) = delete;
+ OutputRedirect(OutputRedirect&&) = delete;
+ OutputRedirect& operator=(OutputRedirect&&) = delete;
+
+ OutputRedirect(std::string& stdout_dest, std::string& stderr_dest);
+ ~OutputRedirect();
+
+ private:
+ int m_originalStdout = -1;
+ int m_originalStderr = -1;
+ TempFile m_stdoutFile;
+ TempFile m_stderrFile;
+ std::string& m_stdoutDest;
+ std::string& m_stderrDest;
+ };
+
+#endif
+
+} // end namespace Catch
+
+#endif // TWOBLUECUBES_CATCH_OUTPUT_REDIRECT_H
+// end catch_output_redirect.h
+#include <cstdio>
+#include <cstring>
+#include <fstream>
+#include <sstream>
+#include <stdexcept>
+
+#if defined(CATCH_CONFIG_NEW_CAPTURE)
+ #if defined(_MSC_VER)
+ #include <io.h> //_dup and _dup2
+ #define dup _dup
+ #define dup2 _dup2
+ #define fileno _fileno
+ #else
+ #include <unistd.h> // dup and dup2
+ #endif
+#endif
+
+namespace Catch {
+
+ RedirectedStream::RedirectedStream( std::ostream& originalStream, std::ostream& redirectionStream )
+ : m_originalStream( originalStream ),
+ m_redirectionStream( redirectionStream ),
+ m_prevBuf( m_originalStream.rdbuf() )
{
+ m_originalStream.rdbuf( m_redirectionStream.rdbuf() );
+ }
+
+ RedirectedStream::~RedirectedStream() {
+ m_originalStream.rdbuf( m_prevBuf );
+ }
+
+ RedirectedStdOut::RedirectedStdOut() : m_cout( Catch::cout(), m_rss.get() ) {}
+ auto RedirectedStdOut::str() const -> std::string { return m_rss.str(); }
+
+ RedirectedStdErr::RedirectedStdErr()
+ : m_cerr( Catch::cerr(), m_rss.get() ),
+ m_clog( Catch::clog(), m_rss.get() )
+ {}
+ auto RedirectedStdErr::str() const -> std::string { return m_rss.str(); }
+
+ RedirectedStreams::RedirectedStreams(std::string& redirectedCout, std::string& redirectedCerr)
+ : m_redirectedCout(redirectedCout),
+ m_redirectedCerr(redirectedCerr)
+ {}
+
+ RedirectedStreams::~RedirectedStreams() {
+ m_redirectedCout += m_redirectedStdOut.str();
+ m_redirectedCerr += m_redirectedStdErr.str();
+ }
+
+#if defined(CATCH_CONFIG_NEW_CAPTURE)
+
+#if defined(_MSC_VER)
+ TempFile::TempFile() {
+ if (tmpnam_s(m_buffer)) {
+ CATCH_RUNTIME_ERROR("Could not get a temp filename");
+ }
+ if (fopen_s(&m_file, m_buffer, "w")) {
+ char buffer[100];
+ if (strerror_s(buffer, errno)) {
+ CATCH_RUNTIME_ERROR("Could not translate errno to a string");
+ }
+ CATCH_RUNTIME_ERROR("Could not open the temp file: '" << m_buffer << "' because: " << buffer);
+ }
+ }
+#else
+ TempFile::TempFile() {
+ m_file = std::tmpfile();
+ if (!m_file) {
+ CATCH_RUNTIME_ERROR("Could not create a temp file.");
+ }
+ }
+
+#endif
+
+ TempFile::~TempFile() {
+ // TBD: What to do about errors here?
+ std::fclose(m_file);
+ // We manually create the file on Windows only, on Linux
+ // it will be autodeleted
+#if defined(_MSC_VER)
+ std::remove(m_buffer);
+#endif
+ }
+
+ FILE* TempFile::getFile() {
+ return m_file;
+ }
+
+ std::string TempFile::getContents() {
+ std::stringstream sstr;
+ char buffer[100] = {};
+ std::rewind(m_file);
+ while (std::fgets(buffer, sizeof(buffer), m_file)) {
+ sstr << buffer;
+ }
+ return sstr.str();
+ }
+
+ OutputRedirect::OutputRedirect(std::string& stdout_dest, std::string& stderr_dest) :
+ m_originalStdout(dup(1)),
+ m_originalStderr(dup(2)),
+ m_stdoutDest(stdout_dest),
+ m_stderrDest(stderr_dest) {
+ dup2(fileno(m_stdoutFile.getFile()), 1);
+ dup2(fileno(m_stderrFile.getFile()), 2);
+ }
+
+ OutputRedirect::~OutputRedirect() {
+ Catch::cout() << std::flush;
+ fflush(stdout);
+ // Since we support overriding these streams, we flush cerr
+ // even though std::cerr is unbuffered
+ Catch::cerr() << std::flush;
+ Catch::clog() << std::flush;
+ fflush(stderr);
+
+ dup2(m_originalStdout, 1);
+ dup2(m_originalStderr, 2);
+
+ m_stdoutDest += m_stdoutFile.getContents();
+ m_stderrDest += m_stderrFile.getContents();
+ }
+
+#endif // CATCH_CONFIG_NEW_CAPTURE
+
+} // namespace Catch
+
+#if defined(CATCH_CONFIG_NEW_CAPTURE)
+ #if defined(_MSC_VER)
+ #undef dup
+ #undef dup2
+ #undef fileno
+ #endif
+#endif
+// end catch_output_redirect.cpp
+// start catch_polyfills.cpp
+
+#include <cmath>
+
+namespace Catch {
+
+#if !defined(CATCH_CONFIG_POLYFILL_ISNAN)
+ bool isnan(float f) {
+ return std::isnan(f);
+ }
+ bool isnan(double d) {
+ return std::isnan(d);
+ }
+#else
+ // For now we only use this for embarcadero
+ bool isnan(float f) {
+ return std::_isnan(f);
+ }
+ bool isnan(double d) {
+ return std::_isnan(d);
+ }
+#endif
+
+} // end namespace Catch
+// end catch_polyfills.cpp
+// start catch_random_number_generator.cpp
+
+namespace Catch {
+
+namespace {
+
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable:4146) // we negate uint32 during the rotate
+#endif
+ // Safe rotr implementation thanks to John Regehr
+ uint32_t rotate_right(uint32_t val, uint32_t count) {
+ const uint32_t mask = 31;
+ count &= mask;
+ return (val >> count) | (val << (-count & mask));
+ }
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+
+}
+
+ SimplePcg32::SimplePcg32(result_type seed_) {
+ seed(seed_);
+ }
+
+ void SimplePcg32::seed(result_type seed_) {
+ m_state = 0;
+ (*this)();
+ m_state += seed_;
+ (*this)();
+ }
+
+ void SimplePcg32::discard(uint64_t skip) {
+ // We could implement this to run in O(log n) steps, but this
+ // should suffice for our use case.
+ for (uint64_t s = 0; s < skip; ++s) {
+ static_cast<void>((*this)());
+ }
+ }
+
+ SimplePcg32::result_type SimplePcg32::operator()() {
+ // prepare the output value
+ const uint32_t xorshifted = static_cast<uint32_t>(((m_state >> 18u) ^ m_state) >> 27u);
+ const auto output = rotate_right(xorshifted, m_state >> 59u);
+
+ // advance state
+ m_state = m_state * 6364136223846793005ULL + s_inc;
+
+ return output;
+ }
+
+ bool operator==(SimplePcg32 const& lhs, SimplePcg32 const& rhs) {
+ return lhs.m_state == rhs.m_state;
+ }
+
+ bool operator!=(SimplePcg32 const& lhs, SimplePcg32 const& rhs) {
+ return lhs.m_state != rhs.m_state;
+ }
+}
+// end catch_random_number_generator.cpp
+// start catch_registry_hub.cpp
+
+// start catch_test_case_registry_impl.h
+
+#include <vector>
+#include <set>
+#include <algorithm>
+#include <ios>
+
+namespace Catch {
+
+ class TestCase;
+ struct IConfig;
+
+ std::vector<TestCase> sortTests( IConfig const& config, std::vector<TestCase> const& unsortedTestCases );
+
+ bool isThrowSafe( TestCase const& testCase, IConfig const& config );
+ bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config );
+
+ void enforceNoDuplicateTestCases( std::vector<TestCase> const& functions );
+
+ std::vector<TestCase> filterTests( std::vector<TestCase> const& testCases, TestSpec const& testSpec, IConfig const& config );
+ std::vector<TestCase> const& getAllTestCasesSorted( IConfig const& config );
+
+ class TestRegistry : public ITestCaseRegistry {
+ public:
+ virtual ~TestRegistry() = default;
+
+ virtual void registerTest( TestCase const& testCase );
+
+ std::vector<TestCase> const& getAllTests() const override;
+ std::vector<TestCase> const& getAllTestsSorted( IConfig const& config ) const override;
+
+ private:
+ std::vector<TestCase> m_functions;
+ mutable RunTests::InWhatOrder m_currentSortOrder = RunTests::InDeclarationOrder;
+ mutable std::vector<TestCase> m_sortedFunctions;
+ std::size_t m_unnamedCount = 0;
+ std::ios_base::Init m_ostreamInit; // Forces cout/ cerr to be initialised
+ };
+
+ ///////////////////////////////////////////////////////////////////////////
+
+ class TestInvokerAsFunction : public ITestInvoker {
+ void(*m_testAsFunction)();
public:
- LegacyReporterAdapter( Ptr<IReporter> const& legacyReporter );
- virtual ~LegacyReporterAdapter();
-
- virtual ReporterPreferences getPreferences() const;
- virtual void noMatchingTestCases( std::string const& );
- virtual void testRunStarting( TestRunInfo const& );
- virtual void testGroupStarting( GroupInfo const& groupInfo );
- virtual void testCaseStarting( TestCaseInfo const& testInfo );
- virtual void sectionStarting( SectionInfo const& sectionInfo );
- virtual void assertionStarting( AssertionInfo const& );
- virtual bool assertionEnded( AssertionStats const& assertionStats );
- virtual void sectionEnded( SectionStats const& sectionStats );
- virtual void testCaseEnded( TestCaseStats const& testCaseStats );
- virtual void testGroupEnded( TestGroupStats const& testGroupStats );
- virtual void testRunEnded( TestRunStats const& testRunStats );
- virtual void skipTest( TestCaseInfo const& );
+ TestInvokerAsFunction( void(*testAsFunction)() ) noexcept;
+
+ void invoke() const override;
+ };
+
+ std::string extractClassName( StringRef const& classOrQualifiedMethodName );
+
+ ///////////////////////////////////////////////////////////////////////////
+
+} // end namespace Catch
+
+// end catch_test_case_registry_impl.h
+// start catch_reporter_registry.h
+
+#include <map>
+
+namespace Catch {
+
+ class ReporterRegistry : public IReporterRegistry {
+
+ public:
+
+ ~ReporterRegistry() override;
+
+ IStreamingReporterPtr create( std::string const& name, IConfigPtr const& config ) const override;
+
+ void registerReporter( std::string const& name, IReporterFactoryPtr const& factory );
+ void registerListener( IReporterFactoryPtr const& factory );
+
+ FactoryMap const& getFactories() const override;
+ Listeners const& getListeners() const override;
private:
- Ptr<IReporter> m_legacyReporter;
+ FactoryMap m_factories;
+ Listeners m_listeners;
};
}
-namespace Catch
-{
- LegacyReporterAdapter::LegacyReporterAdapter( Ptr<IReporter> const& legacyReporter )
- : m_legacyReporter( legacyReporter )
- {}
- LegacyReporterAdapter::~LegacyReporterAdapter() {}
+// end catch_reporter_registry.h
+// start catch_tag_alias_registry.h
+
+// start catch_tag_alias.h
+
+#include <string>
- ReporterPreferences LegacyReporterAdapter::getPreferences() const {
- ReporterPreferences prefs;
- prefs.shouldRedirectStdOut = m_legacyReporter->shouldRedirectStdout();
- return prefs;
+namespace Catch {
+
+ struct TagAlias {
+ TagAlias(std::string const& _tag, SourceLineInfo _lineInfo);
+
+ std::string tag;
+ SourceLineInfo lineInfo;
+ };
+
+} // end namespace Catch
+
+// end catch_tag_alias.h
+#include <map>
+
+namespace Catch {
+
+ class TagAliasRegistry : public ITagAliasRegistry {
+ public:
+ ~TagAliasRegistry() override;
+ TagAlias const* find( std::string const& alias ) const override;
+ std::string expandAliases( std::string const& unexpandedTestSpec ) const override;
+ void add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo );
+
+ private:
+ std::map<std::string, TagAlias> m_registry;
+ };
+
+} // end namespace Catch
+
+// end catch_tag_alias_registry.h
+// start catch_startup_exception_registry.h
+
+#include <vector>
+#include <exception>
+
+namespace Catch {
+
+ class StartupExceptionRegistry {
+ public:
+ void add(std::exception_ptr const& exception) noexcept;
+ std::vector<std::exception_ptr> const& getExceptions() const noexcept;
+ private:
+ std::vector<std::exception_ptr> m_exceptions;
+ };
+
+} // end namespace Catch
+
+// end catch_startup_exception_registry.h
+// start catch_singletons.hpp
+
+namespace Catch {
+
+ struct ISingleton {
+ virtual ~ISingleton();
+ };
+
+ void addSingleton( ISingleton* singleton );
+ void cleanupSingletons();
+
+ template<typename SingletonImplT, typename InterfaceT = SingletonImplT, typename MutableInterfaceT = InterfaceT>
+ class Singleton : SingletonImplT, public ISingleton {
+
+ static auto getInternal() -> Singleton* {
+ static Singleton* s_instance = nullptr;
+ if( !s_instance ) {
+ s_instance = new Singleton;
+ addSingleton( s_instance );
+ }
+ return s_instance;
+ }
+
+ public:
+ static auto get() -> InterfaceT const& {
+ return *getInternal();
+ }
+ static auto getMutable() -> MutableInterfaceT& {
+ return *getInternal();
+ }
+ };
+
+} // namespace Catch
+
+// end catch_singletons.hpp
+namespace Catch {
+
+ namespace {
+
+ class RegistryHub : public IRegistryHub, public IMutableRegistryHub,
+ private NonCopyable {
+
+ public: // IRegistryHub
+ RegistryHub() = default;
+ IReporterRegistry const& getReporterRegistry() const override {
+ return m_reporterRegistry;
+ }
+ ITestCaseRegistry const& getTestCaseRegistry() const override {
+ return m_testCaseRegistry;
+ }
+ IExceptionTranslatorRegistry const& getExceptionTranslatorRegistry() const override {
+ return m_exceptionTranslatorRegistry;
+ }
+ ITagAliasRegistry const& getTagAliasRegistry() const override {
+ return m_tagAliasRegistry;
+ }
+ StartupExceptionRegistry const& getStartupExceptionRegistry() const override {
+ return m_exceptionRegistry;
+ }
+
+ public: // IMutableRegistryHub
+ void registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) override {
+ m_reporterRegistry.registerReporter( name, factory );
+ }
+ void registerListener( IReporterFactoryPtr const& factory ) override {
+ m_reporterRegistry.registerListener( factory );
+ }
+ void registerTest( TestCase const& testInfo ) override {
+ m_testCaseRegistry.registerTest( testInfo );
+ }
+ void registerTranslator( const IExceptionTranslator* translator ) override {
+ m_exceptionTranslatorRegistry.registerTranslator( translator );
+ }
+ void registerTagAlias( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) override {
+ m_tagAliasRegistry.add( alias, tag, lineInfo );
+ }
+ void registerStartupException() noexcept override {
+ m_exceptionRegistry.add(std::current_exception());
+ }
+ IMutableEnumValuesRegistry& getMutableEnumValuesRegistry() override {
+ return m_enumValuesRegistry;
+ }
+
+ private:
+ TestRegistry m_testCaseRegistry;
+ ReporterRegistry m_reporterRegistry;
+ ExceptionTranslatorRegistry m_exceptionTranslatorRegistry;
+ TagAliasRegistry m_tagAliasRegistry;
+ StartupExceptionRegistry m_exceptionRegistry;
+ Detail::EnumValuesRegistry m_enumValuesRegistry;
+ };
}
- void LegacyReporterAdapter::noMatchingTestCases( std::string const& ) {}
- void LegacyReporterAdapter::testRunStarting( TestRunInfo const& ) {
- m_legacyReporter->StartTesting();
+ using RegistryHubSingleton = Singleton<RegistryHub, IRegistryHub, IMutableRegistryHub>;
+
+ IRegistryHub const& getRegistryHub() {
+ return RegistryHubSingleton::get();
+ }
+ IMutableRegistryHub& getMutableRegistryHub() {
+ return RegistryHubSingleton::getMutable();
}
- void LegacyReporterAdapter::testGroupStarting( GroupInfo const& groupInfo ) {
- m_legacyReporter->StartGroup( groupInfo.name );
+ void cleanUp() {
+ cleanupSingletons();
+ cleanUpContext();
}
- void LegacyReporterAdapter::testCaseStarting( TestCaseInfo const& testInfo ) {
- m_legacyReporter->StartTestCase( testInfo );
+ std::string translateActiveException() {
+ return getRegistryHub().getExceptionTranslatorRegistry().translateActiveException();
+ }
+
+} // end namespace Catch
+// end catch_registry_hub.cpp
+// start catch_reporter_registry.cpp
+
+namespace Catch {
+
+ ReporterRegistry::~ReporterRegistry() = default;
+
+ IStreamingReporterPtr ReporterRegistry::create( std::string const& name, IConfigPtr const& config ) const {
+ auto it = m_factories.find( name );
+ if( it == m_factories.end() )
+ return nullptr;
+ return it->second->create( ReporterConfig( config ) );
+ }
+
+ void ReporterRegistry::registerReporter( std::string const& name, IReporterFactoryPtr const& factory ) {
+ m_factories.emplace(name, factory);
+ }
+ void ReporterRegistry::registerListener( IReporterFactoryPtr const& factory ) {
+ m_listeners.push_back( factory );
+ }
+
+ IReporterRegistry::FactoryMap const& ReporterRegistry::getFactories() const {
+ return m_factories;
+ }
+ IReporterRegistry::Listeners const& ReporterRegistry::getListeners() const {
+ return m_listeners;
+ }
+
+}
+// end catch_reporter_registry.cpp
+// start catch_result_type.cpp
+
+namespace Catch {
+
+ bool isOk( ResultWas::OfType resultType ) {
+ return ( resultType & ResultWas::FailureBit ) == 0;
}
- void LegacyReporterAdapter::sectionStarting( SectionInfo const& sectionInfo ) {
- m_legacyReporter->StartSection( sectionInfo.name, sectionInfo.description );
+ bool isJustInfo( int flags ) {
+ return flags == ResultWas::Info;
}
- void LegacyReporterAdapter::assertionStarting( AssertionInfo const& ) {
- // Not on legacy interface
+
+ ResultDisposition::Flags operator | ( ResultDisposition::Flags lhs, ResultDisposition::Flags rhs ) {
+ return static_cast<ResultDisposition::Flags>( static_cast<int>( lhs ) | static_cast<int>( rhs ) );
}
- bool LegacyReporterAdapter::assertionEnded( AssertionStats const& assertionStats ) {
- if( assertionStats.assertionResult.getResultType() != ResultWas::Ok ) {
- for( std::vector<MessageInfo>::const_iterator it = assertionStats.infoMessages.begin(), itEnd = assertionStats.infoMessages.end();
- it != itEnd;
- ++it ) {
- if( it->type == ResultWas::Info ) {
- ResultBuilder rb( it->macroName.c_str(), it->lineInfo, "", ResultDisposition::Normal );
- rb << it->message;
- rb.setResultType( ResultWas::Info );
- AssertionResult result = rb.build();
- m_legacyReporter->Result( result );
+ bool shouldContinueOnFailure( int flags ) { return ( flags & ResultDisposition::ContinueOnFailure ) != 0; }
+ bool shouldSuppressFailure( int flags ) { return ( flags & ResultDisposition::SuppressFail ) != 0; }
+
+} // end namespace Catch
+// end catch_result_type.cpp
+// start catch_run_context.cpp
+
+#include <cassert>
+#include <algorithm>
+#include <sstream>
+
+namespace Catch {
+
+ namespace Generators {
+ struct GeneratorTracker : TestCaseTracking::TrackerBase, IGeneratorTracker {
+ GeneratorBasePtr m_generator;
+
+ GeneratorTracker( TestCaseTracking::NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent )
+ : TrackerBase( nameAndLocation, ctx, parent )
+ {}
+ ~GeneratorTracker();
+
+ static GeneratorTracker& acquire( TrackerContext& ctx, TestCaseTracking::NameAndLocation const& nameAndLocation ) {
+ std::shared_ptr<GeneratorTracker> tracker;
+
+ ITracker& currentTracker = ctx.currentTracker();
+ if( TestCaseTracking::ITrackerPtr childTracker = currentTracker.findChild( nameAndLocation ) ) {
+ assert( childTracker );
+ assert( childTracker->isGeneratorTracker() );
+ tracker = std::static_pointer_cast<GeneratorTracker>( childTracker );
}
+ else {
+ tracker = std::make_shared<GeneratorTracker>( nameAndLocation, ctx, &currentTracker );
+ currentTracker.addChild( tracker );
+ }
+
+ if( !ctx.completedCycle() && !tracker->isComplete() ) {
+ tracker->open();
+ }
+
+ return *tracker;
+ }
+
+ // TrackerBase interface
+ bool isGeneratorTracker() const override { return true; }
+ auto hasGenerator() const -> bool override {
+ return !!m_generator;
}
+ void close() override {
+ TrackerBase::close();
+ // Generator interface only finds out if it has another item on atual move
+ if (m_runState == CompletedSuccessfully && m_generator->next()) {
+ m_children.clear();
+ m_runState = Executing;
+ }
+ }
+
+ // IGeneratorTracker interface
+ auto getGenerator() const -> GeneratorBasePtr const& override {
+ return m_generator;
+ }
+ void setGenerator( GeneratorBasePtr&& generator ) override {
+ m_generator = std::move( generator );
+ }
+ };
+ GeneratorTracker::~GeneratorTracker() {}
+ }
+
+ RunContext::RunContext(IConfigPtr const& _config, IStreamingReporterPtr&& reporter)
+ : m_runInfo(_config->name()),
+ m_context(getCurrentMutableContext()),
+ m_config(_config),
+ m_reporter(std::move(reporter)),
+ m_lastAssertionInfo{ StringRef(), SourceLineInfo("",0), StringRef(), ResultDisposition::Normal },
+ m_includeSuccessfulResults( m_config->includeSuccessfulResults() || m_reporter->getPreferences().shouldReportAllAssertions )
+ {
+ m_context.setRunner(this);
+ m_context.setConfig(m_config);
+ m_context.setResultCapture(this);
+ m_reporter->testRunStarting(m_runInfo);
+ }
+
+ RunContext::~RunContext() {
+ m_reporter->testRunEnded(TestRunStats(m_runInfo, m_totals, aborting()));
+ }
+
+ void RunContext::testGroupStarting(std::string const& testSpec, std::size_t groupIndex, std::size_t groupsCount) {
+ m_reporter->testGroupStarting(GroupInfo(testSpec, groupIndex, groupsCount));
+ }
+
+ void RunContext::testGroupEnded(std::string const& testSpec, Totals const& totals, std::size_t groupIndex, std::size_t groupsCount) {
+ m_reporter->testGroupEnded(TestGroupStats(GroupInfo(testSpec, groupIndex, groupsCount), totals, aborting()));
+ }
+
+ Totals RunContext::runTest(TestCase const& testCase) {
+ Totals prevTotals = m_totals;
+
+ std::string redirectedCout;
+ std::string redirectedCerr;
+
+ auto const& testInfo = testCase.getTestCaseInfo();
+
+ m_reporter->testCaseStarting(testInfo);
+
+ m_activeTestCase = &testCase;
+
+ ITracker& rootTracker = m_trackerContext.startRun();
+ assert(rootTracker.isSectionTracker());
+ static_cast<SectionTracker&>(rootTracker).addInitialFilters(m_config->getSectionsToRun());
+ do {
+ m_trackerContext.startCycle();
+ m_testCaseTracker = &SectionTracker::acquire(m_trackerContext, TestCaseTracking::NameAndLocation(testInfo.name, testInfo.lineInfo));
+ runCurrentTest(redirectedCout, redirectedCerr);
+ } while (!m_testCaseTracker->isSuccessfullyCompleted() && !aborting());
+
+ Totals deltaTotals = m_totals.delta(prevTotals);
+ if (testInfo.expectedToFail() && deltaTotals.testCases.passed > 0) {
+ deltaTotals.assertions.failed++;
+ deltaTotals.testCases.passed--;
+ deltaTotals.testCases.failed++;
}
- m_legacyReporter->Result( assertionStats.assertionResult );
+ m_totals.testCases += deltaTotals.testCases;
+ m_reporter->testCaseEnded(TestCaseStats(testInfo,
+ deltaTotals,
+ redirectedCout,
+ redirectedCerr,
+ aborting()));
+
+ m_activeTestCase = nullptr;
+ m_testCaseTracker = nullptr;
+
+ return deltaTotals;
+ }
+
+ IConfigPtr RunContext::config() const {
+ return m_config;
+ }
+
+ IStreamingReporter& RunContext::reporter() const {
+ return *m_reporter;
+ }
+
+ void RunContext::assertionEnded(AssertionResult const & result) {
+ if (result.getResultType() == ResultWas::Ok) {
+ m_totals.assertions.passed++;
+ m_lastAssertionPassed = true;
+ } else if (!result.isOk()) {
+ m_lastAssertionPassed = false;
+ if( m_activeTestCase->getTestCaseInfo().okToFail() )
+ m_totals.assertions.failedButOk++;
+ else
+ m_totals.assertions.failed++;
+ }
+ else {
+ m_lastAssertionPassed = true;
+ }
+
+ // We have no use for the return value (whether messages should be cleared), because messages were made scoped
+ // and should be let to clear themselves out.
+ static_cast<void>(m_reporter->assertionEnded(AssertionStats(result, m_messages, m_totals)));
+
+ if (result.getResultType() != ResultWas::Warning)
+ m_messageScopes.clear();
+
+ // Reset working state
+ resetAssertionInfo();
+ m_lastResult = result;
+ }
+ void RunContext::resetAssertionInfo() {
+ m_lastAssertionInfo.macroName = StringRef();
+ m_lastAssertionInfo.capturedExpression = "{Unknown expression after the reported line}"_sr;
+ }
+
+ bool RunContext::sectionStarted(SectionInfo const & sectionInfo, Counts & assertions) {
+ ITracker& sectionTracker = SectionTracker::acquire(m_trackerContext, TestCaseTracking::NameAndLocation(sectionInfo.name, sectionInfo.lineInfo));
+ if (!sectionTracker.isOpen())
+ return false;
+ m_activeSections.push_back(&sectionTracker);
+
+ m_lastAssertionInfo.lineInfo = sectionInfo.lineInfo;
+
+ m_reporter->sectionStarting(sectionInfo);
+
+ assertions = m_totals.assertions;
+
return true;
}
- void LegacyReporterAdapter::sectionEnded( SectionStats const& sectionStats ) {
- if( sectionStats.missingAssertions )
- m_legacyReporter->NoAssertionsInSection( sectionStats.sectionInfo.name );
- m_legacyReporter->EndSection( sectionStats.sectionInfo.name, sectionStats.assertions );
+ auto RunContext::acquireGeneratorTracker( SourceLineInfo const& lineInfo ) -> IGeneratorTracker& {
+ using namespace Generators;
+ GeneratorTracker& tracker = GeneratorTracker::acquire( m_trackerContext, TestCaseTracking::NameAndLocation( "generator", lineInfo ) );
+ assert( tracker.isOpen() );
+ m_lastAssertionInfo.lineInfo = lineInfo;
+ return tracker;
}
- void LegacyReporterAdapter::testCaseEnded( TestCaseStats const& testCaseStats ) {
- m_legacyReporter->EndTestCase
- ( testCaseStats.testInfo,
- testCaseStats.totals,
- testCaseStats.stdOut,
- testCaseStats.stdErr );
+
+ bool RunContext::testForMissingAssertions(Counts& assertions) {
+ if (assertions.total() != 0)
+ return false;
+ if (!m_config->warnAboutMissingAssertions())
+ return false;
+ if (m_trackerContext.currentTracker().hasChildren())
+ return false;
+ m_totals.assertions.failed++;
+ assertions.failed++;
+ return true;
}
- void LegacyReporterAdapter::testGroupEnded( TestGroupStats const& testGroupStats ) {
- if( testGroupStats.aborting )
- m_legacyReporter->Aborted();
- m_legacyReporter->EndGroup( testGroupStats.groupInfo.name, testGroupStats.totals );
+
+ void RunContext::sectionEnded(SectionEndInfo const & endInfo) {
+ Counts assertions = m_totals.assertions - endInfo.prevAssertions;
+ bool missingAssertions = testForMissingAssertions(assertions);
+
+ if (!m_activeSections.empty()) {
+ m_activeSections.back()->close();
+ m_activeSections.pop_back();
+ }
+
+ m_reporter->sectionEnded(SectionStats(endInfo.sectionInfo, assertions, endInfo.durationInSeconds, missingAssertions));
+ m_messages.clear();
+ m_messageScopes.clear();
+ }
+
+ void RunContext::sectionEndedEarly(SectionEndInfo const & endInfo) {
+ if (m_unfinishedSections.empty())
+ m_activeSections.back()->fail();
+ else
+ m_activeSections.back()->close();
+ m_activeSections.pop_back();
+
+ m_unfinishedSections.push_back(endInfo);
}
- void LegacyReporterAdapter::testRunEnded( TestRunStats const& testRunStats ) {
- m_legacyReporter->EndTesting( testRunStats.totals );
+
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ void RunContext::benchmarkPreparing(std::string const& name) {
+ m_reporter->benchmarkPreparing(name);
+ }
+ void RunContext::benchmarkStarting( BenchmarkInfo const& info ) {
+ m_reporter->benchmarkStarting( info );
}
- void LegacyReporterAdapter::skipTest( TestCaseInfo const& ) {
+ void RunContext::benchmarkEnded( BenchmarkStats<> const& stats ) {
+ m_reporter->benchmarkEnded( stats );
}
-}
+ void RunContext::benchmarkFailed(std::string const & error) {
+ m_reporter->benchmarkFailed(error);
+ }
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
-// #included from: catch_timer.hpp
+ void RunContext::pushScopedMessage(MessageInfo const & message) {
+ m_messages.push_back(message);
+ }
-#ifdef __clang__
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wc++11-long-long"
-#endif
+ void RunContext::popScopedMessage(MessageInfo const & message) {
+ m_messages.erase(std::remove(m_messages.begin(), m_messages.end(), message), m_messages.end());
+ }
-#ifdef CATCH_PLATFORM_WINDOWS
+ void RunContext::emplaceUnscopedMessage( MessageBuilder const& builder ) {
+ m_messageScopes.emplace_back( builder );
+ }
-#else
+ std::string RunContext::getCurrentTestName() const {
+ return m_activeTestCase
+ ? m_activeTestCase->getTestCaseInfo().name
+ : std::string();
+ }
-#include <sys/time.h>
+ const AssertionResult * RunContext::getLastResult() const {
+ return &(*m_lastResult);
+ }
+ void RunContext::exceptionEarlyReported() {
+ m_shouldReportUnexpected = false;
+ }
+
+ void RunContext::handleFatalErrorCondition( StringRef message ) {
+ // First notify reporter that bad things happened
+ m_reporter->fatalErrorEncountered(message);
+
+ // Don't rebuild the result -- the stringification itself can cause more fatal errors
+ // Instead, fake a result data.
+ AssertionResultData tempResult( ResultWas::FatalErrorCondition, { false } );
+ tempResult.message = static_cast<std::string>(message);
+ AssertionResult result(m_lastAssertionInfo, tempResult);
+
+ assertionEnded(result);
+
+ handleUnfinishedSections();
+
+ // Recreate section for test case (as we will lose the one that was in scope)
+ auto const& testCaseInfo = m_activeTestCase->getTestCaseInfo();
+ SectionInfo testCaseSection(testCaseInfo.lineInfo, testCaseInfo.name);
+
+ Counts assertions;
+ assertions.failed = 1;
+ SectionStats testCaseSectionStats(testCaseSection, assertions, 0, false);
+ m_reporter->sectionEnded(testCaseSectionStats);
+
+ auto const& testInfo = m_activeTestCase->getTestCaseInfo();
+
+ Totals deltaTotals;
+ deltaTotals.testCases.failed = 1;
+ deltaTotals.assertions.failed = 1;
+ m_reporter->testCaseEnded(TestCaseStats(testInfo,
+ deltaTotals,
+ std::string(),
+ std::string(),
+ false));
+ m_totals.testCases.failed++;
+ testGroupEnded(std::string(), m_totals, 1, 1);
+ m_reporter->testRunEnded(TestRunStats(m_runInfo, m_totals, false));
+ }
+
+ bool RunContext::lastAssertionPassed() {
+ return m_lastAssertionPassed;
+ }
+
+ void RunContext::assertionPassed() {
+ m_lastAssertionPassed = true;
+ ++m_totals.assertions.passed;
+ resetAssertionInfo();
+ m_messageScopes.clear();
+ }
+
+ bool RunContext::aborting() const {
+ return m_totals.assertions.failed >= static_cast<std::size_t>(m_config->abortAfter());
+ }
+
+ void RunContext::runCurrentTest(std::string & redirectedCout, std::string & redirectedCerr) {
+ auto const& testCaseInfo = m_activeTestCase->getTestCaseInfo();
+ SectionInfo testCaseSection(testCaseInfo.lineInfo, testCaseInfo.name);
+ m_reporter->sectionStarting(testCaseSection);
+ Counts prevAssertions = m_totals.assertions;
+ double duration = 0;
+ m_shouldReportUnexpected = true;
+ m_lastAssertionInfo = { "TEST_CASE"_sr, testCaseInfo.lineInfo, StringRef(), ResultDisposition::Normal };
+
+ seedRng(*m_config);
+
+ Timer timer;
+ CATCH_TRY {
+ if (m_reporter->getPreferences().shouldRedirectStdOut) {
+#if !defined(CATCH_CONFIG_EXPERIMENTAL_REDIRECT)
+ RedirectedStreams redirectedStreams(redirectedCout, redirectedCerr);
+
+ timer.start();
+ invokeActiveTestCase();
+#else
+ OutputRedirect r(redirectedCout, redirectedCerr);
+ timer.start();
+ invokeActiveTestCase();
#endif
+ } else {
+ timer.start();
+ invokeActiveTestCase();
+ }
+ duration = timer.getElapsedSeconds();
+ } CATCH_CATCH_ANON (TestFailureException&) {
+ // This just means the test was aborted due to failure
+ } CATCH_CATCH_ALL {
+ // Under CATCH_CONFIG_FAST_COMPILE, unexpected exceptions under REQUIRE assertions
+ // are reported without translation at the point of origin.
+ if( m_shouldReportUnexpected ) {
+ AssertionReaction dummyReaction;
+ handleUnexpectedInflightException( m_lastAssertionInfo, translateActiveException(), dummyReaction );
+ }
+ }
+ Counts assertions = m_totals.assertions - prevAssertions;
+ bool missingAssertions = testForMissingAssertions(assertions);
+
+ m_testCaseTracker->close();
+ handleUnfinishedSections();
+ m_messages.clear();
+ m_messageScopes.clear();
+
+ SectionStats testCaseSectionStats(testCaseSection, assertions, duration, missingAssertions);
+ m_reporter->sectionEnded(testCaseSectionStats);
+ }
+
+ void RunContext::invokeActiveTestCase() {
+ FatalConditionHandler fatalConditionHandler; // Handle signals
+ m_activeTestCase->invoke();
+ fatalConditionHandler.reset();
+ }
+
+ void RunContext::handleUnfinishedSections() {
+ // If sections ended prematurely due to an exception we stored their
+ // infos here so we can tear them down outside the unwind process.
+ for (auto it = m_unfinishedSections.rbegin(),
+ itEnd = m_unfinishedSections.rend();
+ it != itEnd;
+ ++it)
+ sectionEnded(*it);
+ m_unfinishedSections.clear();
+ }
+
+ void RunContext::handleExpr(
+ AssertionInfo const& info,
+ ITransientExpression const& expr,
+ AssertionReaction& reaction
+ ) {
+ m_reporter->assertionStarting( info );
+
+ bool negated = isFalseTest( info.resultDisposition );
+ bool result = expr.getResult() != negated;
+
+ if( result ) {
+ if (!m_includeSuccessfulResults) {
+ assertionPassed();
+ }
+ else {
+ reportExpr(info, ResultWas::Ok, &expr, negated);
+ }
+ }
+ else {
+ reportExpr(info, ResultWas::ExpressionFailed, &expr, negated );
+ populateReaction( reaction );
+ }
+ }
+ void RunContext::reportExpr(
+ AssertionInfo const &info,
+ ResultWas::OfType resultType,
+ ITransientExpression const *expr,
+ bool negated ) {
+
+ m_lastAssertionInfo = info;
+ AssertionResultData data( resultType, LazyExpression( negated ) );
+
+ AssertionResult assertionResult{ info, data };
+ assertionResult.m_resultData.lazyExpression.m_transientExpression = expr;
+
+ assertionEnded( assertionResult );
+ }
+
+ void RunContext::handleMessage(
+ AssertionInfo const& info,
+ ResultWas::OfType resultType,
+ StringRef const& message,
+ AssertionReaction& reaction
+ ) {
+ m_reporter->assertionStarting( info );
+
+ m_lastAssertionInfo = info;
+
+ AssertionResultData data( resultType, LazyExpression( false ) );
+ data.message = static_cast<std::string>(message);
+ AssertionResult assertionResult{ m_lastAssertionInfo, data };
+ assertionEnded( assertionResult );
+ if( !assertionResult.isOk() )
+ populateReaction( reaction );
+ }
+ void RunContext::handleUnexpectedExceptionNotThrown(
+ AssertionInfo const& info,
+ AssertionReaction& reaction
+ ) {
+ handleNonExpr(info, Catch::ResultWas::DidntThrowException, reaction);
+ }
+
+ void RunContext::handleUnexpectedInflightException(
+ AssertionInfo const& info,
+ std::string const& message,
+ AssertionReaction& reaction
+ ) {
+ m_lastAssertionInfo = info;
+
+ AssertionResultData data( ResultWas::ThrewException, LazyExpression( false ) );
+ data.message = message;
+ AssertionResult assertionResult{ info, data };
+ assertionEnded( assertionResult );
+ populateReaction( reaction );
+ }
+
+ void RunContext::populateReaction( AssertionReaction& reaction ) {
+ reaction.shouldDebugBreak = m_config->shouldDebugBreak();
+ reaction.shouldThrow = aborting() || (m_lastAssertionInfo.resultDisposition & ResultDisposition::Normal);
+ }
+
+ void RunContext::handleIncomplete(
+ AssertionInfo const& info
+ ) {
+ m_lastAssertionInfo = info;
+
+ AssertionResultData data( ResultWas::ThrewException, LazyExpression( false ) );
+ data.message = "Exception translation was disabled by CATCH_CONFIG_FAST_COMPILE";
+ AssertionResult assertionResult{ info, data };
+ assertionEnded( assertionResult );
+ }
+ void RunContext::handleNonExpr(
+ AssertionInfo const &info,
+ ResultWas::OfType resultType,
+ AssertionReaction &reaction
+ ) {
+ m_lastAssertionInfo = info;
+
+ AssertionResultData data( resultType, LazyExpression( false ) );
+ AssertionResult assertionResult{ info, data };
+ assertionEnded( assertionResult );
+
+ if( !assertionResult.isOk() )
+ populateReaction( reaction );
+ }
+
+ IResultCapture& getResultCapture() {
+ if (auto* capture = getCurrentContext().getResultCapture())
+ return *capture;
+ else
+ CATCH_INTERNAL_ERROR("No result capture instance");
+ }
+
+ void seedRng(IConfig const& config) {
+ if (config.rngSeed() != 0) {
+ std::srand(config.rngSeed());
+ rng().seed(config.rngSeed());
+ }
+ }
+
+ unsigned int rngSeed() {
+ return getCurrentContext().getConfig()->rngSeed();
+ }
+
+}
+// end catch_run_context.cpp
+// start catch_section.cpp
+
+namespace Catch {
+
+ Section::Section( SectionInfo const& info )
+ : m_info( info ),
+ m_sectionIncluded( getResultCapture().sectionStarted( m_info, m_assertions ) )
+ {
+ m_timer.start();
+ }
+
+ Section::~Section() {
+ if( m_sectionIncluded ) {
+ SectionEndInfo endInfo{ m_info, m_assertions, m_timer.getElapsedSeconds() };
+ if( uncaught_exceptions() )
+ getResultCapture().sectionEndedEarly( endInfo );
+ else
+ getResultCapture().sectionEnded( endInfo );
+ }
+ }
+
+ // This indicates whether the section should be executed or not
+ Section::operator bool() const {
+ return m_sectionIncluded;
+ }
+
+} // end namespace Catch
+// end catch_section.cpp
+// start catch_section_info.cpp
+
+namespace Catch {
+
+ SectionInfo::SectionInfo
+ ( SourceLineInfo const& _lineInfo,
+ std::string const& _name )
+ : name( _name ),
+ lineInfo( _lineInfo )
+ {}
+
+} // end namespace Catch
+// end catch_section_info.cpp
+// start catch_session.cpp
+
+// start catch_session.h
+
+#include <memory>
+
+namespace Catch {
+
+ class Session : NonCopyable {
+ public:
+
+ Session();
+ ~Session() override;
+
+ void showHelp() const;
+ void libIdentify();
+
+ int applyCommandLine( int argc, char const * const * argv );
+ #if defined(CATCH_CONFIG_WCHAR) && defined(_WIN32) && defined(UNICODE)
+ int applyCommandLine( int argc, wchar_t const * const * argv );
+ #endif
+
+ void useConfigData( ConfigData const& configData );
+
+ template<typename CharT>
+ int run(int argc, CharT const * const argv[]) {
+ if (m_startupExceptions)
+ return 1;
+ int returnCode = applyCommandLine(argc, argv);
+ if (returnCode == 0)
+ returnCode = run();
+ return returnCode;
+ }
+
+ int run();
+
+ clara::Parser const& cli() const;
+ void cli( clara::Parser const& newParser );
+ ConfigData& configData();
+ Config& config();
+ private:
+ int runInternal();
+
+ clara::Parser m_cli;
+ ConfigData m_configData;
+ std::shared_ptr<Config> m_config;
+ bool m_startupExceptions = false;
+ };
+
+} // end namespace Catch
+
+// end catch_session.h
+// start catch_version.h
+
+#include <iosfwd>
+
+namespace Catch {
+
+ // Versioning information
+ struct Version {
+ Version( Version const& ) = delete;
+ Version& operator=( Version const& ) = delete;
+ Version( unsigned int _majorVersion,
+ unsigned int _minorVersion,
+ unsigned int _patchNumber,
+ char const * const _branchName,
+ unsigned int _buildNumber );
+
+ unsigned int const majorVersion;
+ unsigned int const minorVersion;
+ unsigned int const patchNumber;
+
+ // buildNumber is only used if branchName is not null
+ char const * const branchName;
+ unsigned int const buildNumber;
+
+ friend std::ostream& operator << ( std::ostream& os, Version const& version );
+ };
+
+ Version const& libraryVersion();
+}
+
+// end catch_version.h
+#include <cstdlib>
+#include <iomanip>
+#include <set>
+#include <iterator>
namespace Catch {
namespace {
-#ifdef CATCH_PLATFORM_WINDOWS
- UInt64 getCurrentTicks() {
- static UInt64 hz=0, hzo=0;
- if (!hz) {
- QueryPerformanceFrequency( reinterpret_cast<LARGE_INTEGER*>( &hz ) );
- QueryPerformanceCounter( reinterpret_cast<LARGE_INTEGER*>( &hzo ) );
+ const int MaxExitCode = 255;
+
+ IStreamingReporterPtr createReporter(std::string const& reporterName, IConfigPtr const& config) {
+ auto reporter = Catch::getRegistryHub().getReporterRegistry().create(reporterName, config);
+ CATCH_ENFORCE(reporter, "No reporter registered with name: '" << reporterName << "'");
+
+ return reporter;
+ }
+
+ IStreamingReporterPtr makeReporter(std::shared_ptr<Config> const& config) {
+ if (Catch::getRegistryHub().getReporterRegistry().getListeners().empty()) {
+ return createReporter(config->getReporterName(), config);
}
- UInt64 t;
- QueryPerformanceCounter( reinterpret_cast<LARGE_INTEGER*>( &t ) );
- return ((t-hzo)*1000000)/hz;
+
+ // On older platforms, returning std::unique_ptr<ListeningReporter>
+ // when the return type is std::unique_ptr<IStreamingReporter>
+ // doesn't compile without a std::move call. However, this causes
+ // a warning on newer platforms. Thus, we have to work around
+ // it a bit and downcast the pointer manually.
+ auto ret = std::unique_ptr<IStreamingReporter>(new ListeningReporter);
+ auto& multi = static_cast<ListeningReporter&>(*ret);
+ auto const& listeners = Catch::getRegistryHub().getReporterRegistry().getListeners();
+ for (auto const& listener : listeners) {
+ multi.addListener(listener->create(Catch::ReporterConfig(config)));
+ }
+ multi.addReporter(createReporter(config->getReporterName(), config));
+ return ret;
}
-#else
- UInt64 getCurrentTicks() {
- timeval t;
- gettimeofday(&t,CATCH_NULL);
- return static_cast<UInt64>( t.tv_sec ) * 1000000ull + static_cast<UInt64>( t.tv_usec );
+
+ class TestGroup {
+ public:
+ explicit TestGroup(std::shared_ptr<Config> const& config)
+ : m_config{config}
+ , m_context{config, makeReporter(config)}
+ {
+ auto const& allTestCases = getAllTestCasesSorted(*m_config);
+ m_matches = m_config->testSpec().matchesByFilter(allTestCases, *m_config);
+ auto const& invalidArgs = m_config->testSpec().getInvalidArgs();
+
+ if (m_matches.empty() && invalidArgs.empty()) {
+ for (auto const& test : allTestCases)
+ if (!test.isHidden())
+ m_tests.emplace(&test);
+ } else {
+ for (auto const& match : m_matches)
+ m_tests.insert(match.tests.begin(), match.tests.end());
+ }
+ }
+
+ Totals execute() {
+ auto const& invalidArgs = m_config->testSpec().getInvalidArgs();
+ Totals totals;
+ m_context.testGroupStarting(m_config->name(), 1, 1);
+ for (auto const& testCase : m_tests) {
+ if (!m_context.aborting())
+ totals += m_context.runTest(*testCase);
+ else
+ m_context.reporter().skipTest(*testCase);
+ }
+
+ for (auto const& match : m_matches) {
+ if (match.tests.empty()) {
+ m_context.reporter().noMatchingTestCases(match.name);
+ totals.error = -1;
+ }
+ }
+
+ if (!invalidArgs.empty()) {
+ for (auto const& invalidArg: invalidArgs)
+ m_context.reporter().reportInvalidArguments(invalidArg);
+ }
+
+ m_context.testGroupEnded(m_config->name(), totals, 1, 1);
+ return totals;
+ }
+
+ private:
+ using Tests = std::set<TestCase const*>;
+
+ std::shared_ptr<Config> m_config;
+ RunContext m_context;
+ Tests m_tests;
+ TestSpec::Matches m_matches;
+ };
+
+ void applyFilenamesAsTags(Catch::IConfig const& config) {
+ auto& tests = const_cast<std::vector<TestCase>&>(getAllTestCasesSorted(config));
+ for (auto& testCase : tests) {
+ auto tags = testCase.tags;
+
+ std::string filename = testCase.lineInfo.file;
+ auto lastSlash = filename.find_last_of("\\/");
+ if (lastSlash != std::string::npos) {
+ filename.erase(0, lastSlash);
+ filename[0] = '#';
+ }
+
+ auto lastDot = filename.find_last_of('.');
+ if (lastDot != std::string::npos) {
+ filename.erase(lastDot);
+ }
+
+ tags.push_back(std::move(filename));
+ setTags(testCase, tags);
+ }
+ }
+
+ } // anon namespace
+
+ Session::Session() {
+ static bool alreadyInstantiated = false;
+ if( alreadyInstantiated ) {
+ CATCH_TRY { CATCH_INTERNAL_ERROR( "Only one instance of Catch::Session can ever be used" ); }
+ CATCH_CATCH_ALL { getMutableRegistryHub().registerStartupException(); }
+ }
+
+ // There cannot be exceptions at startup in no-exception mode.
+#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS)
+ const auto& exceptions = getRegistryHub().getStartupExceptionRegistry().getExceptions();
+ if ( !exceptions.empty() ) {
+ config();
+ getCurrentMutableContext().setConfig(m_config);
+
+ m_startupExceptions = true;
+ Colour colourGuard( Colour::Red );
+ Catch::cerr() << "Errors occurred during startup!" << '\n';
+ // iterate over all exceptions and notify user
+ for ( const auto& ex_ptr : exceptions ) {
+ try {
+ std::rethrow_exception(ex_ptr);
+ } catch ( std::exception const& ex ) {
+ Catch::cerr() << Column( ex.what() ).indent(2) << '\n';
+ }
+ }
}
#endif
+
+ alreadyInstantiated = true;
+ m_cli = makeCommandLineParser( m_configData );
+ }
+ Session::~Session() {
+ Catch::cleanUp();
}
- void Timer::start() {
- m_ticks = getCurrentTicks();
+ void Session::showHelp() const {
+ Catch::cout()
+ << "\nCatch v" << libraryVersion() << "\n"
+ << m_cli << std::endl
+ << "For more detailed usage please see the project docs\n" << std::endl;
}
- unsigned int Timer::getElapsedMicroseconds() const {
- return static_cast<unsigned int>(getCurrentTicks() - m_ticks);
+ void Session::libIdentify() {
+ Catch::cout()
+ << std::left << std::setw(16) << "description: " << "A Catch2 test executable\n"
+ << std::left << std::setw(16) << "category: " << "testframework\n"
+ << std::left << std::setw(16) << "framework: " << "Catch Test\n"
+ << std::left << std::setw(16) << "version: " << libraryVersion() << std::endl;
}
- unsigned int Timer::getElapsedMilliseconds() const {
- return static_cast<unsigned int>(getElapsedMicroseconds()/1000);
+
+ int Session::applyCommandLine( int argc, char const * const * argv ) {
+ if( m_startupExceptions )
+ return 1;
+
+ auto result = m_cli.parse( clara::Args( argc, argv ) );
+ if( !result ) {
+ config();
+ getCurrentMutableContext().setConfig(m_config);
+ Catch::cerr()
+ << Colour( Colour::Red )
+ << "\nError(s) in input:\n"
+ << Column( result.errorMessage() ).indent( 2 )
+ << "\n\n";
+ Catch::cerr() << "Run with -? for usage\n" << std::endl;
+ return MaxExitCode;
+ }
+
+ if( m_configData.showHelp )
+ showHelp();
+ if( m_configData.libIdentify )
+ libIdentify();
+ m_config.reset();
+ return 0;
}
- double Timer::getElapsedSeconds() const {
- return getElapsedMicroseconds()/1000000.0;
+
+#if defined(CATCH_CONFIG_WCHAR) && defined(_WIN32) && defined(UNICODE)
+ int Session::applyCommandLine( int argc, wchar_t const * const * argv ) {
+
+ char **utf8Argv = new char *[ argc ];
+
+ for ( int i = 0; i < argc; ++i ) {
+ int bufSize = WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, NULL, 0, NULL, NULL );
+
+ utf8Argv[ i ] = new char[ bufSize ];
+
+ WideCharToMultiByte( CP_UTF8, 0, argv[i], -1, utf8Argv[i], bufSize, NULL, NULL );
+ }
+
+ int returnCode = applyCommandLine( argc, utf8Argv );
+
+ for ( int i = 0; i < argc; ++i )
+ delete [] utf8Argv[ i ];
+
+ delete [] utf8Argv;
+
+ return returnCode;
+ }
+#endif
+
+ void Session::useConfigData( ConfigData const& configData ) {
+ m_configData = configData;
+ m_config.reset();
+ }
+
+ int Session::run() {
+ if( ( m_configData.waitForKeypress & WaitForKeypress::BeforeStart ) != 0 ) {
+ Catch::cout() << "...waiting for enter/ return before starting" << std::endl;
+ static_cast<void>(std::getchar());
+ }
+ int exitCode = runInternal();
+ if( ( m_configData.waitForKeypress & WaitForKeypress::BeforeExit ) != 0 ) {
+ Catch::cout() << "...waiting for enter/ return before exiting, with code: " << exitCode << std::endl;
+ static_cast<void>(std::getchar());
+ }
+ return exitCode;
+ }
+
+ clara::Parser const& Session::cli() const {
+ return m_cli;
+ }
+ void Session::cli( clara::Parser const& newParser ) {
+ m_cli = newParser;
+ }
+ ConfigData& Session::configData() {
+ return m_configData;
+ }
+ Config& Session::config() {
+ if( !m_config )
+ m_config = std::make_shared<Config>( m_configData );
+ return *m_config;
+ }
+
+ int Session::runInternal() {
+ if( m_startupExceptions )
+ return 1;
+
+ if (m_configData.showHelp || m_configData.libIdentify) {
+ return 0;
+ }
+
+ CATCH_TRY {
+ config(); // Force config to be constructed
+
+ seedRng( *m_config );
+
+ if( m_configData.filenamesAsTags )
+ applyFilenamesAsTags( *m_config );
+
+ // Handle list request
+ if( Option<std::size_t> listed = list( m_config ) )
+ return static_cast<int>( *listed );
+
+ TestGroup tests { m_config };
+ auto const totals = tests.execute();
+
+ if( m_config->warnAboutNoTests() && totals.error == -1 )
+ return 2;
+
+ // Note that on unices only the lower 8 bits are usually used, clamping
+ // the return value to 255 prevents false negative when some multiple
+ // of 256 tests has failed
+ return (std::min) (MaxExitCode, (std::max) (totals.error, static_cast<int>(totals.assertions.failed)));
+ }
+#if !defined(CATCH_CONFIG_DISABLE_EXCEPTIONS)
+ catch( std::exception& ex ) {
+ Catch::cerr() << ex.what() << std::endl;
+ return MaxExitCode;
+ }
+#endif
+ }
+
+} // end namespace Catch
+// end catch_session.cpp
+// start catch_singletons.cpp
+
+#include <vector>
+
+namespace Catch {
+
+ namespace {
+ static auto getSingletons() -> std::vector<ISingleton*>*& {
+ static std::vector<ISingleton*>* g_singletons = nullptr;
+ if( !g_singletons )
+ g_singletons = new std::vector<ISingleton*>();
+ return g_singletons;
+ }
+ }
+
+ ISingleton::~ISingleton() {}
+
+ void addSingleton(ISingleton* singleton ) {
+ getSingletons()->push_back( singleton );
+ }
+ void cleanupSingletons() {
+ auto& singletons = getSingletons();
+ for( auto singleton : *singletons )
+ delete singleton;
+ delete singletons;
+ singletons = nullptr;
}
} // namespace Catch
+// end catch_singletons.cpp
+// start catch_startup_exception_registry.cpp
-#ifdef __clang__
-#pragma clang diagnostic pop
+namespace Catch {
+void StartupExceptionRegistry::add( std::exception_ptr const& exception ) noexcept {
+ CATCH_TRY {
+ m_exceptions.push_back(exception);
+ } CATCH_CATCH_ALL {
+ // If we run out of memory during start-up there's really not a lot more we can do about it
+ std::terminate();
+ }
+ }
+
+ std::vector<std::exception_ptr> const& StartupExceptionRegistry::getExceptions() const noexcept {
+ return m_exceptions;
+ }
+
+} // end namespace Catch
+// end catch_startup_exception_registry.cpp
+// start catch_stream.cpp
+
+#include <cstdio>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <vector>
+#include <memory>
+
+namespace Catch {
+
+ Catch::IStream::~IStream() = default;
+
+ namespace Detail { namespace {
+ template<typename WriterF, std::size_t bufferSize=256>
+ class StreamBufImpl : public std::streambuf {
+ char data[bufferSize];
+ WriterF m_writer;
+
+ public:
+ StreamBufImpl() {
+ setp( data, data + sizeof(data) );
+ }
+
+ ~StreamBufImpl() noexcept {
+ StreamBufImpl::sync();
+ }
+
+ private:
+ int overflow( int c ) override {
+ sync();
+
+ if( c != EOF ) {
+ if( pbase() == epptr() )
+ m_writer( std::string( 1, static_cast<char>( c ) ) );
+ else
+ sputc( static_cast<char>( c ) );
+ }
+ return 0;
+ }
+
+ int sync() override {
+ if( pbase() != pptr() ) {
+ m_writer( std::string( pbase(), static_cast<std::string::size_type>( pptr() - pbase() ) ) );
+ setp( pbase(), epptr() );
+ }
+ return 0;
+ }
+ };
+
+ ///////////////////////////////////////////////////////////////////////////
+
+ struct OutputDebugWriter {
+
+ void operator()( std::string const&str ) {
+ writeToDebugConsole( str );
+ }
+ };
+
+ ///////////////////////////////////////////////////////////////////////////
+
+ class FileStream : public IStream {
+ mutable std::ofstream m_ofs;
+ public:
+ FileStream( StringRef filename ) {
+ m_ofs.open( filename.c_str() );
+ CATCH_ENFORCE( !m_ofs.fail(), "Unable to open file: '" << filename << "'" );
+ }
+ ~FileStream() override = default;
+ public: // IStream
+ std::ostream& stream() const override {
+ return m_ofs;
+ }
+ };
+
+ ///////////////////////////////////////////////////////////////////////////
+
+ class CoutStream : public IStream {
+ mutable std::ostream m_os;
+ public:
+ // Store the streambuf from cout up-front because
+ // cout may get redirected when running tests
+ CoutStream() : m_os( Catch::cout().rdbuf() ) {}
+ ~CoutStream() override = default;
+
+ public: // IStream
+ std::ostream& stream() const override { return m_os; }
+ };
+
+ ///////////////////////////////////////////////////////////////////////////
+
+ class DebugOutStream : public IStream {
+ std::unique_ptr<StreamBufImpl<OutputDebugWriter>> m_streamBuf;
+ mutable std::ostream m_os;
+ public:
+ DebugOutStream()
+ : m_streamBuf( new StreamBufImpl<OutputDebugWriter>() ),
+ m_os( m_streamBuf.get() )
+ {}
+
+ ~DebugOutStream() override = default;
+
+ public: // IStream
+ std::ostream& stream() const override { return m_os; }
+ };
+
+ }} // namespace anon::detail
+
+ ///////////////////////////////////////////////////////////////////////////
+
+ auto makeStream( StringRef const &filename ) -> IStream const* {
+ if( filename.empty() )
+ return new Detail::CoutStream();
+ else if( filename[0] == '%' ) {
+ if( filename == "%debug" )
+ return new Detail::DebugOutStream();
+ else
+ CATCH_ERROR( "Unrecognised stream: '" << filename << "'" );
+ }
+ else
+ return new Detail::FileStream( filename );
+ }
+
+ // This class encapsulates the idea of a pool of ostringstreams that can be reused.
+ struct StringStreams {
+ std::vector<std::unique_ptr<std::ostringstream>> m_streams;
+ std::vector<std::size_t> m_unused;
+ std::ostringstream m_referenceStream; // Used for copy state/ flags from
+
+ auto add() -> std::size_t {
+ if( m_unused.empty() ) {
+ m_streams.push_back( std::unique_ptr<std::ostringstream>( new std::ostringstream ) );
+ return m_streams.size()-1;
+ }
+ else {
+ auto index = m_unused.back();
+ m_unused.pop_back();
+ return index;
+ }
+ }
+
+ void release( std::size_t index ) {
+ m_streams[index]->copyfmt( m_referenceStream ); // Restore initial flags and other state
+ m_unused.push_back(index);
+ }
+ };
+
+ ReusableStringStream::ReusableStringStream()
+ : m_index( Singleton<StringStreams>::getMutable().add() ),
+ m_oss( Singleton<StringStreams>::getMutable().m_streams[m_index].get() )
+ {}
+
+ ReusableStringStream::~ReusableStringStream() {
+ static_cast<std::ostringstream*>( m_oss )->str("");
+ m_oss->clear();
+ Singleton<StringStreams>::getMutable().release( m_index );
+ }
+
+ auto ReusableStringStream::str() const -> std::string {
+ return static_cast<std::ostringstream*>( m_oss )->str();
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+
+#ifndef CATCH_CONFIG_NOSTDOUT // If you #define this you must implement these functions
+ std::ostream& cout() { return std::cout; }
+ std::ostream& cerr() { return std::cerr; }
+ std::ostream& clog() { return std::clog; }
#endif
-// #included from: catch_common.hpp
-#define TWOBLUECUBES_CATCH_COMMON_HPP_INCLUDED
+}
+// end catch_stream.cpp
+// start catch_string_manip.cpp
+#include <algorithm>
+#include <ostream>
#include <cstring>
#include <cctype>
+#include <vector>
namespace Catch {
+ namespace {
+ char toLowerCh(char c) {
+ return static_cast<char>( std::tolower( c ) );
+ }
+ }
+
bool startsWith( std::string const& s, std::string const& prefix ) {
return s.size() >= prefix.size() && std::equal(prefix.begin(), prefix.end(), s.begin());
}
@@ -8555,9 +13549,6 @@ namespace Catch {
bool contains( std::string const& s, std::string const& infix ) {
return s.find( infix ) != std::string::npos;
}
- char toLowerCh(char c) {
- return static_cast<char>( std::tolower( c ) );
- }
void toLowerInPlace( std::string& s ) {
std::transform( s.begin(), s.end(), s.begin(), toLowerCh );
}
@@ -8574,6 +13565,18 @@ namespace Catch {
return start != std::string::npos ? str.substr( start, 1+end-start ) : std::string();
}
+ StringRef trim(StringRef ref) {
+ const auto is_ws = [](char c) {
+ return c == ' ' || c == '\t' || c == '\n' || c == '\r';
+ };
+ size_t real_begin = 0;
+ while (real_begin < ref.size() && is_ws(ref[real_begin])) { ++real_begin; }
+ size_t real_end = ref.size();
+ while (real_end > real_begin && is_ws(ref[real_end - 1])) { --real_end; }
+
+ return ref.substr(real_begin, real_end - real_begin);
+ }
+
bool replaceInPlace( std::string& str, std::string const& replaceThis, std::string const& withThis ) {
bool replaced = false;
std::size_t i = str.find( replaceThis );
@@ -8588,6 +13591,21 @@ namespace Catch {
return replaced;
}
+ std::vector<StringRef> splitStringRef( StringRef str, char delimiter ) {
+ std::vector<StringRef> subStrings;
+ std::size_t start = 0;
+ for(std::size_t pos = 0; pos < str.size(); ++pos ) {
+ if( str[pos] == delimiter ) {
+ if( pos - start > 1 )
+ subStrings.push_back( str.substr( start, pos-start ) );
+ start = pos+1;
+ }
+ }
+ if( start < str.size() )
+ subStrings.push_back( str.substr( start, str.size()-start ) );
+ return subStrings;
+ }
+
pluralise::pluralise( std::size_t count, std::string const& label )
: m_count( count ),
m_label( label )
@@ -8600,203 +13618,1035 @@ namespace Catch {
return os;
}
- SourceLineInfo::SourceLineInfo() : file(""), line( 0 ){}
- SourceLineInfo::SourceLineInfo( char const* _file, std::size_t _line )
- : file( _file ),
- line( _line )
+}
+// end catch_string_manip.cpp
+// start catch_stringref.cpp
+
+#if defined(__clang__)
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wexit-time-destructors"
+#endif
+
+#include <ostream>
+#include <cstring>
+#include <cstdint>
+
+namespace Catch {
+ StringRef::StringRef( char const* rawChars ) noexcept
+ : StringRef( rawChars, static_cast<StringRef::size_type>(std::strlen(rawChars) ) )
{}
- bool SourceLineInfo::empty() const {
- return file[0] == '\0';
+
+ void StringRef::swap( StringRef& other ) noexcept {
+ std::swap( m_start, other.m_start );
+ std::swap( m_size, other.m_size );
+ std::swap( m_data, other.m_data );
}
- bool SourceLineInfo::operator == ( SourceLineInfo const& other ) const {
- return line == other.line && (file == other.file || std::strcmp(file, other.file) == 0);
+
+ auto StringRef::c_str() const -> char const* {
+ if( !isSubstring() )
+ return m_start;
+
+ const_cast<StringRef *>( this )->takeOwnership();
+ return m_data;
}
- bool SourceLineInfo::operator < ( SourceLineInfo const& other ) const {
- return line < other.line || ( line == other.line && (std::strcmp(file, other.file) < 0));
+ auto StringRef::currentData() const noexcept -> char const* {
+ return m_start;
}
- void seedRng( IConfig const& config ) {
- if( config.rngSeed() != 0 )
- std::srand( config.rngSeed() );
+ auto StringRef::isOwned() const noexcept -> bool {
+ return m_data != nullptr;
}
- unsigned int rngSeed() {
- return getCurrentContext().getConfig()->rngSeed();
+ auto StringRef::isSubstring() const noexcept -> bool {
+ return m_start[m_size] != '\0';
}
- std::ostream& operator << ( std::ostream& os, SourceLineInfo const& info ) {
-#ifndef __GNUG__
- os << info.file << '(' << info.line << ')';
-#else
- os << info.file << ':' << info.line;
-#endif
- return os;
+ void StringRef::takeOwnership() {
+ if( !isOwned() ) {
+ m_data = new char[m_size+1];
+ memcpy( m_data, m_start, m_size );
+ m_data[m_size] = '\0';
+ }
+ }
+ auto StringRef::substr( size_type start, size_type size ) const noexcept -> StringRef {
+ if( start < m_size )
+ return StringRef( m_start+start, size );
+ else
+ return StringRef();
+ }
+ auto StringRef::operator == ( StringRef const& other ) const noexcept -> bool {
+ return
+ size() == other.size() &&
+ (std::strncmp( m_start, other.m_start, size() ) == 0);
+ }
+ auto StringRef::operator != ( StringRef const& other ) const noexcept -> bool {
+ return !operator==( other );
+ }
+
+ auto operator << ( std::ostream& os, StringRef const& str ) -> std::ostream& {
+ return os.write(str.currentData(), str.size());
}
- void throwLogicError( std::string const& message, SourceLineInfo const& locationInfo ) {
- std::ostringstream oss;
- oss << locationInfo << ": Internal Catch error: '" << message << '\'';
- if( alwaysTrue() )
- throw std::logic_error( oss.str() );
+ auto operator+=( std::string& lhs, StringRef const& rhs ) -> std::string& {
+ lhs.append(rhs.currentData(), rhs.size());
+ return lhs;
}
+
+} // namespace Catch
+
+#if defined(__clang__)
+# pragma clang diagnostic pop
+#endif
+// end catch_stringref.cpp
+// start catch_tag_alias.cpp
+
+namespace Catch {
+ TagAlias::TagAlias(std::string const & _tag, SourceLineInfo _lineInfo): tag(_tag), lineInfo(_lineInfo) {}
}
+// end catch_tag_alias.cpp
+// start catch_tag_alias_autoregistrar.cpp
-// #included from: catch_section.hpp
-#define TWOBLUECUBES_CATCH_SECTION_HPP_INCLUDED
+namespace Catch {
+
+ RegistrarForTagAliases::RegistrarForTagAliases(char const* alias, char const* tag, SourceLineInfo const& lineInfo) {
+ CATCH_TRY {
+ getMutableRegistryHub().registerTagAlias(alias, tag, lineInfo);
+ } CATCH_CATCH_ALL {
+ // Do not throw when constructing global objects, instead register the exception to be processed later
+ getMutableRegistryHub().registerStartupException();
+ }
+ }
+
+}
+// end catch_tag_alias_autoregistrar.cpp
+// start catch_tag_alias_registry.cpp
+
+#include <sstream>
namespace Catch {
- SectionInfo::SectionInfo
- ( SourceLineInfo const& _lineInfo,
- std::string const& _name,
- std::string const& _description )
+ TagAliasRegistry::~TagAliasRegistry() {}
+
+ TagAlias const* TagAliasRegistry::find( std::string const& alias ) const {
+ auto it = m_registry.find( alias );
+ if( it != m_registry.end() )
+ return &(it->second);
+ else
+ return nullptr;
+ }
+
+ std::string TagAliasRegistry::expandAliases( std::string const& unexpandedTestSpec ) const {
+ std::string expandedTestSpec = unexpandedTestSpec;
+ for( auto const& registryKvp : m_registry ) {
+ std::size_t pos = expandedTestSpec.find( registryKvp.first );
+ if( pos != std::string::npos ) {
+ expandedTestSpec = expandedTestSpec.substr( 0, pos ) +
+ registryKvp.second.tag +
+ expandedTestSpec.substr( pos + registryKvp.first.size() );
+ }
+ }
+ return expandedTestSpec;
+ }
+
+ void TagAliasRegistry::add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) {
+ CATCH_ENFORCE( startsWith(alias, "[@") && endsWith(alias, ']'),
+ "error: tag alias, '" << alias << "' is not of the form [@alias name].\n" << lineInfo );
+
+ CATCH_ENFORCE( m_registry.insert(std::make_pair(alias, TagAlias(tag, lineInfo))).second,
+ "error: tag alias, '" << alias << "' already registered.\n"
+ << "\tFirst seen at: " << find(alias)->lineInfo << "\n"
+ << "\tRedefined at: " << lineInfo );
+ }
+
+ ITagAliasRegistry::~ITagAliasRegistry() {}
+
+ ITagAliasRegistry const& ITagAliasRegistry::get() {
+ return getRegistryHub().getTagAliasRegistry();
+ }
+
+} // end namespace Catch
+// end catch_tag_alias_registry.cpp
+// start catch_test_case_info.cpp
+
+#include <cctype>
+#include <exception>
+#include <algorithm>
+#include <sstream>
+
+namespace Catch {
+
+ namespace {
+ TestCaseInfo::SpecialProperties parseSpecialTag( std::string const& tag ) {
+ if( startsWith( tag, '.' ) ||
+ tag == "!hide" )
+ return TestCaseInfo::IsHidden;
+ else if( tag == "!throws" )
+ return TestCaseInfo::Throws;
+ else if( tag == "!shouldfail" )
+ return TestCaseInfo::ShouldFail;
+ else if( tag == "!mayfail" )
+ return TestCaseInfo::MayFail;
+ else if( tag == "!nonportable" )
+ return TestCaseInfo::NonPortable;
+ else if( tag == "!benchmark" )
+ return static_cast<TestCaseInfo::SpecialProperties>( TestCaseInfo::Benchmark | TestCaseInfo::IsHidden );
+ else
+ return TestCaseInfo::None;
+ }
+ bool isReservedTag( std::string const& tag ) {
+ return parseSpecialTag( tag ) == TestCaseInfo::None && tag.size() > 0 && !std::isalnum( static_cast<unsigned char>(tag[0]) );
+ }
+ void enforceNotReservedTag( std::string const& tag, SourceLineInfo const& _lineInfo ) {
+ CATCH_ENFORCE( !isReservedTag(tag),
+ "Tag name: [" << tag << "] is not allowed.\n"
+ << "Tag names starting with non alphanumeric characters are reserved\n"
+ << _lineInfo );
+ }
+ }
+
+ TestCase makeTestCase( ITestInvoker* _testCase,
+ std::string const& _className,
+ NameAndTags const& nameAndTags,
+ SourceLineInfo const& _lineInfo )
+ {
+ bool isHidden = false;
+
+ // Parse out tags
+ std::vector<std::string> tags;
+ std::string desc, tag;
+ bool inTag = false;
+ for (char c : nameAndTags.tags) {
+ if( !inTag ) {
+ if( c == '[' )
+ inTag = true;
+ else
+ desc += c;
+ }
+ else {
+ if( c == ']' ) {
+ TestCaseInfo::SpecialProperties prop = parseSpecialTag( tag );
+ if( ( prop & TestCaseInfo::IsHidden ) != 0 )
+ isHidden = true;
+ else if( prop == TestCaseInfo::None )
+ enforceNotReservedTag( tag, _lineInfo );
+
+ // Merged hide tags like `[.approvals]` should be added as
+ // `[.][approvals]`. The `[.]` is added at later point, so
+ // we only strip the prefix
+ if (startsWith(tag, '.') && tag.size() > 1) {
+ tag.erase(0, 1);
+ }
+ tags.push_back( tag );
+ tag.clear();
+ inTag = false;
+ }
+ else
+ tag += c;
+ }
+ }
+ if( isHidden ) {
+ tags.push_back( "." );
+ }
+
+ TestCaseInfo info( static_cast<std::string>(nameAndTags.name), _className, desc, tags, _lineInfo );
+ return TestCase( _testCase, std::move(info) );
+ }
+
+ void setTags( TestCaseInfo& testCaseInfo, std::vector<std::string> tags ) {
+ std::sort(begin(tags), end(tags));
+ tags.erase(std::unique(begin(tags), end(tags)), end(tags));
+ testCaseInfo.lcaseTags.clear();
+
+ for( auto const& tag : tags ) {
+ std::string lcaseTag = toLower( tag );
+ testCaseInfo.properties = static_cast<TestCaseInfo::SpecialProperties>( testCaseInfo.properties | parseSpecialTag( lcaseTag ) );
+ testCaseInfo.lcaseTags.push_back( lcaseTag );
+ }
+ testCaseInfo.tags = std::move(tags);
+ }
+
+ TestCaseInfo::TestCaseInfo( std::string const& _name,
+ std::string const& _className,
+ std::string const& _description,
+ std::vector<std::string> const& _tags,
+ SourceLineInfo const& _lineInfo )
: name( _name ),
+ className( _className ),
description( _description ),
- lineInfo( _lineInfo )
- {}
+ lineInfo( _lineInfo ),
+ properties( None )
+ {
+ setTags( *this, _tags );
+ }
- Section::Section( SectionInfo const& info )
- : m_info( info ),
- m_sectionIncluded( getResultCapture().sectionStarted( m_info, m_assertions ) )
+ bool TestCaseInfo::isHidden() const {
+ return ( properties & IsHidden ) != 0;
+ }
+ bool TestCaseInfo::throws() const {
+ return ( properties & Throws ) != 0;
+ }
+ bool TestCaseInfo::okToFail() const {
+ return ( properties & (ShouldFail | MayFail ) ) != 0;
+ }
+ bool TestCaseInfo::expectedToFail() const {
+ return ( properties & (ShouldFail ) ) != 0;
+ }
+
+ std::string TestCaseInfo::tagsAsString() const {
+ std::string ret;
+ // '[' and ']' per tag
+ std::size_t full_size = 2 * tags.size();
+ for (const auto& tag : tags) {
+ full_size += tag.size();
+ }
+ ret.reserve(full_size);
+ for (const auto& tag : tags) {
+ ret.push_back('[');
+ ret.append(tag);
+ ret.push_back(']');
+ }
+
+ return ret;
+ }
+
+ TestCase::TestCase( ITestInvoker* testCase, TestCaseInfo&& info ) : TestCaseInfo( std::move(info) ), test( testCase ) {}
+
+ TestCase TestCase::withName( std::string const& _newName ) const {
+ TestCase other( *this );
+ other.name = _newName;
+ return other;
+ }
+
+ void TestCase::invoke() const {
+ test->invoke();
+ }
+
+ bool TestCase::operator == ( TestCase const& other ) const {
+ return test.get() == other.test.get() &&
+ name == other.name &&
+ className == other.className;
+ }
+
+ bool TestCase::operator < ( TestCase const& other ) const {
+ return name < other.name;
+ }
+
+ TestCaseInfo const& TestCase::getTestCaseInfo() const
{
- m_timer.start();
+ return *this;
}
- Section::~Section() {
- if( m_sectionIncluded ) {
- SectionEndInfo endInfo( m_info, m_assertions, m_timer.getElapsedSeconds() );
- if( std::uncaught_exception() )
- getResultCapture().sectionEndedEarly( endInfo );
- else
- getResultCapture().sectionEnded( endInfo );
+} // end namespace Catch
+// end catch_test_case_info.cpp
+// start catch_test_case_registry_impl.cpp
+
+#include <sstream>
+
+namespace Catch {
+
+ std::vector<TestCase> sortTests( IConfig const& config, std::vector<TestCase> const& unsortedTestCases ) {
+
+ std::vector<TestCase> sorted = unsortedTestCases;
+
+ switch( config.runOrder() ) {
+ case RunTests::InLexicographicalOrder:
+ std::sort( sorted.begin(), sorted.end() );
+ break;
+ case RunTests::InRandomOrder:
+ seedRng( config );
+ std::shuffle( sorted.begin(), sorted.end(), rng() );
+ break;
+ case RunTests::InDeclarationOrder:
+ // already in declaration order
+ break;
}
+ return sorted;
}
- // This indicates whether the section should be executed or not
- Section::operator bool() const {
- return m_sectionIncluded;
+ bool isThrowSafe( TestCase const& testCase, IConfig const& config ) {
+ return !testCase.throws() || config.allowThrows();
+ }
+
+ bool matchTest( TestCase const& testCase, TestSpec const& testSpec, IConfig const& config ) {
+ return testSpec.matches( testCase ) && isThrowSafe( testCase, config );
+ }
+
+ void enforceNoDuplicateTestCases( std::vector<TestCase> const& functions ) {
+ std::set<TestCase> seenFunctions;
+ for( auto const& function : functions ) {
+ auto prev = seenFunctions.insert( function );
+ CATCH_ENFORCE( prev.second,
+ "error: TEST_CASE( \"" << function.name << "\" ) already defined.\n"
+ << "\tFirst seen at " << prev.first->getTestCaseInfo().lineInfo << "\n"
+ << "\tRedefined at " << function.getTestCaseInfo().lineInfo );
+ }
+ }
+
+ std::vector<TestCase> filterTests( std::vector<TestCase> const& testCases, TestSpec const& testSpec, IConfig const& config ) {
+ std::vector<TestCase> filtered;
+ filtered.reserve( testCases.size() );
+ for (auto const& testCase : testCases) {
+ if ((!testSpec.hasFilters() && !testCase.isHidden()) ||
+ (testSpec.hasFilters() && matchTest(testCase, testSpec, config))) {
+ filtered.push_back(testCase);
+ }
+ }
+ return filtered;
+ }
+ std::vector<TestCase> const& getAllTestCasesSorted( IConfig const& config ) {
+ return getRegistryHub().getTestCaseRegistry().getAllTestsSorted( config );
+ }
+
+ void TestRegistry::registerTest( TestCase const& testCase ) {
+ std::string name = testCase.getTestCaseInfo().name;
+ if( name.empty() ) {
+ ReusableStringStream rss;
+ rss << "Anonymous test case " << ++m_unnamedCount;
+ return registerTest( testCase.withName( rss.str() ) );
+ }
+ m_functions.push_back( testCase );
+ }
+
+ std::vector<TestCase> const& TestRegistry::getAllTests() const {
+ return m_functions;
+ }
+ std::vector<TestCase> const& TestRegistry::getAllTestsSorted( IConfig const& config ) const {
+ if( m_sortedFunctions.empty() )
+ enforceNoDuplicateTestCases( m_functions );
+
+ if( m_currentSortOrder != config.runOrder() || m_sortedFunctions.empty() ) {
+ m_sortedFunctions = sortTests( config, m_functions );
+ m_currentSortOrder = config.runOrder();
+ }
+ return m_sortedFunctions;
+ }
+
+ ///////////////////////////////////////////////////////////////////////////
+ TestInvokerAsFunction::TestInvokerAsFunction( void(*testAsFunction)() ) noexcept : m_testAsFunction( testAsFunction ) {}
+
+ void TestInvokerAsFunction::invoke() const {
+ m_testAsFunction();
+ }
+
+ std::string extractClassName( StringRef const& classOrQualifiedMethodName ) {
+ std::string className(classOrQualifiedMethodName);
+ if( startsWith( className, '&' ) )
+ {
+ std::size_t lastColons = className.rfind( "::" );
+ std::size_t penultimateColons = className.rfind( "::", lastColons-1 );
+ if( penultimateColons == std::string::npos )
+ penultimateColons = 1;
+ className = className.substr( penultimateColons, lastColons-penultimateColons );
+ }
+ return className;
}
} // end namespace Catch
+// end catch_test_case_registry_impl.cpp
+// start catch_test_case_tracker.cpp
-// #included from: catch_debugger.hpp
-#define TWOBLUECUBES_CATCH_DEBUGGER_HPP_INCLUDED
+#include <algorithm>
+#include <cassert>
+#include <stdexcept>
+#include <memory>
+#include <sstream>
-#ifdef CATCH_PLATFORM_MAC
+#if defined(__clang__)
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wexit-time-destructors"
+#endif
- #include <assert.h>
- #include <stdbool.h>
- #include <sys/types.h>
- #include <unistd.h>
- #include <sys/sysctl.h>
+namespace Catch {
+namespace TestCaseTracking {
- namespace Catch{
+ NameAndLocation::NameAndLocation( std::string const& _name, SourceLineInfo const& _location )
+ : name( _name ),
+ location( _location )
+ {}
- // The following function is taken directly from the following technical note:
- // http://developer.apple.com/library/mac/#qa/qa2004/qa1361.html
+ ITracker::~ITracker() = default;
- // Returns true if the current process is being debugged (either
- // running under the debugger or has a debugger attached post facto).
- bool isDebuggerActive(){
+ ITracker& TrackerContext::startRun() {
+ m_rootTracker = std::make_shared<SectionTracker>( NameAndLocation( "{root}", CATCH_INTERNAL_LINEINFO ), *this, nullptr );
+ m_currentTracker = nullptr;
+ m_runState = Executing;
+ return *m_rootTracker;
+ }
- int mib[4];
- struct kinfo_proc info;
- size_t size;
+ void TrackerContext::endRun() {
+ m_rootTracker.reset();
+ m_currentTracker = nullptr;
+ m_runState = NotStarted;
+ }
- // Initialize the flags so that, if sysctl fails for some bizarre
- // reason, we get a predictable result.
+ void TrackerContext::startCycle() {
+ m_currentTracker = m_rootTracker.get();
+ m_runState = Executing;
+ }
+ void TrackerContext::completeCycle() {
+ m_runState = CompletedCycle;
+ }
- info.kp_proc.p_flag = 0;
+ bool TrackerContext::completedCycle() const {
+ return m_runState == CompletedCycle;
+ }
+ ITracker& TrackerContext::currentTracker() {
+ return *m_currentTracker;
+ }
+ void TrackerContext::setCurrentTracker( ITracker* tracker ) {
+ m_currentTracker = tracker;
+ }
- // Initialize mib, which tells sysctl the info we want, in this case
- // we're looking for information about a specific process ID.
+ TrackerBase::TrackerBase( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent )
+ : m_nameAndLocation( nameAndLocation ),
+ m_ctx( ctx ),
+ m_parent( parent )
+ {}
- mib[0] = CTL_KERN;
- mib[1] = KERN_PROC;
- mib[2] = KERN_PROC_PID;
- mib[3] = getpid();
+ NameAndLocation const& TrackerBase::nameAndLocation() const {
+ return m_nameAndLocation;
+ }
+ bool TrackerBase::isComplete() const {
+ return m_runState == CompletedSuccessfully || m_runState == Failed;
+ }
+ bool TrackerBase::isSuccessfullyCompleted() const {
+ return m_runState == CompletedSuccessfully;
+ }
+ bool TrackerBase::isOpen() const {
+ return m_runState != NotStarted && !isComplete();
+ }
+ bool TrackerBase::hasChildren() const {
+ return !m_children.empty();
+ }
- // Call sysctl.
+ void TrackerBase::addChild( ITrackerPtr const& child ) {
+ m_children.push_back( child );
+ }
- size = sizeof(info);
- if( sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, CATCH_NULL, 0) != 0 ) {
- Catch::cerr() << "\n** Call to sysctl failed - unable to determine if debugger is active **\n" << std::endl;
- return false;
- }
+ ITrackerPtr TrackerBase::findChild( NameAndLocation const& nameAndLocation ) {
+ auto it = std::find_if( m_children.begin(), m_children.end(),
+ [&nameAndLocation]( ITrackerPtr const& tracker ){
+ return
+ tracker->nameAndLocation().location == nameAndLocation.location &&
+ tracker->nameAndLocation().name == nameAndLocation.name;
+ } );
+ return( it != m_children.end() )
+ ? *it
+ : nullptr;
+ }
+ ITracker& TrackerBase::parent() {
+ assert( m_parent ); // Should always be non-null except for root
+ return *m_parent;
+ }
- // We're being debugged if the P_TRACED flag is set.
+ void TrackerBase::openChild() {
+ if( m_runState != ExecutingChildren ) {
+ m_runState = ExecutingChildren;
+ if( m_parent )
+ m_parent->openChild();
+ }
+ }
- return ( (info.kp_proc.p_flag & P_TRACED) != 0 );
+ bool TrackerBase::isSectionTracker() const { return false; }
+ bool TrackerBase::isGeneratorTracker() const { return false; }
+
+ void TrackerBase::open() {
+ m_runState = Executing;
+ moveToThis();
+ if( m_parent )
+ m_parent->openChild();
+ }
+
+ void TrackerBase::close() {
+
+ // Close any still open children (e.g. generators)
+ while( &m_ctx.currentTracker() != this )
+ m_ctx.currentTracker().close();
+
+ switch( m_runState ) {
+ case NeedsAnotherRun:
+ break;
+
+ case Executing:
+ m_runState = CompletedSuccessfully;
+ break;
+ case ExecutingChildren:
+ if( std::all_of(m_children.begin(), m_children.end(), [](ITrackerPtr const& t){ return t->isComplete(); }) )
+ m_runState = CompletedSuccessfully;
+ break;
+
+ case NotStarted:
+ case CompletedSuccessfully:
+ case Failed:
+ CATCH_INTERNAL_ERROR( "Illogical state: " << m_runState );
+
+ default:
+ CATCH_INTERNAL_ERROR( "Unknown state: " << m_runState );
}
- } // namespace Catch
+ moveToParent();
+ m_ctx.completeCycle();
+ }
+ void TrackerBase::fail() {
+ m_runState = Failed;
+ if( m_parent )
+ m_parent->markAsNeedingAnotherRun();
+ moveToParent();
+ m_ctx.completeCycle();
+ }
+ void TrackerBase::markAsNeedingAnotherRun() {
+ m_runState = NeedsAnotherRun;
+ }
-#elif defined(CATCH_PLATFORM_LINUX)
- #include <fstream>
- #include <string>
+ void TrackerBase::moveToParent() {
+ assert( m_parent );
+ m_ctx.setCurrentTracker( m_parent );
+ }
+ void TrackerBase::moveToThis() {
+ m_ctx.setCurrentTracker( this );
+ }
- namespace Catch{
- // The standard POSIX way of detecting a debugger is to attempt to
- // ptrace() the process, but this needs to be done from a child and not
- // this process itself to still allow attaching to this process later
- // if wanted, so is rather heavy. Under Linux we have the PID of the
- // "debugger" (which doesn't need to be gdb, of course, it could also
- // be strace, for example) in /proc/$PID/status, so just get it from
- // there instead.
- bool isDebuggerActive(){
- // Libstdc++ has a bug, where std::ifstream sets errno to 0
- // This way our users can properly assert over errno values
- ErrnoGuard guard;
- std::ifstream in("/proc/self/status");
- for( std::string line; std::getline(in, line); ) {
- static const int PREFIX_LEN = 11;
- if( line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0 ) {
- // We're traced if the PID is not 0 and no other PID starts
- // with 0 digit, so it's enough to check for just a single
- // character.
- return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0';
- }
- }
+ SectionTracker::SectionTracker( NameAndLocation const& nameAndLocation, TrackerContext& ctx, ITracker* parent )
+ : TrackerBase( nameAndLocation, ctx, parent ),
+ m_trimmed_name(trim(nameAndLocation.name))
+ {
+ if( parent ) {
+ while( !parent->isSectionTracker() )
+ parent = &parent->parent();
- return false;
+ SectionTracker& parentSection = static_cast<SectionTracker&>( *parent );
+ addNextFilters( parentSection.m_filters );
}
- } // namespace Catch
-#elif defined(_MSC_VER)
- extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent();
- namespace Catch {
- bool isDebuggerActive() {
- return IsDebuggerPresent() != 0;
+ }
+
+ bool SectionTracker::isComplete() const {
+ bool complete = true;
+
+ if ((m_filters.empty() || m_filters[0] == "")
+ || std::find(m_filters.begin(), m_filters.end(), m_trimmed_name) != m_filters.end()) {
+ complete = TrackerBase::isComplete();
}
+ return complete;
}
-#elif defined(__MINGW32__)
- extern "C" __declspec(dllimport) int __stdcall IsDebuggerPresent();
- namespace Catch {
- bool isDebuggerActive() {
- return IsDebuggerPresent() != 0;
+
+ bool SectionTracker::isSectionTracker() const { return true; }
+
+ SectionTracker& SectionTracker::acquire( TrackerContext& ctx, NameAndLocation const& nameAndLocation ) {
+ std::shared_ptr<SectionTracker> section;
+
+ ITracker& currentTracker = ctx.currentTracker();
+ if( ITrackerPtr childTracker = currentTracker.findChild( nameAndLocation ) ) {
+ assert( childTracker );
+ assert( childTracker->isSectionTracker() );
+ section = std::static_pointer_cast<SectionTracker>( childTracker );
+ }
+ else {
+ section = std::make_shared<SectionTracker>( nameAndLocation, ctx, &currentTracker );
+ currentTracker.addChild( section );
}
+ if( !ctx.completedCycle() )
+ section->tryOpen();
+ return *section;
}
-#else
- namespace Catch {
- inline bool isDebuggerActive() { return false; }
+
+ void SectionTracker::tryOpen() {
+ if( !isComplete() )
+ open();
}
-#endif // Platform
-#ifdef CATCH_PLATFORM_WINDOWS
+ void SectionTracker::addInitialFilters( std::vector<std::string> const& filters ) {
+ if( !filters.empty() ) {
+ m_filters.reserve( m_filters.size() + filters.size() + 2 );
+ m_filters.push_back(""); // Root - should never be consulted
+ m_filters.push_back(""); // Test Case - not a section filter
+ m_filters.insert( m_filters.end(), filters.begin(), filters.end() );
+ }
+ }
+ void SectionTracker::addNextFilters( std::vector<std::string> const& filters ) {
+ if( filters.size() > 1 )
+ m_filters.insert( m_filters.end(), filters.begin()+1, filters.end() );
+ }
- namespace Catch {
- void writeToDebugConsole( std::string const& text ) {
- ::OutputDebugStringA( text.c_str() );
+} // namespace TestCaseTracking
+
+using TestCaseTracking::ITracker;
+using TestCaseTracking::TrackerContext;
+using TestCaseTracking::SectionTracker;
+
+} // namespace Catch
+
+#if defined(__clang__)
+# pragma clang diagnostic pop
+#endif
+// end catch_test_case_tracker.cpp
+// start catch_test_registry.cpp
+
+namespace Catch {
+
+ auto makeTestInvoker( void(*testAsFunction)() ) noexcept -> ITestInvoker* {
+ return new(std::nothrow) TestInvokerAsFunction( testAsFunction );
+ }
+
+ NameAndTags::NameAndTags( StringRef const& name_ , StringRef const& tags_ ) noexcept : name( name_ ), tags( tags_ ) {}
+
+ AutoReg::AutoReg( ITestInvoker* invoker, SourceLineInfo const& lineInfo, StringRef const& classOrMethod, NameAndTags const& nameAndTags ) noexcept {
+ CATCH_TRY {
+ getMutableRegistryHub()
+ .registerTest(
+ makeTestCase(
+ invoker,
+ extractClassName( classOrMethod ),
+ nameAndTags,
+ lineInfo));
+ } CATCH_CATCH_ALL {
+ // Do not throw when constructing global objects, instead register the exception to be processed later
+ getMutableRegistryHub().registerStartupException();
}
}
-#else
- namespace Catch {
- void writeToDebugConsole( std::string const& text ) {
- // !TBD: Need a version for Mac/ XCode and other IDEs
- Catch::cout() << text;
+
+ AutoReg::~AutoReg() = default;
+}
+// end catch_test_registry.cpp
+// start catch_test_spec.cpp
+
+#include <algorithm>
+#include <string>
+#include <vector>
+#include <memory>
+
+namespace Catch {
+
+ TestSpec::Pattern::Pattern( std::string const& name )
+ : m_name( name )
+ {}
+
+ TestSpec::Pattern::~Pattern() = default;
+
+ std::string const& TestSpec::Pattern::name() const {
+ return m_name;
+ }
+
+ TestSpec::NamePattern::NamePattern( std::string const& name, std::string const& filterString )
+ : Pattern( filterString )
+ , m_wildcardPattern( toLower( name ), CaseSensitive::No )
+ {}
+
+ bool TestSpec::NamePattern::matches( TestCaseInfo const& testCase ) const {
+ return m_wildcardPattern.matches( testCase.name );
+ }
+
+ TestSpec::TagPattern::TagPattern( std::string const& tag, std::string const& filterString )
+ : Pattern( filterString )
+ , m_tag( toLower( tag ) )
+ {}
+
+ bool TestSpec::TagPattern::matches( TestCaseInfo const& testCase ) const {
+ return std::find(begin(testCase.lcaseTags),
+ end(testCase.lcaseTags),
+ m_tag) != end(testCase.lcaseTags);
+ }
+
+ TestSpec::ExcludedPattern::ExcludedPattern( PatternPtr const& underlyingPattern )
+ : Pattern( underlyingPattern->name() )
+ , m_underlyingPattern( underlyingPattern )
+ {}
+
+ bool TestSpec::ExcludedPattern::matches( TestCaseInfo const& testCase ) const {
+ return !m_underlyingPattern->matches( testCase );
+ }
+
+ bool TestSpec::Filter::matches( TestCaseInfo const& testCase ) const {
+ return std::all_of( m_patterns.begin(), m_patterns.end(), [&]( PatternPtr const& p ){ return p->matches( testCase ); } );
+ }
+
+ std::string TestSpec::Filter::name() const {
+ std::string name;
+ for( auto const& p : m_patterns )
+ name += p->name();
+ return name;
+ }
+
+ bool TestSpec::hasFilters() const {
+ return !m_filters.empty();
+ }
+
+ bool TestSpec::matches( TestCaseInfo const& testCase ) const {
+ return std::any_of( m_filters.begin(), m_filters.end(), [&]( Filter const& f ){ return f.matches( testCase ); } );
+ }
+
+ TestSpec::Matches TestSpec::matchesByFilter( std::vector<TestCase> const& testCases, IConfig const& config ) const
+ {
+ Matches matches( m_filters.size() );
+ std::transform( m_filters.begin(), m_filters.end(), matches.begin(), [&]( Filter const& filter ){
+ std::vector<TestCase const*> currentMatches;
+ for( auto const& test : testCases )
+ if( isThrowSafe( test, config ) && filter.matches( test ) )
+ currentMatches.emplace_back( &test );
+ return FilterMatch{ filter.name(), currentMatches };
+ } );
+ return matches;
+ }
+
+ const TestSpec::vectorStrings& TestSpec::getInvalidArgs() const{
+ return (m_invalidArgs);
+ }
+
+}
+// end catch_test_spec.cpp
+// start catch_test_spec_parser.cpp
+
+namespace Catch {
+
+ TestSpecParser::TestSpecParser( ITagAliasRegistry const& tagAliases ) : m_tagAliases( &tagAliases ) {}
+
+ TestSpecParser& TestSpecParser::parse( std::string const& arg ) {
+ m_mode = None;
+ m_exclusion = false;
+ m_arg = m_tagAliases->expandAliases( arg );
+ m_escapeChars.clear();
+ m_substring.reserve(m_arg.size());
+ m_patternName.reserve(m_arg.size());
+ m_realPatternPos = 0;
+
+ for( m_pos = 0; m_pos < m_arg.size(); ++m_pos )
+ //if visitChar fails
+ if( !visitChar( m_arg[m_pos] ) ){
+ m_testSpec.m_invalidArgs.push_back(arg);
+ break;
+ }
+ endMode();
+ return *this;
+ }
+ TestSpec TestSpecParser::testSpec() {
+ addFilter();
+ return m_testSpec;
+ }
+ bool TestSpecParser::visitChar( char c ) {
+ if( (m_mode != EscapedName) && (c == '\\') ) {
+ escape();
+ addCharToPattern(c);
+ return true;
+ }else if((m_mode != EscapedName) && (c == ',') ) {
+ return separate();
+ }
+
+ switch( m_mode ) {
+ case None:
+ if( processNoneChar( c ) )
+ return true;
+ break;
+ case Name:
+ processNameChar( c );
+ break;
+ case EscapedName:
+ endMode();
+ addCharToPattern(c);
+ return true;
+ default:
+ case Tag:
+ case QuotedName:
+ if( processOtherChar( c ) )
+ return true;
+ break;
+ }
+
+ m_substring += c;
+ if( !isControlChar( c ) ) {
+ m_patternName += c;
+ m_realPatternPos++;
}
+ return true;
}
-#endif // Platform
+ // Two of the processing methods return true to signal the caller to return
+ // without adding the given character to the current pattern strings
+ bool TestSpecParser::processNoneChar( char c ) {
+ switch( c ) {
+ case ' ':
+ return true;
+ case '~':
+ m_exclusion = true;
+ return false;
+ case '[':
+ startNewMode( Tag );
+ return false;
+ case '"':
+ startNewMode( QuotedName );
+ return false;
+ default:
+ startNewMode( Name );
+ return false;
+ }
+ }
+ void TestSpecParser::processNameChar( char c ) {
+ if( c == '[' ) {
+ if( m_substring == "exclude:" )
+ m_exclusion = true;
+ else
+ endMode();
+ startNewMode( Tag );
+ }
+ }
+ bool TestSpecParser::processOtherChar( char c ) {
+ if( !isControlChar( c ) )
+ return false;
+ m_substring += c;
+ endMode();
+ return true;
+ }
+ void TestSpecParser::startNewMode( Mode mode ) {
+ m_mode = mode;
+ }
+ void TestSpecParser::endMode() {
+ switch( m_mode ) {
+ case Name:
+ case QuotedName:
+ return addPattern<TestSpec::NamePattern>();
+ case Tag:
+ return addPattern<TestSpec::TagPattern>();
+ case EscapedName:
+ revertBackToLastMode();
+ return;
+ case None:
+ default:
+ return startNewMode( None );
+ }
+ }
+ void TestSpecParser::escape() {
+ saveLastMode();
+ m_mode = EscapedName;
+ m_escapeChars.push_back(m_realPatternPos);
+ }
+ bool TestSpecParser::isControlChar( char c ) const {
+ switch( m_mode ) {
+ default:
+ return false;
+ case None:
+ return c == '~';
+ case Name:
+ return c == '[';
+ case EscapedName:
+ return true;
+ case QuotedName:
+ return c == '"';
+ case Tag:
+ return c == '[' || c == ']';
+ }
+ }
+
+ void TestSpecParser::addFilter() {
+ if( !m_currentFilter.m_patterns.empty() ) {
+ m_testSpec.m_filters.push_back( m_currentFilter );
+ m_currentFilter = TestSpec::Filter();
+ }
+ }
+
+ void TestSpecParser::saveLastMode() {
+ lastMode = m_mode;
+ }
+
+ void TestSpecParser::revertBackToLastMode() {
+ m_mode = lastMode;
+ }
+
+ bool TestSpecParser::separate() {
+ if( (m_mode==QuotedName) || (m_mode==Tag) ){
+ //invalid argument, signal failure to previous scope.
+ m_mode = None;
+ m_pos = m_arg.size();
+ m_substring.clear();
+ m_patternName.clear();
+ return false;
+ }
+ endMode();
+ addFilter();
+ return true; //success
+ }
+
+ TestSpec parseTestSpec( std::string const& arg ) {
+ return TestSpecParser( ITagAliasRegistry::get() ).parse( arg ).testSpec();
+ }
+
+} // namespace Catch
+// end catch_test_spec_parser.cpp
+// start catch_timer.cpp
+
+#include <chrono>
-// #included from: catch_tostring.hpp
-#define TWOBLUECUBES_CATCH_TOSTRING_HPP_INCLUDED
+static const uint64_t nanosecondsInSecond = 1000000000;
+
+namespace Catch {
+
+ auto getCurrentNanosecondsSinceEpoch() -> uint64_t {
+ return std::chrono::duration_cast<std::chrono::nanoseconds>( std::chrono::high_resolution_clock::now().time_since_epoch() ).count();
+ }
+
+ namespace {
+ auto estimateClockResolution() -> uint64_t {
+ uint64_t sum = 0;
+ static const uint64_t iterations = 1000000;
+
+ auto startTime = getCurrentNanosecondsSinceEpoch();
+
+ for( std::size_t i = 0; i < iterations; ++i ) {
+
+ uint64_t ticks;
+ uint64_t baseTicks = getCurrentNanosecondsSinceEpoch();
+ do {
+ ticks = getCurrentNanosecondsSinceEpoch();
+ } while( ticks == baseTicks );
+
+ auto delta = ticks - baseTicks;
+ sum += delta;
+
+ // If we have been calibrating for over 3 seconds -- the clock
+ // is terrible and we should move on.
+ // TBD: How to signal that the measured resolution is probably wrong?
+ if (ticks > startTime + 3 * nanosecondsInSecond) {
+ return sum / ( i + 1u );
+ }
+ }
+
+ // We're just taking the mean, here. To do better we could take the std. dev and exclude outliers
+ // - and potentially do more iterations if there's a high variance.
+ return sum/iterations;
+ }
+ }
+ auto getEstimatedClockResolution() -> uint64_t {
+ static auto s_resolution = estimateClockResolution();
+ return s_resolution;
+ }
+
+ void Timer::start() {
+ m_nanoseconds = getCurrentNanosecondsSinceEpoch();
+ }
+ auto Timer::getElapsedNanoseconds() const -> uint64_t {
+ return getCurrentNanosecondsSinceEpoch() - m_nanoseconds;
+ }
+ auto Timer::getElapsedMicroseconds() const -> uint64_t {
+ return getElapsedNanoseconds()/1000;
+ }
+ auto Timer::getElapsedMilliseconds() const -> unsigned int {
+ return static_cast<unsigned int>(getElapsedMicroseconds()/1000);
+ }
+ auto Timer::getElapsedSeconds() const -> double {
+ return getElapsedMicroseconds()/1000000.0;
+ }
+
+} // namespace Catch
+// end catch_timer.cpp
+// start catch_tostring.cpp
+
+#if defined(__clang__)
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wexit-time-destructors"
+# pragma clang diagnostic ignored "-Wglobal-constructors"
+#endif
+
+// Enable specific decls locally
+#if !defined(CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER)
+#define CATCH_CONFIG_ENABLE_CHRONO_STRINGMAKER
+#endif
+
+#include <cmath>
+#include <iomanip>
namespace Catch {
@@ -8822,8 +14672,7 @@ namespace Detail {
};
}
- std::string rawMemoryToString( const void *object, std::size_t size )
- {
+ std::string rawMemoryToString( const void *object, std::size_t size ) {
// Reverse order for little endian architectures
int i = 0, end = static_cast<int>( size ), inc = 1;
if( Endianness::which() == Endianness::Little ) {
@@ -8832,1433 +14681,1671 @@ namespace Detail {
}
unsigned char const *bytes = static_cast<unsigned char const *>(object);
- std::ostringstream os;
- os << "0x" << std::setfill('0') << std::hex;
+ ReusableStringStream rss;
+ rss << "0x" << std::setfill('0') << std::hex;
for( ; i != end; i += inc )
- os << std::setw(2) << static_cast<unsigned>(bytes[i]);
- return os.str();
+ rss << std::setw(2) << static_cast<unsigned>(bytes[i]);
+ return rss.str();
}
}
-std::string toString( std::string const& value ) {
- std::string s = value;
- if( getCurrentContext().getConfig()->showInvisibles() ) {
- for(size_t i = 0; i < s.size(); ++i ) {
- std::string subs;
- switch( s[i] ) {
- case '\n': subs = "\\n"; break;
- case '\t': subs = "\\t"; break;
- default: break;
- }
- if( !subs.empty() ) {
- s = s.substr( 0, i ) + subs + s.substr( i+1 );
- ++i;
- }
- }
+template<typename T>
+std::string fpToString( T value, int precision ) {
+ if (Catch::isnan(value)) {
+ return "nan";
}
- return '"' + s + '"';
-}
-std::string toString( std::wstring const& value ) {
- std::string s;
- s.reserve( value.size() );
- for(size_t i = 0; i < value.size(); ++i )
- s += value[i] <= 0xff ? static_cast<char>( value[i] ) : '?';
- return Catch::toString( s );
+ ReusableStringStream rss;
+ rss << std::setprecision( precision )
+ << std::fixed
+ << value;
+ std::string d = rss.str();
+ std::size_t i = d.find_last_not_of( '0' );
+ if( i != std::string::npos && i != d.size()-1 ) {
+ if( d[i] == '.' )
+ i++;
+ d = d.substr( 0, i+1 );
+ }
+ return d;
}
-std::string toString( const char* const value ) {
- return value ? Catch::toString( std::string( value ) ) : std::string( "{null string}" );
+//// ======================================================= ////
+//
+// Out-of-line defs for full specialization of StringMaker
+//
+//// ======================================================= ////
+
+std::string StringMaker<std::string>::convert(const std::string& str) {
+ if (!getCurrentContext().getConfig()->showInvisibles()) {
+ return '"' + str + '"';
+ }
+
+ std::string s("\"");
+ for (char c : str) {
+ switch (c) {
+ case '\n':
+ s.append("\\n");
+ break;
+ case '\t':
+ s.append("\\t");
+ break;
+ default:
+ s.push_back(c);
+ break;
+ }
+ }
+ s.append("\"");
+ return s;
}
-std::string toString( char* const value ) {
- return Catch::toString( static_cast<const char*>( value ) );
+#ifdef CATCH_CONFIG_CPP17_STRING_VIEW
+std::string StringMaker<std::string_view>::convert(std::string_view str) {
+ return ::Catch::Detail::stringify(std::string{ str });
}
+#endif
-std::string toString( const wchar_t* const value )
-{
- return value ? Catch::toString( std::wstring(value) ) : std::string( "{null string}" );
+std::string StringMaker<char const*>::convert(char const* str) {
+ if (str) {
+ return ::Catch::Detail::stringify(std::string{ str });
+ } else {
+ return{ "{null string}" };
+ }
+}
+std::string StringMaker<char*>::convert(char* str) {
+ if (str) {
+ return ::Catch::Detail::stringify(std::string{ str });
+ } else {
+ return{ "{null string}" };
+ }
}
-std::string toString( wchar_t* const value )
-{
- return Catch::toString( static_cast<const wchar_t*>( value ) );
+#ifdef CATCH_CONFIG_WCHAR
+std::string StringMaker<std::wstring>::convert(const std::wstring& wstr) {
+ std::string s;
+ s.reserve(wstr.size());
+ for (auto c : wstr) {
+ s += (c <= 0xff) ? static_cast<char>(c) : '?';
+ }
+ return ::Catch::Detail::stringify(s);
}
-std::string toString( int value ) {
- std::ostringstream oss;
- oss << value;
- if( value > Detail::hexThreshold )
- oss << " (0x" << std::hex << value << ')';
- return oss.str();
+# ifdef CATCH_CONFIG_CPP17_STRING_VIEW
+std::string StringMaker<std::wstring_view>::convert(std::wstring_view str) {
+ return StringMaker<std::wstring>::convert(std::wstring(str));
}
+# endif
-std::string toString( unsigned long value ) {
- std::ostringstream oss;
- oss << value;
- if( value > Detail::hexThreshold )
- oss << " (0x" << std::hex << value << ')';
- return oss.str();
+std::string StringMaker<wchar_t const*>::convert(wchar_t const * str) {
+ if (str) {
+ return ::Catch::Detail::stringify(std::wstring{ str });
+ } else {
+ return{ "{null string}" };
+ }
}
+std::string StringMaker<wchar_t *>::convert(wchar_t * str) {
+ if (str) {
+ return ::Catch::Detail::stringify(std::wstring{ str });
+ } else {
+ return{ "{null string}" };
+ }
+}
+#endif
-std::string toString( unsigned int value ) {
- return Catch::toString( static_cast<unsigned long>( value ) );
+#if defined(CATCH_CONFIG_CPP17_BYTE)
+#include <cstddef>
+std::string StringMaker<std::byte>::convert(std::byte value) {
+ return ::Catch::Detail::stringify(std::to_integer<unsigned long long>(value));
}
+#endif // defined(CATCH_CONFIG_CPP17_BYTE)
-template<typename T>
-std::string fpToString( T value, int precision ) {
- std::ostringstream oss;
- oss << std::setprecision( precision )
- << std::fixed
- << value;
- std::string d = oss.str();
- std::size_t i = d.find_last_not_of( '0' );
- if( i != std::string::npos && i != d.size()-1 ) {
- if( d[i] == '.' )
- i++;
- d = d.substr( 0, i+1 );
+std::string StringMaker<int>::convert(int value) {
+ return ::Catch::Detail::stringify(static_cast<long long>(value));
+}
+std::string StringMaker<long>::convert(long value) {
+ return ::Catch::Detail::stringify(static_cast<long long>(value));
+}
+std::string StringMaker<long long>::convert(long long value) {
+ ReusableStringStream rss;
+ rss << value;
+ if (value > Detail::hexThreshold) {
+ rss << " (0x" << std::hex << value << ')';
}
- return d;
+ return rss.str();
}
-std::string toString( const double value ) {
- return fpToString( value, 10 );
+std::string StringMaker<unsigned int>::convert(unsigned int value) {
+ return ::Catch::Detail::stringify(static_cast<unsigned long long>(value));
}
-std::string toString( const float value ) {
- return fpToString( value, 5 ) + 'f';
+std::string StringMaker<unsigned long>::convert(unsigned long value) {
+ return ::Catch::Detail::stringify(static_cast<unsigned long long>(value));
+}
+std::string StringMaker<unsigned long long>::convert(unsigned long long value) {
+ ReusableStringStream rss;
+ rss << value;
+ if (value > Detail::hexThreshold) {
+ rss << " (0x" << std::hex << value << ')';
+ }
+ return rss.str();
}
-std::string toString( bool value ) {
- return value ? "true" : "false";
+std::string StringMaker<bool>::convert(bool b) {
+ return b ? "true" : "false";
}
-std::string toString( char value ) {
- if ( value == '\r' )
+std::string StringMaker<signed char>::convert(signed char value) {
+ if (value == '\r') {
return "'\\r'";
- if ( value == '\f' )
+ } else if (value == '\f') {
return "'\\f'";
- if ( value == '\n' )
+ } else if (value == '\n') {
return "'\\n'";
- if ( value == '\t' )
+ } else if (value == '\t') {
return "'\\t'";
- if ( '\0' <= value && value < ' ' )
- return toString( static_cast<unsigned int>( value ) );
- char chstr[] = "' '";
- chstr[1] = value;
- return chstr;
+ } else if ('\0' <= value && value < ' ') {
+ return ::Catch::Detail::stringify(static_cast<unsigned int>(value));
+ } else {
+ char chstr[] = "' '";
+ chstr[1] = value;
+ return chstr;
+ }
}
-
-std::string toString( signed char value ) {
- return toString( static_cast<char>( value ) );
+std::string StringMaker<char>::convert(char c) {
+ return ::Catch::Detail::stringify(static_cast<signed char>(c));
}
-
-std::string toString( unsigned char value ) {
- return toString( static_cast<char>( value ) );
+std::string StringMaker<unsigned char>::convert(unsigned char c) {
+ return ::Catch::Detail::stringify(static_cast<char>(c));
}
-#ifdef CATCH_CONFIG_CPP11_LONG_LONG
-std::string toString( long long value ) {
- std::ostringstream oss;
- oss << value;
- if( value > Detail::hexThreshold )
- oss << " (0x" << std::hex << value << ')';
- return oss.str();
+std::string StringMaker<std::nullptr_t>::convert(std::nullptr_t) {
+ return "nullptr";
}
-std::string toString( unsigned long long value ) {
- std::ostringstream oss;
- oss << value;
- if( value > Detail::hexThreshold )
- oss << " (0x" << std::hex << value << ')';
- return oss.str();
+
+int StringMaker<float>::precision = 5;
+
+std::string StringMaker<float>::convert(float value) {
+ return fpToString(value, precision) + 'f';
}
-#endif
-#ifdef CATCH_CONFIG_CPP11_NULLPTR
-std::string toString( std::nullptr_t ) {
- return "nullptr";
+int StringMaker<double>::precision = 10;
+
+std::string StringMaker<double>::convert(double value) {
+ return fpToString(value, precision);
}
-#endif
-#ifdef __OBJC__
- std::string toString( NSString const * const& nsstring ) {
- if( !nsstring )
- return "nil";
- return "@" + toString([nsstring UTF8String]);
- }
- std::string toString( NSString * CATCH_ARC_STRONG const& nsstring ) {
- if( !nsstring )
- return "nil";
- return "@" + toString([nsstring UTF8String]);
- }
- std::string toString( NSObject* const& nsObject ) {
- return toString( [nsObject description] );
- }
-#endif
+std::string ratio_string<std::atto>::symbol() { return "a"; }
+std::string ratio_string<std::femto>::symbol() { return "f"; }
+std::string ratio_string<std::pico>::symbol() { return "p"; }
+std::string ratio_string<std::nano>::symbol() { return "n"; }
+std::string ratio_string<std::micro>::symbol() { return "u"; }
+std::string ratio_string<std::milli>::symbol() { return "m"; }
} // end namespace Catch
-// #included from: catch_result_builder.hpp
-#define TWOBLUECUBES_CATCH_RESULT_BUILDER_HPP_INCLUDED
+#if defined(__clang__)
+# pragma clang diagnostic pop
+#endif
+
+// end catch_tostring.cpp
+// start catch_totals.cpp
namespace Catch {
- std::string capturedExpressionWithSecondArgument( std::string const& capturedExpression, std::string const& secondArg ) {
- return secondArg.empty() || secondArg == "\"\""
- ? capturedExpression
- : capturedExpression + ", " + secondArg;
+ Counts Counts::operator - ( Counts const& other ) const {
+ Counts diff;
+ diff.passed = passed - other.passed;
+ diff.failed = failed - other.failed;
+ diff.failedButOk = failedButOk - other.failedButOk;
+ return diff;
}
- ResultBuilder::ResultBuilder( char const* macroName,
- SourceLineInfo const& lineInfo,
- char const* capturedExpression,
- ResultDisposition::Flags resultDisposition,
- char const* secondArg )
- : m_assertionInfo( macroName, lineInfo, capturedExpressionWithSecondArgument( capturedExpression, secondArg ), resultDisposition ),
- m_shouldDebugBreak( false ),
- m_shouldThrow( false ),
- m_guardException( false )
- {}
- ResultBuilder::~ResultBuilder() {
-#if defined(CATCH_CONFIG_FAST_COMPILE)
- if ( m_guardException ) {
- m_stream.oss << "Exception translation was disabled by CATCH_CONFIG_FAST_COMPILE";
- captureResult( ResultWas::ThrewException );
- getCurrentContext().getResultCapture()->exceptionEarlyReported();
- }
-#endif
+ Counts& Counts::operator += ( Counts const& other ) {
+ passed += other.passed;
+ failed += other.failed;
+ failedButOk += other.failedButOk;
+ return *this;
}
- ResultBuilder& ResultBuilder::setResultType( ResultWas::OfType result ) {
- m_data.resultType = result;
- return *this;
+ std::size_t Counts::total() const {
+ return passed + failed + failedButOk;
}
- ResultBuilder& ResultBuilder::setResultType( bool result ) {
- m_data.resultType = result ? ResultWas::Ok : ResultWas::ExpressionFailed;
- return *this;
+ bool Counts::allPassed() const {
+ return failed == 0 && failedButOk == 0;
}
-
- void ResultBuilder::endExpression( DecomposedExpression const& expr ) {
- AssertionResult result = build( expr );
- handleResult( result );
+ bool Counts::allOk() const {
+ return failed == 0;
}
- void ResultBuilder::useActiveException( ResultDisposition::Flags resultDisposition ) {
- m_assertionInfo.resultDisposition = resultDisposition;
- m_stream.oss << Catch::translateActiveException();
- captureResult( ResultWas::ThrewException );
+ Totals Totals::operator - ( Totals const& other ) const {
+ Totals diff;
+ diff.assertions = assertions - other.assertions;
+ diff.testCases = testCases - other.testCases;
+ return diff;
}
- void ResultBuilder::captureResult( ResultWas::OfType resultType ) {
- setResultType( resultType );
- captureExpression();
+ Totals& Totals::operator += ( Totals const& other ) {
+ assertions += other.assertions;
+ testCases += other.testCases;
+ return *this;
}
- void ResultBuilder::captureExpectedException( std::string const& expectedMessage ) {
- if( expectedMessage.empty() )
- captureExpectedException( Matchers::Impl::MatchAllOf<std::string>() );
+ Totals Totals::delta( Totals const& prevTotals ) const {
+ Totals diff = *this - prevTotals;
+ if( diff.assertions.failed > 0 )
+ ++diff.testCases.failed;
+ else if( diff.assertions.failedButOk > 0 )
+ ++diff.testCases.failedButOk;
else
- captureExpectedException( Matchers::Equals( expectedMessage ) );
+ ++diff.testCases.passed;
+ return diff;
}
- void ResultBuilder::captureExpectedException( Matchers::Impl::MatcherBase<std::string> const& matcher ) {
+}
+// end catch_totals.cpp
+// start catch_uncaught_exceptions.cpp
- assert( !isFalseTest( m_assertionInfo.resultDisposition ) );
- AssertionResultData data = m_data;
- data.resultType = ResultWas::Ok;
- data.reconstructedExpression = m_assertionInfo.capturedExpression;
+#include <exception>
- std::string actualMessage = Catch::translateActiveException();
- if( !matcher.match( actualMessage ) ) {
- data.resultType = ResultWas::ExpressionFailed;
- data.reconstructedExpression = actualMessage;
- }
- AssertionResult result( m_assertionInfo, data );
- handleResult( result );
- }
+namespace Catch {
+ bool uncaught_exceptions() {
+#if defined(CATCH_CONFIG_CPP17_UNCAUGHT_EXCEPTIONS)
+ return std::uncaught_exceptions() > 0;
+#else
+ return std::uncaught_exception();
+#endif
+ }
+} // end namespace Catch
+// end catch_uncaught_exceptions.cpp
+// start catch_version.cpp
- void ResultBuilder::captureExpression() {
- AssertionResult result = build();
- handleResult( result );
- }
+#include <ostream>
- void ResultBuilder::handleResult( AssertionResult const& result )
- {
- getResultCapture().assertionEnded( result );
+namespace Catch {
- if( !result.isOk() ) {
- if( getCurrentContext().getConfig()->shouldDebugBreak() )
- m_shouldDebugBreak = true;
- if( getCurrentContext().getRunner()->aborting() || (m_assertionInfo.resultDisposition & ResultDisposition::Normal) )
- m_shouldThrow = true;
+ Version::Version
+ ( unsigned int _majorVersion,
+ unsigned int _minorVersion,
+ unsigned int _patchNumber,
+ char const * const _branchName,
+ unsigned int _buildNumber )
+ : majorVersion( _majorVersion ),
+ minorVersion( _minorVersion ),
+ patchNumber( _patchNumber ),
+ branchName( _branchName ),
+ buildNumber( _buildNumber )
+ {}
+
+ std::ostream& operator << ( std::ostream& os, Version const& version ) {
+ os << version.majorVersion << '.'
+ << version.minorVersion << '.'
+ << version.patchNumber;
+ // branchName is never null -> 0th char is \0 if it is empty
+ if (version.branchName[0]) {
+ os << '-' << version.branchName
+ << '.' << version.buildNumber;
}
+ return os;
}
- void ResultBuilder::react() {
-#if defined(CATCH_CONFIG_FAST_COMPILE)
- if (m_shouldDebugBreak) {
- ///////////////////////////////////////////////////////////////////
- // To inspect the state during test, you need to go one level up the callstack
- // To go back to the test and change execution, jump over the throw statement
- ///////////////////////////////////////////////////////////////////
- CATCH_BREAK_INTO_DEBUGGER();
- }
-#endif
- if( m_shouldThrow )
- throw Catch::TestFailureException();
+ Version const& libraryVersion() {
+ static Version version( 2, 10, 1, "", 0 );
+ return version;
}
- bool ResultBuilder::shouldDebugBreak() const { return m_shouldDebugBreak; }
- bool ResultBuilder::allowThrows() const { return getCurrentContext().getConfig()->allowThrows(); }
+}
+// end catch_version.cpp
+// start catch_wildcard_pattern.cpp
- AssertionResult ResultBuilder::build() const
- {
- return build( *this );
- }
+namespace Catch {
- // CAVEAT: The returned AssertionResult stores a pointer to the argument expr,
- // a temporary DecomposedExpression, which in turn holds references to
- // operands, possibly temporary as well.
- // It should immediately be passed to handleResult; if the expression
- // needs to be reported, its string expansion must be composed before
- // the temporaries are destroyed.
- AssertionResult ResultBuilder::build( DecomposedExpression const& expr ) const
+ WildcardPattern::WildcardPattern( std::string const& pattern,
+ CaseSensitive::Choice caseSensitivity )
+ : m_caseSensitivity( caseSensitivity ),
+ m_pattern( normaliseString( pattern ) )
{
- assert( m_data.resultType != ResultWas::Unknown );
- AssertionResultData data = m_data;
-
- // Flip bool results if FalseTest flag is set
- if( isFalseTest( m_assertionInfo.resultDisposition ) ) {
- data.negate( expr.isBinaryExpression() );
+ if( startsWith( m_pattern, '*' ) ) {
+ m_pattern = m_pattern.substr( 1 );
+ m_wildcard = WildcardAtStart;
+ }
+ if( endsWith( m_pattern, '*' ) ) {
+ m_pattern = m_pattern.substr( 0, m_pattern.size()-1 );
+ m_wildcard = static_cast<WildcardPosition>( m_wildcard | WildcardAtEnd );
}
-
- data.message = m_stream.oss.str();
- data.decomposedExpression = &expr; // for lazy reconstruction
- return AssertionResult( m_assertionInfo, data );
}
- void ResultBuilder::reconstructExpression( std::string& dest ) const {
- dest = m_assertionInfo.capturedExpression;
+ bool WildcardPattern::matches( std::string const& str ) const {
+ switch( m_wildcard ) {
+ case NoWildcard:
+ return m_pattern == normaliseString( str );
+ case WildcardAtStart:
+ return endsWith( normaliseString( str ), m_pattern );
+ case WildcardAtEnd:
+ return startsWith( normaliseString( str ), m_pattern );
+ case WildcardAtBothEnds:
+ return contains( normaliseString( str ), m_pattern );
+ default:
+ CATCH_INTERNAL_ERROR( "Unknown enum" );
+ }
}
- void ResultBuilder::setExceptionGuard() {
- m_guardException = true;
- }
- void ResultBuilder::unsetExceptionGuard() {
- m_guardException = false;
+ std::string WildcardPattern::normaliseString( std::string const& str ) const {
+ return trim( m_caseSensitivity == CaseSensitive::No ? toLower( str ) : str );
}
+}
+// end catch_wildcard_pattern.cpp
+// start catch_xmlwriter.cpp
-} // end namespace Catch
+#include <iomanip>
-// #included from: catch_tag_alias_registry.hpp
-#define TWOBLUECUBES_CATCH_TAG_ALIAS_REGISTRY_HPP_INCLUDED
+using uchar = unsigned char;
namespace Catch {
- TagAliasRegistry::~TagAliasRegistry() {}
-
- Option<TagAlias> TagAliasRegistry::find( std::string const& alias ) const {
- std::map<std::string, TagAlias>::const_iterator it = m_registry.find( alias );
- if( it != m_registry.end() )
- return it->second;
- else
- return Option<TagAlias>();
- }
+namespace {
- std::string TagAliasRegistry::expandAliases( std::string const& unexpandedTestSpec ) const {
- std::string expandedTestSpec = unexpandedTestSpec;
- for( std::map<std::string, TagAlias>::const_iterator it = m_registry.begin(), itEnd = m_registry.end();
- it != itEnd;
- ++it ) {
- std::size_t pos = expandedTestSpec.find( it->first );
- if( pos != std::string::npos ) {
- expandedTestSpec = expandedTestSpec.substr( 0, pos ) +
- it->second.tag +
- expandedTestSpec.substr( pos + it->first.size() );
- }
+ size_t trailingBytes(unsigned char c) {
+ if ((c & 0xE0) == 0xC0) {
+ return 2;
}
- return expandedTestSpec;
+ if ((c & 0xF0) == 0xE0) {
+ return 3;
+ }
+ if ((c & 0xF8) == 0xF0) {
+ return 4;
+ }
+ CATCH_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered");
}
- void TagAliasRegistry::add( std::string const& alias, std::string const& tag, SourceLineInfo const& lineInfo ) {
-
- if( !startsWith( alias, "[@" ) || !endsWith( alias, ']' ) ) {
- std::ostringstream oss;
- oss << Colour( Colour::Red )
- << "error: tag alias, \"" << alias << "\" is not of the form [@alias name].\n"
- << Colour( Colour::FileName )
- << lineInfo << '\n';
- throw std::domain_error( oss.str().c_str() );
+ uint32_t headerValue(unsigned char c) {
+ if ((c & 0xE0) == 0xC0) {
+ return c & 0x1F;
}
- if( !m_registry.insert( std::make_pair( alias, TagAlias( tag, lineInfo ) ) ).second ) {
- std::ostringstream oss;
- oss << Colour( Colour::Red )
- << "error: tag alias, \"" << alias << "\" already registered.\n"
- << "\tFirst seen at "
- << Colour( Colour::Red ) << find(alias)->lineInfo << '\n'
- << Colour( Colour::Red ) << "\tRedefined at "
- << Colour( Colour::FileName) << lineInfo << '\n';
- throw std::domain_error( oss.str().c_str() );
+ if ((c & 0xF0) == 0xE0) {
+ return c & 0x0F;
}
+ if ((c & 0xF8) == 0xF0) {
+ return c & 0x07;
+ }
+ CATCH_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered");
}
- ITagAliasRegistry::~ITagAliasRegistry() {}
-
- ITagAliasRegistry const& ITagAliasRegistry::get() {
- return getRegistryHub().getTagAliasRegistry();
+ void hexEscapeChar(std::ostream& os, unsigned char c) {
+ std::ios_base::fmtflags f(os.flags());
+ os << "\\x"
+ << std::uppercase << std::hex << std::setfill('0') << std::setw(2)
+ << static_cast<int>(c);
+ os.flags(f);
}
- RegistrarForTagAliases::RegistrarForTagAliases( char const* alias, char const* tag, SourceLineInfo const& lineInfo ) {
- getMutableRegistryHub().registerTagAlias( alias, tag, lineInfo );
- }
+} // anonymous namespace
-} // end namespace Catch
+ XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat )
+ : m_str( str ),
+ m_forWhat( forWhat )
+ {}
-// #included from: catch_matchers_string.hpp
+ void XmlEncode::encodeTo( std::ostream& os ) const {
+ // Apostrophe escaping not necessary if we always use " to write attributes
+ // (see: http://www.w3.org/TR/xml/#syntax)
-namespace Catch {
-namespace Matchers {
+ for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) {
+ uchar c = m_str[idx];
+ switch (c) {
+ case '<': os << "&lt;"; break;
+ case '&': os << "&amp;"; break;
- namespace StdString {
-
- CasedString::CasedString( std::string const& str, CaseSensitive::Choice caseSensitivity )
- : m_caseSensitivity( caseSensitivity ),
- m_str( adjustString( str ) )
- {}
- std::string CasedString::adjustString( std::string const& str ) const {
- return m_caseSensitivity == CaseSensitive::No
- ? toLower( str )
- : str;
- }
- std::string CasedString::caseSensitivitySuffix() const {
- return m_caseSensitivity == CaseSensitive::No
- ? " (case insensitive)"
- : std::string();
- }
+ case '>':
+ // See: http://www.w3.org/TR/xml/#syntax
+ if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']')
+ os << "&gt;";
+ else
+ os << c;
+ break;
- StringMatcherBase::StringMatcherBase( std::string const& operation, CasedString const& comparator )
- : m_comparator( comparator ),
- m_operation( operation ) {
- }
+ case '\"':
+ if (m_forWhat == ForAttributes)
+ os << "&quot;";
+ else
+ os << c;
+ break;
- std::string StringMatcherBase::describe() const {
- std::string description;
- description.reserve(5 + m_operation.size() + m_comparator.m_str.size() +
- m_comparator.caseSensitivitySuffix().size());
- description += m_operation;
- description += ": \"";
- description += m_comparator.m_str;
- description += "\"";
- description += m_comparator.caseSensitivitySuffix();
- return description;
- }
+ default:
+ // Check for control characters and invalid utf-8
- EqualsMatcher::EqualsMatcher( CasedString const& comparator ) : StringMatcherBase( "equals", comparator ) {}
+ // Escape control characters in standard ascii
+ // see http://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0
+ if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) {
+ hexEscapeChar(os, c);
+ break;
+ }
- bool EqualsMatcher::match( std::string const& source ) const {
- return m_comparator.adjustString( source ) == m_comparator.m_str;
- }
+ // Plain ASCII: Write it to stream
+ if (c < 0x7F) {
+ os << c;
+ break;
+ }
- ContainsMatcher::ContainsMatcher( CasedString const& comparator ) : StringMatcherBase( "contains", comparator ) {}
+ // UTF-8 territory
+ // Check if the encoding is valid and if it is not, hex escape bytes.
+ // Important: We do not check the exact decoded values for validity, only the encoding format
+ // First check that this bytes is a valid lead byte:
+ // This means that it is not encoded as 1111 1XXX
+ // Or as 10XX XXXX
+ if (c < 0xC0 ||
+ c >= 0xF8) {
+ hexEscapeChar(os, c);
+ break;
+ }
- bool ContainsMatcher::match( std::string const& source ) const {
- return contains( m_comparator.adjustString( source ), m_comparator.m_str );
- }
+ auto encBytes = trailingBytes(c);
+ // Are there enough bytes left to avoid accessing out-of-bounds memory?
+ if (idx + encBytes - 1 >= m_str.size()) {
+ hexEscapeChar(os, c);
+ break;
+ }
+ // The header is valid, check data
+ // The next encBytes bytes must together be a valid utf-8
+ // This means: bitpattern 10XX XXXX and the extracted value is sane (ish)
+ bool valid = true;
+ uint32_t value = headerValue(c);
+ for (std::size_t n = 1; n < encBytes; ++n) {
+ uchar nc = m_str[idx + n];
+ valid &= ((nc & 0xC0) == 0x80);
+ value = (value << 6) | (nc & 0x3F);
+ }
- StartsWithMatcher::StartsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "starts with", comparator ) {}
+ if (
+ // Wrong bit pattern of following bytes
+ (!valid) ||
+ // Overlong encodings
+ (value < 0x80) ||
+ (0x80 <= value && value < 0x800 && encBytes > 2) ||
+ (0x800 < value && value < 0x10000 && encBytes > 3) ||
+ // Encoded value out of range
+ (value >= 0x110000)
+ ) {
+ hexEscapeChar(os, c);
+ break;
+ }
- bool StartsWithMatcher::match( std::string const& source ) const {
- return startsWith( m_comparator.adjustString( source ), m_comparator.m_str );
+ // If we got here, this is in fact a valid(ish) utf-8 sequence
+ for (std::size_t n = 0; n < encBytes; ++n) {
+ os << m_str[idx + n];
+ }
+ idx += encBytes - 1;
+ break;
+ }
}
+ }
- EndsWithMatcher::EndsWithMatcher( CasedString const& comparator ) : StringMatcherBase( "ends with", comparator ) {}
-
- bool EndsWithMatcher::match( std::string const& source ) const {
- return endsWith( m_comparator.adjustString( source ), m_comparator.m_str );
- }
+ std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) {
+ xmlEncode.encodeTo( os );
+ return os;
+ }
- } // namespace StdString
+ XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer )
+ : m_writer( writer )
+ {}
- StdString::EqualsMatcher Equals( std::string const& str, CaseSensitive::Choice caseSensitivity ) {
- return StdString::EqualsMatcher( StdString::CasedString( str, caseSensitivity) );
- }
- StdString::ContainsMatcher Contains( std::string const& str, CaseSensitive::Choice caseSensitivity ) {
- return StdString::ContainsMatcher( StdString::CasedString( str, caseSensitivity) );
- }
- StdString::EndsWithMatcher EndsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) {
- return StdString::EndsWithMatcher( StdString::CasedString( str, caseSensitivity) );
+ XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) noexcept
+ : m_writer( other.m_writer ){
+ other.m_writer = nullptr;
}
- StdString::StartsWithMatcher StartsWith( std::string const& str, CaseSensitive::Choice caseSensitivity ) {
- return StdString::StartsWithMatcher( StdString::CasedString( str, caseSensitivity) );
+ XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) noexcept {
+ if ( m_writer ) {
+ m_writer->endElement();
+ }
+ m_writer = other.m_writer;
+ other.m_writer = nullptr;
+ return *this;
}
-} // namespace Matchers
-} // namespace Catch
-// #included from: ../reporters/catch_reporter_multi.hpp
-#define TWOBLUECUBES_CATCH_REPORTER_MULTI_HPP_INCLUDED
-
-namespace Catch {
-
-class MultipleReporters : public SharedImpl<IStreamingReporter> {
- typedef std::vector<Ptr<IStreamingReporter> > Reporters;
- Reporters m_reporters;
-
-public:
- void add( Ptr<IStreamingReporter> const& reporter ) {
- m_reporters.push_back( reporter );
+ XmlWriter::ScopedElement::~ScopedElement() {
+ if( m_writer )
+ m_writer->endElement();
}
-public: // IStreamingReporter
-
- virtual ReporterPreferences getPreferences() const CATCH_OVERRIDE {
- return m_reporters[0]->getPreferences();
+ XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) {
+ m_writer->writeText( text, indent );
+ return *this;
}
- virtual void noMatchingTestCases( std::string const& spec ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->noMatchingTestCases( spec );
+ XmlWriter::XmlWriter( std::ostream& os ) : m_os( os )
+ {
+ writeDeclaration();
}
- virtual void testRunStarting( TestRunInfo const& testRunInfo ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->testRunStarting( testRunInfo );
+ XmlWriter::~XmlWriter() {
+ while( !m_tags.empty() )
+ endElement();
}
- virtual void testGroupStarting( GroupInfo const& groupInfo ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->testGroupStarting( groupInfo );
+ XmlWriter& XmlWriter::startElement( std::string const& name ) {
+ ensureTagClosed();
+ newlineIfNecessary();
+ m_os << m_indent << '<' << name;
+ m_tags.push_back( name );
+ m_indent += " ";
+ m_tagIsOpen = true;
+ return *this;
}
- virtual void testCaseStarting( TestCaseInfo const& testInfo ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->testCaseStarting( testInfo );
+ XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) {
+ ScopedElement scoped( this );
+ startElement( name );
+ return scoped;
}
- virtual void sectionStarting( SectionInfo const& sectionInfo ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->sectionStarting( sectionInfo );
+ XmlWriter& XmlWriter::endElement() {
+ newlineIfNecessary();
+ m_indent = m_indent.substr( 0, m_indent.size()-2 );
+ if( m_tagIsOpen ) {
+ m_os << "/>";
+ m_tagIsOpen = false;
+ }
+ else {
+ m_os << m_indent << "</" << m_tags.back() << ">";
+ }
+ m_os << std::endl;
+ m_tags.pop_back();
+ return *this;
}
- virtual void assertionStarting( AssertionInfo const& assertionInfo ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->assertionStarting( assertionInfo );
+ XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) {
+ if( !name.empty() && !attribute.empty() )
+ m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"';
+ return *this;
}
- // The return value indicates if the messages buffer should be cleared:
- virtual bool assertionEnded( AssertionStats const& assertionStats ) CATCH_OVERRIDE {
- bool clearBuffer = false;
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- clearBuffer |= (*it)->assertionEnded( assertionStats );
- return clearBuffer;
+ XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) {
+ m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"';
+ return *this;
}
- virtual void sectionEnded( SectionStats const& sectionStats ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->sectionEnded( sectionStats );
+ XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) {
+ if( !text.empty() ){
+ bool tagWasOpen = m_tagIsOpen;
+ ensureTagClosed();
+ if( tagWasOpen && indent )
+ m_os << m_indent;
+ m_os << XmlEncode( text );
+ m_needsNewline = true;
+ }
+ return *this;
}
- virtual void testCaseEnded( TestCaseStats const& testCaseStats ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->testCaseEnded( testCaseStats );
+ XmlWriter& XmlWriter::writeComment( std::string const& text ) {
+ ensureTagClosed();
+ m_os << m_indent << "<!--" << text << "-->";
+ m_needsNewline = true;
+ return *this;
}
- virtual void testGroupEnded( TestGroupStats const& testGroupStats ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->testGroupEnded( testGroupStats );
+ void XmlWriter::writeStylesheetRef( std::string const& url ) {
+ m_os << "<?xml-stylesheet type=\"text/xsl\" href=\"" << url << "\"?>\n";
}
- virtual void testRunEnded( TestRunStats const& testRunStats ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->testRunEnded( testRunStats );
+ XmlWriter& XmlWriter::writeBlankLine() {
+ ensureTagClosed();
+ m_os << '\n';
+ return *this;
}
- virtual void skipTest( TestCaseInfo const& testInfo ) CATCH_OVERRIDE {
- for( Reporters::const_iterator it = m_reporters.begin(), itEnd = m_reporters.end();
- it != itEnd;
- ++it )
- (*it)->skipTest( testInfo );
+ void XmlWriter::ensureTagClosed() {
+ if( m_tagIsOpen ) {
+ m_os << ">" << std::endl;
+ m_tagIsOpen = false;
+ }
}
- virtual MultipleReporters* tryAsMulti() CATCH_OVERRIDE {
- return this;
+ void XmlWriter::writeDeclaration() {
+ m_os << "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n";
}
-};
-
-Ptr<IStreamingReporter> addReporter( Ptr<IStreamingReporter> const& existingReporter, Ptr<IStreamingReporter> const& additionalReporter ) {
- Ptr<IStreamingReporter> resultingReporter;
-
- if( existingReporter ) {
- MultipleReporters* multi = existingReporter->tryAsMulti();
- if( !multi ) {
- multi = new MultipleReporters;
- resultingReporter = Ptr<IStreamingReporter>( multi );
- if( existingReporter )
- multi->add( existingReporter );
+ void XmlWriter::newlineIfNecessary() {
+ if( m_needsNewline ) {
+ m_os << std::endl;
+ m_needsNewline = false;
}
- else
- resultingReporter = existingReporter;
- multi->add( additionalReporter );
}
- else
- resultingReporter = additionalReporter;
-
- return resultingReporter;
}
-
-} // end namespace Catch
-
-// #included from: ../reporters/catch_reporter_xml.hpp
-#define TWOBLUECUBES_CATCH_REPORTER_XML_HPP_INCLUDED
-
-// #included from: catch_reporter_bases.hpp
-#define TWOBLUECUBES_CATCH_REPORTER_BASES_HPP_INCLUDED
+// end catch_xmlwriter.cpp
+// start catch_reporter_bases.cpp
#include <cstring>
#include <cfloat>
#include <cstdio>
-#include <assert.h>
+#include <cassert>
+#include <memory>
namespace Catch {
+ void prepareExpandedExpression(AssertionResult& result) {
+ result.getExpandedExpression();
+ }
- namespace {
- // Because formatting using c++ streams is stateful, drop down to C is required
- // Alternatively we could use stringstream, but its performance is... not good.
- std::string getFormattedDuration( double duration ) {
- // Max exponent + 1 is required to represent the whole part
- // + 1 for decimal point
- // + 3 for the 3 decimal places
- // + 1 for null terminator
- const size_t maxDoubleSize = DBL_MAX_10_EXP + 1 + 1 + 3 + 1;
- char buffer[maxDoubleSize];
-
- // Save previous errno, to prevent sprintf from overwriting it
- ErrnoGuard guard;
+ // Because formatting using c++ streams is stateful, drop down to C is required
+ // Alternatively we could use stringstream, but its performance is... not good.
+ std::string getFormattedDuration( double duration ) {
+ // Max exponent + 1 is required to represent the whole part
+ // + 1 for decimal point
+ // + 3 for the 3 decimal places
+ // + 1 for null terminator
+ const std::size_t maxDoubleSize = DBL_MAX_10_EXP + 1 + 1 + 3 + 1;
+ char buffer[maxDoubleSize];
+
+ // Save previous errno, to prevent sprintf from overwriting it
+ ErrnoGuard guard;
#ifdef _MSC_VER
- sprintf_s(buffer, "%.3f", duration);
+ sprintf_s(buffer, "%.3f", duration);
#else
- sprintf(buffer, "%.3f", duration);
+ std::sprintf(buffer, "%.3f", duration);
#endif
- return std::string(buffer);
- }
+ return std::string(buffer);
}
- struct StreamingReporterBase : SharedImpl<IStreamingReporter> {
-
- StreamingReporterBase( ReporterConfig const& _config )
- : m_config( _config.fullConfig() ),
- stream( _config.stream() )
+ std::string serializeFilters( std::vector<std::string> const& container ) {
+ ReusableStringStream oss;
+ bool first = true;
+ for (auto&& filter : container)
{
- m_reporterPrefs.shouldRedirectStdOut = false;
- }
+ if (!first)
+ oss << ' ';
+ else
+ first = false;
- virtual ReporterPreferences getPreferences() const CATCH_OVERRIDE {
- return m_reporterPrefs;
+ oss << filter;
}
+ return oss.str();
+ }
- virtual ~StreamingReporterBase() CATCH_OVERRIDE;
+ TestEventListenerBase::TestEventListenerBase(ReporterConfig const & _config)
+ :StreamingReporterBase(_config) {}
- virtual void noMatchingTestCases( std::string const& ) CATCH_OVERRIDE {}
+ std::set<Verbosity> TestEventListenerBase::getSupportedVerbosities() {
+ return { Verbosity::Quiet, Verbosity::Normal, Verbosity::High };
+ }
- virtual void testRunStarting( TestRunInfo const& _testRunInfo ) CATCH_OVERRIDE {
- currentTestRunInfo = _testRunInfo;
- }
- virtual void testGroupStarting( GroupInfo const& _groupInfo ) CATCH_OVERRIDE {
- currentGroupInfo = _groupInfo;
- }
+ void TestEventListenerBase::assertionStarting(AssertionInfo const &) {}
- virtual void testCaseStarting( TestCaseInfo const& _testInfo ) CATCH_OVERRIDE {
- currentTestCaseInfo = _testInfo;
- }
- virtual void sectionStarting( SectionInfo const& _sectionInfo ) CATCH_OVERRIDE {
- m_sectionStack.push_back( _sectionInfo );
- }
+ bool TestEventListenerBase::assertionEnded(AssertionStats const &) {
+ return false;
+ }
- virtual void sectionEnded( SectionStats const& /* _sectionStats */ ) CATCH_OVERRIDE {
- m_sectionStack.pop_back();
- }
- virtual void testCaseEnded( TestCaseStats const& /* _testCaseStats */ ) CATCH_OVERRIDE {
- currentTestCaseInfo.reset();
- }
- virtual void testGroupEnded( TestGroupStats const& /* _testGroupStats */ ) CATCH_OVERRIDE {
- currentGroupInfo.reset();
- }
- virtual void testRunEnded( TestRunStats const& /* _testRunStats */ ) CATCH_OVERRIDE {
- currentTestCaseInfo.reset();
- currentGroupInfo.reset();
- currentTestRunInfo.reset();
- }
+} // end namespace Catch
+// end catch_reporter_bases.cpp
+// start catch_reporter_compact.cpp
- virtual void skipTest( TestCaseInfo const& ) CATCH_OVERRIDE {
- // Don't do anything with this by default.
- // It can optionally be overridden in the derived class.
- }
+namespace {
- Ptr<IConfig const> m_config;
- std::ostream& stream;
+#ifdef CATCH_PLATFORM_MAC
+ const char* failedString() { return "FAILED"; }
+ const char* passedString() { return "PASSED"; }
+#else
+ const char* failedString() { return "failed"; }
+ const char* passedString() { return "passed"; }
+#endif
- LazyStat<TestRunInfo> currentTestRunInfo;
- LazyStat<GroupInfo> currentGroupInfo;
- LazyStat<TestCaseInfo> currentTestCaseInfo;
+ // Colour::LightGrey
+ Catch::Colour::Code dimColour() { return Catch::Colour::FileName; }
- std::vector<SectionInfo> m_sectionStack;
- ReporterPreferences m_reporterPrefs;
- };
+ std::string bothOrAll( std::size_t count ) {
+ return count == 1 ? std::string() :
+ count == 2 ? "both " : "all " ;
+ }
- struct CumulativeReporterBase : SharedImpl<IStreamingReporter> {
- template<typename T, typename ChildNodeT>
- struct Node : SharedImpl<> {
- explicit Node( T const& _value ) : value( _value ) {}
- virtual ~Node() {}
+} // anon namespace
- typedef std::vector<Ptr<ChildNodeT> > ChildNodes;
- T value;
- ChildNodes children;
- };
- struct SectionNode : SharedImpl<> {
- explicit SectionNode( SectionStats const& _stats ) : stats( _stats ) {}
- virtual ~SectionNode();
+namespace Catch {
+namespace {
+// Colour, message variants:
+// - white: No tests ran.
+// - red: Failed [both/all] N test cases, failed [both/all] M assertions.
+// - white: Passed [both/all] N test cases (no assertions).
+// - red: Failed N tests cases, failed M assertions.
+// - green: Passed [both/all] N tests cases with M assertions.
+void printTotals(std::ostream& out, const Totals& totals) {
+ if (totals.testCases.total() == 0) {
+ out << "No tests ran.";
+ } else if (totals.testCases.failed == totals.testCases.total()) {
+ Colour colour(Colour::ResultError);
+ const std::string qualify_assertions_failed =
+ totals.assertions.failed == totals.assertions.total() ?
+ bothOrAll(totals.assertions.failed) : std::string();
+ out <<
+ "Failed " << bothOrAll(totals.testCases.failed)
+ << pluralise(totals.testCases.failed, "test case") << ", "
+ "failed " << qualify_assertions_failed <<
+ pluralise(totals.assertions.failed, "assertion") << '.';
+ } else if (totals.assertions.total() == 0) {
+ out <<
+ "Passed " << bothOrAll(totals.testCases.total())
+ << pluralise(totals.testCases.total(), "test case")
+ << " (no assertions).";
+ } else if (totals.assertions.failed) {
+ Colour colour(Colour::ResultError);
+ out <<
+ "Failed " << pluralise(totals.testCases.failed, "test case") << ", "
+ "failed " << pluralise(totals.assertions.failed, "assertion") << '.';
+ } else {
+ Colour colour(Colour::ResultSuccess);
+ out <<
+ "Passed " << bothOrAll(totals.testCases.passed)
+ << pluralise(totals.testCases.passed, "test case") <<
+ " with " << pluralise(totals.assertions.passed, "assertion") << '.';
+ }
+}
- bool operator == ( SectionNode const& other ) const {
- return stats.sectionInfo.lineInfo == other.stats.sectionInfo.lineInfo;
- }
- bool operator == ( Ptr<SectionNode> const& other ) const {
- return operator==( *other );
- }
+// Implementation of CompactReporter formatting
+class AssertionPrinter {
+public:
+ AssertionPrinter& operator= (AssertionPrinter const&) = delete;
+ AssertionPrinter(AssertionPrinter const&) = delete;
+ AssertionPrinter(std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages)
+ : stream(_stream)
+ , result(_stats.assertionResult)
+ , messages(_stats.infoMessages)
+ , itMessage(_stats.infoMessages.begin())
+ , printInfoMessages(_printInfoMessages) {}
+
+ void print() {
+ printSourceInfo();
+
+ itMessage = messages.begin();
+
+ switch (result.getResultType()) {
+ case ResultWas::Ok:
+ printResultType(Colour::ResultSuccess, passedString());
+ printOriginalExpression();
+ printReconstructedExpression();
+ if (!result.hasExpression())
+ printRemainingMessages(Colour::None);
+ else
+ printRemainingMessages();
+ break;
+ case ResultWas::ExpressionFailed:
+ if (result.isOk())
+ printResultType(Colour::ResultSuccess, failedString() + std::string(" - but was ok"));
+ else
+ printResultType(Colour::Error, failedString());
+ printOriginalExpression();
+ printReconstructedExpression();
+ printRemainingMessages();
+ break;
+ case ResultWas::ThrewException:
+ printResultType(Colour::Error, failedString());
+ printIssue("unexpected exception with message:");
+ printMessage();
+ printExpressionWas();
+ printRemainingMessages();
+ break;
+ case ResultWas::FatalErrorCondition:
+ printResultType(Colour::Error, failedString());
+ printIssue("fatal error condition with message:");
+ printMessage();
+ printExpressionWas();
+ printRemainingMessages();
+ break;
+ case ResultWas::DidntThrowException:
+ printResultType(Colour::Error, failedString());
+ printIssue("expected exception, got none");
+ printExpressionWas();
+ printRemainingMessages();
+ break;
+ case ResultWas::Info:
+ printResultType(Colour::None, "info");
+ printMessage();
+ printRemainingMessages();
+ break;
+ case ResultWas::Warning:
+ printResultType(Colour::None, "warning");
+ printMessage();
+ printRemainingMessages();
+ break;
+ case ResultWas::ExplicitFailure:
+ printResultType(Colour::Error, failedString());
+ printIssue("explicitly");
+ printRemainingMessages(Colour::None);
+ break;
+ // These cases are here to prevent compiler warnings
+ case ResultWas::Unknown:
+ case ResultWas::FailureBit:
+ case ResultWas::Exception:
+ printResultType(Colour::Error, "** internal error **");
+ break;
+ }
+ }
- SectionStats stats;
- typedef std::vector<Ptr<SectionNode> > ChildSections;
- typedef std::vector<AssertionStats> Assertions;
- ChildSections childSections;
- Assertions assertions;
- std::string stdOut;
- std::string stdErr;
- };
+private:
+ void printSourceInfo() const {
+ Colour colourGuard(Colour::FileName);
+ stream << result.getSourceInfo() << ':';
+ }
- struct BySectionInfo {
- BySectionInfo( SectionInfo const& other ) : m_other( other ) {}
- BySectionInfo( BySectionInfo const& other ) : m_other( other.m_other ) {}
- bool operator() ( Ptr<SectionNode> const& node ) const {
- return node->stats.sectionInfo.lineInfo == m_other.lineInfo;
+ void printResultType(Colour::Code colour, std::string const& passOrFail) const {
+ if (!passOrFail.empty()) {
+ {
+ Colour colourGuard(colour);
+ stream << ' ' << passOrFail;
}
- private:
- void operator=( BySectionInfo const& );
- SectionInfo const& m_other;
- };
-
- typedef Node<TestCaseStats, SectionNode> TestCaseNode;
- typedef Node<TestGroupStats, TestCaseNode> TestGroupNode;
- typedef Node<TestRunStats, TestGroupNode> TestRunNode;
-
- CumulativeReporterBase( ReporterConfig const& _config )
- : m_config( _config.fullConfig() ),
- stream( _config.stream() )
- {
- m_reporterPrefs.shouldRedirectStdOut = false;
- }
- ~CumulativeReporterBase();
-
- virtual ReporterPreferences getPreferences() const CATCH_OVERRIDE {
- return m_reporterPrefs;
+ stream << ':';
}
+ }
- virtual void testRunStarting( TestRunInfo const& ) CATCH_OVERRIDE {}
- virtual void testGroupStarting( GroupInfo const& ) CATCH_OVERRIDE {}
-
- virtual void testCaseStarting( TestCaseInfo const& ) CATCH_OVERRIDE {}
+ void printIssue(std::string const& issue) const {
+ stream << ' ' << issue;
+ }
- virtual void sectionStarting( SectionInfo const& sectionInfo ) CATCH_OVERRIDE {
- SectionStats incompleteStats( sectionInfo, Counts(), 0, false );
- Ptr<SectionNode> node;
- if( m_sectionStack.empty() ) {
- if( !m_rootSection )
- m_rootSection = new SectionNode( incompleteStats );
- node = m_rootSection;
- }
- else {
- SectionNode& parentNode = *m_sectionStack.back();
- SectionNode::ChildSections::const_iterator it =
- std::find_if( parentNode.childSections.begin(),
- parentNode.childSections.end(),
- BySectionInfo( sectionInfo ) );
- if( it == parentNode.childSections.end() ) {
- node = new SectionNode( incompleteStats );
- parentNode.childSections.push_back( node );
- }
- else
- node = *it;
+ void printExpressionWas() {
+ if (result.hasExpression()) {
+ stream << ';';
+ {
+ Colour colour(dimColour());
+ stream << " expression was:";
}
- m_sectionStack.push_back( node );
- m_deepestSection = node;
+ printOriginalExpression();
}
+ }
- virtual void assertionStarting( AssertionInfo const& ) CATCH_OVERRIDE {}
-
- virtual bool assertionEnded( AssertionStats const& assertionStats ) CATCH_OVERRIDE {
- assert( !m_sectionStack.empty() );
- SectionNode& sectionNode = *m_sectionStack.back();
- sectionNode.assertions.push_back( assertionStats );
- // AssertionResult holds a pointer to a temporary DecomposedExpression,
- // which getExpandedExpression() calls to build the expression string.
- // Our section stack copy of the assertionResult will likely outlive the
- // temporary, so it must be expanded or discarded now to avoid calling
- // a destroyed object later.
- prepareExpandedExpression( sectionNode.assertions.back().assertionResult );
- return true;
+ void printOriginalExpression() const {
+ if (result.hasExpression()) {
+ stream << ' ' << result.getExpression();
}
- virtual void sectionEnded( SectionStats const& sectionStats ) CATCH_OVERRIDE {
- assert( !m_sectionStack.empty() );
- SectionNode& node = *m_sectionStack.back();
- node.stats = sectionStats;
- m_sectionStack.pop_back();
- }
- virtual void testCaseEnded( TestCaseStats const& testCaseStats ) CATCH_OVERRIDE {
- Ptr<TestCaseNode> node = new TestCaseNode( testCaseStats );
- assert( m_sectionStack.size() == 0 );
- node->children.push_back( m_rootSection );
- m_testCases.push_back( node );
- m_rootSection.reset();
+ }
- assert( m_deepestSection );
- m_deepestSection->stdOut = testCaseStats.stdOut;
- m_deepestSection->stdErr = testCaseStats.stdErr;
- }
- virtual void testGroupEnded( TestGroupStats const& testGroupStats ) CATCH_OVERRIDE {
- Ptr<TestGroupNode> node = new TestGroupNode( testGroupStats );
- node->children.swap( m_testCases );
- m_testGroups.push_back( node );
- }
- virtual void testRunEnded( TestRunStats const& testRunStats ) CATCH_OVERRIDE {
- Ptr<TestRunNode> node = new TestRunNode( testRunStats );
- node->children.swap( m_testGroups );
- m_testRuns.push_back( node );
- testRunEndedCumulative();
+ void printReconstructedExpression() const {
+ if (result.hasExpandedExpression()) {
+ {
+ Colour colour(dimColour());
+ stream << " for: ";
+ }
+ stream << result.getExpandedExpression();
}
- virtual void testRunEndedCumulative() = 0;
-
- virtual void skipTest( TestCaseInfo const& ) CATCH_OVERRIDE {}
+ }
- virtual void prepareExpandedExpression( AssertionResult& result ) const {
- if( result.isOk() )
- result.discardDecomposedExpression();
- else
- result.expandDecomposedExpression();
+ void printMessage() {
+ if (itMessage != messages.end()) {
+ stream << " '" << itMessage->message << '\'';
+ ++itMessage;
}
+ }
- Ptr<IConfig const> m_config;
- std::ostream& stream;
- std::vector<AssertionStats> m_assertions;
- std::vector<std::vector<Ptr<SectionNode> > > m_sections;
- std::vector<Ptr<TestCaseNode> > m_testCases;
- std::vector<Ptr<TestGroupNode> > m_testGroups;
-
- std::vector<Ptr<TestRunNode> > m_testRuns;
-
- Ptr<SectionNode> m_rootSection;
- Ptr<SectionNode> m_deepestSection;
- std::vector<Ptr<SectionNode> > m_sectionStack;
- ReporterPreferences m_reporterPrefs;
+ void printRemainingMessages(Colour::Code colour = dimColour()) {
+ if (itMessage == messages.end())
+ return;
- };
+ const auto itEnd = messages.cend();
+ const auto N = static_cast<std::size_t>(std::distance(itMessage, itEnd));
- template<char C>
- char const* getLineOfChars() {
- static char line[CATCH_CONFIG_CONSOLE_WIDTH] = {0};
- if( !*line ) {
- std::memset( line, C, CATCH_CONFIG_CONSOLE_WIDTH-1 );
- line[CATCH_CONFIG_CONSOLE_WIDTH-1] = 0;
+ {
+ Colour colourGuard(colour);
+ stream << " with " << pluralise(N, "message") << ':';
}
- return line;
- }
- struct TestEventListenerBase : StreamingReporterBase {
- TestEventListenerBase( ReporterConfig const& _config )
- : StreamingReporterBase( _config )
- {}
-
- virtual void assertionStarting( AssertionInfo const& ) CATCH_OVERRIDE {}
- virtual bool assertionEnded( AssertionStats const& ) CATCH_OVERRIDE {
- return false;
+ while (itMessage != itEnd) {
+ // If this assertion is a warning ignore any INFO messages
+ if (printInfoMessages || itMessage->type != ResultWas::Info) {
+ printMessage();
+ if (itMessage != itEnd) {
+ Colour colourGuard(dimColour());
+ stream << " and";
+ }
+ continue;
+ }
+ ++itMessage;
}
- };
-
-} // end namespace Catch
-
-// #included from: ../internal/catch_reporter_registrars.hpp
-#define TWOBLUECUBES_CATCH_REPORTER_REGISTRARS_HPP_INCLUDED
-
-namespace Catch {
+ }
- template<typename T>
- class LegacyReporterRegistrar {
+private:
+ std::ostream& stream;
+ AssertionResult const& result;
+ std::vector<MessageInfo> messages;
+ std::vector<MessageInfo>::const_iterator itMessage;
+ bool printInfoMessages;
+};
- class ReporterFactory : public IReporterFactory {
- virtual IStreamingReporter* create( ReporterConfig const& config ) const {
- return new LegacyReporterAdapter( new T( config ) );
- }
+} // anon namespace
- virtual std::string getDescription() const {
- return T::getDescription();
- }
- };
+ std::string CompactReporter::getDescription() {
+ return "Reports test results on a single line, suitable for IDEs";
+ }
- public:
+ ReporterPreferences CompactReporter::getPreferences() const {
+ return m_reporterPrefs;
+ }
- LegacyReporterRegistrar( std::string const& name ) {
- getMutableRegistryHub().registerReporter( name, new ReporterFactory() );
+ void CompactReporter::noMatchingTestCases( std::string const& spec ) {
+ stream << "No test cases matched '" << spec << '\'' << std::endl;
}
- };
- template<typename T>
- class ReporterRegistrar {
+ void CompactReporter::assertionStarting( AssertionInfo const& ) {}
- class ReporterFactory : public SharedImpl<IReporterFactory> {
+ bool CompactReporter::assertionEnded( AssertionStats const& _assertionStats ) {
+ AssertionResult const& result = _assertionStats.assertionResult;
- // *** Please Note ***:
- // - If you end up here looking at a compiler error because it's trying to register
- // your custom reporter class be aware that the native reporter interface has changed
- // to IStreamingReporter. The "legacy" interface, IReporter, is still supported via
- // an adapter. Just use REGISTER_LEGACY_REPORTER to take advantage of the adapter.
- // However please consider updating to the new interface as the old one is now
- // deprecated and will probably be removed quite soon!
- // Please contact me via github if you have any questions at all about this.
- // In fact, ideally, please contact me anyway to let me know you've hit this - as I have
- // no idea who is actually using custom reporters at all (possibly no-one!).
- // The new interface is designed to minimise exposure to interface changes in the future.
- virtual IStreamingReporter* create( ReporterConfig const& config ) const {
- return new T( config );
- }
+ bool printInfoMessages = true;
- virtual std::string getDescription() const {
- return T::getDescription();
+ // Drop out if result was successful and we're not printing those
+ if( !m_config->includeSuccessfulResults() && result.isOk() ) {
+ if( result.getResultType() != ResultWas::Warning )
+ return false;
+ printInfoMessages = false;
}
- };
- public:
+ AssertionPrinter printer( stream, _assertionStats, printInfoMessages );
+ printer.print();
- ReporterRegistrar( std::string const& name ) {
- getMutableRegistryHub().registerReporter( name, new ReporterFactory() );
+ stream << std::endl;
+ return true;
}
- };
- template<typename T>
- class ListenerRegistrar {
-
- class ListenerFactory : public SharedImpl<IReporterFactory> {
-
- virtual IStreamingReporter* create( ReporterConfig const& config ) const {
- return new T( config );
- }
- virtual std::string getDescription() const {
- return std::string();
+ void CompactReporter::sectionEnded(SectionStats const& _sectionStats) {
+ if (m_config->showDurations() == ShowDurations::Always) {
+ stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl;
}
- };
-
- public:
+ }
- ListenerRegistrar() {
- getMutableRegistryHub().registerListener( new ListenerFactory() );
+ void CompactReporter::testRunEnded( TestRunStats const& _testRunStats ) {
+ printTotals( stream, _testRunStats.totals );
+ stream << '\n' << std::endl;
+ StreamingReporterBase::testRunEnded( _testRunStats );
}
- };
-}
-#define INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) \
- namespace{ Catch::LegacyReporterRegistrar<reporterType> catch_internal_RegistrarFor##reporterType( name ); }
+ CompactReporter::~CompactReporter() {}
-#define INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType ) \
- namespace{ Catch::ReporterRegistrar<reporterType> catch_internal_RegistrarFor##reporterType( name ); }
+ CATCH_REGISTER_REPORTER( "compact", CompactReporter )
-// Deprecated - use the form without INTERNAL_
-#define INTERNAL_CATCH_REGISTER_LISTENER( listenerType ) \
- namespace{ Catch::ListenerRegistrar<listenerType> catch_internal_RegistrarFor##listenerType; }
+} // end namespace Catch
+// end catch_reporter_compact.cpp
+// start catch_reporter_console.cpp
-#define CATCH_REGISTER_LISTENER( listenerType ) \
- namespace{ Catch::ListenerRegistrar<listenerType> catch_internal_RegistrarFor##listenerType; }
+#include <cfloat>
+#include <cstdio>
-// #included from: ../internal/catch_xmlwriter.hpp
-#define TWOBLUECUBES_CATCH_XMLWRITER_HPP_INCLUDED
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch
+ // Note that 4062 (not all labels are handled and default is missing) is enabled
+#endif
-#include <sstream>
-#include <string>
-#include <vector>
-#include <iomanip>
+#if defined(__clang__)
+# pragma clang diagnostic push
+// For simplicity, benchmarking-only helpers are always enabled
+# pragma clang diagnostic ignored "-Wunused-function"
+#endif
namespace Catch {
- class XmlEncode {
- public:
- enum ForWhat { ForTextNodes, ForAttributes };
-
- XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes )
- : m_str( str ),
- m_forWhat( forWhat )
- {}
+namespace {
- void encodeTo( std::ostream& os ) const {
-
- // Apostrophe escaping not necessary if we always use " to write attributes
- // (see: http://www.w3.org/TR/xml/#syntax)
-
- for( std::size_t i = 0; i < m_str.size(); ++ i ) {
- char c = m_str[i];
- switch( c ) {
- case '<': os << "&lt;"; break;
- case '&': os << "&amp;"; break;
-
- case '>':
- // See: http://www.w3.org/TR/xml/#syntax
- if( i > 2 && m_str[i-1] == ']' && m_str[i-2] == ']' )
- os << "&gt;";
- else
- os << c;
- break;
-
- case '\"':
- if( m_forWhat == ForAttributes )
- os << "&quot;";
- else
- os << c;
- break;
-
- default:
- // Escape control chars - based on contribution by @espenalb in PR #465 and
- // by @mrpi PR #588
- if ( ( c >= 0 && c < '\x09' ) || ( c > '\x0D' && c < '\x20') || c=='\x7F' ) {
- // see http://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0
- os << "\\x" << std::uppercase << std::hex << std::setfill('0') << std::setw(2)
- << static_cast<int>( c );
- }
- else
- os << c;
- }
+// Formatter impl for ConsoleReporter
+class ConsoleAssertionPrinter {
+public:
+ ConsoleAssertionPrinter& operator= (ConsoleAssertionPrinter const&) = delete;
+ ConsoleAssertionPrinter(ConsoleAssertionPrinter const&) = delete;
+ ConsoleAssertionPrinter(std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages)
+ : stream(_stream),
+ stats(_stats),
+ result(_stats.assertionResult),
+ colour(Colour::None),
+ message(result.getMessage()),
+ messages(_stats.infoMessages),
+ printInfoMessages(_printInfoMessages) {
+ switch (result.getResultType()) {
+ case ResultWas::Ok:
+ colour = Colour::Success;
+ passOrFail = "PASSED";
+ //if( result.hasMessage() )
+ if (_stats.infoMessages.size() == 1)
+ messageLabel = "with message";
+ if (_stats.infoMessages.size() > 1)
+ messageLabel = "with messages";
+ break;
+ case ResultWas::ExpressionFailed:
+ if (result.isOk()) {
+ colour = Colour::Success;
+ passOrFail = "FAILED - but was ok";
+ } else {
+ colour = Colour::Error;
+ passOrFail = "FAILED";
}
+ if (_stats.infoMessages.size() == 1)
+ messageLabel = "with message";
+ if (_stats.infoMessages.size() > 1)
+ messageLabel = "with messages";
+ break;
+ case ResultWas::ThrewException:
+ colour = Colour::Error;
+ passOrFail = "FAILED";
+ messageLabel = "due to unexpected exception with ";
+ if (_stats.infoMessages.size() == 1)
+ messageLabel += "message";
+ if (_stats.infoMessages.size() > 1)
+ messageLabel += "messages";
+ break;
+ case ResultWas::FatalErrorCondition:
+ colour = Colour::Error;
+ passOrFail = "FAILED";
+ messageLabel = "due to a fatal error condition";
+ break;
+ case ResultWas::DidntThrowException:
+ colour = Colour::Error;
+ passOrFail = "FAILED";
+ messageLabel = "because no exception was thrown where one was expected";
+ break;
+ case ResultWas::Info:
+ messageLabel = "info";
+ break;
+ case ResultWas::Warning:
+ messageLabel = "warning";
+ break;
+ case ResultWas::ExplicitFailure:
+ passOrFail = "FAILED";
+ colour = Colour::Error;
+ if (_stats.infoMessages.size() == 1)
+ messageLabel = "explicitly with message";
+ if (_stats.infoMessages.size() > 1)
+ messageLabel = "explicitly with messages";
+ break;
+ // These cases are here to prevent compiler warnings
+ case ResultWas::Unknown:
+ case ResultWas::FailureBit:
+ case ResultWas::Exception:
+ passOrFail = "** internal error **";
+ colour = Colour::Error;
+ break;
+ }
+ }
+
+ void print() const {
+ printSourceInfo();
+ if (stats.totals.assertions.total() > 0) {
+ printResultType();
+ printOriginalExpression();
+ printReconstructedExpression();
+ } else {
+ stream << '\n';
}
+ printMessage();
+ }
- friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) {
- xmlEncode.encodeTo( os );
- return os;
+private:
+ void printResultType() const {
+ if (!passOrFail.empty()) {
+ Colour colourGuard(colour);
+ stream << passOrFail << ":\n";
}
+ }
+ void printOriginalExpression() const {
+ if (result.hasExpression()) {
+ Colour colourGuard(Colour::OriginalExpression);
+ stream << " ";
+ stream << result.getExpressionInMacro();
+ stream << '\n';
+ }
+ }
+ void printReconstructedExpression() const {
+ if (result.hasExpandedExpression()) {
+ stream << "with expansion:\n";
+ Colour colourGuard(Colour::ReconstructedExpression);
+ stream << Column(result.getExpandedExpression()).indent(2) << '\n';
+ }
+ }
+ void printMessage() const {
+ if (!messageLabel.empty())
+ stream << messageLabel << ':' << '\n';
+ for (auto const& msg : messages) {
+ // If this assertion is a warning ignore any INFO messages
+ if (printInfoMessages || msg.type != ResultWas::Info)
+ stream << Column(msg.message).indent(2) << '\n';
+ }
+ }
+ void printSourceInfo() const {
+ Colour colourGuard(Colour::FileName);
+ stream << result.getSourceInfo() << ": ";
+ }
- private:
- std::string m_str;
- ForWhat m_forWhat;
- };
-
- class XmlWriter {
- public:
-
- class ScopedElement {
- public:
- ScopedElement( XmlWriter* writer )
- : m_writer( writer )
- {}
+ std::ostream& stream;
+ AssertionStats const& stats;
+ AssertionResult const& result;
+ Colour::Code colour;
+ std::string passOrFail;
+ std::string messageLabel;
+ std::string message;
+ std::vector<MessageInfo> messages;
+ bool printInfoMessages;
+};
- ScopedElement( ScopedElement const& other )
- : m_writer( other.m_writer ){
- other.m_writer = CATCH_NULL;
- }
+std::size_t makeRatio(std::size_t number, std::size_t total) {
+ std::size_t ratio = total > 0 ? CATCH_CONFIG_CONSOLE_WIDTH * number / total : 0;
+ return (ratio == 0 && number > 0) ? 1 : ratio;
+}
- ~ScopedElement() {
- if( m_writer )
- m_writer->endElement();
- }
+std::size_t& findMax(std::size_t& i, std::size_t& j, std::size_t& k) {
+ if (i > j && i > k)
+ return i;
+ else if (j > k)
+ return j;
+ else
+ return k;
+}
- ScopedElement& writeText( std::string const& text, bool indent = true ) {
- m_writer->writeText( text, indent );
- return *this;
- }
+struct ColumnInfo {
+ enum Justification { Left, Right };
+ std::string name;
+ int width;
+ Justification justification;
+};
+struct ColumnBreak {};
+struct RowBreak {};
- template<typename T>
- ScopedElement& writeAttribute( std::string const& name, T const& attribute ) {
- m_writer->writeAttribute( name, attribute );
- return *this;
- }
+class Duration {
+ enum class Unit {
+ Auto,
+ Nanoseconds,
+ Microseconds,
+ Milliseconds,
+ Seconds,
+ Minutes
+ };
+ static const uint64_t s_nanosecondsInAMicrosecond = 1000;
+ static const uint64_t s_nanosecondsInAMillisecond = 1000 * s_nanosecondsInAMicrosecond;
+ static const uint64_t s_nanosecondsInASecond = 1000 * s_nanosecondsInAMillisecond;
+ static const uint64_t s_nanosecondsInAMinute = 60 * s_nanosecondsInASecond;
- private:
- mutable XmlWriter* m_writer;
- };
+ uint64_t m_inNanoseconds;
+ Unit m_units;
- XmlWriter()
- : m_tagIsOpen( false ),
- m_needsNewline( false ),
- m_os( Catch::cout() )
- {
- writeDeclaration();
+public:
+ explicit Duration(double inNanoseconds, Unit units = Unit::Auto)
+ : Duration(static_cast<uint64_t>(inNanoseconds), units) {
+ }
+
+ explicit Duration(uint64_t inNanoseconds, Unit units = Unit::Auto)
+ : m_inNanoseconds(inNanoseconds),
+ m_units(units) {
+ if (m_units == Unit::Auto) {
+ if (m_inNanoseconds < s_nanosecondsInAMicrosecond)
+ m_units = Unit::Nanoseconds;
+ else if (m_inNanoseconds < s_nanosecondsInAMillisecond)
+ m_units = Unit::Microseconds;
+ else if (m_inNanoseconds < s_nanosecondsInASecond)
+ m_units = Unit::Milliseconds;
+ else if (m_inNanoseconds < s_nanosecondsInAMinute)
+ m_units = Unit::Seconds;
+ else
+ m_units = Unit::Minutes;
}
- XmlWriter( std::ostream& os )
- : m_tagIsOpen( false ),
- m_needsNewline( false ),
- m_os( os )
- {
- writeDeclaration();
- }
+ }
- ~XmlWriter() {
- while( !m_tags.empty() )
- endElement();
+ auto value() const -> double {
+ switch (m_units) {
+ case Unit::Microseconds:
+ return m_inNanoseconds / static_cast<double>(s_nanosecondsInAMicrosecond);
+ case Unit::Milliseconds:
+ return m_inNanoseconds / static_cast<double>(s_nanosecondsInAMillisecond);
+ case Unit::Seconds:
+ return m_inNanoseconds / static_cast<double>(s_nanosecondsInASecond);
+ case Unit::Minutes:
+ return m_inNanoseconds / static_cast<double>(s_nanosecondsInAMinute);
+ default:
+ return static_cast<double>(m_inNanoseconds);
}
-
- XmlWriter& startElement( std::string const& name ) {
- ensureTagClosed();
- newlineIfNecessary();
- m_os << m_indent << '<' << name;
- m_tags.push_back( name );
- m_indent += " ";
- m_tagIsOpen = true;
- return *this;
+ }
+ auto unitsAsString() const -> std::string {
+ switch (m_units) {
+ case Unit::Nanoseconds:
+ return "ns";
+ case Unit::Microseconds:
+ return "us";
+ case Unit::Milliseconds:
+ return "ms";
+ case Unit::Seconds:
+ return "s";
+ case Unit::Minutes:
+ return "m";
+ default:
+ return "** internal error **";
}
- ScopedElement scopedElement( std::string const& name ) {
- ScopedElement scoped( this );
- startElement( name );
- return scoped;
- }
+ }
+ friend auto operator << (std::ostream& os, Duration const& duration) -> std::ostream& {
+ return os << duration.value() << ' ' << duration.unitsAsString();
+ }
+};
+} // end anon namespace
- XmlWriter& endElement() {
- newlineIfNecessary();
- m_indent = m_indent.substr( 0, m_indent.size()-2 );
- if( m_tagIsOpen ) {
- m_os << "/>";
- m_tagIsOpen = false;
- }
- else {
- m_os << m_indent << "</" << m_tags.back() << ">";
- }
- m_os << std::endl;
- m_tags.pop_back();
- return *this;
- }
+class TablePrinter {
+ std::ostream& m_os;
+ std::vector<ColumnInfo> m_columnInfos;
+ std::ostringstream m_oss;
+ int m_currentColumn = -1;
+ bool m_isOpen = false;
- XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ) {
- if( !name.empty() && !attribute.empty() )
- m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"';
- return *this;
- }
+public:
+ TablePrinter( std::ostream& os, std::vector<ColumnInfo> columnInfos )
+ : m_os( os ),
+ m_columnInfos( std::move( columnInfos ) ) {}
- XmlWriter& writeAttribute( std::string const& name, bool attribute ) {
- m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"';
- return *this;
- }
+ auto columnInfos() const -> std::vector<ColumnInfo> const& {
+ return m_columnInfos;
+ }
- template<typename T>
- XmlWriter& writeAttribute( std::string const& name, T const& attribute ) {
- std::ostringstream oss;
- oss << attribute;
- return writeAttribute( name, oss.str() );
- }
+ void open() {
+ if (!m_isOpen) {
+ m_isOpen = true;
+ *this << RowBreak();
- XmlWriter& writeText( std::string const& text, bool indent = true ) {
- if( !text.empty() ){
- bool tagWasOpen = m_tagIsOpen;
- ensureTagClosed();
- if( tagWasOpen && indent )
- m_os << m_indent;
- m_os << XmlEncode( text );
- m_needsNewline = true;
- }
- return *this;
- }
+ Columns headerCols;
+ Spacer spacer(2);
+ for (auto const& info : m_columnInfos) {
+ headerCols += Column(info.name).width(static_cast<std::size_t>(info.width - 2));
+ headerCols += spacer;
+ }
+ m_os << headerCols << '\n';
- XmlWriter& writeComment( std::string const& text ) {
- ensureTagClosed();
- m_os << m_indent << "<!--" << text << "-->";
- m_needsNewline = true;
- return *this;
+ m_os << Catch::getLineOfChars<'-'>() << '\n';
}
-
- void writeStylesheetRef( std::string const& url ) {
- m_os << "<?xml-stylesheet type=\"text/xsl\" href=\"" << url << "\"?>\n";
+ }
+ void close() {
+ if (m_isOpen) {
+ *this << RowBreak();
+ m_os << std::endl;
+ m_isOpen = false;
}
+ }
- XmlWriter& writeBlankLine() {
- ensureTagClosed();
- m_os << '\n';
- return *this;
- }
+ template<typename T>
+ friend TablePrinter& operator << (TablePrinter& tp, T const& value) {
+ tp.m_oss << value;
+ return tp;
+ }
+
+ friend TablePrinter& operator << (TablePrinter& tp, ColumnBreak) {
+ auto colStr = tp.m_oss.str();
+ const auto strSize = colStr.size();
+ tp.m_oss.str("");
+ tp.open();
+ if (tp.m_currentColumn == static_cast<int>(tp.m_columnInfos.size() - 1)) {
+ tp.m_currentColumn = -1;
+ tp.m_os << '\n';
+ }
+ tp.m_currentColumn++;
+
+ auto colInfo = tp.m_columnInfos[tp.m_currentColumn];
+ auto padding = (strSize + 1 < static_cast<std::size_t>(colInfo.width))
+ ? std::string(colInfo.width - (strSize + 1), ' ')
+ : std::string();
+ if (colInfo.justification == ColumnInfo::Left)
+ tp.m_os << colStr << padding << ' ';
+ else
+ tp.m_os << padding << colStr << ' ';
+ return tp;
+ }
- void ensureTagClosed() {
- if( m_tagIsOpen ) {
- m_os << ">" << std::endl;
- m_tagIsOpen = false;
- }
+ friend TablePrinter& operator << (TablePrinter& tp, RowBreak) {
+ if (tp.m_currentColumn > 0) {
+ tp.m_os << '\n';
+ tp.m_currentColumn = -1;
}
+ return tp;
+ }
+};
- private:
- XmlWriter( XmlWriter const& );
- void operator=( XmlWriter const& );
-
- void writeDeclaration() {
- m_os << "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n";
+ConsoleReporter::ConsoleReporter(ReporterConfig const& config)
+ : StreamingReporterBase(config),
+ m_tablePrinter(new TablePrinter(config.stream(),
+ [&config]() -> std::vector<ColumnInfo> {
+ if (config.fullConfig()->benchmarkNoAnalysis())
+ {
+ return{
+ { "benchmark name", CATCH_CONFIG_CONSOLE_WIDTH - 43, ColumnInfo::Left },
+ { " samples", 14, ColumnInfo::Right },
+ { " iterations", 14, ColumnInfo::Right },
+ { " mean", 14, ColumnInfo::Right }
+ };
}
-
- void newlineIfNecessary() {
- if( m_needsNewline ) {
- m_os << std::endl;
- m_needsNewline = false;
- }
+ else
+ {
+ return{
+ { "benchmark name", CATCH_CONFIG_CONSOLE_WIDTH - 32, ColumnInfo::Left },
+ { "samples mean std dev", 14, ColumnInfo::Right },
+ { "iterations low mean low std dev", 14, ColumnInfo::Right },
+ { "estimated high mean high std dev", 14, ColumnInfo::Right }
+ };
}
+ }())) {}
+ConsoleReporter::~ConsoleReporter() = default;
- bool m_tagIsOpen;
- bool m_needsNewline;
- std::vector<std::string> m_tags;
- std::string m_indent;
- std::ostream& m_os;
- };
-
+std::string ConsoleReporter::getDescription() {
+ return "Reports test results as plain lines of text";
}
-// #included from: catch_reenable_warnings.h
-#define TWOBLUECUBES_CATCH_REENABLE_WARNINGS_H_INCLUDED
-
-#ifdef __clang__
-# ifdef __ICC // icpc defines the __clang__ macro
-# pragma warning(pop)
-# else
-# pragma clang diagnostic pop
-# endif
-#elif defined __GNUC__
-# pragma GCC diagnostic pop
-#endif
-
-
-namespace Catch {
- class XmlReporter : public StreamingReporterBase {
- public:
- XmlReporter( ReporterConfig const& _config )
- : StreamingReporterBase( _config ),
- m_xml(_config.stream()),
- m_sectionDepth( 0 )
- {
- m_reporterPrefs.shouldRedirectStdOut = true;
- }
+void ConsoleReporter::noMatchingTestCases(std::string const& spec) {
+ stream << "No test cases matched '" << spec << '\'' << std::endl;
+}
- virtual ~XmlReporter() CATCH_OVERRIDE;
+void ConsoleReporter::reportInvalidArguments(std::string const&arg){
+ stream << "Invalid Filter: " << arg << std::endl;
+}
- static std::string getDescription() {
- return "Reports test results as an XML document";
- }
+void ConsoleReporter::assertionStarting(AssertionInfo const&) {}
- virtual std::string getStylesheetRef() const {
- return std::string();
- }
+bool ConsoleReporter::assertionEnded(AssertionStats const& _assertionStats) {
+ AssertionResult const& result = _assertionStats.assertionResult;
- void writeSourceInfo( SourceLineInfo const& sourceInfo ) {
- m_xml
- .writeAttribute( "filename", sourceInfo.file )
- .writeAttribute( "line", sourceInfo.line );
- }
+ bool includeResults = m_config->includeSuccessfulResults() || !result.isOk();
- public: // StreamingReporterBase
+ // Drop out if result was successful but we're not printing them.
+ if (!includeResults && result.getResultType() != ResultWas::Warning)
+ return false;
- virtual void noMatchingTestCases( std::string const& s ) CATCH_OVERRIDE {
- StreamingReporterBase::noMatchingTestCases( s );
- }
+ lazyPrint();
- virtual void testRunStarting( TestRunInfo const& testInfo ) CATCH_OVERRIDE {
- StreamingReporterBase::testRunStarting( testInfo );
- std::string stylesheetRef = getStylesheetRef();
- if( !stylesheetRef.empty() )
- m_xml.writeStylesheetRef( stylesheetRef );
- m_xml.startElement( "Catch" );
- if( !m_config->name().empty() )
- m_xml.writeAttribute( "name", m_config->name() );
- }
+ ConsoleAssertionPrinter printer(stream, _assertionStats, includeResults);
+ printer.print();
+ stream << std::endl;
+ return true;
+}
- virtual void testGroupStarting( GroupInfo const& groupInfo ) CATCH_OVERRIDE {
- StreamingReporterBase::testGroupStarting( groupInfo );
- m_xml.startElement( "Group" )
- .writeAttribute( "name", groupInfo.name );
- }
+void ConsoleReporter::sectionStarting(SectionInfo const& _sectionInfo) {
+ m_tablePrinter->close();
+ m_headerPrinted = false;
+ StreamingReporterBase::sectionStarting(_sectionInfo);
+}
+void ConsoleReporter::sectionEnded(SectionStats const& _sectionStats) {
+ m_tablePrinter->close();
+ if (_sectionStats.missingAssertions) {
+ lazyPrint();
+ Colour colour(Colour::ResultError);
+ if (m_sectionStack.size() > 1)
+ stream << "\nNo assertions in section";
+ else
+ stream << "\nNo assertions in test case";
+ stream << " '" << _sectionStats.sectionInfo.name << "'\n" << std::endl;
+ }
+ if (m_config->showDurations() == ShowDurations::Always) {
+ stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl;
+ }
+ if (m_headerPrinted) {
+ m_headerPrinted = false;
+ }
+ StreamingReporterBase::sectionEnded(_sectionStats);
+}
- virtual void testCaseStarting( TestCaseInfo const& testInfo ) CATCH_OVERRIDE {
- StreamingReporterBase::testCaseStarting(testInfo);
- m_xml.startElement( "TestCase" )
- .writeAttribute( "name", trim( testInfo.name ) )
- .writeAttribute( "description", testInfo.description )
- .writeAttribute( "tags", testInfo.tagsAsString );
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+void ConsoleReporter::benchmarkPreparing(std::string const& name) {
+ lazyPrintWithoutClosingBenchmarkTable();
- writeSourceInfo( testInfo.lineInfo );
+ auto nameCol = Column(name).width(static_cast<std::size_t>(m_tablePrinter->columnInfos()[0].width - 2));
- if ( m_config->showDurations() == ShowDurations::Always )
- m_testCaseTimer.start();
- m_xml.ensureTagClosed();
- }
+ bool firstLine = true;
+ for (auto line : nameCol) {
+ if (!firstLine)
+ (*m_tablePrinter) << ColumnBreak() << ColumnBreak() << ColumnBreak();
+ else
+ firstLine = false;
- virtual void sectionStarting( SectionInfo const& sectionInfo ) CATCH_OVERRIDE {
- StreamingReporterBase::sectionStarting( sectionInfo );
- if( m_sectionDepth++ > 0 ) {
- m_xml.startElement( "Section" )
- .writeAttribute( "name", trim( sectionInfo.name ) )
- .writeAttribute( "description", sectionInfo.description );
- writeSourceInfo( sectionInfo.lineInfo );
- m_xml.ensureTagClosed();
- }
- }
+ (*m_tablePrinter) << line << ColumnBreak();
+ }
+}
- virtual void assertionStarting( AssertionInfo const& ) CATCH_OVERRIDE { }
+void ConsoleReporter::benchmarkStarting(BenchmarkInfo const& info) {
+ (*m_tablePrinter) << info.samples << ColumnBreak()
+ << info.iterations << ColumnBreak();
+ if (!m_config->benchmarkNoAnalysis())
+ (*m_tablePrinter) << Duration(info.estimatedDuration) << ColumnBreak();
+}
+void ConsoleReporter::benchmarkEnded(BenchmarkStats<> const& stats) {
+ if (m_config->benchmarkNoAnalysis())
+ {
+ (*m_tablePrinter) << Duration(stats.mean.point.count()) << ColumnBreak();
+ }
+ else
+ {
+ (*m_tablePrinter) << ColumnBreak()
+ << Duration(stats.mean.point.count()) << ColumnBreak()
+ << Duration(stats.mean.lower_bound.count()) << ColumnBreak()
+ << Duration(stats.mean.upper_bound.count()) << ColumnBreak() << ColumnBreak()
+ << Duration(stats.standardDeviation.point.count()) << ColumnBreak()
+ << Duration(stats.standardDeviation.lower_bound.count()) << ColumnBreak()
+ << Duration(stats.standardDeviation.upper_bound.count()) << ColumnBreak() << ColumnBreak() << ColumnBreak() << ColumnBreak() << ColumnBreak();
+ }
+}
- virtual bool assertionEnded( AssertionStats const& assertionStats ) CATCH_OVERRIDE {
+void ConsoleReporter::benchmarkFailed(std::string const& error) {
+ Colour colour(Colour::Red);
+ (*m_tablePrinter)
+ << "Benchmark failed (" << error << ')'
+ << ColumnBreak() << RowBreak();
+}
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
- AssertionResult const& result = assertionStats.assertionResult;
+void ConsoleReporter::testCaseEnded(TestCaseStats const& _testCaseStats) {
+ m_tablePrinter->close();
+ StreamingReporterBase::testCaseEnded(_testCaseStats);
+ m_headerPrinted = false;
+}
+void ConsoleReporter::testGroupEnded(TestGroupStats const& _testGroupStats) {
+ if (currentGroupInfo.used) {
+ printSummaryDivider();
+ stream << "Summary for group '" << _testGroupStats.groupInfo.name << "':\n";
+ printTotals(_testGroupStats.totals);
+ stream << '\n' << std::endl;
+ }
+ StreamingReporterBase::testGroupEnded(_testGroupStats);
+}
+void ConsoleReporter::testRunEnded(TestRunStats const& _testRunStats) {
+ printTotalsDivider(_testRunStats.totals);
+ printTotals(_testRunStats.totals);
+ stream << std::endl;
+ StreamingReporterBase::testRunEnded(_testRunStats);
+}
+void ConsoleReporter::testRunStarting(TestRunInfo const& _testInfo) {
+ StreamingReporterBase::testRunStarting(_testInfo);
+ printTestFilters();
+}
- bool includeResults = m_config->includeSuccessfulResults() || !result.isOk();
+void ConsoleReporter::lazyPrint() {
- if( includeResults ) {
- // Print any info messages in <Info> tags.
- for( std::vector<MessageInfo>::const_iterator it = assertionStats.infoMessages.begin(), itEnd = assertionStats.infoMessages.end();
- it != itEnd;
- ++it ) {
- if( it->type == ResultWas::Info ) {
- m_xml.scopedElement( "Info" )
- .writeText( it->message );
- } else if ( it->type == ResultWas::Warning ) {
- m_xml.scopedElement( "Warning" )
- .writeText( it->message );
- }
- }
- }
+ m_tablePrinter->close();
+ lazyPrintWithoutClosingBenchmarkTable();
+}
- // Drop out if result was successful but we're not printing them.
- if( !includeResults && result.getResultType() != ResultWas::Warning )
- return true;
+void ConsoleReporter::lazyPrintWithoutClosingBenchmarkTable() {
- // Print the expression if there is one.
- if( result.hasExpression() ) {
- m_xml.startElement( "Expression" )
- .writeAttribute( "success", result.succeeded() )
- .writeAttribute( "type", result.getTestMacroName() );
+ if (!currentTestRunInfo.used)
+ lazyPrintRunInfo();
+ if (!currentGroupInfo.used)
+ lazyPrintGroupInfo();
- writeSourceInfo( result.getSourceInfo() );
+ if (!m_headerPrinted) {
+ printTestCaseAndSectionHeader();
+ m_headerPrinted = true;
+ }
+}
+void ConsoleReporter::lazyPrintRunInfo() {
+ stream << '\n' << getLineOfChars<'~'>() << '\n';
+ Colour colour(Colour::SecondaryText);
+ stream << currentTestRunInfo->name
+ << " is a Catch v" << libraryVersion() << " host application.\n"
+ << "Run with -? for options\n\n";
- m_xml.scopedElement( "Original" )
- .writeText( result.getExpression() );
- m_xml.scopedElement( "Expanded" )
- .writeText( result.getExpandedExpression() );
- }
+ if (m_config->rngSeed() != 0)
+ stream << "Randomness seeded to: " << m_config->rngSeed() << "\n\n";
- // And... Print a result applicable to each result type.
- switch( result.getResultType() ) {
- case ResultWas::ThrewException:
- m_xml.startElement( "Exception" );
- writeSourceInfo( result.getSourceInfo() );
- m_xml.writeText( result.getMessage() );
- m_xml.endElement();
- break;
- case ResultWas::FatalErrorCondition:
- m_xml.startElement( "FatalErrorCondition" );
- writeSourceInfo( result.getSourceInfo() );
- m_xml.writeText( result.getMessage() );
- m_xml.endElement();
- break;
- case ResultWas::Info:
- m_xml.scopedElement( "Info" )
- .writeText( result.getMessage() );
- break;
- case ResultWas::Warning:
- // Warning will already have been written
- break;
- case ResultWas::ExplicitFailure:
- m_xml.startElement( "Failure" );
- writeSourceInfo( result.getSourceInfo() );
- m_xml.writeText( result.getMessage() );
- m_xml.endElement();
- break;
- default:
- break;
- }
+ currentTestRunInfo.used = true;
+}
+void ConsoleReporter::lazyPrintGroupInfo() {
+ if (!currentGroupInfo->name.empty() && currentGroupInfo->groupsCounts > 1) {
+ printClosedHeader("Group: " + currentGroupInfo->name);
+ currentGroupInfo.used = true;
+ }
+}
+void ConsoleReporter::printTestCaseAndSectionHeader() {
+ assert(!m_sectionStack.empty());
+ printOpenHeader(currentTestCaseInfo->name);
- if( result.hasExpression() )
- m_xml.endElement();
+ if (m_sectionStack.size() > 1) {
+ Colour colourGuard(Colour::Headers);
- return true;
- }
+ auto
+ it = m_sectionStack.begin() + 1, // Skip first section (test case)
+ itEnd = m_sectionStack.end();
+ for (; it != itEnd; ++it)
+ printHeaderString(it->name, 2);
+ }
- virtual void sectionEnded( SectionStats const& sectionStats ) CATCH_OVERRIDE {
- StreamingReporterBase::sectionEnded( sectionStats );
- if( --m_sectionDepth > 0 ) {
- XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResults" );
- e.writeAttribute( "successes", sectionStats.assertions.passed );
- e.writeAttribute( "failures", sectionStats.assertions.failed );
- e.writeAttribute( "expectedFailures", sectionStats.assertions.failedButOk );
+ SourceLineInfo lineInfo = m_sectionStack.back().lineInfo;
- if ( m_config->showDurations() == ShowDurations::Always )
- e.writeAttribute( "durationInSeconds", sectionStats.durationInSeconds );
+ stream << getLineOfChars<'-'>() << '\n';
+ Colour colourGuard(Colour::FileName);
+ stream << lineInfo << '\n';
+ stream << getLineOfChars<'.'>() << '\n' << std::endl;
+}
- m_xml.endElement();
- }
- }
+void ConsoleReporter::printClosedHeader(std::string const& _name) {
+ printOpenHeader(_name);
+ stream << getLineOfChars<'.'>() << '\n';
+}
+void ConsoleReporter::printOpenHeader(std::string const& _name) {
+ stream << getLineOfChars<'-'>() << '\n';
+ {
+ Colour colourGuard(Colour::Headers);
+ printHeaderString(_name);
+ }
+}
- virtual void testCaseEnded( TestCaseStats const& testCaseStats ) CATCH_OVERRIDE {
- StreamingReporterBase::testCaseEnded( testCaseStats );
- XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResult" );
- e.writeAttribute( "success", testCaseStats.totals.assertions.allOk() );
+// if string has a : in first line will set indent to follow it on
+// subsequent lines
+void ConsoleReporter::printHeaderString(std::string const& _string, std::size_t indent) {
+ std::size_t i = _string.find(": ");
+ if (i != std::string::npos)
+ i += 2;
+ else
+ i = 0;
+ stream << Column(_string).indent(indent + i).initialIndent(indent) << '\n';
+}
- if ( m_config->showDurations() == ShowDurations::Always )
- e.writeAttribute( "durationInSeconds", m_testCaseTimer.getElapsedSeconds() );
+struct SummaryColumn {
+
+ SummaryColumn( std::string _label, Colour::Code _colour )
+ : label( std::move( _label ) ),
+ colour( _colour ) {}
+ SummaryColumn addRow( std::size_t count ) {
+ ReusableStringStream rss;
+ rss << count;
+ std::string row = rss.str();
+ for (auto& oldRow : rows) {
+ while (oldRow.size() < row.size())
+ oldRow = ' ' + oldRow;
+ while (oldRow.size() > row.size())
+ row = ' ' + row;
+ }
+ rows.push_back(row);
+ return *this;
+ }
- if( !testCaseStats.stdOut.empty() )
- m_xml.scopedElement( "StdOut" ).writeText( trim( testCaseStats.stdOut ), false );
- if( !testCaseStats.stdErr.empty() )
- m_xml.scopedElement( "StdErr" ).writeText( trim( testCaseStats.stdErr ), false );
+ std::string label;
+ Colour::Code colour;
+ std::vector<std::string> rows;
- m_xml.endElement();
- }
+};
- virtual void testGroupEnded( TestGroupStats const& testGroupStats ) CATCH_OVERRIDE {
- StreamingReporterBase::testGroupEnded( testGroupStats );
- // TODO: Check testGroupStats.aborting and act accordingly.
- m_xml.scopedElement( "OverallResults" )
- .writeAttribute( "successes", testGroupStats.totals.assertions.passed )
- .writeAttribute( "failures", testGroupStats.totals.assertions.failed )
- .writeAttribute( "expectedFailures", testGroupStats.totals.assertions.failedButOk );
- m_xml.endElement();
+void ConsoleReporter::printTotals( Totals const& totals ) {
+ if (totals.testCases.total() == 0) {
+ stream << Colour(Colour::Warning) << "No tests ran\n";
+ } else if (totals.assertions.total() > 0 && totals.testCases.allPassed()) {
+ stream << Colour(Colour::ResultSuccess) << "All tests passed";
+ stream << " ("
+ << pluralise(totals.assertions.passed, "assertion") << " in "
+ << pluralise(totals.testCases.passed, "test case") << ')'
+ << '\n';
+ } else {
+
+ std::vector<SummaryColumn> columns;
+ columns.push_back(SummaryColumn("", Colour::None)
+ .addRow(totals.testCases.total())
+ .addRow(totals.assertions.total()));
+ columns.push_back(SummaryColumn("passed", Colour::Success)
+ .addRow(totals.testCases.passed)
+ .addRow(totals.assertions.passed));
+ columns.push_back(SummaryColumn("failed", Colour::ResultError)
+ .addRow(totals.testCases.failed)
+ .addRow(totals.assertions.failed));
+ columns.push_back(SummaryColumn("failed as expected", Colour::ResultExpectedFailure)
+ .addRow(totals.testCases.failedButOk)
+ .addRow(totals.assertions.failedButOk));
+
+ printSummaryRow("test cases", columns, 0);
+ printSummaryRow("assertions", columns, 1);
+ }
+}
+void ConsoleReporter::printSummaryRow(std::string const& label, std::vector<SummaryColumn> const& cols, std::size_t row) {
+ for (auto col : cols) {
+ std::string value = col.rows[row];
+ if (col.label.empty()) {
+ stream << label << ": ";
+ if (value != "0")
+ stream << value;
+ else
+ stream << Colour(Colour::Warning) << "- none -";
+ } else if (value != "0") {
+ stream << Colour(Colour::LightGrey) << " | ";
+ stream << Colour(col.colour)
+ << value << ' ' << col.label;
}
+ }
+ stream << '\n';
+}
- virtual void testRunEnded( TestRunStats const& testRunStats ) CATCH_OVERRIDE {
- StreamingReporterBase::testRunEnded( testRunStats );
- m_xml.scopedElement( "OverallResults" )
- .writeAttribute( "successes", testRunStats.totals.assertions.passed )
- .writeAttribute( "failures", testRunStats.totals.assertions.failed )
- .writeAttribute( "expectedFailures", testRunStats.totals.assertions.failedButOk );
- m_xml.endElement();
- }
+void ConsoleReporter::printTotalsDivider(Totals const& totals) {
+ if (totals.testCases.total() > 0) {
+ std::size_t failedRatio = makeRatio(totals.testCases.failed, totals.testCases.total());
+ std::size_t failedButOkRatio = makeRatio(totals.testCases.failedButOk, totals.testCases.total());
+ std::size_t passedRatio = makeRatio(totals.testCases.passed, totals.testCases.total());
+ while (failedRatio + failedButOkRatio + passedRatio < CATCH_CONFIG_CONSOLE_WIDTH - 1)
+ findMax(failedRatio, failedButOkRatio, passedRatio)++;
+ while (failedRatio + failedButOkRatio + passedRatio > CATCH_CONFIG_CONSOLE_WIDTH - 1)
+ findMax(failedRatio, failedButOkRatio, passedRatio)--;
+
+ stream << Colour(Colour::Error) << std::string(failedRatio, '=');
+ stream << Colour(Colour::ResultExpectedFailure) << std::string(failedButOkRatio, '=');
+ if (totals.testCases.allPassed())
+ stream << Colour(Colour::ResultSuccess) << std::string(passedRatio, '=');
+ else
+ stream << Colour(Colour::Success) << std::string(passedRatio, '=');
+ } else {
+ stream << Colour(Colour::Warning) << std::string(CATCH_CONFIG_CONSOLE_WIDTH - 1, '=');
+ }
+ stream << '\n';
+}
+void ConsoleReporter::printSummaryDivider() {
+ stream << getLineOfChars<'-'>() << '\n';
+}
- private:
- Timer m_testCaseTimer;
- XmlWriter m_xml;
- int m_sectionDepth;
- };
+void ConsoleReporter::printTestFilters() {
+ if (m_config->testSpec().hasFilters())
+ stream << Colour(Colour::BrightYellow) << "Filters: " << serializeFilters( m_config->getTestsOrTags() ) << '\n';
+}
- INTERNAL_CATCH_REGISTER_REPORTER( "xml", XmlReporter )
+CATCH_REGISTER_REPORTER("console", ConsoleReporter)
} // end namespace Catch
-// #included from: ../reporters/catch_reporter_junit.hpp
-#define TWOBLUECUBES_CATCH_REPORTER_JUNIT_HPP_INCLUDED
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
-#include <assert.h>
+#if defined(__clang__)
+# pragma clang diagnostic pop
+#endif
+// end catch_reporter_console.cpp
+// start catch_reporter_junit.cpp
+
+#include <cassert>
+#include <sstream>
+#include <ctime>
+#include <algorithm>
namespace Catch {
@@ -10268,7 +16355,7 @@ namespace Catch {
// Also, UTC only, again because of backward compatibility (%z is C++11)
time_t rawtime;
std::time(&rawtime);
- const size_t timeStampSize = sizeof("2017-01-16T17:06:45Z");
+ auto const timeStampSize = sizeof("2017-01-16T17:06:45Z");
#ifdef _MSC_VER
std::tm timeInfo = {};
@@ -10289,1013 +16376,657 @@ namespace Catch {
return std::string(timeStamp);
}
- }
+ std::string fileNameTag(const std::vector<std::string> &tags) {
+ auto it = std::find_if(begin(tags),
+ end(tags),
+ [] (std::string const& tag) {return tag.front() == '#'; });
+ if (it != tags.end())
+ return it->substr(1);
+ return std::string();
+ }
+ } // anonymous namespace
- class JunitReporter : public CumulativeReporterBase {
- public:
- JunitReporter( ReporterConfig const& _config )
+ JunitReporter::JunitReporter( ReporterConfig const& _config )
: CumulativeReporterBase( _config ),
- xml( _config.stream() ),
- m_okToFail( false )
+ xml( _config.stream() )
{
m_reporterPrefs.shouldRedirectStdOut = true;
+ m_reporterPrefs.shouldReportAllAssertions = true;
}
- virtual ~JunitReporter() CATCH_OVERRIDE;
+ JunitReporter::~JunitReporter() {}
- static std::string getDescription() {
- return "Reports test results in an XML format that looks like Ant's junitreport target";
- }
+ std::string JunitReporter::getDescription() {
+ return "Reports test results in an XML format that looks like Ant's junitreport target";
+ }
- virtual void noMatchingTestCases( std::string const& /*spec*/ ) CATCH_OVERRIDE {}
+ void JunitReporter::noMatchingTestCases( std::string const& /*spec*/ ) {}
- virtual void testRunStarting( TestRunInfo const& runInfo ) CATCH_OVERRIDE {
- CumulativeReporterBase::testRunStarting( runInfo );
- xml.startElement( "testsuites" );
- }
+ void JunitReporter::testRunStarting( TestRunInfo const& runInfo ) {
+ CumulativeReporterBase::testRunStarting( runInfo );
+ xml.startElement( "testsuites" );
+ }
- virtual void testGroupStarting( GroupInfo const& groupInfo ) CATCH_OVERRIDE {
- suiteTimer.start();
- stdOutForSuite.str("");
- stdErrForSuite.str("");
- unexpectedExceptions = 0;
- CumulativeReporterBase::testGroupStarting( groupInfo );
- }
+ void JunitReporter::testGroupStarting( GroupInfo const& groupInfo ) {
+ suiteTimer.start();
+ stdOutForSuite.clear();
+ stdErrForSuite.clear();
+ unexpectedExceptions = 0;
+ CumulativeReporterBase::testGroupStarting( groupInfo );
+ }
- virtual void testCaseStarting( TestCaseInfo const& testCaseInfo ) CATCH_OVERRIDE {
- m_okToFail = testCaseInfo.okToFail();
- }
- virtual bool assertionEnded( AssertionStats const& assertionStats ) CATCH_OVERRIDE {
- if( assertionStats.assertionResult.getResultType() == ResultWas::ThrewException && !m_okToFail )
- unexpectedExceptions++;
- return CumulativeReporterBase::assertionEnded( assertionStats );
- }
+ void JunitReporter::testCaseStarting( TestCaseInfo const& testCaseInfo ) {
+ m_okToFail = testCaseInfo.okToFail();
+ }
- virtual void testCaseEnded( TestCaseStats const& testCaseStats ) CATCH_OVERRIDE {
- stdOutForSuite << testCaseStats.stdOut;
- stdErrForSuite << testCaseStats.stdErr;
- CumulativeReporterBase::testCaseEnded( testCaseStats );
- }
+ bool JunitReporter::assertionEnded( AssertionStats const& assertionStats ) {
+ if( assertionStats.assertionResult.getResultType() == ResultWas::ThrewException && !m_okToFail )
+ unexpectedExceptions++;
+ return CumulativeReporterBase::assertionEnded( assertionStats );
+ }
- virtual void testGroupEnded( TestGroupStats const& testGroupStats ) CATCH_OVERRIDE {
- double suiteTime = suiteTimer.getElapsedSeconds();
- CumulativeReporterBase::testGroupEnded( testGroupStats );
- writeGroup( *m_testGroups.back(), suiteTime );
- }
+ void JunitReporter::testCaseEnded( TestCaseStats const& testCaseStats ) {
+ stdOutForSuite += testCaseStats.stdOut;
+ stdErrForSuite += testCaseStats.stdErr;
+ CumulativeReporterBase::testCaseEnded( testCaseStats );
+ }
- virtual void testRunEndedCumulative() CATCH_OVERRIDE {
- xml.endElement();
- }
+ void JunitReporter::testGroupEnded( TestGroupStats const& testGroupStats ) {
+ double suiteTime = suiteTimer.getElapsedSeconds();
+ CumulativeReporterBase::testGroupEnded( testGroupStats );
+ writeGroup( *m_testGroups.back(), suiteTime );
+ }
- void writeGroup( TestGroupNode const& groupNode, double suiteTime ) {
- XmlWriter::ScopedElement e = xml.scopedElement( "testsuite" );
- TestGroupStats const& stats = groupNode.value;
- xml.writeAttribute( "name", stats.groupInfo.name );
- xml.writeAttribute( "errors", unexpectedExceptions );
- xml.writeAttribute( "failures", stats.totals.assertions.failed-unexpectedExceptions );
- xml.writeAttribute( "tests", stats.totals.assertions.total() );
- xml.writeAttribute( "hostname", "tbd" ); // !TBD
- if( m_config->showDurations() == ShowDurations::Never )
- xml.writeAttribute( "time", "" );
- else
- xml.writeAttribute( "time", suiteTime );
- xml.writeAttribute( "timestamp", getCurrentTimestamp() );
+ void JunitReporter::testRunEndedCumulative() {
+ xml.endElement();
+ }
- // Write test cases
- for( TestGroupNode::ChildNodes::const_iterator
- it = groupNode.children.begin(), itEnd = groupNode.children.end();
- it != itEnd;
- ++it )
- writeTestCase( **it );
+ void JunitReporter::writeGroup( TestGroupNode const& groupNode, double suiteTime ) {
+ XmlWriter::ScopedElement e = xml.scopedElement( "testsuite" );
- xml.scopedElement( "system-out" ).writeText( trim( stdOutForSuite.str() ), false );
- xml.scopedElement( "system-err" ).writeText( trim( stdErrForSuite.str() ), false );
+ TestGroupStats const& stats = groupNode.value;
+ xml.writeAttribute( "name", stats.groupInfo.name );
+ xml.writeAttribute( "errors", unexpectedExceptions );
+ xml.writeAttribute( "failures", stats.totals.assertions.failed-unexpectedExceptions );
+ xml.writeAttribute( "tests", stats.totals.assertions.total() );
+ xml.writeAttribute( "hostname", "tbd" ); // !TBD
+ if( m_config->showDurations() == ShowDurations::Never )
+ xml.writeAttribute( "time", "" );
+ else
+ xml.writeAttribute( "time", suiteTime );
+ xml.writeAttribute( "timestamp", getCurrentTimestamp() );
+
+ // Write properties if there are any
+ if (m_config->hasTestFilters() || m_config->rngSeed() != 0) {
+ auto properties = xml.scopedElement("properties");
+ if (m_config->hasTestFilters()) {
+ xml.scopedElement("property")
+ .writeAttribute("name", "filters")
+ .writeAttribute("value", serializeFilters(m_config->getTestsOrTags()));
+ }
+ if (m_config->rngSeed() != 0) {
+ xml.scopedElement("property")
+ .writeAttribute("name", "random-seed")
+ .writeAttribute("value", m_config->rngSeed());
+ }
}
- void writeTestCase( TestCaseNode const& testCaseNode ) {
- TestCaseStats const& stats = testCaseNode.value;
-
- // All test cases have exactly one section - which represents the
- // test case itself. That section may have 0-n nested sections
- assert( testCaseNode.children.size() == 1 );
- SectionNode const& rootSection = *testCaseNode.children.front();
-
- std::string className = stats.testInfo.className;
+ // Write test cases
+ for( auto const& child : groupNode.children )
+ writeTestCase( *child );
- if( className.empty() ) {
- if( rootSection.childSections.empty() )
- className = "global";
- }
- writeSection( className, "", rootSection );
- }
-
- void writeSection( std::string const& className,
- std::string const& rootName,
- SectionNode const& sectionNode ) {
- std::string name = trim( sectionNode.stats.sectionInfo.name );
- if( !rootName.empty() )
- name = rootName + '/' + name;
-
- if( !sectionNode.assertions.empty() ||
- !sectionNode.stdOut.empty() ||
- !sectionNode.stdErr.empty() ) {
- XmlWriter::ScopedElement e = xml.scopedElement( "testcase" );
- if( className.empty() ) {
- xml.writeAttribute( "classname", name );
- xml.writeAttribute( "name", "root" );
- }
- else {
- xml.writeAttribute( "classname", className );
- xml.writeAttribute( "name", name );
- }
- xml.writeAttribute( "time", Catch::toString( sectionNode.stats.durationInSeconds ) );
-
- writeAssertions( sectionNode );
-
- if( !sectionNode.stdOut.empty() )
- xml.scopedElement( "system-out" ).writeText( trim( sectionNode.stdOut ), false );
- if( !sectionNode.stdErr.empty() )
- xml.scopedElement( "system-err" ).writeText( trim( sectionNode.stdErr ), false );
- }
- for( SectionNode::ChildSections::const_iterator
- it = sectionNode.childSections.begin(),
- itEnd = sectionNode.childSections.end();
- it != itEnd;
- ++it )
- if( className.empty() )
- writeSection( name, "", **it );
- else
- writeSection( className, name, **it );
- }
-
- void writeAssertions( SectionNode const& sectionNode ) {
- for( SectionNode::Assertions::const_iterator
- it = sectionNode.assertions.begin(), itEnd = sectionNode.assertions.end();
- it != itEnd;
- ++it )
- writeAssertion( *it );
- }
- void writeAssertion( AssertionStats const& stats ) {
- AssertionResult const& result = stats.assertionResult;
- if( !result.isOk() ) {
- std::string elementName;
- switch( result.getResultType() ) {
- case ResultWas::ThrewException:
- case ResultWas::FatalErrorCondition:
- elementName = "error";
- break;
- case ResultWas::ExplicitFailure:
- elementName = "failure";
- break;
- case ResultWas::ExpressionFailed:
- elementName = "failure";
- break;
- case ResultWas::DidntThrowException:
- elementName = "failure";
- break;
-
- // We should never see these here:
- case ResultWas::Info:
- case ResultWas::Warning:
- case ResultWas::Ok:
- case ResultWas::Unknown:
- case ResultWas::FailureBit:
- case ResultWas::Exception:
- elementName = "internalError";
- break;
- }
+ xml.scopedElement( "system-out" ).writeText( trim( stdOutForSuite ), false );
+ xml.scopedElement( "system-err" ).writeText( trim( stdErrForSuite ), false );
+ }
- XmlWriter::ScopedElement e = xml.scopedElement( elementName );
+ void JunitReporter::writeTestCase( TestCaseNode const& testCaseNode ) {
+ TestCaseStats const& stats = testCaseNode.value;
- xml.writeAttribute( "message", result.getExpandedExpression() );
- xml.writeAttribute( "type", result.getTestMacroName() );
+ // All test cases have exactly one section - which represents the
+ // test case itself. That section may have 0-n nested sections
+ assert( testCaseNode.children.size() == 1 );
+ SectionNode const& rootSection = *testCaseNode.children.front();
- std::ostringstream oss;
- if( !result.getMessage().empty() )
- oss << result.getMessage() << '\n';
- for( std::vector<MessageInfo>::const_iterator
- it = stats.infoMessages.begin(),
- itEnd = stats.infoMessages.end();
- it != itEnd;
- ++it )
- if( it->type == ResultWas::Info )
- oss << it->message << '\n';
+ std::string className = stats.testInfo.className;
- oss << "at " << result.getSourceInfo();
- xml.writeText( oss.str(), false );
- }
+ if( className.empty() ) {
+ className = fileNameTag(stats.testInfo.tags);
+ if ( className.empty() )
+ className = "global";
}
- XmlWriter xml;
- Timer suiteTimer;
- std::ostringstream stdOutForSuite;
- std::ostringstream stdErrForSuite;
- unsigned int unexpectedExceptions;
- bool m_okToFail;
- };
-
- INTERNAL_CATCH_REGISTER_REPORTER( "junit", JunitReporter )
-
-} // end namespace Catch
+ if ( !m_config->name().empty() )
+ className = m_config->name() + "." + className;
-// #included from: ../reporters/catch_reporter_console.hpp
-#define TWOBLUECUBES_CATCH_REPORTER_CONSOLE_HPP_INCLUDED
+ writeSection( className, "", rootSection );
+ }
-#include <cfloat>
-#include <cstdio>
+ void JunitReporter::writeSection( std::string const& className,
+ std::string const& rootName,
+ SectionNode const& sectionNode ) {
+ std::string name = trim( sectionNode.stats.sectionInfo.name );
+ if( !rootName.empty() )
+ name = rootName + '/' + name;
-namespace Catch {
+ if( !sectionNode.assertions.empty() ||
+ !sectionNode.stdOut.empty() ||
+ !sectionNode.stdErr.empty() ) {
+ XmlWriter::ScopedElement e = xml.scopedElement( "testcase" );
+ if( className.empty() ) {
+ xml.writeAttribute( "classname", name );
+ xml.writeAttribute( "name", "root" );
+ }
+ else {
+ xml.writeAttribute( "classname", className );
+ xml.writeAttribute( "name", name );
+ }
+ xml.writeAttribute( "time", ::Catch::Detail::stringify( sectionNode.stats.durationInSeconds ) );
- struct ConsoleReporter : StreamingReporterBase {
- ConsoleReporter( ReporterConfig const& _config )
- : StreamingReporterBase( _config ),
- m_headerPrinted( false )
- {}
+ writeAssertions( sectionNode );
- virtual ~ConsoleReporter() CATCH_OVERRIDE;
- static std::string getDescription() {
- return "Reports test results as plain lines of text";
+ if( !sectionNode.stdOut.empty() )
+ xml.scopedElement( "system-out" ).writeText( trim( sectionNode.stdOut ), false );
+ if( !sectionNode.stdErr.empty() )
+ xml.scopedElement( "system-err" ).writeText( trim( sectionNode.stdErr ), false );
}
+ for( auto const& childNode : sectionNode.childSections )
+ if( className.empty() )
+ writeSection( name, "", *childNode );
+ else
+ writeSection( className, name, *childNode );
+ }
- virtual void noMatchingTestCases( std::string const& spec ) CATCH_OVERRIDE {
- stream << "No test cases matched '" << spec << '\'' << std::endl;
- }
+ void JunitReporter::writeAssertions( SectionNode const& sectionNode ) {
+ for( auto const& assertion : sectionNode.assertions )
+ writeAssertion( assertion );
+ }
- virtual void assertionStarting( AssertionInfo const& ) CATCH_OVERRIDE {
- }
+ void JunitReporter::writeAssertion( AssertionStats const& stats ) {
+ AssertionResult const& result = stats.assertionResult;
+ if( !result.isOk() ) {
+ std::string elementName;
+ switch( result.getResultType() ) {
+ case ResultWas::ThrewException:
+ case ResultWas::FatalErrorCondition:
+ elementName = "error";
+ break;
+ case ResultWas::ExplicitFailure:
+ elementName = "failure";
+ break;
+ case ResultWas::ExpressionFailed:
+ elementName = "failure";
+ break;
+ case ResultWas::DidntThrowException:
+ elementName = "failure";
+ break;
- virtual bool assertionEnded( AssertionStats const& _assertionStats ) CATCH_OVERRIDE {
- AssertionResult const& result = _assertionStats.assertionResult;
+ // We should never see these here:
+ case ResultWas::Info:
+ case ResultWas::Warning:
+ case ResultWas::Ok:
+ case ResultWas::Unknown:
+ case ResultWas::FailureBit:
+ case ResultWas::Exception:
+ elementName = "internalError";
+ break;
+ }
- bool includeResults = m_config->includeSuccessfulResults() || !result.isOk();
+ XmlWriter::ScopedElement e = xml.scopedElement( elementName );
- // Drop out if result was successful but we're not printing them.
- if( !includeResults && result.getResultType() != ResultWas::Warning )
- return false;
+ xml.writeAttribute( "message", result.getExpandedExpression() );
+ xml.writeAttribute( "type", result.getTestMacroName() );
- lazyPrint();
+ ReusableStringStream rss;
+ if( !result.getMessage().empty() )
+ rss << result.getMessage() << '\n';
+ for( auto const& msg : stats.infoMessages )
+ if( msg.type == ResultWas::Info )
+ rss << msg.message << '\n';
- AssertionPrinter printer( stream, _assertionStats, includeResults );
- printer.print();
- stream << std::endl;
- return true;
+ rss << "at " << result.getSourceInfo();
+ xml.writeText( rss.str(), false );
}
+ }
- virtual void sectionStarting( SectionInfo const& _sectionInfo ) CATCH_OVERRIDE {
- m_headerPrinted = false;
- StreamingReporterBase::sectionStarting( _sectionInfo );
- }
- virtual void sectionEnded( SectionStats const& _sectionStats ) CATCH_OVERRIDE {
- if( _sectionStats.missingAssertions ) {
- lazyPrint();
- Colour colour( Colour::ResultError );
- if( m_sectionStack.size() > 1 )
- stream << "\nNo assertions in section";
- else
- stream << "\nNo assertions in test case";
- stream << " '" << _sectionStats.sectionInfo.name << "'\n" << std::endl;
- }
- if( m_config->showDurations() == ShowDurations::Always ) {
- stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl;
- }
- if( m_headerPrinted ) {
- m_headerPrinted = false;
- }
- StreamingReporterBase::sectionEnded( _sectionStats );
- }
+ CATCH_REGISTER_REPORTER( "junit", JunitReporter )
- virtual void testCaseEnded( TestCaseStats const& _testCaseStats ) CATCH_OVERRIDE {
- StreamingReporterBase::testCaseEnded( _testCaseStats );
- m_headerPrinted = false;
- }
- virtual void testGroupEnded( TestGroupStats const& _testGroupStats ) CATCH_OVERRIDE {
- if( currentGroupInfo.used ) {
- printSummaryDivider();
- stream << "Summary for group '" << _testGroupStats.groupInfo.name << "':\n";
- printTotals( _testGroupStats.totals );
- stream << '\n' << std::endl;
- }
- StreamingReporterBase::testGroupEnded( _testGroupStats );
- }
- virtual void testRunEnded( TestRunStats const& _testRunStats ) CATCH_OVERRIDE {
- printTotalsDivider( _testRunStats.totals );
- printTotals( _testRunStats.totals );
- stream << std::endl;
- StreamingReporterBase::testRunEnded( _testRunStats );
- }
+} // end namespace Catch
+// end catch_reporter_junit.cpp
+// start catch_reporter_listening.cpp
- private:
+#include <cassert>
- class AssertionPrinter {
- void operator= ( AssertionPrinter const& );
- public:
- AssertionPrinter( std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages )
- : stream( _stream ),
- stats( _stats ),
- result( _stats.assertionResult ),
- colour( Colour::None ),
- message( result.getMessage() ),
- messages( _stats.infoMessages ),
- printInfoMessages( _printInfoMessages )
- {
- switch( result.getResultType() ) {
- case ResultWas::Ok:
- colour = Colour::Success;
- passOrFail = "PASSED";
- //if( result.hasMessage() )
- if( _stats.infoMessages.size() == 1 )
- messageLabel = "with message";
- if( _stats.infoMessages.size() > 1 )
- messageLabel = "with messages";
- break;
- case ResultWas::ExpressionFailed:
- if( result.isOk() ) {
- colour = Colour::Success;
- passOrFail = "FAILED - but was ok";
- }
- else {
- colour = Colour::Error;
- passOrFail = "FAILED";
- }
- if( _stats.infoMessages.size() == 1 )
- messageLabel = "with message";
- if( _stats.infoMessages.size() > 1 )
- messageLabel = "with messages";
- break;
- case ResultWas::ThrewException:
- colour = Colour::Error;
- passOrFail = "FAILED";
- messageLabel = "due to unexpected exception with ";
- if (_stats.infoMessages.size() == 1)
- messageLabel += "message";
- if (_stats.infoMessages.size() > 1)
- messageLabel += "messages";
- break;
- case ResultWas::FatalErrorCondition:
- colour = Colour::Error;
- passOrFail = "FAILED";
- messageLabel = "due to a fatal error condition";
- break;
- case ResultWas::DidntThrowException:
- colour = Colour::Error;
- passOrFail = "FAILED";
- messageLabel = "because no exception was thrown where one was expected";
- break;
- case ResultWas::Info:
- messageLabel = "info";
- break;
- case ResultWas::Warning:
- messageLabel = "warning";
- break;
- case ResultWas::ExplicitFailure:
- passOrFail = "FAILED";
- colour = Colour::Error;
- if( _stats.infoMessages.size() == 1 )
- messageLabel = "explicitly with message";
- if( _stats.infoMessages.size() > 1 )
- messageLabel = "explicitly with messages";
- break;
- // These cases are here to prevent compiler warnings
- case ResultWas::Unknown:
- case ResultWas::FailureBit:
- case ResultWas::Exception:
- passOrFail = "** internal error **";
- colour = Colour::Error;
- break;
- }
- }
+namespace Catch {
- void print() const {
- printSourceInfo();
- if( stats.totals.assertions.total() > 0 ) {
- if( result.isOk() )
- stream << '\n';
- printResultType();
- printOriginalExpression();
- printReconstructedExpression();
- }
- else {
- stream << '\n';
- }
- printMessage();
- }
+ ListeningReporter::ListeningReporter() {
+ // We will assume that listeners will always want all assertions
+ m_preferences.shouldReportAllAssertions = true;
+ }
- private:
- void printResultType() const {
- if( !passOrFail.empty() ) {
- Colour colourGuard( colour );
- stream << passOrFail << ":\n";
- }
- }
- void printOriginalExpression() const {
- if( result.hasExpression() ) {
- Colour colourGuard( Colour::OriginalExpression );
- stream << " ";
- stream << result.getExpressionInMacro();
- stream << '\n';
- }
- }
- void printReconstructedExpression() const {
- if( result.hasExpandedExpression() ) {
- stream << "with expansion:\n";
- Colour colourGuard( Colour::ReconstructedExpression );
- stream << Text( result.getExpandedExpression(), TextAttributes().setIndent(2) ) << '\n';
- }
- }
- void printMessage() const {
- if( !messageLabel.empty() )
- stream << messageLabel << ':' << '\n';
- for( std::vector<MessageInfo>::const_iterator it = messages.begin(), itEnd = messages.end();
- it != itEnd;
- ++it ) {
- // If this assertion is a warning ignore any INFO messages
- if( printInfoMessages || it->type != ResultWas::Info )
- stream << Text( it->message, TextAttributes().setIndent(2) ) << '\n';
- }
- }
- void printSourceInfo() const {
- Colour colourGuard( Colour::FileName );
- stream << result.getSourceInfo() << ": ";
- }
+ void ListeningReporter::addListener( IStreamingReporterPtr&& listener ) {
+ m_listeners.push_back( std::move( listener ) );
+ }
- std::ostream& stream;
- AssertionStats const& stats;
- AssertionResult const& result;
- Colour::Code colour;
- std::string passOrFail;
- std::string messageLabel;
- std::string message;
- std::vector<MessageInfo> messages;
- bool printInfoMessages;
- };
+ void ListeningReporter::addReporter(IStreamingReporterPtr&& reporter) {
+ assert(!m_reporter && "Listening reporter can wrap only 1 real reporter");
+ m_reporter = std::move( reporter );
+ m_preferences.shouldRedirectStdOut = m_reporter->getPreferences().shouldRedirectStdOut;
+ }
- void lazyPrint() {
+ ReporterPreferences ListeningReporter::getPreferences() const {
+ return m_preferences;
+ }
- if( !currentTestRunInfo.used )
- lazyPrintRunInfo();
- if( !currentGroupInfo.used )
- lazyPrintGroupInfo();
+ std::set<Verbosity> ListeningReporter::getSupportedVerbosities() {
+ return std::set<Verbosity>{ };
+ }
- if( !m_headerPrinted ) {
- printTestCaseAndSectionHeader();
- m_headerPrinted = true;
- }
+ void ListeningReporter::noMatchingTestCases( std::string const& spec ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->noMatchingTestCases( spec );
}
- void lazyPrintRunInfo() {
- stream << '\n' << getLineOfChars<'~'>() << '\n';
- Colour colour( Colour::SecondaryText );
- stream << currentTestRunInfo->name
- << " is a Catch v" << libraryVersion() << " host application.\n"
- << "Run with -? for options\n\n";
+ m_reporter->noMatchingTestCases( spec );
+ }
- if( m_config->rngSeed() != 0 )
- stream << "Randomness seeded to: " << m_config->rngSeed() << "\n\n";
+ void ListeningReporter::reportInvalidArguments(std::string const&arg){
+ for ( auto const& listener : m_listeners ) {
+ listener->reportInvalidArguments( arg );
+ }
+ m_reporter->reportInvalidArguments( arg );
+ }
- currentTestRunInfo.used = true;
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ void ListeningReporter::benchmarkPreparing( std::string const& name ) {
+ for (auto const& listener : m_listeners) {
+ listener->benchmarkPreparing(name);
+ }
+ m_reporter->benchmarkPreparing(name);
+ }
+ void ListeningReporter::benchmarkStarting( BenchmarkInfo const& benchmarkInfo ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->benchmarkStarting( benchmarkInfo );
}
- void lazyPrintGroupInfo() {
- if( !currentGroupInfo->name.empty() && currentGroupInfo->groupsCounts > 1 ) {
- printClosedHeader( "Group: " + currentGroupInfo->name );
- currentGroupInfo.used = true;
- }
+ m_reporter->benchmarkStarting( benchmarkInfo );
+ }
+ void ListeningReporter::benchmarkEnded( BenchmarkStats<> const& benchmarkStats ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->benchmarkEnded( benchmarkStats );
}
- void printTestCaseAndSectionHeader() {
- assert( !m_sectionStack.empty() );
- printOpenHeader( currentTestCaseInfo->name );
-
- if( m_sectionStack.size() > 1 ) {
- Colour colourGuard( Colour::Headers );
-
- std::vector<SectionInfo>::const_iterator
- it = m_sectionStack.begin()+1, // Skip first section (test case)
- itEnd = m_sectionStack.end();
- for( ; it != itEnd; ++it )
- printHeaderString( it->name, 2 );
- }
+ m_reporter->benchmarkEnded( benchmarkStats );
+ }
- SourceLineInfo lineInfo = m_sectionStack.back().lineInfo;
+ void ListeningReporter::benchmarkFailed( std::string const& error ) {
+ for (auto const& listener : m_listeners) {
+ listener->benchmarkFailed(error);
+ }
+ m_reporter->benchmarkFailed(error);
+ }
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
- if( !lineInfo.empty() ){
- stream << getLineOfChars<'-'>() << '\n';
- Colour colourGuard( Colour::FileName );
- stream << lineInfo << '\n';
- }
- stream << getLineOfChars<'.'>() << '\n' << std::endl;
+ void ListeningReporter::testRunStarting( TestRunInfo const& testRunInfo ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->testRunStarting( testRunInfo );
}
+ m_reporter->testRunStarting( testRunInfo );
+ }
- void printClosedHeader( std::string const& _name ) {
- printOpenHeader( _name );
- stream << getLineOfChars<'.'>() << '\n';
- }
- void printOpenHeader( std::string const& _name ) {
- stream << getLineOfChars<'-'>() << '\n';
- {
- Colour colourGuard( Colour::Headers );
- printHeaderString( _name );
- }
+ void ListeningReporter::testGroupStarting( GroupInfo const& groupInfo ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->testGroupStarting( groupInfo );
}
+ m_reporter->testGroupStarting( groupInfo );
+ }
- // if string has a : in first line will set indent to follow it on
- // subsequent lines
- void printHeaderString( std::string const& _string, std::size_t indent = 0 ) {
- std::size_t i = _string.find( ": " );
- if( i != std::string::npos )
- i+=2;
- else
- i = 0;
- stream << Text( _string, TextAttributes()
- .setIndent( indent+i)
- .setInitialIndent( indent ) ) << '\n';
+ void ListeningReporter::testCaseStarting( TestCaseInfo const& testInfo ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->testCaseStarting( testInfo );
}
+ m_reporter->testCaseStarting( testInfo );
+ }
- struct SummaryColumn {
-
- SummaryColumn( std::string const& _label, Colour::Code _colour )
- : label( _label ),
- colour( _colour )
- {}
- SummaryColumn addRow( std::size_t count ) {
- std::ostringstream oss;
- oss << count;
- std::string row = oss.str();
- for( std::vector<std::string>::iterator it = rows.begin(); it != rows.end(); ++it ) {
- while( it->size() < row.size() )
- *it = ' ' + *it;
- while( it->size() > row.size() )
- row = ' ' + row;
- }
- rows.push_back( row );
- return *this;
- }
-
- std::string label;
- Colour::Code colour;
- std::vector<std::string> rows;
+ void ListeningReporter::sectionStarting( SectionInfo const& sectionInfo ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->sectionStarting( sectionInfo );
+ }
+ m_reporter->sectionStarting( sectionInfo );
+ }
- };
+ void ListeningReporter::assertionStarting( AssertionInfo const& assertionInfo ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->assertionStarting( assertionInfo );
+ }
+ m_reporter->assertionStarting( assertionInfo );
+ }
- void printTotals( Totals const& totals ) {
- if( totals.testCases.total() == 0 ) {
- stream << Colour( Colour::Warning ) << "No tests ran\n";
- }
- else if( totals.assertions.total() > 0 && totals.testCases.allPassed() ) {
- stream << Colour( Colour::ResultSuccess ) << "All tests passed";
- stream << " ("
- << pluralise( totals.assertions.passed, "assertion" ) << " in "
- << pluralise( totals.testCases.passed, "test case" ) << ')'
- << '\n';
- }
- else {
+ // The return value indicates if the messages buffer should be cleared:
+ bool ListeningReporter::assertionEnded( AssertionStats const& assertionStats ) {
+ for( auto const& listener : m_listeners ) {
+ static_cast<void>( listener->assertionEnded( assertionStats ) );
+ }
+ return m_reporter->assertionEnded( assertionStats );
+ }
- std::vector<SummaryColumn> columns;
- columns.push_back( SummaryColumn( "", Colour::None )
- .addRow( totals.testCases.total() )
- .addRow( totals.assertions.total() ) );
- columns.push_back( SummaryColumn( "passed", Colour::Success )
- .addRow( totals.testCases.passed )
- .addRow( totals.assertions.passed ) );
- columns.push_back( SummaryColumn( "failed", Colour::ResultError )
- .addRow( totals.testCases.failed )
- .addRow( totals.assertions.failed ) );
- columns.push_back( SummaryColumn( "failed as expected", Colour::ResultExpectedFailure )
- .addRow( totals.testCases.failedButOk )
- .addRow( totals.assertions.failedButOk ) );
-
- printSummaryRow( "test cases", columns, 0 );
- printSummaryRow( "assertions", columns, 1 );
- }
- }
- void printSummaryRow( std::string const& label, std::vector<SummaryColumn> const& cols, std::size_t row ) {
- for( std::vector<SummaryColumn>::const_iterator it = cols.begin(); it != cols.end(); ++it ) {
- std::string value = it->rows[row];
- if( it->label.empty() ) {
- stream << label << ": ";
- if( value != "0" )
- stream << value;
- else
- stream << Colour( Colour::Warning ) << "- none -";
- }
- else if( value != "0" ) {
- stream << Colour( Colour::LightGrey ) << " | ";
- stream << Colour( it->colour )
- << value << ' ' << it->label;
- }
- }
- stream << '\n';
+ void ListeningReporter::sectionEnded( SectionStats const& sectionStats ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->sectionEnded( sectionStats );
}
+ m_reporter->sectionEnded( sectionStats );
+ }
- static std::size_t makeRatio( std::size_t number, std::size_t total ) {
- std::size_t ratio = total > 0 ? CATCH_CONFIG_CONSOLE_WIDTH * number/ total : 0;
- return ( ratio == 0 && number > 0 ) ? 1 : ratio;
+ void ListeningReporter::testCaseEnded( TestCaseStats const& testCaseStats ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->testCaseEnded( testCaseStats );
}
- static std::size_t& findMax( std::size_t& i, std::size_t& j, std::size_t& k ) {
- if( i > j && i > k )
- return i;
- else if( j > k )
- return j;
- else
- return k;
- }
-
- void printTotalsDivider( Totals const& totals ) {
- if( totals.testCases.total() > 0 ) {
- std::size_t failedRatio = makeRatio( totals.testCases.failed, totals.testCases.total() );
- std::size_t failedButOkRatio = makeRatio( totals.testCases.failedButOk, totals.testCases.total() );
- std::size_t passedRatio = makeRatio( totals.testCases.passed, totals.testCases.total() );
- while( failedRatio + failedButOkRatio + passedRatio < CATCH_CONFIG_CONSOLE_WIDTH-1 )
- findMax( failedRatio, failedButOkRatio, passedRatio )++;
- while( failedRatio + failedButOkRatio + passedRatio > CATCH_CONFIG_CONSOLE_WIDTH-1 )
- findMax( failedRatio, failedButOkRatio, passedRatio )--;
-
- stream << Colour( Colour::Error ) << std::string( failedRatio, '=' );
- stream << Colour( Colour::ResultExpectedFailure ) << std::string( failedButOkRatio, '=' );
- if( totals.testCases.allPassed() )
- stream << Colour( Colour::ResultSuccess ) << std::string( passedRatio, '=' );
- else
- stream << Colour( Colour::Success ) << std::string( passedRatio, '=' );
- }
- else {
- stream << Colour( Colour::Warning ) << std::string( CATCH_CONFIG_CONSOLE_WIDTH-1, '=' );
- }
- stream << '\n';
+ m_reporter->testCaseEnded( testCaseStats );
+ }
+
+ void ListeningReporter::testGroupEnded( TestGroupStats const& testGroupStats ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->testGroupEnded( testGroupStats );
}
- void printSummaryDivider() {
- stream << getLineOfChars<'-'>() << '\n';
+ m_reporter->testGroupEnded( testGroupStats );
+ }
+
+ void ListeningReporter::testRunEnded( TestRunStats const& testRunStats ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->testRunEnded( testRunStats );
}
+ m_reporter->testRunEnded( testRunStats );
+ }
- private:
- bool m_headerPrinted;
- };
+ void ListeningReporter::skipTest( TestCaseInfo const& testInfo ) {
+ for ( auto const& listener : m_listeners ) {
+ listener->skipTest( testInfo );
+ }
+ m_reporter->skipTest( testInfo );
+ }
- INTERNAL_CATCH_REGISTER_REPORTER( "console", ConsoleReporter )
+ bool ListeningReporter::isMulti() const {
+ return true;
+ }
} // end namespace Catch
+// end catch_reporter_listening.cpp
+// start catch_reporter_xml.cpp
-// #included from: ../reporters/catch_reporter_compact.hpp
-#define TWOBLUECUBES_CATCH_REPORTER_COMPACT_HPP_INCLUDED
+#if defined(_MSC_VER)
+#pragma warning(push)
+#pragma warning(disable:4061) // Not all labels are EXPLICITLY handled in switch
+ // Note that 4062 (not all labels are handled
+ // and default is missing) is enabled
+#endif
namespace Catch {
+ XmlReporter::XmlReporter( ReporterConfig const& _config )
+ : StreamingReporterBase( _config ),
+ m_xml(_config.stream())
+ {
+ m_reporterPrefs.shouldRedirectStdOut = true;
+ m_reporterPrefs.shouldReportAllAssertions = true;
+ }
- struct CompactReporter : StreamingReporterBase {
+ XmlReporter::~XmlReporter() = default;
- CompactReporter( ReporterConfig const& _config )
- : StreamingReporterBase( _config )
- {}
+ std::string XmlReporter::getDescription() {
+ return "Reports test results as an XML document";
+ }
- virtual ~CompactReporter();
+ std::string XmlReporter::getStylesheetRef() const {
+ return std::string();
+ }
- static std::string getDescription() {
- return "Reports test results on a single line, suitable for IDEs";
- }
+ void XmlReporter::writeSourceInfo( SourceLineInfo const& sourceInfo ) {
+ m_xml
+ .writeAttribute( "filename", sourceInfo.file )
+ .writeAttribute( "line", sourceInfo.line );
+ }
- virtual ReporterPreferences getPreferences() const {
- ReporterPreferences prefs;
- prefs.shouldRedirectStdOut = false;
- return prefs;
- }
+ void XmlReporter::noMatchingTestCases( std::string const& s ) {
+ StreamingReporterBase::noMatchingTestCases( s );
+ }
- virtual void noMatchingTestCases( std::string const& spec ) {
- stream << "No test cases matched '" << spec << '\'' << std::endl;
- }
+ void XmlReporter::testRunStarting( TestRunInfo const& testInfo ) {
+ StreamingReporterBase::testRunStarting( testInfo );
+ std::string stylesheetRef = getStylesheetRef();
+ if( !stylesheetRef.empty() )
+ m_xml.writeStylesheetRef( stylesheetRef );
+ m_xml.startElement( "Catch" );
+ if( !m_config->name().empty() )
+ m_xml.writeAttribute( "name", m_config->name() );
+ if (m_config->testSpec().hasFilters())
+ m_xml.writeAttribute( "filters", serializeFilters( m_config->getTestsOrTags() ) );
+ if( m_config->rngSeed() != 0 )
+ m_xml.scopedElement( "Randomness" )
+ .writeAttribute( "seed", m_config->rngSeed() );
+ }
- virtual void assertionStarting( AssertionInfo const& ) {}
+ void XmlReporter::testGroupStarting( GroupInfo const& groupInfo ) {
+ StreamingReporterBase::testGroupStarting( groupInfo );
+ m_xml.startElement( "Group" )
+ .writeAttribute( "name", groupInfo.name );
+ }
- virtual bool assertionEnded( AssertionStats const& _assertionStats ) {
- AssertionResult const& result = _assertionStats.assertionResult;
+ void XmlReporter::testCaseStarting( TestCaseInfo const& testInfo ) {
+ StreamingReporterBase::testCaseStarting(testInfo);
+ m_xml.startElement( "TestCase" )
+ .writeAttribute( "name", trim( testInfo.name ) )
+ .writeAttribute( "description", testInfo.description )
+ .writeAttribute( "tags", testInfo.tagsAsString() );
- bool printInfoMessages = true;
+ writeSourceInfo( testInfo.lineInfo );
- // Drop out if result was successful and we're not printing those
- if( !m_config->includeSuccessfulResults() && result.isOk() ) {
- if( result.getResultType() != ResultWas::Warning )
- return false;
- printInfoMessages = false;
- }
-
- AssertionPrinter printer( stream, _assertionStats, printInfoMessages );
- printer.print();
+ if ( m_config->showDurations() == ShowDurations::Always )
+ m_testCaseTimer.start();
+ m_xml.ensureTagClosed();
+ }
- stream << std::endl;
- return true;
+ void XmlReporter::sectionStarting( SectionInfo const& sectionInfo ) {
+ StreamingReporterBase::sectionStarting( sectionInfo );
+ if( m_sectionDepth++ > 0 ) {
+ m_xml.startElement( "Section" )
+ .writeAttribute( "name", trim( sectionInfo.name ) );
+ writeSourceInfo( sectionInfo.lineInfo );
+ m_xml.ensureTagClosed();
}
+ }
- virtual void sectionEnded(SectionStats const& _sectionStats) CATCH_OVERRIDE {
- if (m_config->showDurations() == ShowDurations::Always) {
- stream << getFormattedDuration(_sectionStats.durationInSeconds) << " s: " << _sectionStats.sectionInfo.name << std::endl;
- }
- }
+ void XmlReporter::assertionStarting( AssertionInfo const& ) { }
- virtual void testRunEnded( TestRunStats const& _testRunStats ) {
- printTotals( _testRunStats.totals );
- stream << '\n' << std::endl;
- StreamingReporterBase::testRunEnded( _testRunStats );
- }
+ bool XmlReporter::assertionEnded( AssertionStats const& assertionStats ) {
- private:
- class AssertionPrinter {
- void operator= ( AssertionPrinter const& );
- public:
- AssertionPrinter( std::ostream& _stream, AssertionStats const& _stats, bool _printInfoMessages )
- : stream( _stream )
- , stats( _stats )
- , result( _stats.assertionResult )
- , messages( _stats.infoMessages )
- , itMessage( _stats.infoMessages.begin() )
- , printInfoMessages( _printInfoMessages )
- {}
+ AssertionResult const& result = assertionStats.assertionResult;
- void print() {
- printSourceInfo();
-
- itMessage = messages.begin();
-
- switch( result.getResultType() ) {
- case ResultWas::Ok:
- printResultType( Colour::ResultSuccess, passedString() );
- printOriginalExpression();
- printReconstructedExpression();
- if ( ! result.hasExpression() )
- printRemainingMessages( Colour::None );
- else
- printRemainingMessages();
- break;
- case ResultWas::ExpressionFailed:
- if( result.isOk() )
- printResultType( Colour::ResultSuccess, failedString() + std::string( " - but was ok" ) );
- else
- printResultType( Colour::Error, failedString() );
- printOriginalExpression();
- printReconstructedExpression();
- printRemainingMessages();
- break;
- case ResultWas::ThrewException:
- printResultType( Colour::Error, failedString() );
- printIssue( "unexpected exception with message:" );
- printMessage();
- printExpressionWas();
- printRemainingMessages();
- break;
- case ResultWas::FatalErrorCondition:
- printResultType( Colour::Error, failedString() );
- printIssue( "fatal error condition with message:" );
- printMessage();
- printExpressionWas();
- printRemainingMessages();
- break;
- case ResultWas::DidntThrowException:
- printResultType( Colour::Error, failedString() );
- printIssue( "expected exception, got none" );
- printExpressionWas();
- printRemainingMessages();
- break;
- case ResultWas::Info:
- printResultType( Colour::None, "info" );
- printMessage();
- printRemainingMessages();
- break;
- case ResultWas::Warning:
- printResultType( Colour::None, "warning" );
- printMessage();
- printRemainingMessages();
- break;
- case ResultWas::ExplicitFailure:
- printResultType( Colour::Error, failedString() );
- printIssue( "explicitly" );
- printRemainingMessages( Colour::None );
- break;
- // These cases are here to prevent compiler warnings
- case ResultWas::Unknown:
- case ResultWas::FailureBit:
- case ResultWas::Exception:
- printResultType( Colour::Error, "** internal error **" );
- break;
+ bool includeResults = m_config->includeSuccessfulResults() || !result.isOk();
+
+ if( includeResults || result.getResultType() == ResultWas::Warning ) {
+ // Print any info messages in <Info> tags.
+ for( auto const& msg : assertionStats.infoMessages ) {
+ if( msg.type == ResultWas::Info && includeResults ) {
+ m_xml.scopedElement( "Info" )
+ .writeText( msg.message );
+ } else if ( msg.type == ResultWas::Warning ) {
+ m_xml.scopedElement( "Warning" )
+ .writeText( msg.message );
}
}
+ }
- private:
- // Colour::LightGrey
+ // Drop out if result was successful but we're not printing them.
+ if( !includeResults && result.getResultType() != ResultWas::Warning )
+ return true;
- static Colour::Code dimColour() { return Colour::FileName; }
+ // Print the expression if there is one.
+ if( result.hasExpression() ) {
+ m_xml.startElement( "Expression" )
+ .writeAttribute( "success", result.succeeded() )
+ .writeAttribute( "type", result.getTestMacroName() );
-#ifdef CATCH_PLATFORM_MAC
- static const char* failedString() { return "FAILED"; }
- static const char* passedString() { return "PASSED"; }
-#else
- static const char* failedString() { return "failed"; }
- static const char* passedString() { return "passed"; }
-#endif
-
- void printSourceInfo() const {
- Colour colourGuard( Colour::FileName );
- stream << result.getSourceInfo() << ':';
- }
+ writeSourceInfo( result.getSourceInfo() );
- void printResultType( Colour::Code colour, std::string const& passOrFail ) const {
- if( !passOrFail.empty() ) {
- {
- Colour colourGuard( colour );
- stream << ' ' << passOrFail;
- }
- stream << ':';
- }
- }
+ m_xml.scopedElement( "Original" )
+ .writeText( result.getExpression() );
+ m_xml.scopedElement( "Expanded" )
+ .writeText( result.getExpandedExpression() );
+ }
- void printIssue( std::string const& issue ) const {
- stream << ' ' << issue;
- }
+ // And... Print a result applicable to each result type.
+ switch( result.getResultType() ) {
+ case ResultWas::ThrewException:
+ m_xml.startElement( "Exception" );
+ writeSourceInfo( result.getSourceInfo() );
+ m_xml.writeText( result.getMessage() );
+ m_xml.endElement();
+ break;
+ case ResultWas::FatalErrorCondition:
+ m_xml.startElement( "FatalErrorCondition" );
+ writeSourceInfo( result.getSourceInfo() );
+ m_xml.writeText( result.getMessage() );
+ m_xml.endElement();
+ break;
+ case ResultWas::Info:
+ m_xml.scopedElement( "Info" )
+ .writeText( result.getMessage() );
+ break;
+ case ResultWas::Warning:
+ // Warning will already have been written
+ break;
+ case ResultWas::ExplicitFailure:
+ m_xml.startElement( "Failure" );
+ writeSourceInfo( result.getSourceInfo() );
+ m_xml.writeText( result.getMessage() );
+ m_xml.endElement();
+ break;
+ default:
+ break;
+ }
- void printExpressionWas() {
- if( result.hasExpression() ) {
- stream << ';';
- {
- Colour colour( dimColour() );
- stream << " expression was:";
- }
- printOriginalExpression();
- }
- }
+ if( result.hasExpression() )
+ m_xml.endElement();
- void printOriginalExpression() const {
- if( result.hasExpression() ) {
- stream << ' ' << result.getExpression();
- }
- }
+ return true;
+ }
- void printReconstructedExpression() const {
- if( result.hasExpandedExpression() ) {
- {
- Colour colour( dimColour() );
- stream << " for: ";
- }
- stream << result.getExpandedExpression();
- }
- }
+ void XmlReporter::sectionEnded( SectionStats const& sectionStats ) {
+ StreamingReporterBase::sectionEnded( sectionStats );
+ if( --m_sectionDepth > 0 ) {
+ XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResults" );
+ e.writeAttribute( "successes", sectionStats.assertions.passed );
+ e.writeAttribute( "failures", sectionStats.assertions.failed );
+ e.writeAttribute( "expectedFailures", sectionStats.assertions.failedButOk );
- void printMessage() {
- if ( itMessage != messages.end() ) {
- stream << " '" << itMessage->message << '\'';
- ++itMessage;
- }
- }
+ if ( m_config->showDurations() == ShowDurations::Always )
+ e.writeAttribute( "durationInSeconds", sectionStats.durationInSeconds );
- void printRemainingMessages( Colour::Code colour = dimColour() ) {
- if ( itMessage == messages.end() )
- return;
+ m_xml.endElement();
+ }
+ }
- // using messages.end() directly yields compilation error:
- std::vector<MessageInfo>::const_iterator itEnd = messages.end();
- const std::size_t N = static_cast<std::size_t>( std::distance( itMessage, itEnd ) );
+ void XmlReporter::testCaseEnded( TestCaseStats const& testCaseStats ) {
+ StreamingReporterBase::testCaseEnded( testCaseStats );
+ XmlWriter::ScopedElement e = m_xml.scopedElement( "OverallResult" );
+ e.writeAttribute( "success", testCaseStats.totals.assertions.allOk() );
- {
- Colour colourGuard( colour );
- stream << " with " << pluralise( N, "message" ) << ':';
- }
+ if ( m_config->showDurations() == ShowDurations::Always )
+ e.writeAttribute( "durationInSeconds", m_testCaseTimer.getElapsedSeconds() );
- for(; itMessage != itEnd; ) {
- // If this assertion is a warning ignore any INFO messages
- if( printInfoMessages || itMessage->type != ResultWas::Info ) {
- stream << " '" << itMessage->message << '\'';
- if ( ++itMessage != itEnd ) {
- Colour colourGuard( dimColour() );
- stream << " and";
- }
- }
- }
- }
+ if( !testCaseStats.stdOut.empty() )
+ m_xml.scopedElement( "StdOut" ).writeText( trim( testCaseStats.stdOut ), false );
+ if( !testCaseStats.stdErr.empty() )
+ m_xml.scopedElement( "StdErr" ).writeText( trim( testCaseStats.stdErr ), false );
- private:
- std::ostream& stream;
- AssertionStats const& stats;
- AssertionResult const& result;
- std::vector<MessageInfo> messages;
- std::vector<MessageInfo>::const_iterator itMessage;
- bool printInfoMessages;
- };
+ m_xml.endElement();
+ }
- // Colour, message variants:
- // - white: No tests ran.
- // - red: Failed [both/all] N test cases, failed [both/all] M assertions.
- // - white: Passed [both/all] N test cases (no assertions).
- // - red: Failed N tests cases, failed M assertions.
- // - green: Passed [both/all] N tests cases with M assertions.
-
- std::string bothOrAll( std::size_t count ) const {
- return count == 1 ? std::string() : count == 2 ? "both " : "all " ;
- }
-
- void printTotals( const Totals& totals ) const {
- if( totals.testCases.total() == 0 ) {
- stream << "No tests ran.";
- }
- else if( totals.testCases.failed == totals.testCases.total() ) {
- Colour colour( Colour::ResultError );
- const std::string qualify_assertions_failed =
- totals.assertions.failed == totals.assertions.total() ?
- bothOrAll( totals.assertions.failed ) : std::string();
- stream <<
- "Failed " << bothOrAll( totals.testCases.failed )
- << pluralise( totals.testCases.failed, "test case" ) << ", "
- "failed " << qualify_assertions_failed <<
- pluralise( totals.assertions.failed, "assertion" ) << '.';
- }
- else if( totals.assertions.total() == 0 ) {
- stream <<
- "Passed " << bothOrAll( totals.testCases.total() )
- << pluralise( totals.testCases.total(), "test case" )
- << " (no assertions).";
- }
- else if( totals.assertions.failed ) {
- Colour colour( Colour::ResultError );
- stream <<
- "Failed " << pluralise( totals.testCases.failed, "test case" ) << ", "
- "failed " << pluralise( totals.assertions.failed, "assertion" ) << '.';
- }
- else {
- Colour colour( Colour::ResultSuccess );
- stream <<
- "Passed " << bothOrAll( totals.testCases.passed )
- << pluralise( totals.testCases.passed, "test case" ) <<
- " with " << pluralise( totals.assertions.passed, "assertion" ) << '.';
- }
- }
- };
+ void XmlReporter::testGroupEnded( TestGroupStats const& testGroupStats ) {
+ StreamingReporterBase::testGroupEnded( testGroupStats );
+ // TODO: Check testGroupStats.aborting and act accordingly.
+ m_xml.scopedElement( "OverallResults" )
+ .writeAttribute( "successes", testGroupStats.totals.assertions.passed )
+ .writeAttribute( "failures", testGroupStats.totals.assertions.failed )
+ .writeAttribute( "expectedFailures", testGroupStats.totals.assertions.failedButOk );
+ m_xml.endElement();
+ }
- INTERNAL_CATCH_REGISTER_REPORTER( "compact", CompactReporter )
+ void XmlReporter::testRunEnded( TestRunStats const& testRunStats ) {
+ StreamingReporterBase::testRunEnded( testRunStats );
+ m_xml.scopedElement( "OverallResults" )
+ .writeAttribute( "successes", testRunStats.totals.assertions.passed )
+ .writeAttribute( "failures", testRunStats.totals.assertions.failed )
+ .writeAttribute( "expectedFailures", testRunStats.totals.assertions.failedButOk );
+ m_xml.endElement();
+ }
-} // end namespace Catch
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+ void XmlReporter::benchmarkPreparing(std::string const& name) {
+ m_xml.startElement("BenchmarkResults")
+ .writeAttribute("name", name);
+ }
-namespace Catch {
- // These are all here to avoid warnings about not having any out of line
- // virtual methods
- NonCopyable::~NonCopyable() {}
- IShared::~IShared() {}
- IStream::~IStream() CATCH_NOEXCEPT {}
- FileStream::~FileStream() CATCH_NOEXCEPT {}
- CoutStream::~CoutStream() CATCH_NOEXCEPT {}
- DebugOutStream::~DebugOutStream() CATCH_NOEXCEPT {}
- StreamBufBase::~StreamBufBase() CATCH_NOEXCEPT {}
- IContext::~IContext() {}
- IResultCapture::~IResultCapture() {}
- ITestCase::~ITestCase() {}
- ITestCaseRegistry::~ITestCaseRegistry() {}
- IRegistryHub::~IRegistryHub() {}
- IMutableRegistryHub::~IMutableRegistryHub() {}
- IExceptionTranslator::~IExceptionTranslator() {}
- IExceptionTranslatorRegistry::~IExceptionTranslatorRegistry() {}
- IReporter::~IReporter() {}
- IReporterFactory::~IReporterFactory() {}
- IReporterRegistry::~IReporterRegistry() {}
- IStreamingReporter::~IStreamingReporter() {}
- AssertionStats::~AssertionStats() {}
- SectionStats::~SectionStats() {}
- TestCaseStats::~TestCaseStats() {}
- TestGroupStats::~TestGroupStats() {}
- TestRunStats::~TestRunStats() {}
- CumulativeReporterBase::SectionNode::~SectionNode() {}
- CumulativeReporterBase::~CumulativeReporterBase() {}
-
- StreamingReporterBase::~StreamingReporterBase() {}
- ConsoleReporter::~ConsoleReporter() {}
- CompactReporter::~CompactReporter() {}
- IRunner::~IRunner() {}
- IMutableContext::~IMutableContext() {}
- IConfig::~IConfig() {}
- XmlReporter::~XmlReporter() {}
- JunitReporter::~JunitReporter() {}
- TestRegistry::~TestRegistry() {}
- FreeFunctionTestCase::~FreeFunctionTestCase() {}
- IGeneratorInfo::~IGeneratorInfo() {}
- IGeneratorsForTest::~IGeneratorsForTest() {}
- WildcardPattern::~WildcardPattern() {}
- TestSpec::Pattern::~Pattern() {}
- TestSpec::NamePattern::~NamePattern() {}
- TestSpec::TagPattern::~TagPattern() {}
- TestSpec::ExcludedPattern::~ExcludedPattern() {}
- Matchers::Impl::MatcherUntypedBase::~MatcherUntypedBase() {}
+ void XmlReporter::benchmarkStarting(BenchmarkInfo const &info) {
+ m_xml.writeAttribute("samples", info.samples)
+ .writeAttribute("resamples", info.resamples)
+ .writeAttribute("iterations", info.iterations)
+ .writeAttribute("clockResolution", static_cast<uint64_t>(info.clockResolution))
+ .writeAttribute("estimatedDuration", static_cast<uint64_t>(info.estimatedDuration))
+ .writeComment("All values in nano seconds");
+ }
- void Config::dummy() {}
+ void XmlReporter::benchmarkEnded(BenchmarkStats<> const& benchmarkStats) {
+ m_xml.startElement("mean")
+ .writeAttribute("value", static_cast<uint64_t>(benchmarkStats.mean.point.count()))
+ .writeAttribute("lowerBound", static_cast<uint64_t>(benchmarkStats.mean.lower_bound.count()))
+ .writeAttribute("upperBound", static_cast<uint64_t>(benchmarkStats.mean.upper_bound.count()))
+ .writeAttribute("ci", benchmarkStats.mean.confidence_interval);
+ m_xml.endElement();
+ m_xml.startElement("standardDeviation")
+ .writeAttribute("value", benchmarkStats.standardDeviation.point.count())
+ .writeAttribute("lowerBound", benchmarkStats.standardDeviation.lower_bound.count())
+ .writeAttribute("upperBound", benchmarkStats.standardDeviation.upper_bound.count())
+ .writeAttribute("ci", benchmarkStats.standardDeviation.confidence_interval);
+ m_xml.endElement();
+ m_xml.startElement("outliers")
+ .writeAttribute("variance", benchmarkStats.outlierVariance)
+ .writeAttribute("lowMild", benchmarkStats.outliers.low_mild)
+ .writeAttribute("lowSevere", benchmarkStats.outliers.low_severe)
+ .writeAttribute("highMild", benchmarkStats.outliers.high_mild)
+ .writeAttribute("highSevere", benchmarkStats.outliers.high_severe);
+ m_xml.endElement();
+ m_xml.endElement();
+ }
- namespace TestCaseTracking {
- ITracker::~ITracker() {}
- TrackerBase::~TrackerBase() {}
- SectionTracker::~SectionTracker() {}
- IndexTracker::~IndexTracker() {}
+ void XmlReporter::benchmarkFailed(std::string const &error) {
+ m_xml.scopedElement("failed").
+ writeAttribute("message", error);
+ m_xml.endElement();
}
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
+
+ CATCH_REGISTER_REPORTER( "xml", XmlReporter )
+
+} // end namespace Catch
+
+#if defined(_MSC_VER)
+#pragma warning(pop)
+#endif
+// end catch_reporter_xml.cpp
+
+namespace Catch {
+ LeakDetector leakDetector;
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
+// end catch_impl.hpp
#endif
#ifdef CATCH_CONFIG_MAIN
-// #included from: internal/catch_default_main.hpp
-#define TWOBLUECUBES_CATCH_DEFAULT_MAIN_HPP_INCLUDED
+// start catch_default_main.hpp
#ifndef __OBJC__
-#if defined(WIN32) && defined(_UNICODE) && !defined(DO_NOT_USE_WMAIN)
+#if defined(CATCH_CONFIG_WCHAR) && defined(WIN32) && defined(_UNICODE) && !defined(DO_NOT_USE_WMAIN)
// Standard C/C++ Win32 Unicode wmain entry point
extern "C" int wmain (int argc, wchar_t * argv[], wchar_t * []) {
#else
@@ -11303,8 +17034,7 @@ extern "C" int wmain (int argc, wchar_t * argv[], wchar_t * []) {
int main (int argc, char * argv[]) {
#endif
- int result = Catch::Session().run(argc, argv);
- return ( result < 0xff ? result : 0xff );
+ return Catch::Session().run( argc, argv );
}
#else // __OBJC__
@@ -11316,193 +17046,425 @@ int main (int argc, char * const argv[]) {
#endif
Catch::registerTestMethods();
- int result = Catch::Session().run( argc, (char* const*)argv );
+ int result = Catch::Session().run( argc, (char**)argv );
#if !CATCH_ARC_ENABLED
[pool drain];
#endif
- return ( result < 0xff ? result : 0xff );
+ return result;
}
#endif // __OBJC__
+// end catch_default_main.hpp
#endif
+#if !defined(CATCH_CONFIG_IMPL_ONLY)
+
#ifdef CLARA_CONFIG_MAIN_NOT_DEFINED
# undef CLARA_CONFIG_MAIN
#endif
+#if !defined(CATCH_CONFIG_DISABLE)
//////
-
// If this config identifier is defined then all CATCH macros are prefixed with CATCH_
#ifdef CATCH_CONFIG_PREFIX_ALL
-#if defined(CATCH_CONFIG_FAST_COMPILE)
-#define CATCH_REQUIRE( expr ) INTERNAL_CATCH_TEST_NO_TRY( "CATCH_REQUIRE", Catch::ResultDisposition::Normal, expr )
-#define CATCH_REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST_NO_TRY( "CATCH_REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, expr )
-#else
-#define CATCH_REQUIRE( expr ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE", Catch::ResultDisposition::Normal, expr )
-#define CATCH_REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, expr )
-#endif
+#define CATCH_REQUIRE( ... ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE", Catch::ResultDisposition::Normal, __VA_ARGS__ )
+#define CATCH_REQUIRE_FALSE( ... ) INTERNAL_CATCH_TEST( "CATCH_REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, __VA_ARGS__ )
-#define CATCH_REQUIRE_THROWS( expr ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, "", expr )
+#define CATCH_REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
#define CATCH_REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CATCH_REQUIRE_THROWS_AS", exceptionType, Catch::ResultDisposition::Normal, expr )
-#define CATCH_REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS( "CATCH_REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr )
-#define CATCH_REQUIRE_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( "CATCH_REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, expr )
-
-#define CATCH_CHECK( expr ) INTERNAL_CATCH_TEST( "CATCH_CHECK", Catch::ResultDisposition::ContinueOnFailure, expr )
-#define CATCH_CHECK_FALSE( expr ) INTERNAL_CATCH_TEST( "CATCH_CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, expr )
-#define CATCH_CHECKED_IF( expr ) INTERNAL_CATCH_IF( "CATCH_CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, expr )
-#define CATCH_CHECKED_ELSE( expr ) INTERNAL_CATCH_ELSE( "CATCH_CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, expr )
-#define CATCH_CHECK_NOFAIL( expr ) INTERNAL_CATCH_TEST( "CATCH_CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, expr )
-
-#define CATCH_CHECK_THROWS( expr ) INTERNAL_CATCH_THROWS( "CATCH_CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, "", expr )
+#define CATCH_REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CATCH_REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr )
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define CATCH_REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CATCH_REQUIRE_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::Normal, matcher, expr )
+#endif// CATCH_CONFIG_DISABLE_MATCHERS
+#define CATCH_REQUIRE_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CATCH_REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, __VA_ARGS__ )
+
+#define CATCH_CHECK( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define CATCH_CHECK_FALSE( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, __VA_ARGS__ )
+#define CATCH_CHECKED_IF( ... ) INTERNAL_CATCH_IF( "CATCH_CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define CATCH_CHECKED_ELSE( ... ) INTERNAL_CATCH_ELSE( "CATCH_CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define CATCH_CHECK_NOFAIL( ... ) INTERNAL_CATCH_TEST( "CATCH_CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, __VA_ARGS__ )
+
+#define CATCH_CHECK_THROWS( ... ) INTERNAL_CATCH_THROWS( "CATCH_CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
#define CATCH_CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CATCH_CHECK_THROWS_AS", exceptionType, Catch::ResultDisposition::ContinueOnFailure, expr )
-#define CATCH_CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS( "CATCH_CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr )
-#define CATCH_CHECK_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( "CATCH_CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, expr )
+#define CATCH_CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CATCH_CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr )
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define CATCH_CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CATCH_CHECK_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::ContinueOnFailure, matcher, expr )
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+#define CATCH_CHECK_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CATCH_CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
#define CATCH_CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CATCH_CHECK_THAT", matcher, Catch::ResultDisposition::ContinueOnFailure, arg )
-#if defined(CATCH_CONFIG_FAST_COMPILE)
-#define CATCH_REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT_NO_TRY( "CATCH_REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg )
-#else
#define CATCH_REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CATCH_REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg )
-#endif
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
#define CATCH_INFO( msg ) INTERNAL_CATCH_INFO( "CATCH_INFO", msg )
+#define CATCH_UNSCOPED_INFO( msg ) INTERNAL_CATCH_UNSCOPED_INFO( "CATCH_UNSCOPED_INFO", msg )
#define CATCH_WARN( msg ) INTERNAL_CATCH_MSG( "CATCH_WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg )
-#define CATCH_SCOPED_INFO( msg ) INTERNAL_CATCH_INFO( "CATCH_INFO", msg )
-#define CATCH_CAPTURE( msg ) INTERNAL_CATCH_INFO( "CATCH_CAPTURE", #msg " := " << Catch::toString(msg) )
-#define CATCH_SCOPED_CAPTURE( msg ) INTERNAL_CATCH_INFO( "CATCH_CAPTURE", #msg " := " << Catch::toString(msg) )
-
-#ifdef CATCH_CONFIG_VARIADIC_MACROS
- #define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ )
- #define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ )
- #define CATCH_METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ )
- #define CATCH_REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ )
- #define CATCH_SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ )
- #define CATCH_FAIL( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ )
- #define CATCH_FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
- #define CATCH_SUCCEED( ... ) INTERNAL_CATCH_MSG( "CATCH_SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define CATCH_CAPTURE( ... ) INTERNAL_CATCH_CAPTURE( INTERNAL_CATCH_UNIQUE_NAME(capturer), "CATCH_CAPTURE",__VA_ARGS__ )
+
+#define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ )
+#define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define CATCH_METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ )
+#define CATCH_REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ )
+#define CATCH_SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ )
+#define CATCH_DYNAMIC_SECTION( ... ) INTERNAL_CATCH_DYNAMIC_SECTION( __VA_ARGS__ )
+#define CATCH_FAIL( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ )
+#define CATCH_FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "CATCH_FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define CATCH_SUCCEED( ... ) INTERNAL_CATCH_MSG( "CATCH_SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+
+#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE()
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define CATCH_TEMPLATE_TEST_CASE_SIG( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG( __VA_ARGS__ )
+#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( className, __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG( ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG( __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, __VA_ARGS__ )
#else
- #define CATCH_TEST_CASE( name, description ) INTERNAL_CATCH_TESTCASE( name, description )
- #define CATCH_TEST_CASE_METHOD( className, name, description ) INTERNAL_CATCH_TEST_CASE_METHOD( className, name, description )
- #define CATCH_METHOD_AS_TEST_CASE( method, name, description ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, name, description )
- #define CATCH_REGISTER_TEST_CASE( function, name, description ) INTERNAL_CATCH_REGISTER_TESTCASE( function, name, description )
- #define CATCH_SECTION( name, description ) INTERNAL_CATCH_SECTION( name, description )
- #define CATCH_FAIL( msg ) INTERNAL_CATCH_MSG( "CATCH_FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, msg )
- #define CATCH_FAIL_CHECK( msg ) INTERNAL_CATCH_MSG( "CATCH_FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, msg )
- #define CATCH_SUCCEED( msg ) INTERNAL_CATCH_MSG( "CATCH_SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, msg )
+#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_TEST_CASE_SIG( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG( __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( className, __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG( __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, __VA_ARGS__ ) )
#endif
-#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE( "", "" )
-
-#define CATCH_REGISTER_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType )
-#define CATCH_REGISTER_LEGACY_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType )
-#define CATCH_GENERATE( expr) INTERNAL_CATCH_GENERATE( expr )
+#if !defined(CATCH_CONFIG_RUNTIME_STATIC_REQUIRE)
+#define CATCH_STATIC_REQUIRE( ... ) static_assert( __VA_ARGS__ , #__VA_ARGS__ ); CATCH_SUCCEED( #__VA_ARGS__ )
+#define CATCH_STATIC_REQUIRE_FALSE( ... ) static_assert( !(__VA_ARGS__), "!(" #__VA_ARGS__ ")" ); CATCH_SUCCEED( #__VA_ARGS__ )
+#else
+#define CATCH_STATIC_REQUIRE( ... ) CATCH_REQUIRE( __VA_ARGS__ )
+#define CATCH_STATIC_REQUIRE_FALSE( ... ) CATCH_REQUIRE_FALSE( __VA_ARGS__ )
+#endif
// "BDD-style" convenience wrappers
-#ifdef CATCH_CONFIG_VARIADIC_MACROS
#define CATCH_SCENARIO( ... ) CATCH_TEST_CASE( "Scenario: " __VA_ARGS__ )
#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " __VA_ARGS__ )
-#else
-#define CATCH_SCENARIO( name, tags ) CATCH_TEST_CASE( "Scenario: " name, tags )
-#define CATCH_SCENARIO_METHOD( className, name, tags ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " name, tags )
-#endif
-#define CATCH_GIVEN( desc ) CATCH_SECTION( std::string( "Given: ") + desc, "" )
-#define CATCH_WHEN( desc ) CATCH_SECTION( std::string( " When: ") + desc, "" )
-#define CATCH_AND_WHEN( desc ) CATCH_SECTION( std::string( " And: ") + desc, "" )
-#define CATCH_THEN( desc ) CATCH_SECTION( std::string( " Then: ") + desc, "" )
-#define CATCH_AND_THEN( desc ) CATCH_SECTION( std::string( " And: ") + desc, "" )
+#define CATCH_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Given: " << desc )
+#define CATCH_AND_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( "And given: " << desc )
+#define CATCH_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " When: " << desc )
+#define CATCH_AND_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And when: " << desc )
+#define CATCH_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Then: " << desc )
+#define CATCH_AND_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And: " << desc )
+
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+#define CATCH_BENCHMARK(...) \
+ INTERNAL_CATCH_BENCHMARK(INTERNAL_CATCH_UNIQUE_NAME(____C_A_T_C_H____B_E_N_C_H____), INTERNAL_CATCH_GET_1_ARG(__VA_ARGS__,,), INTERNAL_CATCH_GET_2_ARG(__VA_ARGS__,,))
+#define CATCH_BENCHMARK_ADVANCED(name) \
+ INTERNAL_CATCH_BENCHMARK_ADVANCED(INTERNAL_CATCH_UNIQUE_NAME(____C_A_T_C_H____B_E_N_C_H____), name)
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
// If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required
#else
-#if defined(CATCH_CONFIG_FAST_COMPILE)
-#define REQUIRE( expr ) INTERNAL_CATCH_TEST_NO_TRY( "REQUIRE", Catch::ResultDisposition::Normal, expr )
-#define REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST_NO_TRY( "REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, expr )
+#define REQUIRE( ... ) INTERNAL_CATCH_TEST( "REQUIRE", Catch::ResultDisposition::Normal, __VA_ARGS__ )
+#define REQUIRE_FALSE( ... ) INTERNAL_CATCH_TEST( "REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, __VA_ARGS__ )
-#else
-#define REQUIRE( expr ) INTERNAL_CATCH_TEST( "REQUIRE", Catch::ResultDisposition::Normal, expr )
-#define REQUIRE_FALSE( expr ) INTERNAL_CATCH_TEST( "REQUIRE_FALSE", Catch::ResultDisposition::Normal | Catch::ResultDisposition::FalseTest, expr )
-#endif
-
-#define REQUIRE_THROWS( expr ) INTERNAL_CATCH_THROWS( "REQUIRE_THROWS", Catch::ResultDisposition::Normal, "", expr )
+#define REQUIRE_THROWS( ... ) INTERNAL_CATCH_THROWS( "REQUIRE_THROWS", Catch::ResultDisposition::Normal, __VA_ARGS__ )
#define REQUIRE_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "REQUIRE_THROWS_AS", exceptionType, Catch::ResultDisposition::Normal, expr )
-#define REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS( "REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr )
-#define REQUIRE_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( "REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, expr )
-
-#define CHECK( expr ) INTERNAL_CATCH_TEST( "CHECK", Catch::ResultDisposition::ContinueOnFailure, expr )
-#define CHECK_FALSE( expr ) INTERNAL_CATCH_TEST( "CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, expr )
-#define CHECKED_IF( expr ) INTERNAL_CATCH_IF( "CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, expr )
-#define CHECKED_ELSE( expr ) INTERNAL_CATCH_ELSE( "CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, expr )
-#define CHECK_NOFAIL( expr ) INTERNAL_CATCH_TEST( "CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, expr )
-
-#define CHECK_THROWS( expr ) INTERNAL_CATCH_THROWS( "CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, "", expr )
+#define REQUIRE_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "REQUIRE_THROWS_WITH", Catch::ResultDisposition::Normal, matcher, expr )
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "REQUIRE_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::Normal, matcher, expr )
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+#define REQUIRE_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "REQUIRE_NOTHROW", Catch::ResultDisposition::Normal, __VA_ARGS__ )
+
+#define CHECK( ... ) INTERNAL_CATCH_TEST( "CHECK", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define CHECK_FALSE( ... ) INTERNAL_CATCH_TEST( "CHECK_FALSE", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::FalseTest, __VA_ARGS__ )
+#define CHECKED_IF( ... ) INTERNAL_CATCH_IF( "CHECKED_IF", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define CHECKED_ELSE( ... ) INTERNAL_CATCH_ELSE( "CHECKED_ELSE", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define CHECK_NOFAIL( ... ) INTERNAL_CATCH_TEST( "CHECK_NOFAIL", Catch::ResultDisposition::ContinueOnFailure | Catch::ResultDisposition::SuppressFail, __VA_ARGS__ )
+
+#define CHECK_THROWS( ... ) INTERNAL_CATCH_THROWS( "CHECK_THROWS", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
#define CHECK_THROWS_AS( expr, exceptionType ) INTERNAL_CATCH_THROWS_AS( "CHECK_THROWS_AS", exceptionType, Catch::ResultDisposition::ContinueOnFailure, expr )
-#define CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS( "CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr )
-#define CHECK_NOTHROW( expr ) INTERNAL_CATCH_NO_THROW( "CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, expr )
+#define CHECK_THROWS_WITH( expr, matcher ) INTERNAL_CATCH_THROWS_STR_MATCHES( "CHECK_THROWS_WITH", Catch::ResultDisposition::ContinueOnFailure, matcher, expr )
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) INTERNAL_CATCH_THROWS_MATCHES( "CHECK_THROWS_MATCHES", exceptionType, Catch::ResultDisposition::ContinueOnFailure, matcher, expr )
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+#define CHECK_NOTHROW( ... ) INTERNAL_CATCH_NO_THROW( "CHECK_NOTHROW", Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
#define CHECK_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "CHECK_THAT", matcher, Catch::ResultDisposition::ContinueOnFailure, arg )
-#if defined(CATCH_CONFIG_FAST_COMPILE)
-#define REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT_NO_TRY( "REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg )
-#else
#define REQUIRE_THAT( arg, matcher ) INTERNAL_CHECK_THAT( "REQUIRE_THAT", matcher, Catch::ResultDisposition::Normal, arg )
-#endif
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
#define INFO( msg ) INTERNAL_CATCH_INFO( "INFO", msg )
+#define UNSCOPED_INFO( msg ) INTERNAL_CATCH_UNSCOPED_INFO( "UNSCOPED_INFO", msg )
#define WARN( msg ) INTERNAL_CATCH_MSG( "WARN", Catch::ResultWas::Warning, Catch::ResultDisposition::ContinueOnFailure, msg )
-#define SCOPED_INFO( msg ) INTERNAL_CATCH_INFO( "INFO", msg )
-#define CAPTURE( msg ) INTERNAL_CATCH_INFO( "CAPTURE", #msg " := " << Catch::toString(msg) )
-#define SCOPED_CAPTURE( msg ) INTERNAL_CATCH_INFO( "CAPTURE", #msg " := " << Catch::toString(msg) )
+#define CAPTURE( ... ) INTERNAL_CATCH_CAPTURE( INTERNAL_CATCH_UNIQUE_NAME(capturer), "CAPTURE",__VA_ARGS__ )
-#ifdef CATCH_CONFIG_VARIADIC_MACROS
#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE( __VA_ARGS__ )
#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, __VA_ARGS__ )
#define METHOD_AS_TEST_CASE( method, ... ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, __VA_ARGS__ )
#define REGISTER_TEST_CASE( Function, ... ) INTERNAL_CATCH_REGISTER_TESTCASE( Function, __VA_ARGS__ )
#define SECTION( ... ) INTERNAL_CATCH_SECTION( __VA_ARGS__ )
+#define DYNAMIC_SECTION( ... ) INTERNAL_CATCH_DYNAMIC_SECTION( __VA_ARGS__ )
#define FAIL( ... ) INTERNAL_CATCH_MSG( "FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, __VA_ARGS__ )
#define FAIL_CHECK( ... ) INTERNAL_CATCH_MSG( "FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
#define SUCCEED( ... ) INTERNAL_CATCH_MSG( "SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, __VA_ARGS__ )
+#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE()
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define TEMPLATE_TEST_CASE_SIG( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG( __VA_ARGS__ )
+#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define TEMPLATE_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( className, __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_SIG( ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG( __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, __VA_ARGS__ )
+#define TEMPLATE_LIST_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE(__VA_ARGS__)
+#define TEMPLATE_LIST_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD( className, __VA_ARGS__ )
#else
-#define TEST_CASE( name, description ) INTERNAL_CATCH_TESTCASE( name, description )
- #define TEST_CASE_METHOD( className, name, description ) INTERNAL_CATCH_TEST_CASE_METHOD( className, name, description )
- #define METHOD_AS_TEST_CASE( method, name, description ) INTERNAL_CATCH_METHOD_AS_TEST_CASE( method, name, description )
- #define REGISTER_TEST_CASE( method, name, description ) INTERNAL_CATCH_REGISTER_TESTCASE( method, name, description )
- #define SECTION( name, description ) INTERNAL_CATCH_SECTION( name, description )
- #define FAIL( msg ) INTERNAL_CATCH_MSG( "FAIL", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::Normal, msg )
- #define FAIL_CHECK( msg ) INTERNAL_CATCH_MSG( "FAIL_CHECK", Catch::ResultWas::ExplicitFailure, Catch::ResultDisposition::ContinueOnFailure, msg )
- #define SUCCEED( msg ) INTERNAL_CATCH_MSG( "SUCCEED", Catch::ResultWas::Ok, Catch::ResultDisposition::ContinueOnFailure, msg )
+#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ ) )
+#define TEMPLATE_TEST_CASE_SIG( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG( __VA_ARGS__ ) )
+#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ ) )
+#define TEMPLATE_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( className, __VA_ARGS__ ) )
+#define TEMPLATE_PRODUCT_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE( __VA_ARGS__ ) )
+#define TEMPLATE_PRODUCT_TEST_CASE_SIG( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG( __VA_ARGS__ ) )
+#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, __VA_ARGS__ ) )
+#define TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, __VA_ARGS__ ) )
+#define TEMPLATE_LIST_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE( __VA_ARGS__ ) )
+#define TEMPLATE_LIST_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_LIST_TEST_CASE_METHOD( className, __VA_ARGS__ ) )
#endif
-#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE( "", "" )
-#define REGISTER_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_REPORTER( name, reporterType )
-#define REGISTER_LEGACY_REPORTER( name, reporterType ) INTERNAL_CATCH_REGISTER_LEGACY_REPORTER( name, reporterType )
-
-#define GENERATE( expr) INTERNAL_CATCH_GENERATE( expr )
+#if !defined(CATCH_CONFIG_RUNTIME_STATIC_REQUIRE)
+#define STATIC_REQUIRE( ... ) static_assert( __VA_ARGS__, #__VA_ARGS__ ); SUCCEED( #__VA_ARGS__ )
+#define STATIC_REQUIRE_FALSE( ... ) static_assert( !(__VA_ARGS__), "!(" #__VA_ARGS__ ")" ); SUCCEED( "!(" #__VA_ARGS__ ")" )
+#else
+#define STATIC_REQUIRE( ... ) REQUIRE( __VA_ARGS__ )
+#define STATIC_REQUIRE_FALSE( ... ) REQUIRE_FALSE( __VA_ARGS__ )
+#endif
#endif
#define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION( signature )
// "BDD-style" convenience wrappers
-#ifdef CATCH_CONFIG_VARIADIC_MACROS
#define SCENARIO( ... ) TEST_CASE( "Scenario: " __VA_ARGS__ )
#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " __VA_ARGS__ )
+
+#define GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Given: " << desc )
+#define AND_GIVEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( "And given: " << desc )
+#define WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " When: " << desc )
+#define AND_WHEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And when: " << desc )
+#define THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " Then: " << desc )
+#define AND_THEN( desc ) INTERNAL_CATCH_DYNAMIC_SECTION( " And: " << desc )
+
+#if defined(CATCH_CONFIG_ENABLE_BENCHMARKING)
+#define BENCHMARK(...) \
+ INTERNAL_CATCH_BENCHMARK(INTERNAL_CATCH_UNIQUE_NAME(____C_A_T_C_H____B_E_N_C_H____), INTERNAL_CATCH_GET_1_ARG(__VA_ARGS__,,), INTERNAL_CATCH_GET_2_ARG(__VA_ARGS__,,))
+#define BENCHMARK_ADVANCED(name) \
+ INTERNAL_CATCH_BENCHMARK_ADVANCED(INTERNAL_CATCH_UNIQUE_NAME(____C_A_T_C_H____B_E_N_C_H____), name)
+#endif // CATCH_CONFIG_ENABLE_BENCHMARKING
+
+using Catch::Detail::Approx;
+
+#else // CATCH_CONFIG_DISABLE
+
+//////
+// If this config identifier is defined then all CATCH macros are prefixed with CATCH_
+#ifdef CATCH_CONFIG_PREFIX_ALL
+
+#define CATCH_REQUIRE( ... ) (void)(0)
+#define CATCH_REQUIRE_FALSE( ... ) (void)(0)
+
+#define CATCH_REQUIRE_THROWS( ... ) (void)(0)
+#define CATCH_REQUIRE_THROWS_AS( expr, exceptionType ) (void)(0)
+#define CATCH_REQUIRE_THROWS_WITH( expr, matcher ) (void)(0)
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define CATCH_REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0)
+#endif// CATCH_CONFIG_DISABLE_MATCHERS
+#define CATCH_REQUIRE_NOTHROW( ... ) (void)(0)
+
+#define CATCH_CHECK( ... ) (void)(0)
+#define CATCH_CHECK_FALSE( ... ) (void)(0)
+#define CATCH_CHECKED_IF( ... ) if (__VA_ARGS__)
+#define CATCH_CHECKED_ELSE( ... ) if (!(__VA_ARGS__))
+#define CATCH_CHECK_NOFAIL( ... ) (void)(0)
+
+#define CATCH_CHECK_THROWS( ... ) (void)(0)
+#define CATCH_CHECK_THROWS_AS( expr, exceptionType ) (void)(0)
+#define CATCH_CHECK_THROWS_WITH( expr, matcher ) (void)(0)
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define CATCH_CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0)
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+#define CATCH_CHECK_NOTHROW( ... ) (void)(0)
+
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define CATCH_CHECK_THAT( arg, matcher ) (void)(0)
+
+#define CATCH_REQUIRE_THAT( arg, matcher ) (void)(0)
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+
+#define CATCH_INFO( msg ) (void)(0)
+#define CATCH_UNSCOPED_INFO( msg ) (void)(0)
+#define CATCH_WARN( msg ) (void)(0)
+#define CATCH_CAPTURE( msg ) (void)(0)
+
+#define CATCH_TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ))
+#define CATCH_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ))
+#define CATCH_METHOD_AS_TEST_CASE( method, ... )
+#define CATCH_REGISTER_TEST_CASE( Function, ... ) (void)(0)
+#define CATCH_SECTION( ... )
+#define CATCH_DYNAMIC_SECTION( ... )
+#define CATCH_FAIL( ... ) (void)(0)
+#define CATCH_FAIL_CHECK( ... ) (void)(0)
+#define CATCH_SUCCEED( ... ) (void)(0)
+
+#define CATCH_ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ))
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(__VA_ARGS__)
+#define CATCH_TEMPLATE_TEST_CASE_SIG( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG_NO_REGISTRATION(__VA_ARGS__)
+#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(className, __VA_ARGS__)
+#define CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG_NO_REGISTRATION(className, __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG( ... ) CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, ... ) CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#else
+#define CATCH_TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(__VA_ARGS__) )
+#define CATCH_TEMPLATE_TEST_CASE_SIG( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG_NO_REGISTRATION(__VA_ARGS__) )
+#define CATCH_TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(className, __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG_NO_REGISTRATION(className, __VA_ARGS__ ) )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE( ... ) CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_SIG( ... ) CATCH_TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define CATCH_TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, ... ) CATCH_TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#endif
+
+// "BDD-style" convenience wrappers
+#define CATCH_SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ))
+#define CATCH_SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), className )
+#define CATCH_GIVEN( desc )
+#define CATCH_AND_GIVEN( desc )
+#define CATCH_WHEN( desc )
+#define CATCH_AND_WHEN( desc )
+#define CATCH_THEN( desc )
+#define CATCH_AND_THEN( desc )
+
+#define CATCH_STATIC_REQUIRE( ... ) (void)(0)
+#define CATCH_STATIC_REQUIRE_FALSE( ... ) (void)(0)
+
+// If CATCH_CONFIG_PREFIX_ALL is not defined then the CATCH_ prefix is not required
+#else
+
+#define REQUIRE( ... ) (void)(0)
+#define REQUIRE_FALSE( ... ) (void)(0)
+
+#define REQUIRE_THROWS( ... ) (void)(0)
+#define REQUIRE_THROWS_AS( expr, exceptionType ) (void)(0)
+#define REQUIRE_THROWS_WITH( expr, matcher ) (void)(0)
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define REQUIRE_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0)
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+#define REQUIRE_NOTHROW( ... ) (void)(0)
+
+#define CHECK( ... ) (void)(0)
+#define CHECK_FALSE( ... ) (void)(0)
+#define CHECKED_IF( ... ) if (__VA_ARGS__)
+#define CHECKED_ELSE( ... ) if (!(__VA_ARGS__))
+#define CHECK_NOFAIL( ... ) (void)(0)
+
+#define CHECK_THROWS( ... ) (void)(0)
+#define CHECK_THROWS_AS( expr, exceptionType ) (void)(0)
+#define CHECK_THROWS_WITH( expr, matcher ) (void)(0)
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define CHECK_THROWS_MATCHES( expr, exceptionType, matcher ) (void)(0)
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+#define CHECK_NOTHROW( ... ) (void)(0)
+
+#if !defined(CATCH_CONFIG_DISABLE_MATCHERS)
+#define CHECK_THAT( arg, matcher ) (void)(0)
+
+#define REQUIRE_THAT( arg, matcher ) (void)(0)
+#endif // CATCH_CONFIG_DISABLE_MATCHERS
+
+#define INFO( msg ) (void)(0)
+#define UNSCOPED_INFO( msg ) (void)(0)
+#define WARN( msg ) (void)(0)
+#define CAPTURE( msg ) (void)(0)
+
+#define TEST_CASE( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ))
+#define TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ))
+#define METHOD_AS_TEST_CASE( method, ... )
+#define REGISTER_TEST_CASE( Function, ... ) (void)(0)
+#define SECTION( ... )
+#define DYNAMIC_SECTION( ... )
+#define FAIL( ... ) (void)(0)
+#define FAIL_CHECK( ... ) (void)(0)
+#define SUCCEED( ... ) (void)(0)
+#define ANON_TEST_CASE() INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ))
+
+#ifndef CATCH_CONFIG_TRADITIONAL_MSVC_PREPROCESSOR
+#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(__VA_ARGS__)
+#define TEMPLATE_TEST_CASE_SIG( ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG_NO_REGISTRATION(__VA_ARGS__)
+#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(className, __VA_ARGS__)
+#define TEMPLATE_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG_NO_REGISTRATION(className, __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE( ... ) TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_SIG( ... ) TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, ... ) TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
#else
-#define SCENARIO( name, tags ) TEST_CASE( "Scenario: " name, tags )
-#define SCENARIO_METHOD( className, name, tags ) INTERNAL_CATCH_TEST_CASE_METHOD( className, "Scenario: " name, tags )
+#define TEMPLATE_TEST_CASE( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_NO_REGISTRATION(__VA_ARGS__) )
+#define TEMPLATE_TEST_CASE_SIG( ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_SIG_NO_REGISTRATION(__VA_ARGS__) )
+#define TEMPLATE_TEST_CASE_METHOD( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_NO_REGISTRATION(className, __VA_ARGS__ ) )
+#define TEMPLATE_TEST_CASE_METHOD_SIG( className, ... ) INTERNAL_CATCH_EXPAND_VARGS( INTERNAL_CATCH_TEMPLATE_TEST_CASE_METHOD_SIG_NO_REGISTRATION(className, __VA_ARGS__ ) )
+#define TEMPLATE_PRODUCT_TEST_CASE( ... ) TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_SIG( ... ) TEMPLATE_TEST_CASE( __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_METHOD( className, ... ) TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
+#define TEMPLATE_PRODUCT_TEST_CASE_METHOD_SIG( className, ... ) TEMPLATE_TEST_CASE_METHOD( className, __VA_ARGS__ )
#endif
-#define GIVEN( desc ) SECTION( std::string(" Given: ") + desc, "" )
-#define WHEN( desc ) SECTION( std::string(" When: ") + desc, "" )
-#define AND_WHEN( desc ) SECTION( std::string("And when: ") + desc, "" )
-#define THEN( desc ) SECTION( std::string(" Then: ") + desc, "" )
-#define AND_THEN( desc ) SECTION( std::string(" And: ") + desc, "" )
+
+#define STATIC_REQUIRE( ... ) (void)(0)
+#define STATIC_REQUIRE_FALSE( ... ) (void)(0)
+
+#endif
+
+#define CATCH_TRANSLATE_EXCEPTION( signature ) INTERNAL_CATCH_TRANSLATE_EXCEPTION_NO_REG( INTERNAL_CATCH_UNIQUE_NAME( catch_internal_ExceptionTranslator ), signature )
+
+// "BDD-style" convenience wrappers
+#define SCENARIO( ... ) INTERNAL_CATCH_TESTCASE_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ) )
+#define SCENARIO_METHOD( className, ... ) INTERNAL_CATCH_TESTCASE_METHOD_NO_REGISTRATION(INTERNAL_CATCH_UNIQUE_NAME( ____C_A_T_C_H____T_E_S_T____ ), className )
+
+#define GIVEN( desc )
+#define AND_GIVEN( desc )
+#define WHEN( desc )
+#define AND_WHEN( desc )
+#define THEN( desc )
+#define AND_THEN( desc )
using Catch::Detail::Approx;
+#endif
+
+#endif // ! CATCH_CONFIG_IMPL_ONLY
+
+// start catch_reenable_warnings.h
+
+
+#ifdef __clang__
+# ifdef __ICC // icpc defines the __clang__ macro
+# pragma warning(pop)
+# else
+# pragma clang diagnostic pop
+# endif
+#elif defined __GNUC__
+# pragma GCC diagnostic pop
+#endif
+
+// end catch_reenable_warnings.h
+// end catch.hpp
#endif // TWOBLUECUBES_SINGLE_INCLUDE_CATCH_HPP_INCLUDED
diff --git a/src/3rd_party/cnpy/cnpy.cpp b/src/3rd_party/cnpy/cnpy.cpp
index 9ad3c2fa..9ad3c2fa 100755..100644
--- a/src/3rd_party/cnpy/cnpy.cpp
+++ b/src/3rd_party/cnpy/cnpy.cpp
diff --git a/src/3rd_party/cnpy/cnpy.h b/src/3rd_party/cnpy/cnpy.h
index 15053ed0..15053ed0 100755..100644
--- a/src/3rd_party/cnpy/cnpy.h
+++ b/src/3rd_party/cnpy/cnpy.h
diff --git a/src/3rd_party/fbgemm b/src/3rd_party/fbgemm
new file mode 160000
+Subproject 84e66a976046180187724aff60a236c5378fde7
diff --git a/src/3rd_party/half_float/HalfPrecisionFloatTest.cpp b/src/3rd_party/half_float/HalfPrecisionFloatTest.cpp
new file mode 100755
index 00000000..a5d8b846
--- /dev/null
+++ b/src/3rd_party/half_float/HalfPrecisionFloatTest.cpp
@@ -0,0 +1,115 @@
+
+
+
+#include "umHalf.h"
+#include <iostream>
+#include <assert.h>
+
+#define VALIDATE(x) if (!(x)){std::cout << "Failed: " << #x << std::endl;assert((x));}
+
+int main(int, char*)
+{
+ half h = 1.f, h2 = 2.f;
+ --h2;
+ ++h2;
+ --h;
+ ++h;
+ h2 -= 1.f;
+ float f = h2, f2 = h;
+ VALIDATE(1.f == f && f == f2);
+
+ h = h2;
+ h2 = 15.5f;
+
+ f = h2, f2 = h;
+ VALIDATE(15.5f == f && 1.f == f2);
+ h2 *= h;
+ f = h2, f2 = h;
+ VALIDATE(15.5f == f && 1.f == f2);
+ h2 /= h;
+ f = h2, f2 = h;
+ VALIDATE(15.5f == f && 1.f == f2);
+ h2 += h;
+ f = h2, f2 = h;
+ VALIDATE(16.5f == f && 1.f == f2);
+ h++;h++;h++;
+ h2 = -h2;
+ h2 += 17.5f;
+ h2 *= h;
+ f = h2, f2 = h;
+ VALIDATE(4.f == f && 4.f == f2);
+ VALIDATE(h == h2);
+ VALIDATE(h <= h2);
+ --h;
+ VALIDATE(h <= h2);
+
+ h -= 250.f;
+ VALIDATE(h < h2);
+
+ h += 500.f;
+ VALIDATE(h > h2);
+ VALIDATE(h >= h2);
+
+ f = h2, f2 = h;
+ VALIDATE(h * h2 == (half)(f * f2));
+
+ // addition
+ // ****************************************************************************
+
+ // identical exponents
+ for (float v = 0.f; v < 1000.f; ++v)
+ {
+ half one = v;
+ half two = v;
+ half three = one + two;
+ f2 = three;
+ VALIDATE(v*2.f == f2);
+ }
+
+ // different exponents
+ for (float v = 0.f, fp = 1000.f; v < 500.f; ++v, --fp)
+ {
+ half one = v;
+ half two = fp;
+ half three = one + two;
+ f2 = three;
+ VALIDATE(v+fp == f2);
+ }
+
+ // very small numbers - this is already beyond the accuracy of 16 bit floats.
+ for (float v = 0.003f; v < 1000.f; v += 0.0005f)
+ {
+ half one = v;
+ half two = v;
+ half three = one + two;
+ f2 = three;
+ float m = v*2.f;
+ VALIDATE(f2 > (m-0.05*m) && f2 < (m+0.05*m));
+ }
+
+
+ // subtraction
+ // ****************************************************************************
+
+ // identical exponents
+ for (float v = 0.f; v < 1000.f; ++v)
+ {
+ half one = v;
+ half two = v;
+ half three = one - two;
+ f2 = three;
+ VALIDATE(0.f == f2);
+ }
+
+ // different exponents
+ for (float v = 0.f, fp = 1000.f; v < 500.f; ++v, --fp)
+ {
+ half one = v;
+ half two = fp;
+ half three = one - two;
+ f2 = three;
+ VALIDATE(v-fp == f2);
+ }
+ return 0;
+}
+
diff --git a/src/3rd_party/half_float/Readme.md b/src/3rd_party/half_float/Readme.md
new file mode 100644
index 00000000..dd20491a
--- /dev/null
+++ b/src/3rd_party/half_float/Readme.md
@@ -0,0 +1,43 @@
+half_float
+========
+
+#### 16 bit floating-point data type for C++ ####
+
+Implements a `HalfFloat` class that implements all the common arithmetic operations for a 16 bit
+floating-point type (10 bits mantissa, 5 bits exponent and one sign bit) and can thus be used (almost)
+interchangeably with regular `float`s. Not all operations have efficent implementations (some just convert to `float`,
+compute the result and convert back again) - if in doubt, check out the source code.
+
+The implementation tries to adhere to IEEE 754 in that it supports NaN and Infinity, but fails in other points:
+
+ - no difference between qnan and snan
+ - no traps
+ - no well-defined rounding mode
+
+
+We also supply a specialization for `std::numeric_limits<half>` that `half` be usable in template code
+dependent on type traits.
+
+
+#### Usage ####
+
+ // get some halfs (half is a typedef for HalfFloat)
+ half a = 1.0f;
+ half b = 0.5f;
+
+ // and have some FUN
+ half c = (a+b) / (a-b);
+ ++c;
+
+ // now that we have a result in loosy precision,
+ // convert it back to double precision.
+ // if anybody asks, it's for the lulz.
+ double result = c;
+
+
+Credits to _Chris Maiwald_ for the conversion code to `double` and extensive testing.
+
+
+#### License ####
+
+3-clause BSD license: use it for anything, but give credit, don't blame us if your rocket crashes and don't advertise with it (who would). \ No newline at end of file
diff --git a/src/3rd_party/half_float/stdint.h b/src/3rd_party/half_float/stdint.h
new file mode 100644
index 00000000..0cc9fd99
--- /dev/null
+++ b/src/3rd_party/half_float/stdint.h
@@ -0,0 +1,222 @@
+// ISO C9x compliant stdint.h for Microsoft Visual Studio
+// Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124
+//
+// Copyright (c) 2006 Alexander Chemeris
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are met:
+//
+// 1. Redistributions of source code must retain the above copyright notice,
+// this list of conditions and the following disclaimer.
+//
+// 2. Redistributions in binary form must reproduce the above copyright
+// notice, this list of conditions and the following disclaimer in the
+// documentation and/or other materials provided with the distribution.
+//
+// 3. The name of the author may be used to endorse or promote products
+// derived from this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
+// WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
+// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
+// EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
+// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
+// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+//
+///////////////////////////////////////////////////////////////////////////////
+
+#ifndef _MSC_VER // [
+#error "Use this header only with Microsoft Visual C++ compilers!"
+#endif // _MSC_VER ]
+
+#ifndef _MSC_STDINT_H_ // [
+#define _MSC_STDINT_H_
+
+#if _MSC_VER > 1000
+#pragma once
+#endif
+
+#include <limits.h>
+
+// For Visual Studio 6 in C++ mode wrap <wchar.h> include with 'extern "C++" {}'
+// or compiler give many errors like this:
+// error C2733: second C linkage of overloaded function 'wmemchr' not allowed
+#if (_MSC_VER < 1300) && defined(__cplusplus)
+ extern "C++" {
+#endif
+# include <wchar.h>
+#if (_MSC_VER < 1300) && defined(__cplusplus)
+ }
+#endif
+
+// 7.18.1 Integer types
+
+// 7.18.1.1 Exact-width integer types
+typedef __int8 int8_t;
+typedef __int16 int16_t;
+typedef __int32 int32_t;
+typedef __int64 int64_t;
+typedef unsigned __int8 uint8_t;
+typedef unsigned __int16 uint16_t;
+typedef unsigned __int32 uint32_t;
+typedef unsigned __int64 uint64_t;
+
+// 7.18.1.2 Minimum-width integer types
+typedef int8_t int_least8_t;
+typedef int16_t int_least16_t;
+typedef int32_t int_least32_t;
+typedef int64_t int_least64_t;
+typedef uint8_t uint_least8_t;
+typedef uint16_t uint_least16_t;
+typedef uint32_t uint_least32_t;
+typedef uint64_t uint_least64_t;
+
+// 7.18.1.3 Fastest minimum-width integer types
+typedef int8_t int_fast8_t;
+typedef int16_t int_fast16_t;
+typedef int32_t int_fast32_t;
+typedef int64_t int_fast64_t;
+typedef uint8_t uint_fast8_t;
+typedef uint16_t uint_fast16_t;
+typedef uint32_t uint_fast32_t;
+typedef uint64_t uint_fast64_t;
+
+// 7.18.1.4 Integer types capable of holding object pointers
+#ifdef _WIN64 // [
+ typedef __int64 intptr_t;
+ typedef unsigned __int64 uintptr_t;
+#else // _WIN64 ][
+ typedef int intptr_t;
+ typedef unsigned int uintptr_t;
+#endif // _WIN64 ]
+
+// 7.18.1.5 Greatest-width integer types
+typedef int64_t intmax_t;
+typedef uint64_t uintmax_t;
+
+
+// 7.18.2 Limits of specified-width integer types
+
+#if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259
+
+// 7.18.2.1 Limits of exact-width integer types
+#define INT8_MIN ((int8_t)_I8_MIN)
+#define INT8_MAX _I8_MAX
+#define INT16_MIN ((int16_t)_I16_MIN)
+#define INT16_MAX _I16_MAX
+#define INT32_MIN ((int32_t)_I32_MIN)
+#define INT32_MAX _I32_MAX
+#define INT64_MIN ((int64_t)_I64_MIN)
+#define INT64_MAX _I64_MAX
+#define UINT8_MAX _UI8_MAX
+#define UINT16_MAX _UI16_MAX
+#define UINT32_MAX _UI32_MAX
+#define UINT64_MAX _UI64_MAX
+
+// 7.18.2.2 Limits of minimum-width integer types
+#define INT_LEAST8_MIN INT8_MIN
+#define INT_LEAST8_MAX INT8_MAX
+#define INT_LEAST16_MIN INT16_MIN
+#define INT_LEAST16_MAX INT16_MAX
+#define INT_LEAST32_MIN INT32_MIN
+#define INT_LEAST32_MAX INT32_MAX
+#define INT_LEAST64_MIN INT64_MIN
+#define INT_LEAST64_MAX INT64_MAX
+#define UINT_LEAST8_MAX UINT8_MAX
+#define UINT_LEAST16_MAX UINT16_MAX
+#define UINT_LEAST32_MAX UINT32_MAX
+#define UINT_LEAST64_MAX UINT64_MAX
+
+// 7.18.2.3 Limits of fastest minimum-width integer types
+#define INT_FAST8_MIN INT8_MIN
+#define INT_FAST8_MAX INT8_MAX
+#define INT_FAST16_MIN INT16_MIN
+#define INT_FAST16_MAX INT16_MAX
+#define INT_FAST32_MIN INT32_MIN
+#define INT_FAST32_MAX INT32_MAX
+#define INT_FAST64_MIN INT64_MIN
+#define INT_FAST64_MAX INT64_MAX
+#define UINT_FAST8_MAX UINT8_MAX
+#define UINT_FAST16_MAX UINT16_MAX
+#define UINT_FAST32_MAX UINT32_MAX
+#define UINT_FAST64_MAX UINT64_MAX
+
+// 7.18.2.4 Limits of integer types capable of holding object pointers
+#ifdef _WIN64 // [
+# define INTPTR_MIN INT64_MIN
+# define INTPTR_MAX INT64_MAX
+# define UINTPTR_MAX UINT64_MAX
+#else // _WIN64 ][
+# define INTPTR_MIN INT32_MIN
+# define INTPTR_MAX INT32_MAX
+# define UINTPTR_MAX UINT32_MAX
+#endif // _WIN64 ]
+
+// 7.18.2.5 Limits of greatest-width integer types
+#define INTMAX_MIN INT64_MIN
+#define INTMAX_MAX INT64_MAX
+#define UINTMAX_MAX UINT64_MAX
+
+// 7.18.3 Limits of other integer types
+
+#ifdef _WIN64 // [
+# define PTRDIFF_MIN _I64_MIN
+# define PTRDIFF_MAX _I64_MAX
+#else // _WIN64 ][
+# define PTRDIFF_MIN _I32_MIN
+# define PTRDIFF_MAX _I32_MAX
+#endif // _WIN64 ]
+
+#define SIG_ATOMIC_MIN INT_MIN
+#define SIG_ATOMIC_MAX INT_MAX
+
+#ifndef SIZE_MAX // [
+# ifdef _WIN64 // [
+# define SIZE_MAX _UI64_MAX
+# else // _WIN64 ][
+# define SIZE_MAX _UI32_MAX
+# endif // _WIN64 ]
+#endif // SIZE_MAX ]
+
+// WCHAR_MIN and WCHAR_MAX are also defined in <wchar.h>
+#ifndef WCHAR_MIN // [
+# define WCHAR_MIN 0
+#endif // WCHAR_MIN ]
+#ifndef WCHAR_MAX // [
+# define WCHAR_MAX _UI16_MAX
+#endif // WCHAR_MAX ]
+
+#define WINT_MIN 0
+#define WINT_MAX _UI16_MAX
+
+#endif // __STDC_LIMIT_MACROS ]
+
+
+// 7.18.4 Limits of other integer types
+
+#if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260
+
+// 7.18.4.1 Macros for minimum-width integer constants
+
+#define INT8_C(val) val##i8
+#define INT16_C(val) val##i16
+#define INT32_C(val) val##i32
+#define INT64_C(val) val##i64
+
+#define UINT8_C(val) val##ui8
+#define UINT16_C(val) val##ui16
+#define UINT32_C(val) val##ui32
+#define UINT64_C(val) val##ui64
+
+// 7.18.4.2 Macros for greatest-width integer constants
+#define INTMAX_C INT64_C
+#define UINTMAX_C UINT64_C
+
+#endif // __STDC_CONSTANT_MACROS ]
+
+
+#endif // _MSC_STDINT_H_ ] \ No newline at end of file
diff --git a/src/3rd_party/half_float/umHalf.h b/src/3rd_party/half_float/umHalf.h
new file mode 100755
index 00000000..c7ea5dcc
--- /dev/null
+++ b/src/3rd_party/half_float/umHalf.h
@@ -0,0 +1,294 @@
+
+///////////////////////////////////////////////////////////////////////////////////
+/*
+Copyright (c) 2006-2008,
+Chris "Krishty" Maiwald, Alexander "Aramis" Gessler
+
+All rights reserved.
+
+Redistribution and use of this software in source and binary forms,
+with or without modification, are permitted provided that the following
+conditions are met:
+
+* Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the
+ following disclaimer.
+
+* Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the
+ following disclaimer in the documentation and/or other
+ materials provided with the distribution.
+
+* Neither the name of the class, nor the names of its
+ contributors may be used to endorse or promote products
+ derived from this software without specific prior
+ written permission of the Development Team.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+///////////////////////////////////////////////////////////////////////////////////
+
+#ifndef UM_HALF_H_INCLUDED
+#define UM_HALF_H_INCLUDED
+
+#include <limits>
+#include <algorithm>
+
+#include <stdint.h>
+
+#undef min
+#undef max
+
+///////////////////////////////////////////////////////////////////////////////////
+/** 1. Represents a half-precision floating point value (16 bits) that behaves
+ * nearly conformant to the IEE 754 standard for floating-point computations.
+ *
+ * Not all operators have special implementations, most perform time-consuming
+ * conversions from half to float and back again.
+ * Differences to IEEE 754:
+ * - no difference between qnan and snan
+ * - no traps
+ * - no well-defined rounding mode
+ */
+///////////////////////////////////////////////////////////////////////////////////
+class HalfFloat
+{
+ friend HalfFloat operator+ (HalfFloat, HalfFloat);
+ friend HalfFloat operator- (HalfFloat, HalfFloat);
+ friend HalfFloat operator* (HalfFloat, HalfFloat);
+ friend HalfFloat operator/ (HalfFloat, HalfFloat);
+
+public:
+
+ enum { BITS_MANTISSA = 10 };
+ enum { BITS_EXPONENT = 5 };
+
+ enum { MAX_EXPONENT_VALUE = 31 };
+ enum { BIAS = MAX_EXPONENT_VALUE/2 };
+
+ enum { MAX_EXPONENT = BIAS };
+ enum { MIN_EXPONENT = -BIAS };
+
+ enum { MAX_EXPONENT10 = 9 };
+ enum { MIN_EXPONENT10 = -9 };
+
+public:
+
+ /** Default constructor. Unitialized by default.
+ */
+ inline HalfFloat() {}
+
+ /** Construction from an existing half
+ */
+ inline HalfFloat(const HalfFloat& other)
+ : bits(other.GetBits())
+ {}
+
+ /** Construction from existing values for mantissa, sign
+ * and exponent. No validation is performed.
+ * @note The exponent is unsigned and biased by #BIAS
+ */
+ inline HalfFloat(uint16_t _m,uint16_t _e,uint16_t _s);
+
+
+ /** Construction from a single-precision float
+ */
+ inline HalfFloat(float other);
+
+ /** Conversion operator to convert from half to float
+ */
+ inline operator float() const;
+
+ /** Assignment operator to assign another half to
+ * *this* object.
+ */
+ inline HalfFloat& operator= (HalfFloat other);
+ inline HalfFloat& operator= (float other);
+
+ /** Comparison operators
+ */
+ inline bool operator== (HalfFloat other) const;
+ inline bool operator!= (HalfFloat other) const;
+
+
+ /** Relational comparison operators
+ */
+ inline bool operator< (HalfFloat other) const;
+ inline bool operator> (HalfFloat other) const;
+ inline bool operator<= (HalfFloat other) const;
+ inline bool operator>= (HalfFloat other) const;
+
+ inline bool operator< (float other) const;
+ inline bool operator> (float other) const;
+ inline bool operator<= (float other) const;
+ inline bool operator>= (float other) const;
+
+
+ /** Combined assignment operators
+ */
+ inline HalfFloat& operator += (HalfFloat other);
+ inline HalfFloat& operator -= (HalfFloat other);
+ inline HalfFloat& operator *= (HalfFloat other);
+ inline HalfFloat& operator /= (HalfFloat other);
+
+ inline HalfFloat& operator += (float other);
+ inline HalfFloat& operator -= (float other);
+ inline HalfFloat& operator *= (float other);
+ inline HalfFloat& operator /= (float other);
+
+ /** Post and prefix increment operators
+ */
+ inline HalfFloat& operator++();
+ inline HalfFloat operator++(int);
+
+ /** Post and prefix decrement operators
+ */
+ inline HalfFloat& operator--();
+ inline HalfFloat operator--(int);
+
+ /** Unary minus operator
+ */
+ inline HalfFloat operator-() const;
+
+
+ /** Provides direct access to the bits of a half float
+ */
+ inline uint16_t GetBits() const;
+ inline uint16_t& GetBits();
+
+
+ /** Classification of floating-point types
+ */
+ inline bool IsNaN() const;
+ inline bool IsInfinity() const;
+ inline bool IsDenorm() const;
+
+ /** Returns the sign of the floating-point value -
+ * true stands for positive.
+ */
+ inline bool GetSign() const;
+
+public:
+
+ union
+ {
+ uint16_t bits; // All bits
+ struct
+ {
+ uint16_t Frac : 10; // mantissa
+ uint16_t Exp : 5; // exponent
+ uint16_t Sign : 1; // sign
+ } IEEE;
+ };
+
+
+ union IEEESingle
+ {
+ float Float;
+ struct
+ {
+ uint32_t Frac : 23;
+ uint32_t Exp : 8;
+ uint32_t Sign : 1;
+ } IEEE;
+ };
+};
+
+/** 2. Binary operations
+ */
+inline HalfFloat operator+ (HalfFloat one, HalfFloat two);
+inline HalfFloat operator- (HalfFloat one, HalfFloat two);
+inline HalfFloat operator* (HalfFloat one, HalfFloat two);
+inline HalfFloat operator/ (HalfFloat one, HalfFloat two);
+
+inline float operator+ (HalfFloat one, float two);
+inline float operator- (HalfFloat one, float two);
+inline float operator* (HalfFloat one, float two);
+inline float operator/ (HalfFloat one, float two);
+
+inline float operator+ (float one, HalfFloat two);
+inline float operator- (float one, HalfFloat two);
+inline float operator* (float one, HalfFloat two);
+inline float operator/ (float one, HalfFloat two);
+
+
+
+///////////////////////////////////////////////////////////////////////////////////
+/** 3. Specialization of std::numeric_limits for type half.
+ */
+///////////////////////////////////////////////////////////////////////////////////
+namespace std {
+template <>
+class numeric_limits<HalfFloat> {
+
+ public:
+
+ // General -- meaningful for all specializations.
+
+ static const bool is_specialized = true;
+ static HalfFloat min ()
+ {return HalfFloat(0,1,0);}
+ static HalfFloat max ()
+ {return HalfFloat((uint16_t)~0,HalfFloat::MAX_EXPONENT_VALUE-1,0);}
+ static const int radix = 2;
+ static const int digits = 10; // conservative assumption
+ static const int digits10 = 2; // conservative assumption
+ static const bool is_signed = true;
+ static const bool is_integer = true;
+ static const bool is_exact = false;
+ static const bool traps = false;
+ static const bool is_modulo = false;
+ static const bool is_bounded = true;
+
+ static const HalfFloat lowest() {
+ return HalfFloat((uint16_t)~0,HalfFloat::MAX_EXPONENT_VALUE-1,(uint16_t)~0);
+ }
+
+ // Floating point specific.
+
+ static HalfFloat epsilon ()
+ {return HalfFloat(0.00097656f);} // from OpenEXR, needs to be confirmed
+ static HalfFloat round_error ()
+ {return HalfFloat(0.00097656f/2);}
+ static const int min_exponent10 = HalfFloat::MIN_EXPONENT10;
+ static const int max_exponent10 = HalfFloat::MAX_EXPONENT10;
+ static const int min_exponent = HalfFloat::MIN_EXPONENT;
+ static const int max_exponent = HalfFloat::MAX_EXPONENT;
+
+ static const bool has_infinity = true;
+ static const bool has_quiet_NaN = true;
+ static const bool has_signaling_NaN = true;
+ static const bool is_iec559 = false;
+ static const bool has_denorm = denorm_present;
+ static const bool tinyness_before = false;
+ static const float_round_style round_style = round_to_nearest;
+
+ static HalfFloat denorm_min ()
+ {return HalfFloat(1,0,1);}
+ static HalfFloat infinity ()
+ {return HalfFloat(0,HalfFloat::MAX_EXPONENT_VALUE,0);}
+ static HalfFloat quiet_NaN ()
+ {return HalfFloat(1,HalfFloat::MAX_EXPONENT_VALUE,0);}
+ static HalfFloat signaling_NaN ()
+ {return HalfFloat(1,HalfFloat::MAX_EXPONENT_VALUE,0);}
+ };
+} // end namespace std
+
+
+#include "umHalf.inl"
+
+#ifndef UM_HALF_NO_TYPEDEFS
+ typedef HalfFloat float16;
+#endif
+
+#endif // !! UM_HALF_H_INCLUDED
diff --git a/src/3rd_party/half_float/umHalf.inl b/src/3rd_party/half_float/umHalf.inl
new file mode 100644
index 00000000..3f5285a2
--- /dev/null
+++ b/src/3rd_party/half_float/umHalf.inl
@@ -0,0 +1,495 @@
+
+///////////////////////////////////////////////////////////////////////////////////
+/*
+Copyright (c) 2006-2008, Alexander Gessler
+
+All rights reserved.
+
+Redistribution and use of this software in source and binary forms,
+with or without modification, are permitted provided that the following
+conditions are met:
+
+* Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the
+ following disclaimer.
+
+* Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the
+ following disclaimer in the documentation and/or other
+ materials provided with the distribution.
+
+* Neither the name of the ASSIMP team, nor the names of its
+ contributors may be used to endorse or promote products
+ derived from this software without specific prior
+ written permission of the ASSIMP Development Team.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+///////////////////////////////////////////////////////////////////////////////////
+
+#ifndef UM_HALF_INL_INCLUDED
+#define UM_HALF_INL_INCLUDED
+
+#ifdef _MSC_VER
+ #include <intrin.h>
+ #pragma intrinsic(_BitScanReverse)
+#endif
+
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat::HalfFloat(float other)
+{
+ IEEESingle f;
+ f.Float = other;
+
+ IEEE.Sign = f.IEEE.Sign;
+
+ if ( !f.IEEE.Exp)
+ {
+ IEEE.Frac = 0;
+ IEEE.Exp = 0;
+ }
+ else if (f.IEEE.Exp==0xff)
+ {
+ // NaN or INF
+ IEEE.Frac = (f.IEEE.Frac!=0) ? 1 : 0;
+ IEEE.Exp = 31;
+ }
+ else
+ {
+ // regular number
+ int new_exp = f.IEEE.Exp-127;
+
+ if (new_exp<-24)
+ { // this maps to 0
+ IEEE.Frac = 0;
+ IEEE.Exp = 0;
+ }
+
+ else if (new_exp<-14)
+ {
+ // this maps to a denorm
+ IEEE.Exp = 0;
+ unsigned int exp_val = (unsigned int) (-14 - new_exp); // 2^-exp_val
+ switch (exp_val)
+ {
+ case 0:
+ IEEE.Frac = 0;
+ break;
+ case 1: IEEE.Frac = 512 + (f.IEEE.Frac>>14); break;
+ case 2: IEEE.Frac = 256 + (f.IEEE.Frac>>15); break;
+ case 3: IEEE.Frac = 128 + (f.IEEE.Frac>>16); break;
+ case 4: IEEE.Frac = 64 + (f.IEEE.Frac>>17); break;
+ case 5: IEEE.Frac = 32 + (f.IEEE.Frac>>18); break;
+ case 6: IEEE.Frac = 16 + (f.IEEE.Frac>>19); break;
+ case 7: IEEE.Frac = 8 + (f.IEEE.Frac>>20); break;
+ case 8: IEEE.Frac = 4 + (f.IEEE.Frac>>21); break;
+ case 9: IEEE.Frac = 2 + (f.IEEE.Frac>>22); break;
+ case 10: IEEE.Frac = 1; break;
+ }
+ }
+ else if (new_exp>15)
+ { // map this value to infinity
+ IEEE.Frac = 0;
+ IEEE.Exp = 31;
+ }
+ else
+ {
+ IEEE.Exp = new_exp+15;
+ IEEE.Frac = (f.IEEE.Frac >> 13);
+ }
+ }
+}
+
+inline HalfFloat::HalfFloat(uint16_t _m,uint16_t _e,uint16_t _s)
+{
+ IEEE.Frac = _m;
+ IEEE.Exp = _e;
+ IEEE.Sign = _s;
+}
+// ------------------------------------------------------------------------------------------------
+HalfFloat::operator float() const
+{
+ IEEESingle sng;
+ sng.IEEE.Sign = IEEE.Sign;
+
+ if (!IEEE.Exp)
+ {
+ if (!IEEE.Frac)
+ {
+ sng.IEEE.Frac=0;
+ sng.IEEE.Exp=0;
+ }
+ else
+ {
+ const float half_denorm = (1.0f/16384.0f);
+ float mantissa = ((float)(IEEE.Frac)) / 1024.0f;
+ float sgn = (IEEE.Sign)? -1.0f :1.0f;
+ sng.Float = sgn*mantissa*half_denorm;
+ }
+ }
+ else if (31 == IEEE.Exp)
+ {
+ sng.IEEE.Exp = 0xff;
+ sng.IEEE.Frac = (IEEE.Frac!=0) ? 1 : 0;
+ }
+ else
+ {
+ sng.IEEE.Exp = IEEE.Exp+112;
+ sng.IEEE.Frac = (IEEE.Frac << 13);
+ }
+ return sng.Float;
+}
+
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::IsNaN() const
+{
+ return IEEE.Frac != 0 && IEEE.Exp == MAX_EXPONENT_VALUE;
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::IsInfinity() const
+{
+ return IEEE.Frac == 0 && IEEE.Exp == MAX_EXPONENT_VALUE;
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::IsDenorm() const
+{
+ return IEEE.Exp == 0;
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::GetSign() const
+{
+ return IEEE.Sign == 0;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator= (HalfFloat other)
+{
+ bits = other.GetBits();
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator= (float other)
+{
+ *this = (HalfFloat)other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::operator== (HalfFloat other) const
+{
+ // +0 and -0 are considered to be equal
+ if ((bits << 1u) == 0 && (other.bits << 1u) == 0) return true;
+
+ return bits == other.bits && !this->IsNaN();
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::operator!= (HalfFloat other) const
+{
+ // +0 and -0 are considered to be equal
+ if ((bits << 1u) == 0 && (other.bits << 1u) == 0) return false;
+
+ return bits != other.bits || this->IsNaN();
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::operator< (HalfFloat other) const
+{
+ // NaN comparisons are always false
+ if (this->IsNaN() || other.IsNaN())
+ return false;
+
+ // this works since the segment oder is s,e,m.
+ return (int16_t)this->bits < (int16_t)other.GetBits();
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::operator> (HalfFloat other) const
+{
+ // NaN comparisons are always false
+ if (this->IsNaN() || other.IsNaN())
+ return false;
+
+ // this works since the segment oder is s,e,m.
+ return (int16_t)this->bits > (int16_t)other.GetBits();
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::operator<= (HalfFloat other) const
+{
+ return !(*this > other);
+}
+// ------------------------------------------------------------------------------------------------
+inline bool HalfFloat::operator>= (HalfFloat other) const
+{
+ return !(*this < other);
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator += (HalfFloat other)
+{
+ *this = (*this) + other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator -= (HalfFloat other)
+{
+ *this = (*this) - other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator *= (HalfFloat other)
+{
+ *this = (float)(*this) * (float)other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator /= (HalfFloat other)
+{
+ *this = (float)(*this) / (float)other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator += (float other)
+{
+ *this = (*this) + (HalfFloat)other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator -= (float other)
+{
+ *this = (*this) - (HalfFloat)other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator *= (float other)
+{
+ *this = (float)(*this) * other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator /= (float other)
+{
+ *this = (float)(*this) / other;
+ return *this;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator++()
+{
+ // setting the exponent to bias means using 0 as exponent - thus we
+ // can set the mantissa to any value we like, we'll always get 1.0
+ return this->operator+=(HalfFloat(0,BIAS,0));
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat HalfFloat::operator++(int)
+{
+ HalfFloat f = *this;
+ this->operator+=(HalfFloat(0,BIAS,0));
+ return f;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat& HalfFloat::operator--()
+{
+ return this->operator-=(HalfFloat(0,BIAS,0));
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat HalfFloat::operator--(int)
+{
+ HalfFloat f = *this;
+ this->operator-=(HalfFloat(0,BIAS,0));
+ return f;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat HalfFloat::operator-() const
+{
+ return HalfFloat(IEEE.Frac,IEEE.Exp,~IEEE.Sign);
+}
+// ------------------------------------------------------------------------------------------------
+inline uint16_t HalfFloat::GetBits() const
+{
+ return bits;
+}
+// ------------------------------------------------------------------------------------------------
+inline uint16_t& HalfFloat::GetBits()
+{
+ return bits;
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat operator+ (HalfFloat one, HalfFloat two)
+{
+#if (!defined HALFFLOAT_NO_CUSTOM_IMPLEMENTATIONS)
+
+ if (one.IEEE.Exp == HalfFloat::MAX_EXPONENT_VALUE)
+ {
+ // if one of the components is NaN the result becomes NaN, too.
+ if (0 != one.IEEE.Frac || two.IsNaN())
+ return HalfFloat(1,HalfFloat::MAX_EXPONENT_VALUE,0);
+
+ // otherwise this must be infinity
+ return HalfFloat(0,HalfFloat::MAX_EXPONENT_VALUE,one.IEEE.Sign | two.IEEE.Sign);
+ }
+ else if (two.IEEE.Exp == HalfFloat::MAX_EXPONENT_VALUE)
+ {
+ if (one.IsNaN() || 0 != two.IEEE.Frac)
+ return HalfFloat(1,HalfFloat::MAX_EXPONENT_VALUE,0);
+
+ return HalfFloat(0,HalfFloat::MAX_EXPONENT_VALUE,one.IEEE.Sign | two.IEEE.Sign);
+ }
+
+ HalfFloat out;
+ long m1,m2,temp;
+
+ // compute the difference between the two exponents. shifts with negative
+ // numbers are undefined, thus we need two code paths
+ register int expDiff = one.IEEE.Exp - two.IEEE.Exp;
+
+ if (0 == expDiff)
+ {
+ // the exponents are equal, thus we must just add the hidden bit
+ temp = two.IEEE.Exp;
+
+ if (0 == one.IEEE.Exp)m1 = one.IEEE.Frac;
+ else m1 = (int)one.IEEE.Frac | ( 1 << HalfFloat::BITS_MANTISSA );
+
+ if (0 == two.IEEE.Exp)m2 = two.IEEE.Frac;
+ else m2 = (int)two.IEEE.Frac | ( 1 << HalfFloat::BITS_MANTISSA );
+ }
+ else
+ {
+ if (expDiff < 0)
+ {
+ expDiff = -expDiff;
+ std::swap(one,two);
+ }
+
+ m1 = (int)one.IEEE.Frac | ( 1 << HalfFloat::BITS_MANTISSA );
+
+ if (0 == two.IEEE.Exp)m2 = two.IEEE.Frac;
+ else m2 = (int)two.IEEE.Frac | ( 1 << HalfFloat::BITS_MANTISSA );
+
+ if (expDiff < ((sizeof(long)<<3)-(HalfFloat::BITS_MANTISSA+1)))
+ {
+ m1 <<= expDiff;
+ temp = two.IEEE.Exp;
+ }
+ else
+ {
+ if (0 != two.IEEE.Exp)
+ {
+ // arithmetic underflow
+ if (expDiff > HalfFloat::BITS_MANTISSA)return HalfFloat(0,0,0);
+ else
+ {
+ m2 >>= expDiff;
+ }
+ }
+ temp = one.IEEE.Exp;
+ }
+ }
+
+ // convert from sign-bit to two's complement representation
+ if (one.IEEE.Sign)m1 = -m1;
+ if (two.IEEE.Sign)m2 = -m2;
+ m1 += m2;
+ if (m1 < 0)
+ {
+ out.IEEE.Sign = 1;
+ m1 = -m1;
+ }
+ else out.IEEE.Sign = 0;
+
+ // and renormalize the result to fit in a half
+ if (0 == m1)return HalfFloat(0,0,0);
+
+#ifdef _MSC_VER
+ _BitScanReverse((unsigned long*)&m2,m1);
+#else
+ m2 = __builtin_clz(m1);
+#endif
+ expDiff = m2 - HalfFloat::BITS_MANTISSA;
+ temp += expDiff;
+ if (expDiff >= HalfFloat::MAX_EXPONENT_VALUE)
+ {
+ // arithmetic overflow. return INF and keep the sign
+ return HalfFloat(0,HalfFloat::MAX_EXPONENT_VALUE,out.IEEE.Sign);
+ }
+ else if (temp <= 0)
+ {
+ // this maps to a denorm
+ m1 <<= (-expDiff-1);
+ temp = 0;
+ }
+ else
+ {
+ // rebuild the normalized representation, take care of the hidden bit
+ if (expDiff < 0)m1 <<= (-expDiff);
+ else m1 >>= expDiff; // m1 >= 0
+ }
+ out.IEEE.Frac = m1;
+ out.IEEE.Exp = temp;
+ return out;
+
+#else
+ return HalfFloat((float)one + (float)two);
+#endif
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat operator- (HalfFloat one, HalfFloat two)
+{
+ return HalfFloat(one + (-two));
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat operator* (HalfFloat one, HalfFloat two)
+{
+ return HalfFloat((float)one * (float)two);
+}
+// ------------------------------------------------------------------------------------------------
+inline HalfFloat operator/ (HalfFloat one, HalfFloat two)
+{
+ return HalfFloat((float)one / (float)two);
+}
+// ------------------------------------------------------------------------------------------------
+inline float operator+ (HalfFloat one, float two)
+{
+ return (float)one + two;
+}
+// ------------------------------------------------------------------------------------------------
+inline float operator- (HalfFloat one, float two)
+{
+ return (float)one - two;
+}
+// ------------------------------------------------------------------------------------------------
+inline float operator* (HalfFloat one, float two)
+{
+ return (float)one * two;
+}
+// ------------------------------------------------------------------------------------------------
+inline float operator/ (HalfFloat one, float two)
+{
+ return (float)one / two;
+}
+// ------------------------------------------------------------------------------------------------
+inline float operator+ (float one, HalfFloat two)
+{
+ return two + one;
+}
+// ------------------------------------------------------------------------------------------------
+inline float operator- (float one, HalfFloat two)
+{
+ return two - one;
+}
+// ------------------------------------------------------------------------------------------------
+inline float operator* (float one, HalfFloat two)
+{
+ return two * one;
+}
+// ------------------------------------------------------------------------------------------------
+inline float operator/ (float one, HalfFloat two)
+{
+ return two / one;
+}
+
+#endif //!! UM_HALF_INL_INCLUDED
diff --git a/src/3rd_party/mio/LICENSE b/src/3rd_party/mio/LICENSE
new file mode 100644
index 00000000..36177074
--- /dev/null
+++ b/src/3rd_party/mio/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2018 https://github.com/mandreyel/
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/src/3rd_party/mio/README.md b/src/3rd_party/mio/README.md
new file mode 100644
index 00000000..e50b23ff
--- /dev/null
+++ b/src/3rd_party/mio/README.md
@@ -0,0 +1,337 @@
+# mio
+An easy to use header-only cross-platform C++11 memory mapping library with an MIT license.
+
+mio has been created with the goal to be easily includable (i.e. no dependencies) in any C++ project that needs memory mapped file IO without the need to pull in Boost.
+
+Please feel free to open an issue, I'll try to address any concerns as best I can.
+
+### Why?
+Because memory mapping is the best thing since sliced bread!
+
+More seriously, the primary motivation for writing this library instead of using Boost.Iostreams, was the lack of support for establishing a memory mapping with an already open file handle/descriptor. This is possible with mio.
+
+Furthermore, Boost.Iostreams' solution requires that the user pick offsets exactly at page boundaries, which is cumbersome and error prone. mio, on the other hand, manages this internally, accepting any offset and finding the nearest page boundary.
+
+Albeit a minor nitpick, Boost.Iostreams implements memory mapped file IO with a `std::shared_ptr` to provide shared semantics, even if not needed, and the overhead of the heap allocation may be unnecessary and/or unwanted.
+In mio, there are two classes to cover the two use-cases: one that is move-only (basically a zero-cost abstraction over the system specific mmapping functions), and the other that acts just like its Boost.Iostreams counterpart, with shared semantics.
+
+### How to create a mapping
+NOTE: the file must exist before creating a mapping.
+
+There are three ways to map a file into memory:
+
+- Using the constructor, which throws a `std::system_error` on failure:
+```c++
+mio::mmap_source mmap(path, offset, size_to_map);
+```
+or you can omit the `offset` and `size_to_map` arguments, in which case the
+entire file is mapped:
+```c++
+mio::mmap_source mmap(path);
+```
+
+- Using the factory function:
+```c++
+std::error_code error;
+mio::mmap_source mmap = mio::make_mmap_source(path, offset, size_to_map, error);
+```
+or:
+```c++
+mio::mmap_source mmap = mio::make_mmap_source(path, error);
+```
+
+- Using the `map` member function:
+```c++
+std::error_code error;
+mio::mmap_source mmap;
+mmap.map(path, offset, size_to_map, error);
+```
+or:
+```c++
+mmap.map(path, error);
+```
+**NOTE:** The constructors **require** exceptions to be enabled. If you prefer
+to build your projects with `-fno-exceptions`, you can still use the other ways.
+
+Moreover, in each case, you can provide either some string type for the file's path, or you can use an existing, valid file handle.
+```c++
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <mio/mmap.hpp>
+#include <algorithm>
+
+int main()
+{
+ // NOTE: error handling omitted for brevity.
+ const int fd = open("file.txt", O_RDONLY);
+ mio::mmap_source mmap(fd, 0, mio::map_entire_file);
+ // ...
+}
+```
+However, mio does not check whether the provided file descriptor has the same access permissions as the desired mapping, so the mapping may fail. Such errors are reported via the `std::error_code` out parameter that is passed to the mapping function.
+
+**WINDOWS USERS**: This library *does* support the use of wide character types
+for functions where character strings are expected (e.g. path parameters).
+
+### Example
+
+```c++
+#include <mio/mmap.hpp>
+#include <system_error> // for std::error_code
+#include <cstdio> // for std::printf
+#include <cassert>
+#include <algorithm>
+#include <fstream>
+
+int handle_error(const std::error_code& error);
+void allocate_file(const std::string& path, const int size);
+
+int main()
+{
+ const auto path = "file.txt";
+
+ // NOTE: mio does *not* create the file for you if it doesn't exist! You
+ // must ensure that the file exists before establishing a mapping. It
+ // must also be non-empty. So for illustrative purposes the file is
+ // created now.
+ allocate_file(path, 155);
+
+ // Read-write memory map the whole file by using `map_entire_file` where the
+ // length of the mapping is otherwise expected, with the factory method.
+ std::error_code error;
+ mio::mmap_sink rw_mmap = mio::make_mmap_sink(
+ path, 0, mio::map_entire_file, error);
+ if (error) { return handle_error(error); }
+
+ // You can use any iterator based function.
+ std::fill(rw_mmap.begin(), rw_mmap.end(), 'a');
+
+ // Or manually iterate through the mapped region just as if it were any other
+ // container, and change each byte's value (since this is a read-write mapping).
+ for (auto& b : rw_mmap) {
+ b += 10;
+ }
+
+ // Or just change one value with the subscript operator.
+ const int answer_index = rw_mmap.size() / 2;
+ rw_mmap[answer_index] = 42;
+
+ // Don't forget to flush changes to disk before unmapping. However, if
+ // `rw_mmap` were to go out of scope at this point, the destructor would also
+ // automatically invoke `sync` before `unmap`.
+ rw_mmap.sync(error);
+ if (error) { return handle_error(error); }
+
+ // We can then remove the mapping, after which rw_mmap will be in a default
+ // constructed state, i.e. this and the above call to `sync` have the same
+ // effect as if the destructor had been invoked.
+ rw_mmap.unmap();
+
+ // Now create the same mapping, but in read-only mode. Note that calling the
+ // overload without the offset and file length parameters maps the entire
+ // file.
+ mio::mmap_source ro_mmap;
+ ro_mmap.map(path, error);
+ if (error) { return handle_error(error); }
+
+ const int the_answer_to_everything = ro_mmap[answer_index];
+ assert(the_answer_to_everything == 42);
+}
+
+int handle_error(const std::error_code& error)
+{
+ const auto& errmsg = error.message();
+ std::printf("error mapping file: %s, exiting...\n", errmsg.c_str());
+ return error.value();
+}
+
+void allocate_file(const std::string& path, const int size)
+{
+ std::ofstream file(path);
+ std::string s(size, '0');
+ file << s;
+}
+```
+
+`mio::basic_mmap` is move-only, but if multiple copies to the same mapping are needed, use `mio::basic_shared_mmap` which has `std::shared_ptr` semantics and has the same interface as `mio::basic_mmap`.
+```c++
+#include <mio/shared_mmap.hpp>
+
+mio::shared_mmap_source shared_mmap1("path", offset, size_to_map);
+mio::shared_mmap_source shared_mmap2(std::move(mmap1)); // or use operator=
+mio::shared_mmap_source shared_mmap3(std::make_shared<mio::mmap_source>(mmap1)); // or use operator=
+mio::shared_mmap_source shared_mmap4;
+shared_mmap4.map("path", offset, size_to_map, error);
+```
+
+It's possible to define the type of a byte (which has to be the same width as `char`), though aliases for the most common ones are provided by default:
+```c++
+using mmap_source = basic_mmap_source<char>;
+using ummap_source = basic_mmap_source<unsigned char>;
+
+using mmap_sink = basic_mmap_sink<char>;
+using ummap_sink = basic_mmap_sink<unsigned char>;
+```
+But it may be useful to define your own types, say when using the new `std::byte` type in C++17:
+```c++
+using mmap_source = mio::basic_mmap_source<std::byte>;
+using mmap_sink = mio::basic_mmap_sink<std::byte>;
+```
+
+Though generally not needed, since mio maps users requested offsets to page boundaries, you can query the underlying system's page allocation granularity by invoking `mio::page_size()`, which is located in `mio/page.hpp`.
+
+### Single Header File
+Mio can be added to your project as a single header file simply by including `\single_include\mio\mio.hpp`. Single header files can be regenerated at any time by running the `amalgamate.py` script within `\third_party`.
+```
+python amalgamate.py -c config.json -s ../include
+```
+
+## CMake
+As a header-only library, mio has no compiled components. Nevertheless, a [CMake](https://cmake.org/overview/) build system is provided to allow easy testing, installation, and subproject composition on many platforms and operating systems.
+
+### Testing
+Mio is distributed with a small suite of tests and examples.
+When mio is configured as the highest level CMake project, this suite of executables is built by default.
+Mio's test executables are integrated with the CMake test driver program, [CTest](https://cmake.org/cmake/help/latest/manual/ctest.1.html).
+
+CMake supports a number of backends for compilation and linking.
+
+To use a static configuration build tool, such as GNU Make or Ninja:
+
+```sh
+cd <mio source directory>
+mkdir build
+cd build
+
+# Configure the build
+cmake -D CMAKE_BUILD_TYPE=<Debug | Release> \
+ -G <"Unix Makefiles" | "Ninja"> ..
+
+# build the tests
+< make | ninja | cmake --build . >
+
+# run the tests
+< make test | ninja test | cmake --build . --target test | ctest >
+```
+
+To use a dynamic configuration build tool, such as Visual Studio or Xcode:
+
+```sh
+cd <mio source directory>
+mkdir build
+cd build
+
+# Configure the build
+cmake -G <"Visual Studio 14 2015 Win64" | "Xcode"> ..
+
+# build the tests
+cmake --build . --config <Debug | Release>
+
+# run the tests via ctest...
+ctest --build-config <Debug | Release>
+
+# ... or via CMake build tool mode...
+cmake --build . --config <Debug | Release> --target test
+```
+
+Of course the **build** and **test** steps can also be executed via the **all** and **test** targets, respectively, from within the IDE after opening the project file generated during the configuration step.
+
+Mio's testing is also configured to operate as a client to the [CDash](https://www.cdash.org/) software quality dashboard application. Please see the [Kitware documentation](https://cmake.org/cmake/help/latest/manual/ctest.1.html#dashboard-client) for more information on this mode of operation.
+
+### Installation
+
+Mio's build system provides an installation target and support for downstream consumption via CMake's [`find_package`](https://cmake.org/cmake/help/v3.0/command/find_package.html) intrinsic function.
+CMake allows installation to an arbitrary location, which may be specified by defining `CMAKE_INSTALL_PREFIX` at configure time.
+In the absense of a user specification, CMake will install mio to conventional location based on the platform operating system.
+
+To use a static configuration build tool, such as GNU Make or Ninja:
+
+```sh
+cd <mio source directory>
+mkdir build
+cd build
+
+# Configure the build
+cmake [-D CMAKE_INSTALL_PREFIX="path/to/installation"] \
+ [-D BUILD_TESTING=False] \
+ -D CMAKE_BUILD_TYPE=Release \
+ -G <"Unix Makefiles" | "Ninja"> ..
+
+# install mio
+<make install | ninja install | cmake --build . --target install>
+```
+
+To use a dynamic configuration build tool, such as Visual Studio or Xcode:
+
+```sh
+cd <mio source directory>
+mkdir build
+cd build
+
+# Configure the project
+cmake [-D CMAKE_INSTALL_PREFIX="path/to/installation"] \
+ [-D BUILD_TESTING=False] \
+ -G <"Visual Studio 14 2015 Win64" | "Xcode"> ..
+
+# install mio
+cmake --build . --config Release --target install
+```
+
+Note that the last command of the installation sequence may require administrator privileges (e.g. `sudo`) if the installation root directory lies outside your home directory.
+
+This installation
++ copies the mio header files to the `include/mio` subdirectory of the installation root
++ generates and copies several CMake configuration files to the `share/cmake/mio` subdirectory of the installation root
+
+This latter step allows downstream CMake projects to consume mio via `find_package`, e.g.
+
+```cmake
+find_package( mio REQUIRED )
+target_link_libraries( MyTarget PUBLIC mio::mio )
+```
+
+**WINDOWS USERS**: The `mio::mio` target `#define`s `WIN32_LEAN_AND_MEAN` and `NOMINMAX`. The former ensures the imported surface area of the Win API is minimal, and the latter disables Windows' `min` and `max` macros so they don't intefere with `std::min` and `std::max`. Because *mio* is a header only library, these defintions will leak into downstream CMake builds. If their presence is causing problems with your build then you can use the alternative `mio::mio_full_winapi` target, which adds none of these defintions.
+
+If mio was installed to a non-conventional location, it may be necessary for downstream projects to specify the mio installation root directory via either
+
++ the `CMAKE_PREFIX_PATH` configuration option,
++ the `CMAKE_PREFIX_PATH` environment variable, or
++ `mio_DIR` environment variable.
+
+Please see the [Kitware documentation](https://cmake.org/cmake/help/v3.0/command/find_package.html) for more information.
+
+In addition, mio supports packaged relocatable installations via [CPack](https://cmake.org/cmake/help/latest/manual/cpack.1.html).
+Following configuration, from the build directory, invoke cpack as follows to generate a packaged installation:
+
+```sh
+cpack -G <generator name> -C Release
+```
+
+The list of supported generators varies from platform to platform. See the output of `cpack --help` for a complete list of supported generators on your platform.
+
+### Subproject Composition
+To use mio as a subproject, copy the mio repository to your project's dependencies/externals folder.
+If your project is version controlled using git, a git submodule or git subtree can be used to syncronize with the updstream repository.
+The [use](https://services.github.com/on-demand/downloads/submodule-vs-subtree-cheat-sheet/) and [relative advantages](https://andrey.nering.com.br/2016/git-submodules-vs-subtrees/) of these git facilities is beyond the scope of this document, but in brief, each may be established as follows:
+
+```sh
+# via git submodule
+cd <my project's dependencies directory>
+git submodule add -b master https://github.com/mandreyel/mio.git
+
+# via git subtree
+cd <my project's root directory>
+git subtree add --prefix <path/to/dependencies>/mio \
+ https://github.com/mandreyel/mio.git master --squash
+```
+
+Given a mio subdirectory in a project, simply add the following lines to your project's to add mio include directories to your target's include path.
+
+```cmake
+add_subdirectory( path/to/mio/ )
+target_link_libraries( MyTarget PUBLIC <mio::mio | mio> )
+```
+
+Note that, as a subproject, mio's tests and examples will not be built and CPack integration is deferred to the host project.
+
diff --git a/src/3rd_party/mio/mio.hpp b/src/3rd_party/mio/mio.hpp
new file mode 100644
index 00000000..b4b8cd5e
--- /dev/null
+++ b/src/3rd_party/mio/mio.hpp
@@ -0,0 +1,1748 @@
+/* Copyright 2017 https://github.com/mandreyel
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+ * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+ * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
+ * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
+ * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+#ifndef MIO_MMAP_HEADER
+#define MIO_MMAP_HEADER
+
+// #include "mio/page.hpp"
+/* Copyright 2017 https://github.com/mandreyel
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+ * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+ * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
+ * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
+ * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+#ifndef MIO_PAGE_HEADER
+#define MIO_PAGE_HEADER
+
+#ifdef _WIN32
+# include <windows.h>
+#else
+# include <unistd.h>
+#endif
+
+namespace mio {
+
+/**
+ * This is used by `basic_mmap` to determine whether to create a read-only or
+ * a read-write memory mapping.
+ */
+enum class access_mode
+{
+ read,
+ write
+};
+
+/**
+ * Determines the operating system's page allocation granularity.
+ *
+ * On the first call to this function, it invokes the operating system specific syscall
+ * to determine the page size, caches the value, and returns it. Any subsequent call to
+ * this function serves the cached value, so no further syscalls are made.
+ */
+inline size_t page_size()
+{
+ static const size_t page_size = []
+ {
+#ifdef _WIN32
+ SYSTEM_INFO SystemInfo;
+ GetSystemInfo(&SystemInfo);
+ return SystemInfo.dwAllocationGranularity;
+#else
+ return sysconf(_SC_PAGE_SIZE);
+#endif
+ }();
+ return page_size;
+}
+
+/**
+ * Alligns `offset` to the operating's system page size such that it subtracts the
+ * difference until the nearest page boundary before `offset`, or does nothing if
+ * `offset` is already page aligned.
+ */
+inline size_t make_offset_page_aligned(size_t offset) noexcept
+{
+ const size_t page_size_ = page_size();
+ // Use integer division to round down to the nearest page alignment.
+ return offset / page_size_ * page_size_;
+}
+
+} // namespace mio
+
+#endif // MIO_PAGE_HEADER
+
+
+#include <iterator>
+#include <string>
+#include <system_error>
+#include <cstdint>
+
+#ifdef _WIN32
+# ifndef WIN32_LEAN_AND_MEAN
+# define WIN32_LEAN_AND_MEAN
+# endif // WIN32_LEAN_AND_MEAN
+# include <windows.h>
+#else // ifdef _WIN32
+# define INVALID_HANDLE_VALUE -1
+#endif // ifdef _WIN32
+
+namespace mio {
+
+// This value may be provided as the `length` parameter to the constructor or
+// `map`, in which case a memory mapping of the entire file is created.
+enum { map_entire_file = 0 };
+
+#ifdef _WIN32
+using file_handle_type = HANDLE;
+#else
+using file_handle_type = int;
+#endif
+
+// This value represents an invalid file handle type. This can be used to
+// determine whether `basic_mmap::file_handle` is valid, for example.
+const static file_handle_type invalid_handle = INVALID_HANDLE_VALUE;
+
+template<access_mode AccessMode, typename ByteT>
+struct basic_mmap
+{
+ using value_type = ByteT;
+ using size_type = size_t;
+ using reference = value_type&;
+ using const_reference = const value_type&;
+ using pointer = value_type*;
+ using const_pointer = const value_type*;
+ using difference_type = std::ptrdiff_t;
+ using iterator = pointer;
+ using const_iterator = const_pointer;
+ using reverse_iterator = std::reverse_iterator<iterator>;
+ using const_reverse_iterator = std::reverse_iterator<const_iterator>;
+ using iterator_category = std::random_access_iterator_tag;
+ using handle_type = file_handle_type;
+
+ static_assert(sizeof(ByteT) == sizeof(char), "ByteT must be the same size as char.");
+
+private:
+ // Points to the first requested byte, and not to the actual start of the mapping.
+ pointer data_ = nullptr;
+
+ // Length--in bytes--requested by user (which may not be the length of the
+ // full mapping) and the length of the full mapping.
+ size_type length_ = 0;
+ size_type mapped_length_ = 0;
+
+ // Letting user map a file using both an existing file handle and a path
+ // introcudes some complexity (see `is_handle_internal_`).
+ // On POSIX, we only need a file handle to create a mapping, while on
+ // Windows systems the file handle is necessary to retrieve a file mapping
+ // handle, but any subsequent operations on the mapped region must be done
+ // through the latter.
+ handle_type file_handle_ = INVALID_HANDLE_VALUE;
+#ifdef _WIN32
+ handle_type file_mapping_handle_ = INVALID_HANDLE_VALUE;
+#endif
+
+ // Letting user map a file using both an existing file handle and a path
+ // introcudes some complexity in that we must not close the file handle if
+ // user provided it, but we must close it if we obtained it using the
+ // provided path. For this reason, this flag is used to determine when to
+ // close `file_handle_`.
+ bool is_handle_internal_;
+
+public:
+ /**
+ * The default constructed mmap object is in a non-mapped state, that is,
+ * any operation that attempts to access nonexistent underlying data will
+ * result in undefined behaviour/segmentation faults.
+ */
+ basic_mmap() = default;
+
+#ifdef __cpp_exceptions
+ /**
+ * The same as invoking the `map` function, except any error that may occur
+ * while establishing the mapping is wrapped in a `std::system_error` and is
+ * thrown.
+ */
+ template<typename String>
+ basic_mmap(const String& path, const size_type offset = 0, const size_type length = map_entire_file)
+ {
+ std::error_code error;
+ map(path, offset, length, error);
+ if(error) { throw std::system_error(error); }
+ }
+
+ /**
+ * The same as invoking the `map` function, except any error that may occur
+ * while establishing the mapping is wrapped in a `std::system_error` and is
+ * thrown.
+ */
+ basic_mmap(const handle_type handle, const size_type offset = 0, const size_type length = map_entire_file)
+ {
+ std::error_code error;
+ map(handle, offset, length, error);
+ if(error) { throw std::system_error(error); }
+ }
+#endif // __cpp_exceptions
+
+ /**
+ * `basic_mmap` has single-ownership semantics, so transferring ownership
+ * may only be accomplished by moving the object.
+ */
+ basic_mmap(const basic_mmap&) = delete;
+ basic_mmap(basic_mmap&&);
+ basic_mmap& operator=(const basic_mmap&) = delete;
+ basic_mmap& operator=(basic_mmap&&);
+
+ /**
+ * If this is a read-write mapping, the destructor invokes sync. Regardless
+ * of the access mode, unmap is invoked as a final step.
+ */
+ ~basic_mmap();
+
+ /**
+ * On UNIX systems 'file_handle' and 'mapping_handle' are the same. On Windows,
+ * however, a mapped region of a file gets its own handle, which is returned by
+ * 'mapping_handle'.
+ */
+ handle_type file_handle() const noexcept { return file_handle_; }
+ handle_type mapping_handle() const noexcept;
+
+ /** Returns whether a valid memory mapping has been created. */
+ bool is_open() const noexcept { return file_handle_ != invalid_handle; }
+
+ /**
+ * Returns true if no mapping was established, that is, conceptually the
+ * same as though the length that was mapped was 0. This function is
+ * provided so that this class has Container semantics.
+ */
+ bool empty() const noexcept { return length() == 0; }
+
+ /** Returns true if a mapping was established. */
+ bool is_mapped() const noexcept;
+
+ /**
+ * `size` and `length` both return the logical length, i.e. the number of bytes
+ * user requested to be mapped, while `mapped_length` returns the actual number of
+ * bytes that were mapped which is a multiple of the underlying operating system's
+ * page allocation granularity.
+ */
+ size_type size() const noexcept { return length(); }
+ size_type length() const noexcept { return length_; }
+ size_type mapped_length() const noexcept { return mapped_length_; }
+
+ /** Returns the offset relative to the start of the mapping. */
+ size_type mapping_offset() const noexcept
+ {
+ return mapped_length_ - length_;
+ }
+
+ /**
+ * Returns a pointer to the first requested byte, or `nullptr` if no memory mapping
+ * exists.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > pointer data() noexcept { return data_; }
+ const_pointer data() const noexcept { return data_; }
+
+ /**
+ * Returns an iterator to the first requested byte, if a valid memory mapping
+ * exists, otherwise this function call is undefined behaviour.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > iterator begin() noexcept { return data(); }
+ const_iterator begin() const noexcept { return data(); }
+ const_iterator cbegin() const noexcept { return data(); }
+
+ /**
+ * Returns an iterator one past the last requested byte, if a valid memory mapping
+ * exists, otherwise this function call is undefined behaviour.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > iterator end() noexcept { return data() + length(); }
+ const_iterator end() const noexcept { return data() + length(); }
+ const_iterator cend() const noexcept { return data() + length(); }
+
+ /**
+ * Returns a reverse iterator to the last memory mapped byte, if a valid
+ * memory mapping exists, otherwise this function call is undefined
+ * behaviour.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > reverse_iterator rbegin() noexcept { return reverse_iterator(end()); }
+ const_reverse_iterator rbegin() const noexcept
+ { return const_reverse_iterator(end()); }
+ const_reverse_iterator crbegin() const noexcept
+ { return const_reverse_iterator(end()); }
+
+ /**
+ * Returns a reverse iterator past the first mapped byte, if a valid memory
+ * mapping exists, otherwise this function call is undefined behaviour.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > reverse_iterator rend() noexcept { return reverse_iterator(begin()); }
+ const_reverse_iterator rend() const noexcept
+ { return const_reverse_iterator(begin()); }
+ const_reverse_iterator crend() const noexcept
+ { return const_reverse_iterator(begin()); }
+
+ /**
+ * Returns a reference to the `i`th byte from the first requested byte (as returned
+ * by `data`). If this is invoked when no valid memory mapping has been created
+ * prior to this call, undefined behaviour ensues.
+ */
+ reference operator[](const size_type i) noexcept { return data_[i]; }
+ const_reference operator[](const size_type i) const noexcept { return data_[i]; }
+
+ /**
+ * Establishes a memory mapping with AccessMode. If the mapping is unsuccesful, the
+ * reason is reported via `error` and the object remains in a state as if this
+ * function hadn't been called.
+ *
+ * `path`, which must be a path to an existing file, is used to retrieve a file
+ * handle (which is closed when the object destructs or `unmap` is called), which is
+ * then used to memory map the requested region. Upon failure, `error` is set to
+ * indicate the reason and the object remains in an unmapped state.
+ *
+ * `offset` is the number of bytes, relative to the start of the file, where the
+ * mapping should begin. When specifying it, there is no need to worry about
+ * providing a value that is aligned with the operating system's page allocation
+ * granularity. This is adjusted by the implementation such that the first requested
+ * byte (as returned by `data` or `begin`), so long as `offset` is valid, will be at
+ * `offset` from the start of the file.
+ *
+ * `length` is the number of bytes to map. It may be `map_entire_file`, in which
+ * case a mapping of the entire file is created.
+ */
+ template<typename String>
+ void map(const String& path, const size_type offset,
+ const size_type length, std::error_code& error);
+
+ /**
+ * Establishes a memory mapping with AccessMode. If the mapping is unsuccesful, the
+ * reason is reported via `error` and the object remains in a state as if this
+ * function hadn't been called.
+ *
+ * `path`, which must be a path to an existing file, is used to retrieve a file
+ * handle (which is closed when the object destructs or `unmap` is called), which is
+ * then used to memory map the requested region. Upon failure, `error` is set to
+ * indicate the reason and the object remains in an unmapped state.
+ *
+ * The entire file is mapped.
+ */
+ template<typename String>
+ void map(const String& path, std::error_code& error)
+ {
+ map(path, 0, map_entire_file, error);
+ }
+
+ /**
+ * Establishes a memory mapping with AccessMode. If the mapping is
+ * unsuccesful, the reason is reported via `error` and the object remains in
+ * a state as if this function hadn't been called.
+ *
+ * `handle`, which must be a valid file handle, which is used to memory map the
+ * requested region. Upon failure, `error` is set to indicate the reason and the
+ * object remains in an unmapped state.
+ *
+ * `offset` is the number of bytes, relative to the start of the file, where the
+ * mapping should begin. When specifying it, there is no need to worry about
+ * providing a value that is aligned with the operating system's page allocation
+ * granularity. This is adjusted by the implementation such that the first requested
+ * byte (as returned by `data` or `begin`), so long as `offset` is valid, will be at
+ * `offset` from the start of the file.
+ *
+ * `length` is the number of bytes to map. It may be `map_entire_file`, in which
+ * case a mapping of the entire file is created.
+ */
+ void map(const handle_type handle, const size_type offset,
+ const size_type length, std::error_code& error);
+
+ /**
+ * Establishes a memory mapping with AccessMode. If the mapping is
+ * unsuccesful, the reason is reported via `error` and the object remains in
+ * a state as if this function hadn't been called.
+ *
+ * `handle`, which must be a valid file handle, which is used to memory map the
+ * requested region. Upon failure, `error` is set to indicate the reason and the
+ * object remains in an unmapped state.
+ *
+ * The entire file is mapped.
+ */
+ void map(const handle_type handle, std::error_code& error)
+ {
+ map(handle, 0, map_entire_file, error);
+ }
+
+ /**
+ * If a valid memory mapping has been created prior to this call, this call
+ * instructs the kernel to unmap the memory region and disassociate this object
+ * from the file.
+ *
+ * The file handle associated with the file that is mapped is only closed if the
+ * mapping was created using a file path. If, on the other hand, an existing
+ * file handle was used to create the mapping, the file handle is not closed.
+ */
+ void unmap();
+
+ void swap(basic_mmap& other);
+
+ /** Flushes the memory mapped page to disk. Errors are reported via `error`. */
+ template<access_mode A = AccessMode>
+ typename std::enable_if<A == access_mode::write, void>::type
+ sync(std::error_code& error);
+
+ /**
+ * All operators compare the address of the first byte and size of the two mapped
+ * regions.
+ */
+
+private:
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > pointer get_mapping_start() noexcept
+ {
+ return !data() ? nullptr : data() - mapping_offset();
+ }
+
+ const_pointer get_mapping_start() const noexcept
+ {
+ return !data() ? nullptr : data() - mapping_offset();
+ }
+
+ /**
+ * The destructor syncs changes to disk if `AccessMode` is `write`, but not
+ * if it's `read`, but since the destructor cannot be templated, we need to
+ * do SFINAE in a dedicated function, where one syncs and the other is a noop.
+ */
+ template<access_mode A = AccessMode>
+ typename std::enable_if<A == access_mode::write, void>::type
+ conditional_sync();
+ template<access_mode A = AccessMode>
+ typename std::enable_if<A == access_mode::read, void>::type conditional_sync();
+};
+
+template<access_mode AccessMode, typename ByteT>
+bool operator==(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b);
+
+template<access_mode AccessMode, typename ByteT>
+bool operator!=(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b);
+
+template<access_mode AccessMode, typename ByteT>
+bool operator<(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b);
+
+template<access_mode AccessMode, typename ByteT>
+bool operator<=(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b);
+
+template<access_mode AccessMode, typename ByteT>
+bool operator>(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b);
+
+template<access_mode AccessMode, typename ByteT>
+bool operator>=(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b);
+
+/**
+ * This is the basis for all read-only mmap objects and should be preferred over
+ * directly using `basic_mmap`.
+ */
+template<typename ByteT>
+using basic_mmap_source = basic_mmap<access_mode::read, ByteT>;
+
+/**
+ * This is the basis for all read-write mmap objects and should be preferred over
+ * directly using `basic_mmap`.
+ */
+template<typename ByteT>
+using basic_mmap_sink = basic_mmap<access_mode::write, ByteT>;
+
+/**
+ * These aliases cover the most common use cases, both representing a raw byte stream
+ * (either with a char or an unsigned char/uint8_t).
+ */
+using mmap_source = basic_mmap_source<char>;
+using ummap_source = basic_mmap_source<unsigned char>;
+
+using mmap_sink = basic_mmap_sink<char>;
+using ummap_sink = basic_mmap_sink<unsigned char>;
+
+/**
+ * Convenience factory method that constructs a mapping for any `basic_mmap` or
+ * `basic_mmap` type.
+ */
+template<
+ typename MMap,
+ typename MappingToken
+> MMap make_mmap(const MappingToken& token,
+ int64_t offset, int64_t length, std::error_code& error)
+{
+ MMap mmap;
+ mmap.map(token, offset, length, error);
+ return mmap;
+}
+
+/**
+ * Convenience factory method.
+ *
+ * MappingToken may be a String (`std::string`, `std::string_view`, `const char*`,
+ * `std::filesystem::path`, `std::vector<char>`, or similar), or a
+ * `mmap_source::handle_type`.
+ */
+template<typename MappingToken>
+mmap_source make_mmap_source(const MappingToken& token, mmap_source::size_type offset,
+ mmap_source::size_type length, std::error_code& error)
+{
+ return make_mmap<mmap_source>(token, offset, length, error);
+}
+
+template<typename MappingToken>
+mmap_source make_mmap_source(const MappingToken& token, std::error_code& error)
+{
+ return make_mmap_source(token, 0, map_entire_file, error);
+}
+
+/**
+ * Convenience factory method.
+ *
+ * MappingToken may be a String (`std::string`, `std::string_view`, `const char*`,
+ * `std::filesystem::path`, `std::vector<char>`, or similar), or a
+ * `mmap_sink::handle_type`.
+ */
+template<typename MappingToken>
+mmap_sink make_mmap_sink(const MappingToken& token, mmap_sink::size_type offset,
+ mmap_sink::size_type length, std::error_code& error)
+{
+ return make_mmap<mmap_sink>(token, offset, length, error);
+}
+
+template<typename MappingToken>
+mmap_sink make_mmap_sink(const MappingToken& token, std::error_code& error)
+{
+ return make_mmap_sink(token, 0, map_entire_file, error);
+}
+
+} // namespace mio
+
+// #include "detail/mmap.ipp"
+/* Copyright 2017 https://github.com/mandreyel
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+ * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+ * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
+ * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
+ * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+#ifndef MIO_BASIC_MMAP_IMPL
+#define MIO_BASIC_MMAP_IMPL
+
+// #include "mio/mmap.hpp"
+
+// #include "mio/page.hpp"
+
+// #include "mio/detail/string_util.hpp"
+/* Copyright 2017 https://github.com/mandreyel
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+ * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+ * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
+ * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
+ * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+#ifndef MIO_STRING_UTIL_HEADER
+#define MIO_STRING_UTIL_HEADER
+
+#include <type_traits>
+
+namespace mio {
+namespace detail {
+
+template<
+ typename S,
+ typename C = typename std::decay<S>::type,
+ typename = decltype(std::declval<C>().data()),
+ typename = typename std::enable_if<
+ std::is_same<typename C::value_type, char>::value
+#ifdef _WIN32
+ || std::is_same<typename C::value_type, wchar_t>::value
+#endif
+ >::type
+> struct char_type_helper {
+ using type = typename C::value_type;
+};
+
+template<class T>
+struct char_type {
+ using type = typename char_type_helper<T>::type;
+};
+
+// TODO: can we avoid this brute force approach?
+template<>
+struct char_type<char*> {
+ using type = char;
+};
+
+template<>
+struct char_type<const char*> {
+ using type = char;
+};
+
+template<size_t N>
+struct char_type<char[N]> {
+ using type = char;
+};
+
+template<size_t N>
+struct char_type<const char[N]> {
+ using type = char;
+};
+
+#ifdef _WIN32
+template<>
+struct char_type<wchar_t*> {
+ using type = wchar_t;
+};
+
+template<>
+struct char_type<const wchar_t*> {
+ using type = wchar_t;
+};
+
+template<size_t N>
+struct char_type<wchar_t[N]> {
+ using type = wchar_t;
+};
+
+template<size_t N>
+struct char_type<const wchar_t[N]> {
+ using type = wchar_t;
+};
+#endif // _WIN32
+
+template<typename CharT, typename S>
+struct is_c_str_helper
+{
+ static constexpr bool value = std::is_same<
+ CharT*,
+ // TODO: I'm so sorry for this... Can this be made cleaner?
+ typename std::add_pointer<
+ typename std::remove_cv<
+ typename std::remove_pointer<
+ typename std::decay<
+ S
+ >::type
+ >::type
+ >::type
+ >::type
+ >::value;
+};
+
+template<typename S>
+struct is_c_str
+{
+ static constexpr bool value = is_c_str_helper<char, S>::value;
+};
+
+#ifdef _WIN32
+template<typename S>
+struct is_c_wstr
+{
+ static constexpr bool value = is_c_str_helper<wchar_t, S>::value;
+};
+#endif // _WIN32
+
+template<typename S>
+struct is_c_str_or_c_wstr
+{
+ static constexpr bool value = is_c_str<S>::value
+#ifdef _WIN32
+ || is_c_wstr<S>::value
+#endif
+ ;
+};
+
+template<
+ typename String,
+ typename = decltype(std::declval<String>().data()),
+ typename = typename std::enable_if<!is_c_str_or_c_wstr<String>::value>::type
+> const typename char_type<String>::type* c_str(const String& path)
+{
+ return path.data();
+}
+
+template<
+ typename String,
+ typename = decltype(std::declval<String>().empty()),
+ typename = typename std::enable_if<!is_c_str_or_c_wstr<String>::value>::type
+> bool empty(const String& path)
+{
+ return path.empty();
+}
+
+template<
+ typename String,
+ typename = typename std::enable_if<is_c_str_or_c_wstr<String>::value>::type
+> const typename char_type<String>::type* c_str(String path)
+{
+ return path;
+}
+
+template<
+ typename String,
+ typename = typename std::enable_if<is_c_str_or_c_wstr<String>::value>::type
+> bool empty(String path)
+{
+ return !path || (*path == 0);
+}
+
+} // namespace detail
+} // namespace mio
+
+#endif // MIO_STRING_UTIL_HEADER
+
+
+#include <algorithm>
+
+#ifndef _WIN32
+# include <unistd.h>
+# include <fcntl.h>
+# include <sys/mman.h>
+# include <sys/stat.h>
+#endif
+
+namespace mio {
+namespace detail {
+
+#ifdef _WIN32
+namespace win {
+
+/** Returns the 4 upper bytes of an 8-byte integer. */
+inline DWORD int64_high(int64_t n) noexcept
+{
+ return n >> 32;
+}
+
+/** Returns the 4 lower bytes of an 8-byte integer. */
+inline DWORD int64_low(int64_t n) noexcept
+{
+ return n & 0xffffffff;
+}
+
+template<
+ typename String,
+ typename = typename std::enable_if<
+ std::is_same<typename char_type<String>::type, char>::value
+ >::type
+> file_handle_type open_file_helper(const String& path, const access_mode mode)
+{
+ return ::CreateFileA(c_str(path),
+ mode == access_mode::read ? GENERIC_READ : GENERIC_READ | GENERIC_WRITE,
+ FILE_SHARE_READ | FILE_SHARE_WRITE,
+ 0,
+ OPEN_EXISTING,
+ FILE_ATTRIBUTE_NORMAL,
+ 0);
+}
+
+template<typename String>
+typename std::enable_if<
+ std::is_same<typename char_type<String>::type, wchar_t>::value,
+ file_handle_type
+>::type open_file_helper(const String& path, const access_mode mode)
+{
+ return ::CreateFileW(c_str(path),
+ mode == access_mode::read ? GENERIC_READ : GENERIC_READ | GENERIC_WRITE,
+ FILE_SHARE_READ | FILE_SHARE_WRITE,
+ 0,
+ OPEN_EXISTING,
+ FILE_ATTRIBUTE_NORMAL,
+ 0);
+}
+
+} // win
+#endif // _WIN32
+
+/**
+ * Returns the last platform specific system error (errno on POSIX and
+ * GetLastError on Win) as a `std::error_code`.
+ */
+inline std::error_code last_error() noexcept
+{
+ std::error_code error;
+#ifdef _WIN32
+ error.assign(GetLastError(), std::system_category());
+#else
+ error.assign(errno, std::system_category());
+#endif
+ return error;
+}
+
+template<typename String>
+file_handle_type open_file(const String& path, const access_mode mode,
+ std::error_code& error)
+{
+ error.clear();
+ if(detail::empty(path))
+ {
+ error = std::make_error_code(std::errc::invalid_argument);
+ return invalid_handle;
+ }
+#ifdef _WIN32
+ const auto handle = win::open_file_helper(path, mode);
+#else // POSIX
+ const auto handle = ::open(c_str(path),
+ mode == access_mode::read ? O_RDONLY : O_RDWR);
+#endif
+ if(handle == invalid_handle)
+ {
+ error = detail::last_error();
+ }
+ return handle;
+}
+
+inline size_t query_file_size(file_handle_type handle, std::error_code& error)
+{
+ error.clear();
+#ifdef _WIN32
+ LARGE_INTEGER file_size;
+ if(::GetFileSizeEx(handle, &file_size) == 0)
+ {
+ error = detail::last_error();
+ return 0;
+ }
+ return static_cast<int64_t>(file_size.QuadPart);
+#else // POSIX
+ struct stat sbuf;
+ if(::fstat(handle, &sbuf) == -1)
+ {
+ error = detail::last_error();
+ return 0;
+ }
+ return sbuf.st_size;
+#endif
+}
+
+struct mmap_context
+{
+ char* data;
+ int64_t length;
+ int64_t mapped_length;
+#ifdef _WIN32
+ file_handle_type file_mapping_handle;
+#endif
+};
+
+inline mmap_context memory_map(const file_handle_type file_handle, const int64_t offset,
+ const int64_t length, const access_mode mode, std::error_code& error)
+{
+ const int64_t aligned_offset = make_offset_page_aligned(offset);
+ const int64_t length_to_map = offset - aligned_offset + length;
+#ifdef _WIN32
+ const int64_t max_file_size = offset + length;
+ const auto file_mapping_handle = ::CreateFileMapping(
+ file_handle,
+ 0,
+ mode == access_mode::read ? PAGE_READONLY : PAGE_READWRITE,
+ win::int64_high(max_file_size),
+ win::int64_low(max_file_size),
+ 0);
+ if(file_mapping_handle == invalid_handle)
+ {
+ error = detail::last_error();
+ return {};
+ }
+ char* mapping_start = static_cast<char*>(::MapViewOfFile(
+ file_mapping_handle,
+ mode == access_mode::read ? FILE_MAP_READ : FILE_MAP_WRITE,
+ win::int64_high(aligned_offset),
+ win::int64_low(aligned_offset),
+ length_to_map));
+ if(mapping_start == nullptr)
+ {
+ error = detail::last_error();
+ return {};
+ }
+#else // POSIX
+ char* mapping_start = static_cast<char*>(::mmap(
+ 0, // Don't give hint as to where to map.
+ length_to_map,
+ mode == access_mode::read ? PROT_READ : PROT_WRITE,
+ MAP_SHARED,
+ file_handle,
+ aligned_offset));
+ if(mapping_start == MAP_FAILED)
+ {
+ error = detail::last_error();
+ return {};
+ }
+#endif
+ mmap_context ctx;
+ ctx.data = mapping_start + offset - aligned_offset;
+ ctx.length = length;
+ ctx.mapped_length = length_to_map;
+#ifdef _WIN32
+ ctx.file_mapping_handle = file_mapping_handle;
+#endif
+ return ctx;
+}
+
+} // namespace detail
+
+// -- basic_mmap --
+
+template<access_mode AccessMode, typename ByteT>
+basic_mmap<AccessMode, ByteT>::~basic_mmap()
+{
+ conditional_sync();
+ unmap();
+}
+
+template<access_mode AccessMode, typename ByteT>
+basic_mmap<AccessMode, ByteT>::basic_mmap(basic_mmap&& other)
+ : data_(std::move(other.data_))
+ , length_(std::move(other.length_))
+ , mapped_length_(std::move(other.mapped_length_))
+ , file_handle_(std::move(other.file_handle_))
+#ifdef _WIN32
+ , file_mapping_handle_(std::move(other.file_mapping_handle_))
+#endif
+ , is_handle_internal_(std::move(other.is_handle_internal_))
+{
+ other.data_ = nullptr;
+ other.length_ = other.mapped_length_ = 0;
+ other.file_handle_ = invalid_handle;
+#ifdef _WIN32
+ other.file_mapping_handle_ = invalid_handle;
+#endif
+}
+
+template<access_mode AccessMode, typename ByteT>
+basic_mmap<AccessMode, ByteT>&
+basic_mmap<AccessMode, ByteT>::operator=(basic_mmap&& other)
+{
+ if(this != &other)
+ {
+ // First the existing mapping needs to be removed.
+ unmap();
+ data_ = std::move(other.data_);
+ length_ = std::move(other.length_);
+ mapped_length_ = std::move(other.mapped_length_);
+ file_handle_ = std::move(other.file_handle_);
+#ifdef _WIN32
+ file_mapping_handle_ = std::move(other.file_mapping_handle_);
+#endif
+ is_handle_internal_ = std::move(other.is_handle_internal_);
+
+ // The moved from basic_mmap's fields need to be reset, because
+ // otherwise other's destructor will unmap the same mapping that was
+ // just moved into this.
+ other.data_ = nullptr;
+ other.length_ = other.mapped_length_ = 0;
+ other.file_handle_ = invalid_handle;
+#ifdef _WIN32
+ other.file_mapping_handle_ = invalid_handle;
+#endif
+ other.is_handle_internal_ = false;
+ }
+ return *this;
+}
+
+template<access_mode AccessMode, typename ByteT>
+typename basic_mmap<AccessMode, ByteT>::handle_type
+basic_mmap<AccessMode, ByteT>::mapping_handle() const noexcept
+{
+#ifdef _WIN32
+ return file_mapping_handle_;
+#else
+ return file_handle_;
+#endif
+}
+
+template<access_mode AccessMode, typename ByteT>
+template<typename String>
+void basic_mmap<AccessMode, ByteT>::map(const String& path, const size_type offset,
+ const size_type length, std::error_code& error)
+{
+ error.clear();
+ if(detail::empty(path))
+ {
+ error = std::make_error_code(std::errc::invalid_argument);
+ return;
+ }
+ const auto handle = detail::open_file(path, AccessMode, error);
+ if(error)
+ {
+ return;
+ }
+
+ map(handle, offset, length, error);
+ // This MUST be after the call to map, as that sets this to true.
+ if(!error)
+ {
+ is_handle_internal_ = true;
+ }
+}
+
+template<access_mode AccessMode, typename ByteT>
+void basic_mmap<AccessMode, ByteT>::map(const handle_type handle,
+ const size_type offset, const size_type length, std::error_code& error)
+{
+ error.clear();
+ if(handle == invalid_handle)
+ {
+ error = std::make_error_code(std::errc::bad_file_descriptor);
+ return;
+ }
+
+ const auto file_size = detail::query_file_size(handle, error);
+ if(error)
+ {
+ return;
+ }
+
+ if(offset + length > file_size)
+ {
+ error = std::make_error_code(std::errc::invalid_argument);
+ return;
+ }
+
+ const auto ctx = detail::memory_map(handle, offset,
+ length == map_entire_file ? (file_size - offset) : length,
+ AccessMode, error);
+ if(!error)
+ {
+ // We must unmap the previous mapping that may have existed prior to this call.
+ // Note that this must only be invoked after a new mapping has been created in
+ // order to provide the strong guarantee that, should the new mapping fail, the
+ // `map` function leaves this instance in a state as though the function had
+ // never been invoked.
+ unmap();
+ file_handle_ = handle;
+ is_handle_internal_ = false;
+ data_ = reinterpret_cast<pointer>(ctx.data);
+ length_ = ctx.length;
+ mapped_length_ = ctx.mapped_length;
+#ifdef _WIN32
+ file_mapping_handle_ = ctx.file_mapping_handle;
+#endif
+ }
+}
+
+template<access_mode AccessMode, typename ByteT>
+template<access_mode A>
+typename std::enable_if<A == access_mode::write, void>::type
+basic_mmap<AccessMode, ByteT>::sync(std::error_code& error)
+{
+ error.clear();
+ if(!is_open())
+ {
+ error = std::make_error_code(std::errc::bad_file_descriptor);
+ return;
+ }
+
+ if(data())
+ {
+#ifdef _WIN32
+ if(::FlushViewOfFile(get_mapping_start(), mapped_length_) == 0
+ || ::FlushFileBuffers(file_handle_) == 0)
+#else // POSIX
+ if(::msync(get_mapping_start(), mapped_length_, MS_SYNC) != 0)
+#endif
+ {
+ error = detail::last_error();
+ return;
+ }
+ }
+#ifdef _WIN32
+ if(::FlushFileBuffers(file_handle_) == 0)
+ {
+ error = detail::last_error();
+ }
+#endif
+}
+
+template<access_mode AccessMode, typename ByteT>
+void basic_mmap<AccessMode, ByteT>::unmap()
+{
+ if(!is_open()) { return; }
+ // TODO do we care about errors here?
+#ifdef _WIN32
+ if(is_mapped())
+ {
+ ::UnmapViewOfFile(get_mapping_start());
+ ::CloseHandle(file_mapping_handle_);
+ }
+#else // POSIX
+ if(data_) { ::munmap(const_cast<pointer>(get_mapping_start()), mapped_length_); }
+#endif
+
+ // If `file_handle_` was obtained by our opening it (when map is called with
+ // a path, rather than an existing file handle), we need to close it,
+ // otherwise it must not be closed as it may still be used outside this
+ // instance.
+ if(is_handle_internal_)
+ {
+#ifdef _WIN32
+ ::CloseHandle(file_handle_);
+#else // POSIX
+ ::close(file_handle_);
+#endif
+ }
+
+ // Reset fields to their default values.
+ data_ = nullptr;
+ length_ = mapped_length_ = 0;
+ file_handle_ = invalid_handle;
+#ifdef _WIN32
+ file_mapping_handle_ = invalid_handle;
+#endif
+}
+
+template<access_mode AccessMode, typename ByteT>
+bool basic_mmap<AccessMode, ByteT>::is_mapped() const noexcept
+{
+#ifdef _WIN32
+ return file_mapping_handle_ != invalid_handle;
+#else // POSIX
+ return is_open();
+#endif
+}
+
+template<access_mode AccessMode, typename ByteT>
+void basic_mmap<AccessMode, ByteT>::swap(basic_mmap& other)
+{
+ if(this != &other)
+ {
+ using std::swap;
+ swap(data_, other.data_);
+ swap(file_handle_, other.file_handle_);
+#ifdef _WIN32
+ swap(file_mapping_handle_, other.file_mapping_handle_);
+#endif
+ swap(length_, other.length_);
+ swap(mapped_length_, other.mapped_length_);
+ swap(is_handle_internal_, other.is_handle_internal_);
+ }
+}
+
+template<access_mode AccessMode, typename ByteT>
+template<access_mode A>
+typename std::enable_if<A == access_mode::write, void>::type
+basic_mmap<AccessMode, ByteT>::conditional_sync()
+{
+ // This is invoked from the destructor, so not much we can do about
+ // failures here.
+ std::error_code ec;
+ sync(ec);
+}
+
+template<access_mode AccessMode, typename ByteT>
+template<access_mode A>
+typename std::enable_if<A == access_mode::read, void>::type
+basic_mmap<AccessMode, ByteT>::conditional_sync()
+{
+ // noop
+}
+
+template<access_mode AccessMode, typename ByteT>
+bool operator==(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b)
+{
+ return a.data() == b.data()
+ && a.size() == b.size();
+}
+
+template<access_mode AccessMode, typename ByteT>
+bool operator!=(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b)
+{
+ return !(a == b);
+}
+
+template<access_mode AccessMode, typename ByteT>
+bool operator<(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b)
+{
+ if(a.data() == b.data()) { return a.size() < b.size(); }
+ return a.data() < b.data();
+}
+
+template<access_mode AccessMode, typename ByteT>
+bool operator<=(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b)
+{
+ return !(a > b);
+}
+
+template<access_mode AccessMode, typename ByteT>
+bool operator>(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b)
+{
+ if(a.data() == b.data()) { return a.size() > b.size(); }
+ return a.data() > b.data();
+}
+
+template<access_mode AccessMode, typename ByteT>
+bool operator>=(const basic_mmap<AccessMode, ByteT>& a,
+ const basic_mmap<AccessMode, ByteT>& b)
+{
+ return !(a < b);
+}
+
+} // namespace mio
+
+#endif // MIO_BASIC_MMAP_IMPL
+
+
+#endif // MIO_MMAP_HEADER
+/* Copyright 2017 https://github.com/mandreyel
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+ * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+ * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
+ * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
+ * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+#ifndef MIO_PAGE_HEADER
+#define MIO_PAGE_HEADER
+
+#ifdef _WIN32
+# include <windows.h>
+#else
+# include <unistd.h>
+#endif
+
+namespace mio {
+
+/**
+ * This is used by `basic_mmap` to determine whether to create a read-only or
+ * a read-write memory mapping.
+ */
+enum class access_mode
+{
+ read,
+ write
+};
+
+/**
+ * Determines the operating system's page allocation granularity.
+ *
+ * On the first call to this function, it invokes the operating system specific syscall
+ * to determine the page size, caches the value, and returns it. Any subsequent call to
+ * this function serves the cached value, so no further syscalls are made.
+ */
+inline size_t page_size()
+{
+ static const size_t page_size = []
+ {
+#ifdef _WIN32
+ SYSTEM_INFO SystemInfo;
+ GetSystemInfo(&SystemInfo);
+ return SystemInfo.dwAllocationGranularity;
+#else
+ return sysconf(_SC_PAGE_SIZE);
+#endif
+ }();
+ return page_size;
+}
+
+/**
+ * Alligns `offset` to the operating's system page size such that it subtracts the
+ * difference until the nearest page boundary before `offset`, or does nothing if
+ * `offset` is already page aligned.
+ */
+inline size_t make_offset_page_aligned(size_t offset) noexcept
+{
+ const size_t page_size_ = page_size();
+ // Use integer division to round down to the nearest page alignment.
+ return offset / page_size_ * page_size_;
+}
+
+} // namespace mio
+
+#endif // MIO_PAGE_HEADER
+/* Copyright 2017 https://github.com/mandreyel
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy of this
+ * software and associated documentation files (the "Software"), to deal in the Software
+ * without restriction, including without limitation the rights to use, copy, modify,
+ * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
+ * permit persons to whom the Software is furnished to do so, subject to the following
+ * conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all copies
+ * or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
+ * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
+ * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+ * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF
+ * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
+ * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+ */
+
+#ifndef MIO_SHARED_MMAP_HEADER
+#define MIO_SHARED_MMAP_HEADER
+
+// #include "mio/mmap.hpp"
+
+
+#include <system_error> // std::error_code
+#include <memory> // std::shared_ptr
+
+namespace mio {
+
+/**
+ * Exposes (nearly) the same interface as `basic_mmap`, but endowes it with
+ * `std::shared_ptr` semantics.
+ *
+ * This is not the default behaviour of `basic_mmap` to avoid allocating on the heap if
+ * shared semantics are not required.
+ */
+template<
+ access_mode AccessMode,
+ typename ByteT
+> class basic_shared_mmap
+{
+ using impl_type = basic_mmap<AccessMode, ByteT>;
+ std::shared_ptr<impl_type> pimpl_;
+
+public:
+ using value_type = typename impl_type::value_type;
+ using size_type = typename impl_type::size_type;
+ using reference = typename impl_type::reference;
+ using const_reference = typename impl_type::const_reference;
+ using pointer = typename impl_type::pointer;
+ using const_pointer = typename impl_type::const_pointer;
+ using difference_type = typename impl_type::difference_type;
+ using iterator = typename impl_type::iterator;
+ using const_iterator = typename impl_type::const_iterator;
+ using reverse_iterator = typename impl_type::reverse_iterator;
+ using const_reverse_iterator = typename impl_type::const_reverse_iterator;
+ using iterator_category = typename impl_type::iterator_category;
+ using handle_type = typename impl_type::handle_type;
+ using mmap_type = impl_type;
+
+ basic_shared_mmap() = default;
+ basic_shared_mmap(const basic_shared_mmap&) = default;
+ basic_shared_mmap& operator=(const basic_shared_mmap&) = default;
+ basic_shared_mmap(basic_shared_mmap&&) = default;
+ basic_shared_mmap& operator=(basic_shared_mmap&&) = default;
+
+ /** Takes ownership of an existing mmap object. */
+ basic_shared_mmap(mmap_type&& mmap)
+ : pimpl_(std::make_shared<mmap_type>(std::move(mmap)))
+ {}
+
+ /** Takes ownership of an existing mmap object. */
+ basic_shared_mmap& operator=(mmap_type&& mmap)
+ {
+ pimpl_ = std::make_shared<mmap_type>(std::move(mmap));
+ return *this;
+ }
+
+ /** Initializes this object with an already established shared mmap. */
+ basic_shared_mmap(std::shared_ptr<mmap_type> mmap) : pimpl_(std::move(mmap)) {}
+
+ /** Initializes this object with an already established shared mmap. */
+ basic_shared_mmap& operator=(std::shared_ptr<mmap_type> mmap)
+ {
+ pimpl_ = std::move(mmap);
+ return *this;
+ }
+
+#ifdef __cpp_exceptions
+ /**
+ * The same as invoking the `map` function, except any error that may occur
+ * while establishing the mapping is wrapped in a `std::system_error` and is
+ * thrown.
+ */
+ template<typename String>
+ basic_shared_mmap(const String& path, const size_type offset = 0, const size_type length = map_entire_file)
+ {
+ std::error_code error;
+ map(path, offset, length, error);
+ if(error) { throw std::system_error(error); }
+ }
+
+ /**
+ * The same as invoking the `map` function, except any error that may occur
+ * while establishing the mapping is wrapped in a `std::system_error` and is
+ * thrown.
+ */
+ basic_shared_mmap(const handle_type handle, const size_type offset = 0, const size_type length = map_entire_file)
+ {
+ std::error_code error;
+ map(handle, offset, length, error);
+ if(error) { throw std::system_error(error); }
+ }
+#endif // __cpp_exceptions
+
+ /**
+ * If this is a read-write mapping and the last reference to the mapping,
+ * the destructor invokes sync. Regardless of the access mode, unmap is
+ * invoked as a final step.
+ */
+ ~basic_shared_mmap() = default;
+
+ /** Returns the underlying `std::shared_ptr` instance that holds the mmap. */
+ std::shared_ptr<mmap_type> get_shared_ptr() { return pimpl_; }
+
+ /**
+ * On UNIX systems 'file_handle' and 'mapping_handle' are the same. On Windows,
+ * however, a mapped region of a file gets its own handle, which is returned by
+ * 'mapping_handle'.
+ */
+ handle_type file_handle() const noexcept
+ {
+ return pimpl_ ? pimpl_->file_handle() : invalid_handle;
+ }
+
+ handle_type mapping_handle() const noexcept
+ {
+ return pimpl_ ? pimpl_->mapping_handle() : invalid_handle;
+ }
+
+ /** Returns whether a valid memory mapping has been created. */
+ bool is_open() const noexcept { return pimpl_ && pimpl_->is_open(); }
+
+ /**
+ * Returns true if no mapping was established, that is, conceptually the
+ * same as though the length that was mapped was 0. This function is
+ * provided so that this class has Container semantics.
+ */
+ bool empty() const noexcept { return !pimpl_ || pimpl_->empty(); }
+
+ /**
+ * `size` and `length` both return the logical length, i.e. the number of bytes
+ * user requested to be mapped, while `mapped_length` returns the actual number of
+ * bytes that were mapped which is a multiple of the underlying operating system's
+ * page allocation granularity.
+ */
+ size_type size() const noexcept { return pimpl_ ? pimpl_->length() : 0; }
+ size_type length() const noexcept { return pimpl_ ? pimpl_->length() : 0; }
+ size_type mapped_length() const noexcept
+ {
+ return pimpl_ ? pimpl_->mapped_length() : 0;
+ }
+
+ /**
+ * Returns a pointer to the first requested byte, or `nullptr` if no memory mapping
+ * exists.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > pointer data() noexcept { return pimpl_->data(); }
+ const_pointer data() const noexcept { return pimpl_ ? pimpl_->data() : nullptr; }
+
+ /**
+ * Returns an iterator to the first requested byte, if a valid memory mapping
+ * exists, otherwise this function call is undefined behaviour.
+ */
+ iterator begin() noexcept { return pimpl_->begin(); }
+ const_iterator begin() const noexcept { return pimpl_->begin(); }
+ const_iterator cbegin() const noexcept { return pimpl_->cbegin(); }
+
+ /**
+ * Returns an iterator one past the last requested byte, if a valid memory mapping
+ * exists, otherwise this function call is undefined behaviour.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > iterator end() noexcept { return pimpl_->end(); }
+ const_iterator end() const noexcept { return pimpl_->end(); }
+ const_iterator cend() const noexcept { return pimpl_->cend(); }
+
+ /**
+ * Returns a reverse iterator to the last memory mapped byte, if a valid
+ * memory mapping exists, otherwise this function call is undefined
+ * behaviour.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > reverse_iterator rbegin() noexcept { return pimpl_->rbegin(); }
+ const_reverse_iterator rbegin() const noexcept { return pimpl_->rbegin(); }
+ const_reverse_iterator crbegin() const noexcept { return pimpl_->crbegin(); }
+
+ /**
+ * Returns a reverse iterator past the first mapped byte, if a valid memory
+ * mapping exists, otherwise this function call is undefined behaviour.
+ */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > reverse_iterator rend() noexcept { return pimpl_->rend(); }
+ const_reverse_iterator rend() const noexcept { return pimpl_->rend(); }
+ const_reverse_iterator crend() const noexcept { return pimpl_->crend(); }
+
+ /**
+ * Returns a reference to the `i`th byte from the first requested byte (as returned
+ * by `data`). If this is invoked when no valid memory mapping has been created
+ * prior to this call, undefined behaviour ensues.
+ */
+ reference operator[](const size_type i) noexcept { return (*pimpl_)[i]; }
+ const_reference operator[](const size_type i) const noexcept { return (*pimpl_)[i]; }
+
+ /**
+ * Establishes a memory mapping with AccessMode. If the mapping is unsuccesful, the
+ * reason is reported via `error` and the object remains in a state as if this
+ * function hadn't been called.
+ *
+ * `path`, which must be a path to an existing file, is used to retrieve a file
+ * handle (which is closed when the object destructs or `unmap` is called), which is
+ * then used to memory map the requested region. Upon failure, `error` is set to
+ * indicate the reason and the object remains in an unmapped state.
+ *
+ * `offset` is the number of bytes, relative to the start of the file, where the
+ * mapping should begin. When specifying it, there is no need to worry about
+ * providing a value that is aligned with the operating system's page allocation
+ * granularity. This is adjusted by the implementation such that the first requested
+ * byte (as returned by `data` or `begin`), so long as `offset` is valid, will be at
+ * `offset` from the start of the file.
+ *
+ * `length` is the number of bytes to map. It may be `map_entire_file`, in which
+ * case a mapping of the entire file is created.
+ */
+ template<typename String>
+ void map(const String& path, const size_type offset,
+ const size_type length, std::error_code& error)
+ {
+ map_impl(path, offset, length, error);
+ }
+
+ /**
+ * Establishes a memory mapping with AccessMode. If the mapping is unsuccesful, the
+ * reason is reported via `error` and the object remains in a state as if this
+ * function hadn't been called.
+ *
+ * `path`, which must be a path to an existing file, is used to retrieve a file
+ * handle (which is closed when the object destructs or `unmap` is called), which is
+ * then used to memory map the requested region. Upon failure, `error` is set to
+ * indicate the reason and the object remains in an unmapped state.
+ *
+ * The entire file is mapped.
+ */
+ template<typename String>
+ void map(const String& path, std::error_code& error)
+ {
+ map_impl(path, 0, map_entire_file, error);
+ }
+
+ /**
+ * Establishes a memory mapping with AccessMode. If the mapping is unsuccesful, the
+ * reason is reported via `error` and the object remains in a state as if this
+ * function hadn't been called.
+ *
+ * `handle`, which must be a valid file handle, which is used to memory map the
+ * requested region. Upon failure, `error` is set to indicate the reason and the
+ * object remains in an unmapped state.
+ *
+ * `offset` is the number of bytes, relative to the start of the file, where the
+ * mapping should begin. When specifying it, there is no need to worry about
+ * providing a value that is aligned with the operating system's page allocation
+ * granularity. This is adjusted by the implementation such that the first requested
+ * byte (as returned by `data` or `begin`), so long as `offset` is valid, will be at
+ * `offset` from the start of the file.
+ *
+ * `length` is the number of bytes to map. It may be `map_entire_file`, in which
+ * case a mapping of the entire file is created.
+ */
+ void map(const handle_type handle, const size_type offset,
+ const size_type length, std::error_code& error)
+ {
+ map_impl(handle, offset, length, error);
+ }
+
+ /**
+ * Establishes a memory mapping with AccessMode. If the mapping is unsuccesful, the
+ * reason is reported via `error` and the object remains in a state as if this
+ * function hadn't been called.
+ *
+ * `handle`, which must be a valid file handle, which is used to memory map the
+ * requested region. Upon failure, `error` is set to indicate the reason and the
+ * object remains in an unmapped state.
+ *
+ * The entire file is mapped.
+ */
+ void map(const handle_type handle, std::error_code& error)
+ {
+ map_impl(handle, 0, map_entire_file, error);
+ }
+
+ /**
+ * If a valid memory mapping has been created prior to this call, this call
+ * instructs the kernel to unmap the memory region and disassociate this object
+ * from the file.
+ *
+ * The file handle associated with the file that is mapped is only closed if the
+ * mapping was created using a file path. If, on the other hand, an existing
+ * file handle was used to create the mapping, the file handle is not closed.
+ */
+ void unmap() { if(pimpl_) pimpl_->unmap(); }
+
+ void swap(basic_shared_mmap& other) { pimpl_.swap(other.pimpl_); }
+
+ /** Flushes the memory mapped page to disk. Errors are reported via `error`. */
+ template<
+ access_mode A = AccessMode,
+ typename = typename std::enable_if<A == access_mode::write>::type
+ > void sync(std::error_code& error) { if(pimpl_) pimpl_->sync(error); }
+
+ /** All operators compare the underlying `basic_mmap`'s addresses. */
+
+ friend bool operator==(const basic_shared_mmap& a, const basic_shared_mmap& b)
+ {
+ return a.pimpl_ == b.pimpl_;
+ }
+
+ friend bool operator!=(const basic_shared_mmap& a, const basic_shared_mmap& b)
+ {
+ return !(a == b);
+ }
+
+ friend bool operator<(const basic_shared_mmap& a, const basic_shared_mmap& b)
+ {
+ return a.pimpl_ < b.pimpl_;
+ }
+
+ friend bool operator<=(const basic_shared_mmap& a, const basic_shared_mmap& b)
+ {
+ return a.pimpl_ <= b.pimpl_;
+ }
+
+ friend bool operator>(const basic_shared_mmap& a, const basic_shared_mmap& b)
+ {
+ return a.pimpl_ > b.pimpl_;
+ }
+
+ friend bool operator>=(const basic_shared_mmap& a, const basic_shared_mmap& b)
+ {
+ return a.pimpl_ >= b.pimpl_;
+ }
+
+private:
+ template<typename MappingToken>
+ void map_impl(const MappingToken& token, const size_type offset,
+ const size_type length, std::error_code& error)
+ {
+ if(!pimpl_)
+ {
+ mmap_type mmap = make_mmap<mmap_type>(token, offset, length, error);
+ if(error) { return; }
+ pimpl_ = std::make_shared<mmap_type>(std::move(mmap));
+ }
+ else
+ {
+ pimpl_->map(token, offset, length, error);
+ }
+ }
+};
+
+/**
+ * This is the basis for all read-only mmap objects and should be preferred over
+ * directly using basic_shared_mmap.
+ */
+template<typename ByteT>
+using basic_shared_mmap_source = basic_shared_mmap<access_mode::read, ByteT>;
+
+/**
+ * This is the basis for all read-write mmap objects and should be preferred over
+ * directly using basic_shared_mmap.
+ */
+template<typename ByteT>
+using basic_shared_mmap_sink = basic_shared_mmap<access_mode::write, ByteT>;
+
+/**
+ * These aliases cover the most common use cases, both representing a raw byte stream
+ * (either with a char or an unsigned char/uint8_t).
+ */
+using shared_mmap_source = basic_shared_mmap_source<char>;
+using shared_ummap_source = basic_shared_mmap_source<unsigned char>;
+
+using shared_mmap_sink = basic_shared_mmap_sink<char>;
+using shared_ummap_sink = basic_shared_mmap_sink<unsigned char>;
+
+} // namespace mio
+
+#endif // MIO_SHARED_MMAP_HEADER
diff --git a/src/3rd_party/nccl b/src/3rd_party/nccl
-Subproject d6297d250433715c283d17f1969cfcb50d2b653
+Subproject b56650c7f59b8cd40d18809784a6d6be38ef8ac
diff --git a/src/3rd_party/pathie-cpp/src/entry_iterator.cpp b/src/3rd_party/pathie-cpp/src/entry_iterator.cpp
index e2ecb2fe..baaf5394 100644
--- a/src/3rd_party/pathie-cpp/src/entry_iterator.cpp
+++ b/src/3rd_party/pathie-cpp/src/entry_iterator.cpp
@@ -31,7 +31,7 @@
#include "../include/path.hpp"
#include "../include/errors.hpp"
-#if defined(__unix__)
+#if defined(__unix__) || defined(__APPLE__)
#include <sys/types.h>
#include <dirent.h>
#include <errno.h>
@@ -178,7 +178,7 @@ entry_iterator& entry_iterator::operator++(int)
/// Same as the other operator++().
entry_iterator& entry_iterator::operator++()
{
- return (operator++());
+ return (operator++(0));
}
/**
diff --git a/src/3rd_party/pathie-cpp/src/path.cpp b/src/3rd_party/pathie-cpp/src/path.cpp
index 3dc1e14b..e732e09c 100644
--- a/src/3rd_party/pathie-cpp/src/path.cpp
+++ b/src/3rd_party/pathie-cpp/src/path.cpp
@@ -51,7 +51,7 @@
#include <shlwapi.h>
//#include <ntifs.h> // Currently not in msys2
-// @TODO: This is a hack to make it compile under Windows, check if this is save.
+// @TODO: This is a hack to make it compile under Windows, check if this is safe.
#define F_OK 0
#elif defined(_PATHIE_UNIX)
@@ -149,6 +149,8 @@ Path::Path(const std::vector<Path>& components)
*/
void Path::sanitize()
{
+ bool isWindowsUNCPath = m_path.size() >= 2 && (m_path[0] == '\\' && m_path[1] == '\\'); // UNC path
+
// Replace any backslashes \ with forward slashes /.
size_t cur = string::npos;
while ((cur = m_path.find("\\")) != string::npos) { // assignment intended
@@ -156,8 +158,9 @@ void Path::sanitize()
}
// Replace all double slashes // with a single one
+ // [fseide] except for the first position, which would be a Windows UNC path
cur = string::npos;
- while ((cur = m_path.find("//")) != string::npos) { // assignment intended
+ while ((cur = m_path.find("//", isWindowsUNCPath ? 1 : 0)) != string::npos) { // assignment intended
m_path.replace(cur, 2, "/");
}
@@ -899,7 +902,7 @@ Path Path::pwd()
*/
Path Path::exe()
{
-#if defined(__linux__)
+#if defined(__linux__) || defined(__APPLE__)
char buf[PATH_MAX];
ssize_t size = ::readlink("/proc/self/exe", buf, PATH_MAX);
@@ -1546,7 +1549,7 @@ bool Path::is_directory() const
throw(Pathie::ErrnoError(errsav));
}
- return s.st_mode & S_IFDIR;
+ return (s.st_mode & S_IFDIR) != 0;
#else
#error Unsupported system.
#endif
@@ -1590,7 +1593,7 @@ bool Path::is_file() const
throw(Pathie::ErrnoError(errno));
}
- return s.st_mode & S_IFREG;
+ return (s.st_mode & S_IFREG) != 0;
#else
#error Unsupported system.
#endif
@@ -1710,9 +1713,9 @@ void Path::remove() const
* function uses the apropriate native Win32API function
* calls accordingly therefore. */
if (is_directory())
- result = RemoveDirectoryW(utf16.c_str());
+ result = RemoveDirectoryW(utf16.c_str()) != 0;
else
- result = DeleteFileW(utf16.c_str());
+ result = DeleteFileW(utf16.c_str()) != 0;
if (!result) {
DWORD err = GetLastError();
@@ -3282,7 +3285,7 @@ bool Path::fnmatch(const std::string& pattern, int flags /* = 0 */) const
#elif defined(_WIN32)
std::wstring utf16path = utf8_to_utf16(m_path);
std::wstring utf16pattern = utf8_to_utf16(pattern);
- return PathMatchSpecW(utf16path.c_str(), utf16pattern.c_str());
+ return PathMatchSpecW(utf16path.c_str(), utf16pattern.c_str()) != 0;
#else
#error Unsupported system.
#endif
diff --git a/src/3rd_party/pathie-cpp/src/pathie.cpp b/src/3rd_party/pathie-cpp/src/pathie.cpp
index c0abc689..d8e567e1 100644
--- a/src/3rd_party/pathie-cpp/src/pathie.cpp
+++ b/src/3rd_party/pathie-cpp/src/pathie.cpp
@@ -143,7 +143,7 @@ std::string Pathie::convert_encodings(const char* from_encoding, const char* to_
errno = 0;
errsav = 0;
-#ifdef BSD
+#if defined(BSD) && ! defined(__APPLE__) //Since MacOS evolved from BSD, it is captured here but the iconv on macos behaves differently
// What the heck. FreeBSD violates POSIX.1-2008: it declares iconv()
// differently than mandated by POSIX: http://pubs.opengroup.org/onlinepubs/9699919799/functions/iconv.html
// (it declares a `const' where it must not be).
@@ -181,11 +181,10 @@ std::string Pathie::convert_encodings(const char* from_encoding, const char* to_
std::string Pathie::utf8_to_filename(const std::string& utf8)
{
bool fs_encoding_is_utf8 = false;
-
+ char* fsencoding = NULL;
#if defined(__APPLE__) || defined(PATHIE_ASSUME_UTF8_ON_UNIX)
fs_encoding_is_utf8 = true;
#else
- char* fsencoding = NULL;
fsencoding = nl_langinfo(CODESET);
fs_encoding_is_utf8 = (strcmp(fsencoding, "UTF-8") == 0);
#endif
@@ -206,11 +205,10 @@ std::string Pathie::utf8_to_filename(const std::string& utf8)
std::string Pathie::filename_to_utf8(const std::string& native_filename)
{
bool fs_encoding_is_utf8 = false;
-
+ char* fsencoding = NULL;
#if defined(__APPLE__) || defined(PATHIE_ASSUME_UTF8_ON_UNIX)
fs_encoding_is_utf8 = true;
#else
- char* fsencoding = NULL;
fsencoding = nl_langinfo(CODESET);
fs_encoding_is_utf8 = (strcmp(fsencoding, "UTF-8") == 0);
#endif
diff --git a/src/3rd_party/phf/LICENSE b/src/3rd_party/phf/LICENSE
new file mode 100644
index 00000000..039208c4
--- /dev/null
+++ b/src/3rd_party/phf/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2014-2015 William Ahern
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to
+deal in the Software without restriction, including without limitation the
+rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+sell copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+IN THE SOFTWARE.
diff --git a/src/3rd_party/phf/README.md b/src/3rd_party/phf/README.md
new file mode 100644
index 00000000..f3fd6287
--- /dev/null
+++ b/src/3rd_party/phf/README.md
@@ -0,0 +1,182 @@
+# Introduction #
+
+This is a simple implementation of the CHD perfect hash algorithm. CHD can
+generate perfect hash functions for very large key sets--on the order of
+millions of keys--in a very short time. On my circa 2012 desktop and using
+the default parameters (hash load factor of 80% and average displacement map
+bucket load of 4.0 keys) this implementation can generate a hash function
+for 1,000 keys in less than 1/100th of a second, and 1,000,000 keys in less
+than a second.
+
+For more information about the algorithm, see
+http://cmph.sourceforge.net/chd.html.
+
+# Dependencies #
+
+* No runtime dependencies.
+* Requires a modern C++ compiler to build.
+* The included build requires GNU Make.
+
+# Building #
+
+## Make Macros ##
+
+The typical GNU macros can be used control the build.
+
+### Compilation ###
+
+Note that the modules for Lua 5.1, 5.2, and 5.3 can be built simultaneously.
+
+* CXX: C++ compiler path.
+* CXXFLAGS: C++ compiler flags.
+* CPPFLAGS: C preprocessor flags. Necessary if Lua API cannot be discovered
+ automatically. You can specify multiple include paths if building more than
+ one Lua module.
+* LDFLAGS: Linker flags. Not normally needed.
+* SOFLAGS: Flags needed to build dynamic library.
+* LOFLAGS: Flags needed to build loadable module. Normally should be the
+ same as SOFLAGS, except on OS X.
+* LIBS: Library dependencies. Normally empty, but see the section Avoiding
+ C++ Dependencies.
+
+#### Avoiding C++ Dependencies
+
+Defining the preprocessor macro PHF_NO_LIBCXX to 1 will prevent usage of C++
+interfaces such as std::string that would require a dependency on libc++ or
+libstdc++. This allows using platform-dependent flags in CXXFLAGS, LDFLAGS,
+and SOFLAGS to prevent a dependency on the system C++ library.
+
+For example, on OS X you can do:
+```sh
+$ make CPPFLAGS="-DPHF_NO_LIBCXX" \
+CXXFLAGS="-std=c++11 -fno-rtti -fno-exceptions -O3 -march=native" \
+LDFLAGS="-nostdlib" \
+LIBS="-lSystem"
+```
+
+### Installation ####
+* prefix
+* includedir
+* libdir
+* luacpath: Lua C module install path. Can be used for one-shot installation
+ of a particular Lua version module.
+* lua51cpath: Lua 5.1 C module install path.
+* lua52cpath: Same as above, for 5.2.
+* lua53cpath: Same as above, for 5.3.
+
+## Make Targets ##
+
+* phf: Builds command-line utility (development)
+* libphf.so: Builds dynamic library for non-OS X
+* libphf.dylib: Builds dynamic library for OS X
+* lua5.1: Builds Lua 5.1 module at 5.1/phf.so. Lua 5.1 headers should be
+ specified using CPPFLAGS if not in normal locations.
+* lua5.2: Same as above, for Lua 5.2.
+* lua5.3: Same as above, for Lua 5.3.
+
+# Usage #
+
+## Lua ##
+
+## API ###
+
+### phf.new(keys[, lambda][, alpha][, seed][, nodiv]) ###
+
+* keys: array of keys in order from 1..#keys. They should be all
+ numbers or all strings.
+
+* lambda: number of keys per bucket when generating the g() function mapping.
+
+* alpha: output hash space loading factor as percentage from
+ 1..100. 100% generates a *minimal* perfect hash function. But note that
+ the implementation does *not* implement the necessary optimizations to
+ ensure timely generation of minimal perfect hash functions. Normally you
+ want a loading factor of 80% to 90% for large key sets.
+
+* seed: random integer seed.
+
+* nodiv: if true rounds r and m to powers of 2, and performs modular
+ reduction using bitwise AND. Otherwise, r and m are rounded up to the
+ nearest primes and modulo division used when indexing tables. Note that
+ the rounding occurs after calculation of the intermediate and output hash
+ table loading.
+
+ This is more important when building small hash tables with the C
+ interface. The optimization is substantial when the compiler can inline
+ the code, but isn't substantial from Lua.
+
+Returns a callable object.
+
+### phf:hash(key)
+
+* Returns an integer hash in the range 1..phf:m(). The returned integer will
+ be unique for all keys in the original set. Otherwise the result is
+ unspecified.
+
+### Example ###
+
+```Lua
+local phf = require"phf"
+
+local lambda = 4 -- how many keys per intermediate bucket
+local alpha = 80 -- output hash space loading in percentage.
+
+local keys = { "apple", "banana", "cherry", "date", "eggplant", "fig",
+ "guava", "honeydew", "jackfruit", "kiwi", "lemon", "mango" }
+
+local F = phf.new(keys, lambda, alpha)
+
+for i=1,#keys do
+ print(keys[i], F(keys[i]))
+end
+
+```
+
+## C++ ##
+
+## API ##
+
+### PHF::uniq<T>(T k[], size_t n); ###
+
+Similar to the shell command `sort | uniq`. Sorts, deduplicates, and shifts
+down the keys in the array k. Returns the number of unique keys, which will
+have been moved to the beginning of the array. If necessary do this before
+calling PHF::init, as PHF::init does not tolerate duplicate keys.
+
+### int PHF::init<T, nodiv>(struct phf *f, const T k[], size_t n, size_t l, size_t a, phf_seed_t s);
+
+Generate a perfect hash function for the n keys in array k and store the
+results in f. Returns a system error number on failure, or 0 on success. f
+is unmodified on failure.
+
+### void PHF::destroy(struct phf *);
+
+Deallocates internal tables, but not the struct object itself.
+
+### void PHF::compact<T, nodiv>(struct phf *);
+
+By default the displacement map is an array of uint32_t integers. This
+function will select the smallest type necessary to hold the largest
+displacement value and update the internal state accordingly. For a loading
+factor of 80% (0.8) in the output hash space, and displacement map loading
+factor of 4 (400%), the smallest primitive type will often be uint8_t.
+
+### phf_hash_t PHF::hash<T>(struct phf *f, T k);
+
+Returns an integer hash value, h, where 0 <= h < f->m. h will be unique for
+each unique key provided when generating the function. f->m will be larger
+than the number of unique keys and is based on the specified loading factor
+(alpha), rounded up to the nearest prime or nearest power of 2, depending on
+the mode of modular reduction selected. For example, for a loading factor of
+80% m will be 127: 100 is 80% of 125, and 127 is the closest prime greater
+than or equal to 125. With the nodiv option, m would be 128: 100 is 80% of
+125, and 128 is the closest power of 2 greater than or equal to 125.
+
+## C ##
+
+The C API is nearly identical to the C++ API, except the prefix is phf_
+instead of PHF::. phf_uniq, phf_init, and phf_hash are macros which utilize
+C11's _Generic or GCC's __builtin_types_compatible_p interfaces to overload
+the interfaces by key type. The explicit suffixes _uint32, _uint64, and
+_string may be used directly.
+
diff --git a/src/3rd_party/phf/phf.cc b/src/3rd_party/phf/phf.cc
new file mode 100644
index 00000000..d00d5c5e
--- /dev/null
+++ b/src/3rd_party/phf/phf.cc
@@ -0,0 +1,1478 @@
+/* ==========================================================================
+ * phf.cc - Tiny perfect hash function library.
+ * --------------------------------------------------------------------------
+ * Copyright (c) 2014-2015 William Ahern
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to permit
+ * persons to whom the Software is furnished to do so, subject to the
+ * following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
+ * NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+ * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
+ * USE OR OTHER DEALINGS IN THE SOFTWARE.
+ * ==========================================================================
+ */
+#include <limits.h> /* CHAR_BIT SIZE_MAX */
+#include <inttypes.h> /* PRIu32 PRIu64 PRIx64 */
+#include <stdint.h> /* UINT32_C UINT64_C uint32_t uint64_t */
+#include <stdlib.h> /* abort(3) calloc(3) free(3) qsort(3) */
+#include <string.h> /* memset(3) */
+#include <errno.h> /* errno */
+#include <assert.h> /* assert(3) */
+#if !PHF_NO_LIBCXX
+#include <string> /* std::string */
+#endif
+
+
+#include "phf.h"
+
+
+#ifdef __clang__
+#pragma clang diagnostic ignored "-Wunused-function"
+#if __cplusplus < 201103L
+#pragma clang diagnostic ignored "-Wc++11-long-long"
+#endif
+#elif PHF_GNUC_PREREQ(4, 6)
+#pragma GCC diagnostic ignored "-Wunused-function"
+#if __cplusplus < 201103L
+#pragma GCC diagnostic ignored "-Wlong-long"
+#pragma GCC diagnostic ignored "-Wformat" // %zu
+#endif
+#endif
+
+
+
+/*
+ * M A C R O R O U T I N E S
+ *
+ * Mostly copies of <sys/param.h>
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+#define PHF_BITS(T) (sizeof (T) * CHAR_BIT)
+#define PHF_HOWMANY(x, y) (((x) + ((y) - 1)) / (y))
+#define PHF_MIN(a, b) (((a) < (b))? (a) : (b))
+#define PHF_MAX(a, b) (((a) > (b))? (a) : (b))
+#define PHF_ROTL(x, y) (((x) << (y)) | ((x) >> (PHF_BITS(x) - (y))))
+#define PHF_COUNTOF(a) (sizeof (a) / sizeof *(a))
+
+
+/*
+ * M O D U L A R A R I T H M E T I C R O U T I N E S
+ *
+ * Two modular reduction schemes are supported: bitwise AND and naive
+ * modular division. For bitwise AND we must round up the values r and m to
+ * a power of 2.
+ *
+ * TODO: Implement and test Barrett reduction as alternative to naive
+ * modular division.
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+/* round up to nearest power of 2 */
+static inline size_t phf_powerup(size_t i) {
+#if defined SIZE_MAX
+ i--;
+ i |= i >> 1;
+ i |= i >> 2;
+ i |= i >> 4;
+ i |= i >> 8;
+ i |= i >> 16;
+#if SIZE_MAX != 0xffffffffu
+ i |= i >> 32;
+#endif
+ return ++i;
+#else
+#error No SIZE_MAX defined
+#endif
+} /* phf_powerup() */
+
+static inline uint64_t phf_a_s_mod_n(uint64_t a, uint64_t s, uint64_t n) {
+ uint64_t v;
+
+ assert(n <= UINT32_MAX);
+
+ v = 1;
+ a %= n;
+
+ while (s > 0) {
+ if (s % 2 == 1)
+ v = (v * a) % n;
+ a = (a * a) % n;
+ s /= 2;
+ }
+
+ return v;
+} /* phf_a_s_mod_n() */
+
+/*
+ * Rabin-Miller primality test adapted from Niels Ferguson and Bruce
+ * Schneier, "Practical Cryptography" (Wiley, 2003), 201-204.
+ */
+static inline bool phf_witness(uint64_t n, uint64_t a, uint64_t s, uint64_t t) {
+ uint64_t v, i;
+
+ assert(a > 0 && a < n);
+ assert(n <= UINT32_MAX);
+
+ if (1 == (v = phf_a_s_mod_n(a, s, n)))
+ return 1;
+
+ for (i = 0; v != n - 1; i++) {
+ if (i == t - 1)
+ return 0;
+ v = (v * v) % n;
+ }
+
+ return 1;
+} /* phf_witness() */
+
+static inline bool phf_rabinmiller(uint64_t n) {
+ /*
+ * Witness 2 is deterministic for all n < 2047. Witnesses 2, 7, 61
+ * are deterministic for all n < 4,759,123,141.
+ */
+ static const int witness[] = { 2, 7, 61 };
+ uint64_t s, t, i;
+
+ assert(n <= UINT32_MAX);
+
+ if (n < 3 || n % 2 == 0)
+ return 0;
+
+ /* derive 2^t * s = n - 1 where s is odd */
+ s = n - 1;
+ t = 0;
+ while (s % 2 == 0) {
+ s /= 2;
+ t++;
+ }
+
+ /* NB: witness a must be 1 <= a < n */
+ if (n < 2027)
+ return phf_witness(n, 2, s, t);
+
+ for (i = 0; i < PHF_COUNTOF(witness); i++) {
+ if (!phf_witness(n, witness[i], s, t))
+ return 0;
+ }
+
+ return 1;
+} /* phf_rabinmiller() */
+
+static inline bool phf_isprime(size_t n) {
+ static const char map[] = { 0, 1, 2, 3, 0, 5, 0, 7 };
+ size_t i;
+
+ if (n < PHF_COUNTOF(map))
+ return map[n];
+
+ for (i = 2; i < PHF_COUNTOF(map); i++) {
+ if (map[i] && (n % map[i] == 0))
+ return 0;
+ }
+
+ return phf_rabinmiller(n);
+} /* phf_isprime() */
+
+static inline size_t phf_primeup(size_t n) {
+ /* NB: 4294967291 is largest 32-bit prime */
+ if (n > 4294967291)
+ return 0;
+
+ while (n < SIZE_MAX && !phf_isprime(n))
+ n++;
+
+ return n;
+} /* phf_primeup() */
+
+
+/*
+ * B I T M A P R O U T I N E S
+ *
+ * We use a bitmap to track output hash occupancy when searching for
+ * displacement values.
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+typedef unsigned long phf_bits_t;
+
+static inline bool phf_isset(phf_bits_t *set, size_t i) {
+ return set[i / PHF_BITS(*set)] & ((size_t)1 << (i % PHF_BITS(*set)));
+} /* phf_isset() */
+
+static inline void phf_setbit(phf_bits_t *set, size_t i) {
+ set[i / PHF_BITS(*set)] |= ((size_t)1 << (i % PHF_BITS(*set)));
+} /* phf_setbit() */
+
+static inline void phf_clrbit(phf_bits_t *set, size_t i) {
+ set[i / PHF_BITS(*set)] &= ~((size_t)1 << (i % PHF_BITS(*set)));
+} /* phf_clrbit() */
+
+static inline void phf_clrall(phf_bits_t *set, size_t n) {
+ memset(set, '\0', PHF_HOWMANY(n, PHF_BITS(*set)) * sizeof *set);
+} /* phf_clrall() */
+
+
+/*
+ * K E Y D E D U P L I C A T I O N
+ *
+ * Auxiliary routine to ensure uniqueness of each key in array.
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+namespace PHF {
+ namespace Uniq {
+ static bool operator!=(const phf_string_t &a, const phf_string_t &b) {
+ return a.n != b.n || 0 != memcmp(a.p, b.p, a.n);
+ }
+
+ template<typename T>
+ static int cmp(const T *a, const T *b) {
+ if (*a > *b)
+ return -1;
+ if (*a < *b)
+ return 1;
+ return 0;
+ } /* cmp() */
+
+ template<>
+ int cmp(const phf_string_t *a, const phf_string_t *b) {
+ int cmp;
+ if ((cmp = memcmp(a->p, b->p, PHF_MIN(a->n, b->n))))
+ return cmp;
+ if (a->n > b->n)
+ return -1;
+ if (a->n < b->n)
+ return 1;
+ return 0;
+ } /* cmp<phf_string_t>() */
+ } /* Uniq:: */
+} /* PHF:: */
+
+template<typename key_t>
+PHF_PUBLIC size_t PHF::uniq(key_t k[], const size_t n) {
+ using namespace PHF::Uniq;
+ size_t i, j;
+
+ qsort(k, n, sizeof *k, reinterpret_cast<int(*)(const void *, const void *)>(&cmp<key_t>));
+
+ for (i = 1, j = 0; i < n; i++) {
+ if (k[i] != k[j])
+ k[++j] = k[i];
+ }
+
+ return (n > 0)? j + 1 : 0;
+} /* PHF::uniq() */
+
+template size_t PHF::uniq<uint32_t>(uint32_t[], const size_t);
+template size_t PHF::uniq<uint64_t>(uint64_t[], const size_t);
+template size_t PHF::uniq<phf_string_t>(phf_string_t[], const size_t);
+#if !PHF_NO_LIBCXX
+template size_t PHF::uniq<std::string>(std::string[], const size_t);
+#endif
+
+PHF_PUBLIC size_t phf_uniq_uint32(uint32_t k[], const size_t n) {
+ return PHF::uniq(k, n);
+} /* phf_uniq_uint32() */
+
+PHF_PUBLIC size_t phf_uniq_uint64(uint64_t k[], const size_t n) {
+ return PHF::uniq(k, n);
+} /* phf_uniq_uint64() */
+
+PHF_PUBLIC size_t phf_uniq_string(phf_string_t k[], const size_t n) {
+ return PHF::uniq(k, n);
+} /* phf_uniq_string() */
+
+
+/*
+ * H A S H P R I M I T I V E S
+ *
+ * Universal hash based on MurmurHash3_x86_32. Variants for 32- and 64-bit
+ * integer keys, and string keys.
+ *
+ * We use a random seed to address the non-cryptographic-strength collision
+ * resistance of MurmurHash3. A stronger hash like SipHash is just too slow
+ * and unnecessary for my particular needs. For some environments a
+ * cryptographically stronger hash may be prudent.
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+static inline uint32_t phf_round32(uint32_t k1, uint32_t h1) {
+ k1 *= UINT32_C(0xcc9e2d51);
+ k1 = PHF_ROTL(k1, 15);
+ k1 *= UINT32_C(0x1b873593);
+
+ h1 ^= k1;
+ h1 = PHF_ROTL(h1, 13);
+ h1 = h1 * 5 + UINT32_C(0xe6546b64);
+
+ return h1;
+} /* phf_round32() */
+
+static inline uint32_t phf_round32(const unsigned char *p, size_t n, uint32_t h1) {
+ uint32_t k1;
+
+ while (n >= 4) {
+ k1 = (p[0] << 24)
+ | (p[1] << 16)
+ | (p[2] << 8)
+ | (p[3] << 0);
+
+ h1 = phf_round32(k1, h1);
+
+ p += 4;
+ n -= 4;
+ }
+
+ k1 = 0;
+
+ switch (n & 3) {
+ case 3:
+ k1 |= p[2] << 8;
+ /* FALLTHRU */
+ case 2:
+ k1 |= p[1] << 16;
+ /* FALLTHRU */
+ case 1:
+ k1 |= p[0] << 24;
+ h1 = phf_round32(k1, h1);
+ }
+
+ return h1;
+} /* phf_round32() */
+
+static inline uint32_t phf_round32(phf_string_t k, uint32_t h1) {
+ return phf_round32(reinterpret_cast<const unsigned char *>(k.p), k.n, h1);
+} /* phf_round32() */
+
+#if !PHF_NO_LIBCXX
+static inline uint32_t phf_round32(std::string k, uint32_t h1) {
+ return phf_round32(reinterpret_cast<const unsigned char *>(k.c_str()), k.length(), h1);
+} /* phf_round32() */
+#endif
+
+static inline uint32_t phf_mix32(uint32_t h1) {
+ h1 ^= h1 >> 16;
+ h1 *= UINT32_C(0x85ebca6b);
+ h1 ^= h1 >> 13;
+ h1 *= UINT32_C(0xc2b2ae35);
+ h1 ^= h1 >> 16;
+
+ return h1;
+} /* phf_mix32() */
+
+
+/*
+ * g(k) & f(d, k) S P E C I A L I Z A T I O N S
+ *
+ * For every key we first calculate g(k). Then for every group of collisions
+ * from g(k) we search for a displacement value d such that f(d, k) places
+ * each key into a unique hash slot.
+ *
+ * g() and f() are specialized for 32-bit, 64-bit, and string keys.
+ *
+ * g_mod_r() and f_mod_n() are specialized for the method of modular
+ * reduction--modular division or bitwise AND. bitwise AND is substantially
+ * faster than modular division, and more than makes up for any space
+ * inefficiency, particularly for small hash tables.
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+/* 32-bit, phf_string_t, and std::string keys */
+template<typename T>
+static inline uint32_t phf_g(T k, uint32_t seed) {
+ uint32_t h1 = seed;
+
+ h1 = phf_round32(k, h1);
+
+ return phf_mix32(h1);
+} /* phf_g() */
+
+template<typename T>
+static inline uint32_t phf_f(uint32_t d, T k, uint32_t seed) {
+ uint32_t h1 = seed;
+
+ h1 = phf_round32(d, h1);
+ h1 = phf_round32(k, h1);
+
+ return phf_mix32(h1);
+} /* phf_f() */
+
+
+/* 64-bit keys */
+static inline uint32_t phf_g(uint64_t k, uint32_t seed) {
+ uint32_t h1 = seed;
+
+ h1 = phf_round32(k, h1);
+ h1 = phf_round32(k >> 32, h1);
+
+ return phf_mix32(h1);
+} /* phf_g() */
+
+static inline uint32_t phf_f(uint32_t d, uint64_t k, uint32_t seed) {
+ uint32_t h1 = seed;
+
+ h1 = phf_round32(d, h1);
+ h1 = phf_round32(static_cast<uint32_t>(k), h1);
+ h1 = phf_round32(static_cast<uint32_t>(k >> 32), h1);
+
+ return phf_mix32(h1);
+} /* phf_f() */
+
+
+/* g() and f() which parameterize modular reduction */
+template<bool nodiv, typename T>
+static inline uint32_t phf_g_mod_r(T k, uint32_t seed, size_t r) {
+ return (nodiv)? (phf_g(k, seed) & (r - 1)) : (phf_g(k, seed) % r);
+} /* phf_g_mod_r() */
+
+template<bool nodiv, typename T>
+static inline uint32_t phf_f_mod_m(uint32_t d, T k, uint32_t seed, size_t m) {
+ return (nodiv)? (phf_f(d, k, seed) & (m - 1)) : (phf_f(d, k, seed) % m);
+} /* phf_f_mod_m() */
+
+
+/*
+ * B U C K E T S O R T I N G I N T E R F A C E S
+ *
+ * For every key [0..n) we calculate g(k) % r, where 0 < r <= n, and
+ * associate it with a bucket [0..r). We then sort the buckets in decreasing
+ * order according to the number of keys. The sorting is required for both
+ * optimal time complexity when calculating f(d, k) (less contention) and
+ * optimal space complexity (smaller d).
+ *
+ * The actual sorting is done in the core routine. The buckets are organized
+ * and sorted as a 1-dimensional array to minimize run-time memory (less
+ * data structure overhead) and improve data locality (less pointer
+ * indirection). The following section merely implements a templated
+ * bucket-key structure and the comparison routine passed to qsort(3).
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+static bool operator==(const phf_string_t &a, const phf_string_t &b) {
+ return a.n == b.n && 0 == memcmp(a.p, b.p, a.n);
+}
+
+template<typename T>
+struct phf_key {
+ T k;
+ phf_hash_t g; /* result of g(k) % r */
+ size_t *n; /* number of keys in bucket g */
+}; /* struct phf_key */
+
+template<typename T>
+static int phf_keycmp(const phf_key<T> *a, const phf_key<T> *b) {
+ if (*(a->n) > *(b->n))
+ return -1;
+ if (*(a->n) < *(b->n))
+ return 1;
+ if (a->g > b->g)
+ return -1;
+ if (a->g < b->g)
+ return 1;
+
+ /* duplicate key? */
+ if (a->k == b->k && a != b) {
+ assert(!(a->k == b->k));
+ abort(); /* if NDEBUG defined */
+ }
+
+ return 0;
+} /* phf_keycmp() */
+
+
+/*
+ * C O R E F U N C T I O N G E N E R A T O R
+ *
+ * The entire algorithm is contained in PHF:init. Everything else in this
+ * source file is either a simple utility routine used by PHF:init, or an
+ * interface to PHF:init or the generated function state.
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+template<typename key_t, bool nodiv>
+PHF_PUBLIC int PHF::init(struct phf *phf, const key_t k[], const size_t n, const size_t l, const size_t a, const phf_seed_t seed) {
+ size_t n1 = PHF_MAX(n, 1); /* for computations that require n > 0 */
+ size_t l1 = PHF_MAX(l, 1);
+ size_t a1 = PHF_MAX(PHF_MIN(a, 100), 1);
+ size_t r; /* number of buckets */
+ size_t m; /* size of output array */
+ phf_key<key_t> *B_k = NULL; /* linear bucket-slot array */
+ size_t *B_z = NULL; /* number of slots per bucket */
+ phf_key<key_t> *B_p, *B_pe;
+ phf_bits_t *T = NULL; /* bitmap to track index occupancy */
+ phf_bits_t *T_b; /* per-bucket working bitmap */
+ size_t T_n;
+ uint32_t *g = NULL; /* displacement map */
+ uint32_t d_max = 0; /* maximum displacement value */
+ int error;
+
+ if ((phf->nodiv = nodiv)) {
+ /* round to power-of-2 so we can use bit masks instead of modulo division */
+ r = phf_powerup(n1 / PHF_MIN(l1, n1));
+ m = phf_powerup((n1 * 100) / a1);
+ } else {
+ r = phf_primeup(PHF_HOWMANY(n1, l1));
+ /* XXX: should we bother rounding m to prime number for small n? */
+ m = phf_primeup((n1 * 100) / a1);
+ }
+
+ if (r == 0 || m == 0)
+ return ERANGE;
+
+ if (!(B_k = static_cast<phf_key<key_t> *>(calloc(n1, sizeof *B_k))))
+ goto syerr;
+ if (!(B_z = static_cast<size_t *>(calloc(r, sizeof *B_z))))
+ goto syerr;
+
+ for (size_t i = 0; i < n; i++) {
+ phf_hash_t g = phf_g_mod_r<nodiv>(k[i], seed, r);
+
+ B_k[i].k = k[i];
+ B_k[i].g = g;
+ B_k[i].n = &B_z[g];
+ ++*B_k[i].n;
+ }
+
+ qsort(B_k, n1, sizeof *B_k, reinterpret_cast<int(*)(const void *, const void *)>(&phf_keycmp<key_t>));
+
+ T_n = PHF_HOWMANY(m, PHF_BITS(*T));
+ if (!(T = static_cast<phf_bits_t *>(calloc(T_n * 2, sizeof *T))))
+ goto syerr;
+ T_b = &T[T_n]; /* share single allocation */
+
+ /*
+ * FIXME: T_b[] is unnecessary. We could clear T[] the same way we
+ * clear T_b[]. In fact, at the end of generation T_b[] is identical
+ * to T[] because we don't clear T_b[] on success.
+ *
+ * We just need to tweak the current reset logic to stop before the
+ * key that failed, and then we can elide the commit to T[] at the
+ * end of the outer loop.
+ */
+
+ if (!(g = static_cast<uint32_t *>(calloc(r, sizeof *g))))
+ goto syerr;
+
+ B_p = B_k;
+ B_pe = &B_k[n];
+
+ for (; B_p < B_pe && *B_p->n > 0; B_p += *B_p->n) {
+ phf_key<key_t> *Bi_p, *Bi_pe;
+ size_t d = 0;
+ uint32_t f;
+retry:
+ d++;
+ Bi_p = B_p;
+ Bi_pe = B_p + *B_p->n;
+
+ for (; Bi_p < Bi_pe; Bi_p++) {
+ f = phf_f_mod_m<nodiv>(d, Bi_p->k, seed, m);
+
+ if (phf_isset(T, f) || phf_isset(T_b, f)) {
+ /* reset T_b[] */
+ for (Bi_p = B_p; Bi_p < Bi_pe; Bi_p++) {
+ f = phf_f_mod_m<nodiv>(d, Bi_p->k, seed, m);
+ phf_clrbit(T_b, f);
+ }
+
+ goto retry;
+ } else {
+ phf_setbit(T_b, f);
+ }
+ }
+
+ /* commit to T[] */
+ for (Bi_p = B_p; Bi_p < Bi_pe; Bi_p++) {
+ f = phf_f_mod_m<nodiv>(d, Bi_p->k, seed, m);
+ phf_setbit(T, f);
+ }
+
+ /* commit to g[] */
+ g[B_p->g] = d;
+ d_max = PHF_MAX(d, d_max);
+ }
+
+ phf->seed = seed;
+ phf->r = r;
+ phf->m = m;
+
+ phf->g = g;
+ g = NULL;
+
+ phf->d_max = d_max;
+ phf->g_op = (nodiv)? phf::PHF_G_UINT32_BAND_R : phf::PHF_G_UINT32_MOD_R;
+ phf->g_jmp = NULL;
+
+ error = 0;
+
+ goto clean;
+syerr:
+ error = errno;
+clean:
+ free(g);
+ free(T);
+ free(B_z);
+ free(B_k);
+
+ return error;
+} /* PHF::init() */
+
+
+/*
+ * D I S P L A C E M E N T M A P C O M P A C T I O N
+ *
+ * By default the displacement map is an array of uint32_t. This routine
+ * compacts the map by using the smallest primitive type that will fit the
+ * largest displacement value.
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+template<typename dst_t, typename src_t>
+static inline void phf_memmove(dst_t *dst, src_t *src, size_t n) {
+ for (size_t i = 0; i < n; i++) {
+ dst_t tmp = src[i];
+ dst[i] = tmp;
+ }
+} /* phf_memmove() */
+
+PHF_PUBLIC void PHF::compact(struct phf *phf) {
+ size_t size = 0;
+ void *tmp;
+
+ switch (phf->g_op) {
+ case phf::PHF_G_UINT32_MOD_R:
+ case phf::PHF_G_UINT32_BAND_R:
+ break;
+ default:
+ return; /* already compacted */
+ }
+
+ if (phf->d_max <= 255) {
+ phf_memmove(reinterpret_cast<uint8_t *>(phf->g), reinterpret_cast<uint32_t *>(phf->g), phf->r);
+ phf->g_op = (phf->nodiv)? phf::PHF_G_UINT8_BAND_R : phf::PHF_G_UINT8_MOD_R;
+ size = sizeof (uint8_t);
+ } else if (phf->d_max <= 65535) {
+ phf_memmove(reinterpret_cast<uint16_t *>(phf->g), reinterpret_cast<uint32_t *>(phf->g), phf->r);
+ phf->g_op = (phf->nodiv)? phf::PHF_G_UINT16_BAND_R : phf::PHF_G_UINT16_MOD_R;
+ size = sizeof (uint16_t);
+ } else {
+ return; /* nothing to compact */
+ }
+
+ /* simply keep old array if realloc fails */
+ if ((tmp = realloc(phf->g, phf->r * size)))
+ phf->g = static_cast<uint32_t *>(tmp);
+} /* PHF::compact() */
+
+
+/*
+ * F U N C T I O N G E N E R A T O R & S T A T E I N T E R F A C E S
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+template int PHF::init<uint32_t, true>(struct phf *, const uint32_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+template int PHF::init<uint64_t, true>(struct phf *, const uint64_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+template int PHF::init<phf_string_t, true>(struct phf *, const phf_string_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+#if !PHF_NO_LIBCXX
+template int PHF::init<std::string, true>(struct phf *, const std::string[], const size_t, const size_t, const size_t, const phf_seed_t);
+#endif
+
+template int PHF::init<uint32_t, false>(struct phf *, const uint32_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+template int PHF::init<uint64_t, false>(struct phf *, const uint64_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+template int PHF::init<phf_string_t, false>(struct phf *, const phf_string_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+#if !PHF_NO_LIBCXX
+template int PHF::init<std::string, false>(struct phf *, const std::string[], const size_t, const size_t, const size_t, const phf_seed_t);
+#endif
+
+template<bool nodiv, typename map_t, typename key_t>
+static inline phf_hash_t phf_hash_(map_t *g, key_t k, uint32_t seed, size_t r, size_t m) {
+ if (nodiv) {
+ uint32_t d = g[phf_g(k, seed) & (r - 1)];
+
+ return phf_f(d, k, seed) & (m - 1);
+ } else {
+ uint32_t d = g[phf_g(k, seed) % r];
+
+ return phf_f(d, k, seed) % m;
+ }
+} /* phf_hash_() */
+
+template<typename T>
+PHF_PUBLIC phf_hash_t PHF::hash(struct phf *phf, T k) {
+#if PHF_HAVE_COMPUTED_GOTOS && !PHF_NO_COMPUTED_GOTOS
+ static const void *const jmp[] = {
+ NULL,
+ &&uint8_mod_r, &&uint8_band_r,
+ &&uint16_mod_r, &&uint16_band_r,
+ &&uint32_mod_r, &&uint32_band_r,
+ };
+
+ goto *((phf->g_jmp)? phf->g_jmp : (phf->g_jmp = jmp[phf->g_op]));
+
+ uint8_mod_r:
+ return phf_hash_<false>(reinterpret_cast<uint8_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ uint8_band_r:
+ return phf_hash_<true>(reinterpret_cast<uint8_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ uint16_mod_r:
+ return phf_hash_<false>(reinterpret_cast<uint16_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ uint16_band_r:
+ return phf_hash_<true>(reinterpret_cast<uint16_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ uint32_mod_r:
+ return phf_hash_<false>(reinterpret_cast<uint32_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ uint32_band_r:
+ return phf_hash_<true>(reinterpret_cast<uint32_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+#else
+ switch (phf->g_op) {
+ case phf::PHF_G_UINT8_MOD_R:
+ return phf_hash_<false>(reinterpret_cast<uint8_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ case phf::PHF_G_UINT8_BAND_R:
+ return phf_hash_<true>(reinterpret_cast<uint8_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ case phf::PHF_G_UINT16_MOD_R:
+ return phf_hash_<false>(reinterpret_cast<uint16_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ case phf::PHF_G_UINT16_BAND_R:
+ return phf_hash_<true>(reinterpret_cast<uint16_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ case phf::PHF_G_UINT32_MOD_R:
+ return phf_hash_<false>(reinterpret_cast<uint32_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ case phf::PHF_G_UINT32_BAND_R:
+ return phf_hash_<true>(reinterpret_cast<uint32_t *>(phf->g), k, phf->seed, phf->r, phf->m);
+ default:
+ abort();
+ return 0;
+ }
+#endif
+} /* PHF::hash() */
+
+template phf_hash_t PHF::hash<uint32_t>(struct phf *, uint32_t);
+template phf_hash_t PHF::hash<uint64_t>(struct phf *, uint64_t);
+template phf_hash_t PHF::hash<phf_string_t>(struct phf *, phf_string_t);
+#if !PHF_NO_LIBCXX
+template phf_hash_t PHF::hash<std::string>(struct phf *, std::string);
+#endif
+
+PHF_PUBLIC void PHF::destroy(struct phf *phf) {
+ free(phf->g);
+ phf->g = NULL;
+} /* PHF::destroy() */
+
+PHF_PUBLIC int phf_init_uint32(struct phf *phf, const uint32_t *k, const size_t n, const size_t lambda, const size_t alpha, const phf_seed_t seed, const bool nodiv) {
+ if (nodiv)
+ return PHF::init<uint32_t, true>(phf, k, n, lambda, alpha, seed);
+ else
+ return PHF::init<uint32_t, false>(phf, k, n, lambda, alpha, seed);
+} /* phf_init_uint32() */
+
+PHF_PUBLIC int phf_init_uint64(struct phf *phf, const uint64_t *k, const size_t n, const size_t lambda, const size_t alpha, const phf_seed_t seed, const bool nodiv) {
+ if (nodiv)
+ return PHF::init<uint64_t, true>(phf, k, n, lambda, alpha, seed);
+ else
+ return PHF::init<uint64_t, false>(phf, k, n, lambda, alpha, seed);
+} /* phf_init_uint64() */
+
+PHF_PUBLIC int phf_init_string(struct phf *phf, const phf_string_t *k, const size_t n, const size_t lambda, const size_t alpha, const phf_seed_t seed, const bool nodiv) {
+ if (nodiv)
+ return PHF::init<phf_string_t, true>(phf, k, n, lambda, alpha, seed);
+ else
+ return PHF::init<phf_string_t, false>(phf, k, n, lambda, alpha, seed);
+} /* phf_init_string() */
+
+PHF_PUBLIC void phf_compact(struct phf *phf) {
+ PHF::compact(phf);
+} /* phf_compact() */
+
+PHF_PUBLIC phf_hash_t phf_hash_uint32(struct phf *phf, const uint32_t k) {
+ return PHF::hash(phf, k);
+} /* phf_hash_uint32() */
+
+PHF_PUBLIC phf_hash_t phf_hash_uint64(struct phf *phf, const uint64_t k) {
+ return PHF::hash(phf, k);
+} /* phf_hash_uint64() */
+
+PHF_PUBLIC phf_hash_t phf_hash_string(struct phf *phf, const phf_string_t k) {
+ return PHF::hash(phf, k);
+} /* phf_hash_string() */
+
+PHF_PUBLIC void phf_destroy(struct phf *phf) {
+ PHF::destroy(phf);
+} /* phf_destroy() */
+
+
+#if PHF_LUALIB
+#include <time.h> /* time(2) */
+
+#include <lua.hpp>
+
+
+#if LUA_VERSION_NUM < 502
+static int lua_absindex(lua_State *L, int idx) {
+ return (idx > 0 || idx <= LUA_REGISTRYINDEX)? idx : lua_gettop(L) + idx + 1;
+} /* lua_absindex() */
+
+#define lua_rawlen(t, index) lua_objlen(t, index)
+#endif
+
+
+struct phfctx {
+ int (*hash)(struct phf *, lua_State *, int index);
+ struct phf ctx;
+}; /* struct phfctx */
+
+
+static int phf_hash_uint32(struct phf *phf, lua_State *L, int index) {
+ uint32_t k = static_cast<uint32_t>(luaL_checkinteger(L, index));
+
+ lua_pushinteger(L, static_cast<lua_Integer>(PHF::hash(phf, k) + 1));
+
+ return 1;
+} /* phf_hash_uint32() */
+
+static int phf_hash_uint64(struct phf *phf, lua_State *L, int index) {
+ uint64_t k = static_cast<uint64_t>(luaL_checkinteger(L, index));
+
+ lua_pushinteger(L, static_cast<lua_Integer>(PHF::hash(phf, k) + 1));
+
+ return 1;
+} /* phf_hash_uint64() */
+
+static int phf_hash_string(struct phf *phf, lua_State *L, int index) {
+ phf_string_t k;
+
+ k.p = const_cast<char *>(luaL_checklstring(L, index, &k.n));
+
+ lua_pushinteger(L, static_cast<lua_Integer>(PHF::hash(phf, k) + 1));
+
+ return 1;
+} /* phf_hash_string() */
+
+static phf_seed_t phf_seed(lua_State *L) {
+ return phf_g(static_cast<uint32_t>(reinterpret_cast<intptr_t>(L)), static_cast<uint32_t>(time(NULL)));
+} /* phf_seed() */
+
+template<typename T>
+static phf_error_t phf_reallocarray(T **p, size_t count) {
+ T *tmp;
+
+ if (SIZE_MAX / sizeof **p < count)
+ return ENOMEM;
+
+ if (!(tmp = static_cast<T*>(realloc(*p, count * sizeof **p))))
+ return errno;
+
+ *p = tmp;
+
+ return 0;
+} /* phf_reallocarray() */
+
+static phf_error_t phf_tokey(lua_State *L, int index, uint32_t *k) {
+ if (LUA_TNUMBER != lua_type(L, index))
+ return EINVAL;
+
+#if LUA_VERSION_NUM > 502
+ lua_Integer v;
+
+ v = static_cast<lua_Integer>(lua_tointeger(L, index));
+
+ if (v > UINT32_MAX)
+ return ERANGE;
+#else
+ lua_Number v;
+
+ v = static_cast<lua_Number>(lua_tonumber(L, index));
+
+ if (v > UINT32_MAX)
+ return ERANGE;
+#endif
+ *k = static_cast<uint32_t>(v);
+
+ return 0;
+} /* phf_tokey() */
+
+static phf_error_t phf_tokey(lua_State *L, int index, uint64_t *k) {
+ if (LUA_TNUMBER != lua_type(L, index))
+ return EINVAL;
+
+#if LUA_VERSION_NUM > 502
+ lua_Integer v;
+
+ v = static_cast<lua_Integer>(lua_tointeger(L, index));
+#else
+ lua_Number v;
+
+ v = static_cast<lua_Number>(lua_tonumber(L, index));
+#endif
+ *k = static_cast<uint64_t>(v);
+
+ return 0;
+} /* phf_tokey() */
+
+static phf_error_t phf_tokey(lua_State *L, int index, phf_string_t *k) {
+ if (LUA_TSTRING != lua_type(L, index))
+ return EINVAL;
+
+ k->p = const_cast<char *>(lua_tolstring(L, index, &k->n));
+
+ return 0;
+} /* phf_tokey() */
+
+template<typename T>
+static phf_error_t phf_addkeys(lua_State *L, int index, T **keys, int *n) {
+ int i, error = 0;
+ T *p;
+
+ *n = lua_rawlen(L, index);
+
+ if ((error = phf_reallocarray(keys, *n)))
+ return error;
+
+ p = *keys;
+
+ for (i = 1; i <= *n; i++) {
+ lua_rawgeti(L, index, i);
+
+ error = phf_tokey(L, -1, p++);
+
+ lua_pop(L, 1);
+
+ if (error)
+ return error;
+
+ }
+
+ *n = PHF::uniq(*keys, *n);
+
+ return 0;
+} /* phf_addkeys() */
+
+static int phf_new(lua_State *L) {
+ size_t l = static_cast<size_t>(luaL_optinteger(L, 2, 4));
+ size_t a = static_cast<size_t>(luaL_optinteger(L, 3, 80));
+ phf_seed_t seed = (lua_isnoneornil(L, 4))? phf_seed(L) : static_cast<phf_seed_t>(luaL_checkinteger(L, 4));
+ bool nodiv = static_cast<bool>(lua_toboolean(L, 5));
+ void *keys = NULL;
+ struct phfctx *phf;
+ int n, error;
+
+ lua_settop(L, 5);
+ luaL_checktype(L, 1, LUA_TTABLE);
+
+ phf = static_cast<struct phfctx *>(lua_newuserdata(L, sizeof *phf));
+ memset(phf, 0, sizeof *phf);
+
+ luaL_getmetatable(L, "PHF*");
+ lua_setmetatable(L, -2);
+
+ switch ((error = phf_addkeys(L, 1, reinterpret_cast<uint32_t **>(&keys), &n))) {
+ case 0:
+ break;
+ case ERANGE:
+ goto uint64;
+ case EINVAL:
+ goto string;
+ default:
+ goto error;
+ }
+
+ if (n == 0)
+ goto empty;
+
+ if ((error = phf_init_uint32(&phf->ctx, reinterpret_cast<uint32_t *>(keys), n, l, a, seed, nodiv)))
+ goto error;
+
+ phf->hash = &phf_hash_uint32;
+
+ goto done;
+uint64:
+ switch ((error = phf_addkeys(L, 1, reinterpret_cast<uint64_t **>(&keys), &n))) {
+ case 0:
+ break;
+ case EINVAL:
+ goto string;
+ default:
+ goto error;
+ }
+
+ if (n == 0)
+ goto empty;
+
+ if ((error = phf_init_uint64(&phf->ctx, reinterpret_cast<uint64_t *>(keys), n, l, a, seed, nodiv)))
+ goto error;
+
+ phf->hash = &phf_hash_uint64;
+
+ goto done;
+string:
+ if ((error = phf_addkeys(L, 1, reinterpret_cast<phf_string_t **>(&keys), &n)))
+ goto error;
+
+ if (n == 0)
+ goto empty;
+
+ if ((error = phf_init_string(&phf->ctx, reinterpret_cast<phf_string_t *>(keys), n, l, a, seed, nodiv)))
+ goto error;
+
+ phf->hash = &phf_hash_string;
+
+ goto done;
+done:
+ free(keys);
+
+ PHF::compact(&phf->ctx);
+
+ return 1;
+empty:
+ free(keys);
+
+ lua_pushstring(L, "empty key set");
+
+ return lua_error(L);
+error:
+ free(keys);
+
+ lua_pushstring(L, strerror(error));
+
+ return lua_error(L);
+} /* phf_new() */
+
+static int phf_r(lua_State *L) {
+ struct phfctx *phf = static_cast<struct phfctx *>(luaL_checkudata(L, 1, "PHF*"));
+
+ lua_pushinteger(L, static_cast<lua_Integer>(phf->ctx.r));
+
+ return 1;
+} /* phf_r() */
+
+static int phf_m(lua_State *L) {
+ struct phfctx *phf = static_cast<struct phfctx *>(luaL_checkudata(L, 1, "PHF*"));
+
+ lua_pushinteger(L, static_cast<lua_Integer>(phf->ctx.m));
+
+ return 1;
+} /* phf_m() */
+
+static int (phf_hash)(lua_State *L) {
+ struct phfctx *phf = static_cast<struct phfctx *>(luaL_checkudata(L, 1, "PHF*"));
+
+ return phf->hash(&phf->ctx, L, 2);
+} /* phf_hash() */
+
+static int phf__gc(lua_State *L) {
+ struct phfctx *phf = (struct phfctx *)luaL_checkudata(L, 1, "PHF*");
+
+ phf_destroy(&phf->ctx);
+
+ return 0;
+} /* phf__gc() */
+
+static const luaL_Reg phf_methods[] = {
+ { "hash", &(phf_hash) },
+ { "r", &phf_r },
+ { "m", &phf_m },
+ { NULL, NULL },
+}; /* phf_methods[] */
+
+static const luaL_Reg phf_metatable[] = {
+ { "__call", &phf_hash },
+ { "__gc", &phf__gc },
+ { NULL, NULL },
+}; /* phf_metatable[] */
+
+static const luaL_Reg phf_globals[] = {
+ { "new", &phf_new },
+ { NULL, NULL },
+}; /* phf_globals[] */
+
+static void phf_register(lua_State *L, const luaL_Reg *l) {
+#if LUA_VERSION_NUM >= 502
+ luaL_setfuncs(L, l, 0);
+#else
+ luaL_register(L, NULL, l);
+#endif
+} /* phf_register() */
+
+extern "C" int luaopen_phf(lua_State *L) {
+ if (luaL_newmetatable(L, "PHF*")) {
+ phf_register(L, phf_metatable);
+ lua_newtable(L);
+ phf_register(L, phf_methods);
+ lua_setfield(L, -2, "__index");
+ }
+
+ lua_pop(L, 1);
+
+ lua_newtable(L);
+ phf_register(L, phf_globals);
+
+ return 1;
+} /* luaopen_phf() */
+
+#endif /* PHF_LUALIB */
+
+
+#if PHF_MAIN
+
+#include <stdlib.h> /* arc4random(3) free(3) realloc(3) */
+#include <stdio.h> /* fclose(3) fopen(3) fprintf(3) fread(3) freopen(3) printf(3) */
+#include <time.h> /* CLOCKS_PER_SEC clock(3) */
+#include <string.h> /* strcmp(3) */
+#include <sys/param.h> /* BSD */
+#include <unistd.h> /* getopt(3) */
+#include <strings.h> /* ffsl(3) */
+#include <err.h> /* err(3) errx(3) warnx(3) */
+
+
+static uint32_t randomseed(void) {
+#if defined BSD /* catchall for modern BSDs, which all have arc4random */
+ return arc4random();
+#else
+ FILE *fp;
+ uint32_t seed;
+
+ if (!(fp = fopen("/dev/urandom", "r")))
+ err(1, "/dev/urandom");
+
+ if (1 != fread(&seed, sizeof seed, 1, fp))
+ err(1, "/dev/urandom");
+
+ fclose(fp);
+
+ return seed;
+#endif
+} /* randomseed() */
+
+
+template<typename T>
+static void pushkey(T **k, size_t *n, size_t *z, T kn) {
+ if (!(*n < *z)) {
+ size_t z1 = PHF_MAX(*z, 1) * 2;
+ T *p;
+
+ if (z1 < *z || (SIZE_MAX / sizeof **k) < z1)
+ errx(1, "addkey: %s", strerror(ERANGE));
+
+ if (!(p = (T *)realloc(*k, z1 * sizeof **k)))
+ err(1, "realloc");
+
+ *k = p;
+ *z = z1;
+ }
+
+ (*k)[(*n)++] = kn;
+} /* pushkey() */
+
+
+template<typename T>
+static void addkey(T **k, size_t *n, size_t *z, const char *src) {
+ pushkey(k, n, z, static_cast<T>(strtoull(src, NULL, 0)));
+} /* addkey() */
+
+static void addkey(phf_string_t **k, size_t *n, size_t *z, char *src, size_t len) {
+ phf_string_t kn = { (void *)src, len };
+ pushkey(k, n, z, kn);
+} /* addkey() */
+
+static void addkey(phf_string_t **k, size_t *n, size_t *z, char *src) {
+ addkey(k, n, z, src, strlen(src));
+} /* addkey() */
+
+#if !PHF_NO_LIBCXX
+static void addkey(std::string **k, size_t *n, size_t *z, char *src, size_t len) {
+ pushkey(k, n, z, std::string(src, len));
+} /* addkey() */
+
+static void addkey(std::string **k, size_t *n, size_t *z, char *src) {
+ addkey(k, n, z, src, strlen(src));
+} /* addkey() */
+#endif
+
+template<typename T>
+static void addkeys(T **k, size_t *n, size_t *z, char **src, int count) {
+ for (int i = 0; i < count; i++)
+ addkey(k, n, z, src[i]);
+} /* addkey() */
+
+template<typename T>
+static void addkeys(T **k, size_t *n, size_t *z, FILE *fp, char **data) {
+ char *ln = NULL;
+ size_t lz = 0;
+ ssize_t len;
+
+ (void)data;
+
+ while ((len = getline(&ln, &lz, fp)) > 0) {
+ if (--len > 0) {
+ if (ln[len] == '\n')
+ ln[len] = '\0';
+ addkey(k, n, z, ln);
+ }
+ }
+
+ free(ln);
+} /* addkeys() */
+
+/* slurp file into a single string and take pointers */
+static void addkeys(phf_string_t **k, size_t *n, size_t *z, FILE *fp, char **data) {
+ size_t p = 0, pe = 0, tp;
+ char buf[BUFSIZ], *tmp;
+ size_t buflen;
+
+ while ((buflen = fread(buf, 1, sizeof buf, fp))) {
+ if (buflen > (pe - p)) {
+ if (~buflen < pe || 0 == (pe = phf_powerup(buflen + pe)))
+ errx(1, "realloc: %s", strerror(ERANGE));
+ if (!(tmp = (char *)realloc(*data, pe)))
+ err(1, "realloc");
+ *data = tmp;
+ }
+
+ memcpy(*data + p, buf, buflen);
+ p += buflen;
+ }
+
+ for (p = 0; p < pe; ) {
+ while (p < pe && (*data)[p] == '\n')
+ p++;
+
+ tp = p;
+
+ while (p < pe && (*data)[p] != '\n')
+ p++;
+
+ if (p > tp)
+ addkey(k, n, z, &(*data)[tp], (size_t)(p - tp));
+ }
+} /* addkeys() */
+
+
+static inline void printkey(phf_string_t &k, phf_hash_t hash) {
+ printf("%-32.*s : %" PHF_PRIuHASH "\n", (int)k.n, (char *)k.p, hash);
+} /* printkey() */
+
+#if !PHF_NO_LIBCXX
+static inline void printkey(std::string &k, phf_hash_t hash) {
+ printf("%-32s : %" PHF_PRIuHASH "\n", k.c_str(), hash);
+} /* printkey() */
+#endif
+
+template<typename T>
+static inline void printkey(T k, phf_hash_t hash) {
+ printf("%llu : %" PHF_PRIuHASH "\n", (unsigned long long)k, hash);
+} /* printkey() */
+
+template<typename T, bool nodiv>
+static inline void exec(int argc, char **argv, size_t lambda, size_t alpha, size_t seed, bool verbose, bool noprint) {
+ T *k = NULL;
+ size_t n = 0, z = 0;
+ char *data = NULL;
+ struct phf phf;
+ clock_t begin, end;
+
+ addkeys(&k, &n, &z, argv, argc);
+ addkeys(&k, &n, &z, stdin, &data);
+
+ size_t m = PHF::uniq(k, n);
+ if (verbose)
+ warnx("loaded %zu keys (%zu duplicates)", m, (n - m));
+ n = m;
+
+ begin = clock();
+ PHF::init<T, nodiv>(&phf, k, n, lambda, alpha, seed);
+ end = clock();
+
+
+ if (verbose) {
+ warnx("found perfect hash for %zu keys in %fs", n, (double)(end - begin) / CLOCKS_PER_SEC);
+
+ begin = clock();
+ PHF::compact(&phf);
+ end = clock();
+ warnx("compacted displacement map in %fs", (double)(end - begin) / CLOCKS_PER_SEC);
+
+ int d_bits = ffsl((long)phf_powerup(phf.d_max));
+ double k_bits = ((double)phf.r * d_bits) / n;
+ double g_load = (double)n / phf.r;
+ warnx("r:%zu m:%zu d_max:%zu d_bits:%d k_bits:%.2f g_load:%.2f", phf.r, phf.m, phf.d_max, d_bits, k_bits, g_load);
+
+ size_t x = 0;
+ begin = clock();
+ for (size_t i = 0; i < n; i++) {
+ x += PHF::hash(&phf, k[i]);
+ }
+ end = clock();
+ warnx("hashed %zu keys in %fs (x:%zu)", n, (double)(end - begin) / CLOCKS_PER_SEC, x);
+ }
+
+ if (!noprint) {
+ for (size_t i = 0; i < n; i++) {
+ printkey(k[i], PHF::hash(&phf, k[i]));
+ }
+ }
+
+ phf_destroy(&phf);
+ free(data);
+ free(k);
+} /* exec() */
+
+static void printprimes(int argc, char **argv) {
+ intmax_t n = 0, m = UINT32_MAX;
+ char *end;
+
+ if (argc > 0) {
+ n = strtoimax(argv[0], &end, 0);
+ if (end == argv[0] || *end != '\0' || n < 0 || n > UINT32_MAX)
+ errx(1, "%s: invalid number", argv[0]);
+ n = PHF_MAX(n, 2);
+ }
+
+ if (argc > 1) {
+ m = strtoimax(argv[1], &end, 0);
+ if (end == argv[1] || *end != '\0' || m < n || m > UINT32_MAX)
+ errx(1, "%s: invalid number", argv[1]);
+ }
+
+ for (; n <= m; n++) {
+ if (phf_isprime(n))
+ printf("%" PRIdMAX "\n", n);
+ }
+} /* printprimes() */
+
+int main(int argc, char **argv) {
+ const char *path = "/dev/null";
+ size_t lambda = 4;
+ size_t alpha = 80;
+ uint32_t seed = randomseed();
+ bool verbose = 0;
+ bool noprint = 0;
+ bool nodiv = 0;
+ enum {
+ PHF_UINT32,
+ PHF_UINT64,
+ PHF_STRING,
+#if !PHF_NO_LIBCXX
+ PHF_STD_STRING
+#endif
+ } type = PHF_UINT32;
+ bool primes = 0;
+ extern char *optarg;
+ extern int optind;
+ int optc;
+
+ while (-1 != (optc = getopt(argc, argv, "f:l:a:s:2t:nvph"))) {
+ switch (optc) {
+ case 'f':
+ path = optarg;
+ break;
+ case 'l':
+ lambda = strtoul(optarg, NULL, 0);
+ break;
+ case 'a':
+ alpha = strtoul(optarg, NULL, 0);
+ break;
+ case 's':
+ seed = strtoul(optarg, NULL, 0);
+ break;
+ case '2':
+ nodiv = 1;
+ break;
+ case 't':
+ if (!strcmp(optarg, "uint32")) {
+ type = PHF_UINT32;
+ } else if (!strcmp(optarg, "uint64")) {
+ type = PHF_UINT64;
+ } else if (!strcmp(optarg, "string")) {
+ type = PHF_STRING;
+#if !PHF_NO_LIBCXX
+ } else if (!strcmp(optarg, "std::string")) {
+ type = PHF_STD_STRING;
+#endif
+ } else {
+ errx(1, "%s: invalid key type", optarg);
+ }
+
+ break;
+ case 'n':
+ noprint = 1;
+ break;
+ case 'v':
+ verbose = 1;
+ break;
+ case 'p':
+ primes = 1;
+ break;
+ case 'h':
+ /* FALL THROUGH */
+ default:
+ fprintf(optc == 'h'? stdout : stderr,
+ "%s [-f:l:a:s:t:2nvph] [key [...]]\n"
+ " -f PATH read keys from PATH (- for stdin)\n"
+ " -l NUM number of keys per displacement map bucket (reported as g_load)\n"
+ " -a PCT hash table load factor (1%% - 100%%)\n"
+ " -s SEED random seed\n"
+ " -t TYPE parse and hash keys as uint32, uint64, or string\n"
+ " -2 avoid modular division by rounding r and m to power of 2\n"
+ " -n do not print key-hash pairs\n"
+ " -v report hashing status\n"
+ " -p operate like primes(3) utility\n"
+ " -h print usage message\n"
+ "\n"
+ "Report bugs to <william@25thandClement.com>\n",
+ argv[0]
+ );
+
+ return optc == 'h'? 0 : 1;
+ }
+ }
+
+ argc -= optind;
+ argv += optind;
+
+ if (primes)
+ return printprimes(argc, argv), 0;
+
+ if (strcmp(path, "-") && !freopen(path, "r", stdin))
+ err(1, "%s", path);
+
+ switch (type) {
+ case PHF_UINT32:
+ if (nodiv)
+ exec<uint32_t, true>(argc, argv, lambda, alpha, seed, verbose, noprint);
+ else
+ exec<uint32_t, false>(argc, argv, lambda, alpha, seed, verbose, noprint);
+ break;
+ case PHF_UINT64:
+ if (nodiv)
+ exec<uint64_t, true>(argc, argv, lambda, alpha, seed, verbose, noprint);
+ else
+ exec<uint64_t, false>(argc, argv, lambda, alpha, seed, verbose, noprint);
+ break;
+ case PHF_STRING:
+ if (nodiv)
+ exec<phf_string_t, true>(argc, argv, lambda, alpha, seed, verbose, noprint);
+ else
+ exec<phf_string_t, false>(argc, argv, lambda, alpha, seed, verbose, noprint);
+ break;
+#if !PHF_NO_LIBCXX
+ case PHF_STD_STRING:
+ if (nodiv)
+ exec<std::string, true>(argc, argv, lambda, alpha, seed, verbose, noprint);
+ else
+ exec<std::string, false>(argc, argv, lambda, alpha, seed, verbose, noprint);
+ break;
+#endif
+ }
+
+ return 0;
+} /* main() */
+
+#endif /* PHF_MAIN */
diff --git a/src/3rd_party/phf/phf.h b/src/3rd_party/phf/phf.h
new file mode 100644
index 00000000..f821dfa4
--- /dev/null
+++ b/src/3rd_party/phf/phf.h
@@ -0,0 +1,299 @@
+/* ==========================================================================
+ * phf.h - Tiny perfect hash function library.
+ * --------------------------------------------------------------------------
+ * Copyright (c) 2014-2015 William Ahern
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the
+ * "Software"), to deal in the Software without restriction, including
+ * without limitation the rights to use, copy, modify, merge, publish,
+ * distribute, sublicense, and/or sell copies of the Software, and to permit
+ * persons to whom the Software is furnished to do so, subject to the
+ * following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included
+ * in all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
+ * NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+ * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
+ * USE OR OTHER DEALINGS IN THE SOFTWARE.
+ * ==========================================================================
+ */
+#ifndef PHF_H
+#define PHF_H
+
+#include <stddef.h> /* size_t */
+#include <stdint.h> /* UINT32_MAX uint32_t uint64_t */
+#include <stdbool.h> /* bool */
+#include <inttypes.h> /* PRIu32 PRIx32 */
+
+
+/*
+ * C O M P I L E R F E A T U R E S & D I A G N O S T I C S
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+#define PHF_GNUC_PREREQ(M, m) (__GNUC__ > (M) || (__GNUC__ == (M) && __GNUC_MINOR__ >= (m)))
+
+#ifdef __clang__
+#define phf_has_extension(x) __has_extension(x)
+#define phf_has_attribute(x) __has_attribute(x)
+#else
+#define phf_has_extension(x) 0
+#define phf_has_attribute(x) 0
+#endif
+
+#ifndef PHF_HAVE_NOEXCEPT
+#define PHF_HAVE_NOEXCEPT \
+ (__cplusplus >= 201103L || \
+ phf_has_extension(cxx_noexcept) || \
+ PHF_GNUC_PREREQ(4, 6))
+#endif
+
+#ifndef PHF_HAVE_GENERIC
+#define PHF_HAVE_GENERIC \
+ (__STDC_VERSION__ >= 201112L || \
+ phf_has_extension(c_generic_selections) || \
+ PHF_GNUC_PREREQ(4, 9))
+#endif
+
+#ifndef PHF_HAVE_BUILTIN_TYPES_COMPATIBLE_P
+#define PHF_HAVE_BUILTIN_TYPES_COMPATIBLE_P (defined __GNUC__)
+#endif
+
+#ifndef PHF_HAVE_BUILTIN_CHOOSE_EXPR
+#define PHF_HAVE_BUILTIN_CHOOSE_EXPR (defined __GNUC__)
+#endif
+
+#ifndef PHF_HAVE_ATTRIBUTE_VISIBILITY
+#define PHF_HAVE_ATTRIBUTE_VISIBILITY \
+ (phf_has_attribute(visibility) || PHF_GNUC_PREREQ(4, 0))
+#endif
+
+#ifndef PHF_HAVE_COMPUTED_GOTOS
+#ifdef __GNUC__
+#define PHF_HAVE_COMPUTED_GOTOS 1
+#else
+#define PHF_HAVE_COMPUTED_GOTOS 0
+#endif
+#endif
+
+#ifdef __clang__
+#pragma clang diagnostic push
+#if __cplusplus < 201103L
+#pragma clang diagnostic ignored "-Wc++11-extensions"
+#pragma clang diagnostic ignored "-Wvariadic-macros"
+#endif
+#elif PHF_GNUC_PREREQ(4, 6)
+#pragma GCC diagnostic push
+#if __cplusplus < 201103L
+#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wvariadic-macros"
+#endif
+#endif
+
+
+/*
+ * C / C + + V I S I B I L I T Y
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+#ifndef PHF_PUBLIC
+#define PHF_PUBLIC
+#endif
+
+#ifndef PHF_LOCAL
+#if PHF_HAVE_ATTRIBUTE_VISIBILITY
+#define PHF_LOCAL __attribute__((visibility("hidden")))
+#else
+#define PHF_LOCAL
+#endif
+#endif
+
+
+/*
+ * C / C + + S H A R E D T Y P E S
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+
+#define phf_error_t int /* for documentation purposes */
+
+#define PHF_HASH_MAX UINT32_MAX
+#define PHF_PRIuHASH PRIu32
+#define PHF_PRIxHASH PRIx32
+
+typedef uint32_t phf_hash_t;
+typedef uint32_t phf_seed_t;
+
+typedef struct phf_string {
+ void *p;
+ size_t n;
+} phf_string_t;
+
+struct phf {
+ bool nodiv;
+
+ phf_seed_t seed;
+
+ size_t r; /* number of elements in g */
+ size_t m; /* number of elements in perfect hash */
+ uint32_t *g; /* displacement map indexed by g(k) % r */
+
+ size_t d_max; /* maximum displacement value in g */
+
+ enum {
+ PHF_G_UINT8_MOD_R = 1,
+ PHF_G_UINT8_BAND_R,
+ PHF_G_UINT16_MOD_R,
+ PHF_G_UINT16_BAND_R,
+ PHF_G_UINT32_MOD_R,
+ PHF_G_UINT32_BAND_R,
+ } g_op;
+
+ const void *g_jmp;
+}; /* struct phf */
+
+
+/*
+ * C + + I N T E R F A C E S
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+#ifdef __cplusplus
+
+#if !PHF_NO_LIBCXX
+#include <string> /* std::string */
+#endif
+
+namespace PHF {
+ template<typename key_t>
+ PHF_PUBLIC size_t uniq(key_t[], const size_t);
+
+ template<typename key_t, bool nodiv>
+ PHF_PUBLIC phf_error_t init(struct phf *, const key_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+
+ PHF_PUBLIC void compact(struct phf *);
+
+ template<typename key_t>
+ PHF_PUBLIC phf_hash_t hash(struct phf *, key_t);
+
+ PHF_PUBLIC void destroy(struct phf *);
+}
+
+extern template size_t PHF::uniq<uint32_t>(uint32_t[], const size_t);
+extern template size_t PHF::uniq<uint64_t>(uint64_t[], const size_t);
+extern template size_t PHF::uniq<phf_string_t>(phf_string_t[], const size_t);
+#if !PHF_NO_LIBCXX
+extern template size_t PHF::uniq<std::string>(std::string[], const size_t);
+#endif
+
+extern template phf_error_t PHF::init<uint32_t, true>(struct phf *, const uint32_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+extern template phf_error_t PHF::init<uint64_t, true>(struct phf *, const uint64_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+extern template phf_error_t PHF::init<phf_string_t, true>(struct phf *, const phf_string_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+#if !PHF_NO_LIBCXX
+extern template phf_error_t PHF::init<std::string, true>(struct phf *, const std::string[], const size_t, const size_t, const size_t, const phf_seed_t);
+#endif
+
+extern template phf_error_t PHF::init<uint32_t, false>(struct phf *, const uint32_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+extern template phf_error_t PHF::init<uint64_t, false>(struct phf *, const uint64_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+extern template phf_error_t PHF::init<phf_string_t, false>(struct phf *, const phf_string_t[], const size_t, const size_t, const size_t, const phf_seed_t);
+#if !PHF_NO_LIBCXX
+extern template phf_error_t PHF::init<std::string, false>(struct phf *, const std::string[], const size_t, const size_t, const size_t, const phf_seed_t);
+#endif
+
+extern template phf_hash_t PHF::hash<uint32_t>(struct phf *, uint32_t);
+extern template phf_hash_t PHF::hash<uint64_t>(struct phf *, uint64_t);
+extern template phf_hash_t PHF::hash<phf_string_t>(struct phf *, phf_string_t);
+#if !PHF_NO_LIBCXX
+extern template phf_hash_t PHF::hash<std::string>(struct phf *, std::string);
+#endif
+
+#endif /* __cplusplus */
+
+
+/*
+ * C 8 9 I N T E R F A C E S
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+PHF_PUBLIC size_t phf_uniq_uint32(uint32_t *, const size_t);
+PHF_PUBLIC size_t phf_uniq_uint64(uint64_t *, const size_t);
+PHF_PUBLIC size_t phf_uniq_string(phf_string_t *, const size_t);
+
+PHF_PUBLIC phf_error_t phf_init_uint32(struct phf *, const uint32_t *, const size_t, const size_t, const size_t, const phf_seed_t, const bool nodiv);
+PHF_PUBLIC phf_error_t phf_init_uint64(struct phf *, const uint64_t *, const size_t, const size_t, const size_t, const phf_seed_t, const bool nodiv);
+PHF_PUBLIC phf_error_t phf_init_string(struct phf *, const phf_string_t *, const size_t, const size_t, const size_t, const phf_seed_t, const bool nodiv);
+
+PHF_PUBLIC void phf_compact(struct phf *);
+
+PHF_PUBLIC phf_hash_t phf_hash_uint32(struct phf *, const uint32_t);
+PHF_PUBLIC phf_hash_t phf_hash_uint64(struct phf *, const uint64_t);
+PHF_PUBLIC phf_hash_t phf_hash_string(struct phf *, const phf_string_t);
+
+PHF_PUBLIC void phf_destroy(struct phf *);
+
+#ifdef __cplusplus
+}
+#endif
+
+
+/*
+ * C 1 1 / G N U I N T E R F A C E S
+ *
+ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
+#if PHF_HAVE_GENERIC
+
+#define phf_uniq(k, n) _Generic(*(k), \
+ uint32_t: phf_uniq_uint32, \
+ uint64_t: phf_uniq_uint64, \
+ phf_string_t: phf_uniq_string)((k), (n))
+
+#define phf_init(f, k, ...) _Generic(*(k), \
+ uint32_t: phf_init_uint32, \
+ uint64_t: phf_init_uint64, \
+ phf_string_t: phf_init_string)((f), (k), __VA_ARGS__)
+
+#define phf_hash(f, k) _Generic((k), \
+ uint32_t: phf_hash_uint32, \
+ uint64_t: phf_hash_uint64, \
+ phf_string_t: phf_hash_string)((f), (k))
+
+#elif PHF_HAVE_BUILTIN_TYPES_COMPATIBLE_P && PHF_HAVE_BUILTIN_CHOOSE_EXPR
+
+#define phf_choose(cond, a, b) __builtin_choose_expr(cond, a, b)
+#define phf_istype(E, T) __builtin_types_compatible_p(__typeof__(E), T)
+
+#define phf_uniq(k, n) \
+ phf_choose(phf_istype(*(k), uint32_t), phf_uniq_uint32((uint32_t *)(k), (n)), \
+ phf_choose(phf_istype(*(k), uint64_t), phf_uniq_uint64((uint64_t *)(k), (n)), \
+ phf_choose(phf_istype(*(k), phf_string_t), phf_uniq_string((phf_string_t *)(k), (n)), \
+ (void)0)))
+
+#define phf_init(f, k, ...) \
+ phf_choose(phf_istype(*(k), uint32_t), phf_init_uint32((f), (const uint32_t *)(k), __VA_ARGS__), \
+ phf_choose(phf_istype(*(k), uint64_t), phf_init_uint64((f), (const uint64_t *)(k), __VA_ARGS__), \
+ phf_choose(phf_istype(*(k), phf_string_t), phf_init_string((f), (const phf_string_t *)(k), __VA_ARGS__), \
+ (void)0)))
+
+#define phf_hash(f, k) ((*(phf_hash_t (*)()) \
+ phf_choose(phf_istype((k), uint32_t), &phf_hash_uint32, \
+ phf_choose(phf_istype((k), uint64_t), &phf_hash_uint64, \
+ phf_choose(phf_istype((k), phf_string_t), &phf_hash_string, \
+ (void)0))))((f), (k)))
+
+#endif
+
+
+#ifdef __clang__
+#pragma clang diagnostic pop
+#elif PHF_GNUC_PREREQ(4, 6)
+#pragma GCC diagnostic pop
+#endif
+
+#endif /* PHF_H */
diff --git a/src/3rd_party/reduce_all.h b/src/3rd_party/reduce_all.h
index 1869d122..26397e37 100644
--- a/src/3rd_party/reduce_all.h
+++ b/src/3rd_party/reduce_all.h
@@ -1,321 +1,248 @@
-/*
- * Copyright 1993-2015 NVIDIA Corporation. All rights reserved.
+/* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
*
- * Please refer to the NVIDIA end user license agreement (EULA) associated
- * with this source code for terms and conditions that govern your use of
- * this software. Any use, reproduction, disclosure, or distribution of
- * this software and related documentation outside the terms of the EULA
- * is strictly prohibited.
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of NVIDIA CORPORATION nor the names of its
+ * contributors may be used to endorse or promote products derived
+ * from this software without specific prior written permission.
*
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECINDIRECFunctor, T, AccTyf, pe, INCIDENTAL, SPECIAL,
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+ * OF LIABILITY, WHETHER IN CONTRACSf, TRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
-#pragma once
-
-#include "tensors/tensor.h"
-#include <cuda_runtime.h>
+#include "functional/tmp.h"
+#include <cooperative_groups.h>
namespace marian {
-template <unsigned int blockSize>
-__device__ void
-reduceBlock(volatile float *sdata, float mySum, const unsigned int tid)
-{
- sdata[tid] = mySum;
- __syncthreads();
-
- // do reduction in shared mem
- if (blockSize >= 512)
- {
- if (tid < 256)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 256];
- }
-
- __syncthreads();
- }
-
- if (blockSize >= 256)
- {
- if (tid < 128)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 128];
- }
-
- __syncthreads();
- }
-
- if (blockSize >= 128)
- {
- if (tid < 64)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 64];
- }
-
- __syncthreads();
- }
+namespace cg = cooperative_groups;
+
+// Utility class used to avoid linker errors with extern
+// unsized shared memory arrays with templated type
+template <class T>
+struct SharedMemory {
+ __device__ inline operator T *() {
+ extern __shared__ int __smem[];
+ return (T *)__smem;
+ }
+
+ __device__ inline operator const T *() const {
+ extern __shared__ int __smem[];
+ return (T *)__smem;
+ }
+};
+
+// specialize for double to avoid unaligned memory
+// access compile errors
+template <>
+struct SharedMemory<double> {
+ __device__ inline operator double *() {
+ extern __shared__ double __smem_d[];
+ return (double *)__smem_d;
+ }
+
+ __device__ inline operator const double *() const {
+ extern __shared__ double __smem_d[];
+ return (double *)__smem_d;
+ }
+};
- if (tid < 32)
- {
- if (blockSize >= 64)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 32];
- }
-
- if (blockSize >= 32)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 16];
- }
-
- if (blockSize >= 16)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 8];
- }
-
- if (blockSize >= 8)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 4];
- }
-
- if (blockSize >= 4)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 2];
- }
-
- if (blockSize >= 2)
- {
- sdata[tid] = mySum = mySum + sdata[tid + 1];
- }
- }
-}
-template <unsigned int blockSize, bool nIsPow2, class Functor>
-__device__ void
-reduceBlocks(Functor f, float *g_idata, float *g_odata, unsigned int n)
-{
- extern __shared__ float sdata[];
-
- // perform first level of reduction,
- // reading from global memory, writing to shared memory
- unsigned int tid = threadIdx.x;
- unsigned int i = blockIdx.x*(blockSize*2) + threadIdx.x;
- unsigned int gridSize = blockSize*2*gridDim.x;
- float mySum = 0;
-
- // we reduce multiple elements per thread. The number is determined by the
- // number of active thread blocks (via gridDim). More blocks will result
- // in a larger gridSize and therefore fewer elements per thread
- while (i < n)
- {
- mySum += f(g_idata[i]);
-
- // ensure we don't read out of bounds -- this is optimized away for powerOf2 sized arrays
- if (nIsPow2 || i + blockSize < n)
- mySum += f(g_idata[i+blockSize]);
-
- i += gridSize;
+/*
+ This version adds multiple elements per thread sequentially. This reduces
+ the overall cost of the algorithm while keeping the work complexity O(n) and
+ the step complexity O(log n). (Brent's Theorem optimization)
+
+ Note, this kernel needs a minimum of 64*sizeof(T) bytes of shared memory.
+ In other words if blockSize <= 32, allocate 64*sizeof(T) bytes.
+ If blockSize > 32, allocate blockSize*sizeof(T) bytes.
+*/
+template <typename T, typename AccType, unsigned int blockSize, bool nIsPow2Greater1, size_t K, class Functor, class AggFunctor>
+__global__ void reduceSinglePass(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale,
+ const functional::Shape full,
+ functional::Tensor<AccType> out,
+ functional::Array<functional::Tensor<T>, K> ins) {
+ int n = full.elements();
+
+ // Handle to thread block group
+ cg::thread_block cta = cg::this_thread_block();
+ AccType *sdata = SharedMemory<AccType>();
+
+ // perform first level of reduction,
+ // reading from global memory, writing to shared memory
+ unsigned int tid = threadIdx.x;
+ unsigned int i = blockIdx.x * blockSize * 2 + threadIdx.x;
+ unsigned int gridSize = blockSize * 2 * gridDim.x;
+
+ AccType mySum = aggInit;
+
+ // we reduceSinglePass multiple elements per thread. The number is determined by the
+ // number of active thread blocks (via gridDim). More blocks will result
+ // in a larger gridSize and therefore fewer elements per thread
+ while (i < n) {
+ mySum = aggFunctor(mySum, functional::applyWithCast<AccType>(functor, ins, i));
+
+ // ensure we don't read out of bounds -- this is optimized away for powerOf2
+ // sized arrays
+ if (nIsPow2Greater1 || i + blockSize < n)
+ mySum = aggFunctor(mySum, functional::applyWithCast<AccType>(functor, ins, i + blockSize));
+
+ i += gridSize;
+ }
+
+ // each thread puts its local sum into shared memory
+ sdata[tid] = mySum;
+ cg::sync(cta);
+
+ // do reduction in shared mem
+ if ((blockSize >= 512) && (tid < 256)) {
+ sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 256]);
+ }
+
+ cg::sync(cta);
+
+ if ((blockSize >= 256) && (tid < 128)) {
+ sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 128]);
+ }
+
+ cg::sync(cta);
+
+ if ((blockSize >= 128) && (tid < 64)) {
+ sdata[tid] = mySum = aggFunctor(mySum, sdata[tid + 64]);
+ }
+
+ cg::sync(cta);
+
+ cg::thread_block_tile<32> tile32 = cg::tiled_partition<32>(cta);
+
+ if (cta.thread_rank() < 32) {
+ // Fetch final intermediate sum from 2nd warp
+ if (blockSize >= 64)
+ mySum = aggFunctor(mySum, sdata[tid + 32]);
+ // reduce final warp using shuffle
+ for (int offset = tile32.size() / 2; offset > 0; offset /= 2) {
+ mySum = aggFunctor(mySum, tile32.shfl_down(mySum, offset));
}
+ }
- // do reduction in shared mem
- reduceBlock<blockSize>(sdata, mySum, tid);
-
- // write result for this block to global mem
- if (tid == 0) g_odata[blockIdx.x] = sdata[0];
+ // write result for this block to global mem
+ if (cta.thread_rank() == 0)
+ out[blockIdx.x] = aggFunctor(out[blockIdx.x], mySum * scale); // aggFunctor?
}
-// Global variable used by reduceSinglePass to count how many blocks have finished
-__device__ unsigned int retirementCount = 0;
-
-cudaError_t setRetirementCount(int retCnt)
-{
- return cudaMemcpyToSymbol(retirementCount, &retCnt, sizeof(unsigned int), 0, cudaMemcpyHostToDevice);
+static inline bool isPow2Greater1(unsigned int x) { // is power of two but also larger than 1, otherwise an out-of-bounds read occurs
+ return x > 1 && ((x & (x - 1)) == 0);
}
-// This reduction kernel reduces an arbitrary size array in a single kernel invocation
-// It does so by keeping track of how many blocks have finished. After each thread
-// block completes the reduction of its own block of data, it "takes a ticket" by
-// atomically incrementing a global counter. If the ticket value is equal to the number
-// of thread blocks, then the block holding the ticket knows that it is the last block
-// to finish. This last block is responsible for summing the results of all the other
-// blocks.
-//
-// In order for this to work, we must be sure that before a block takes a ticket, all
-// of its memory transactions have completed. This is what __threadfence() does -- it
-// blocks until the results of all outstanding memory transactions within the
-// calling thread are visible to all other threads.
-//
-// For more details on the reduction algorithm (notably the multi-pass approach), see
-// the "reduction" sample in the CUDA SDK.
-
-template <unsigned int blockSize, bool nIsPow2, class Functor>
-__global__ void reduceSinglePass(Functor f, float *g_idata, float *g_odata, unsigned int n)
-{
-
- //
- // PHASE 1: Process all inputs assigned to this block
- //
-
- reduceBlocks<blockSize, nIsPow2>(f, g_idata, g_odata, n);
-
- //
- // PHASE 2: Last block finished will process all partial sums
- //
-
- if (gridDim.x > 1)
- {
- const unsigned int tid = threadIdx.x;
- __shared__ bool amLast;
- extern float __shared__ smem[];
-
- // wait until all outstanding memory instructions in this thread are finished
- __threadfence();
-
- // Thread 0 takes a ticket
- if (tid==0)
- {
- unsigned int ticket = atomicInc(&retirementCount, gridDim.x);
- // If the ticket ID is equal to the number of blocks, we are the last block!
- amLast = (ticket == gridDim.x-1);
- }
-
- __syncthreads();
-
- // The last block sums the results of all other blocks
- if (amLast)
- {
- int i = tid;
- float mySum = 0;
-
- while (i < gridDim.x)
- {
- mySum += g_odata[i];
- i += blockSize;
- }
-
- reduceBlock<blockSize>(smem, mySum, tid);
-
- if (tid==0)
- {
- g_odata[0] = smem[0];
-
- // reset retirement count so that next run succeeds
- retirementCount = 0;
- }
- }
- }
-}
-
-bool isPow2(unsigned int x)
-{
- return ((x&(x-1))==0);
+static inline unsigned int nextPow2(unsigned int x) {
+ --x;
+ x |= x >> 1;
+ x |= x >> 2;
+ x |= x >> 4;
+ x |= x >> 8;
+ x |= x >> 16;
+ return ++x;
}
-template <class Functor>
-void ReduceAll(Functor f, Tensor out, Tensor in)
-{
- cudaSetDevice(out->getDeviceId().no);
- int size = in->shape().elements();
- int threads = std::min(MAX_THREADS, size);
- int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
-
- dim3 dimBlock(threads, 1, 1);
- dim3 dimGrid(blocks, 1, 1);
- int smemSize = threads * sizeof(float);
-
- float* d_idata = in->data();
- float* d_odata = out->data();
-
- // choose which of the optimized versions of reduction to launch
- if (isPow2(size))
- {
- switch (threads)
- {
- case 512:
- reduceSinglePass<512, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 256:
- reduceSinglePass<256, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 128:
- reduceSinglePass<128, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 64:
- reduceSinglePass< 64, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 32:
- reduceSinglePass< 32, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 16:
- reduceSinglePass< 16, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 8:
- reduceSinglePass< 8, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 4:
- reduceSinglePass< 4, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 2:
- reduceSinglePass< 2, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 1:
- reduceSinglePass< 1, true><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
- }
+////////////////////////////////////////////////////////////////////////////////
+// Wrapper function for kernel launch
+////////////////////////////////////////////////////////////////////////////////
+template <typename T, typename AccType, size_t K, class Functor, class AggFunctor>
+void reduceSinglePass(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale,
+ const functional::Shape full,
+ functional::Tensor<AccType> out,
+ functional::Array<functional::Tensor<T>, K> ins,
+ int threads, int blocks) {
+ int size = full.elements();
+ // when there is only one warp per block, we need to allocate two warps
+ // worth of shared memory so that we don't index shared memory out of bounds
+ int smemSize = (threads <= 32) ? 2 * threads * sizeof(AccType) : threads * sizeof(AccType);
+ dim3 dimBlock(threads, 1, 1);
+ dim3 dimGrid(blocks, 1, 1);
+
+ if (isPow2Greater1(size)) {
+ switch (threads) {
+ case 512:
+ reduceSinglePass<T, AccType, 512, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 256:
+ reduceSinglePass<T, AccType, 256, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 128:
+ reduceSinglePass<T, AccType, 128, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 64:
+ reduceSinglePass<T, AccType, 64, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 32:
+ reduceSinglePass<T, AccType, 32, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 16:
+ reduceSinglePass<T, AccType, 16, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 8:
+ reduceSinglePass<T, AccType, 8, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 4:
+ reduceSinglePass<T, AccType, 4, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 2:
+ reduceSinglePass<T, AccType, 2, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 1:
+ reduceSinglePass<T, AccType, 1, true><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
}
- else
- {
- switch (threads)
- {
- case 512:
- reduceSinglePass<512, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 256:
- reduceSinglePass<256, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 128:
- reduceSinglePass<128, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 64:
- reduceSinglePass< 64, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 32:
- reduceSinglePass< 32, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 16:
- reduceSinglePass< 16, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 8:
- reduceSinglePass< 8, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 4:
- reduceSinglePass< 4, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 2:
- reduceSinglePass< 2, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
-
- case 1:
- reduceSinglePass< 1, false><<< dimGrid, dimBlock, smemSize >>>(f, d_idata, d_odata, size);
- break;
- }
+ } else {
+ switch (threads) {
+ case 512:
+ reduceSinglePass<T, AccType, 512, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 256:
+ reduceSinglePass<T, AccType, 256, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 128:
+ reduceSinglePass<T, AccType, 128, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 64:
+ reduceSinglePass<T, AccType, 64, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 32:
+ reduceSinglePass<T, AccType, 32, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 16:
+ reduceSinglePass<T, AccType, 16, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 8:
+ reduceSinglePass<T, AccType, 8, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 4:
+ reduceSinglePass<T, AccType, 4, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 2:
+ reduceSinglePass<T, AccType, 2, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
+ case 1:
+ reduceSinglePass<T, AccType, 1, false><<<dimGrid, dimBlock, smemSize>>>(functor, aggInit, aggFunctor, scale, full, out, ins);
+ break;
}
+ }
}
-}
+} \ No newline at end of file
diff --git a/src/3rd_party/sentencepiece b/src/3rd_party/sentencepiece
-Subproject 1a38d26a13cc67b1aae641d4983b624bef6d530
+Subproject 1d33bb67c3b6b2a51d3c9ffd55f37725801da39
diff --git a/src/3rd_party/spdlog/astyle.sh b/src/3rd_party/spdlog/astyle.sh
index a7a90510..a7a90510 100755..100644
--- a/src/3rd_party/spdlog/astyle.sh
+++ b/src/3rd_party/spdlog/astyle.sh
diff --git a/src/3rd_party/spdlog/bench/latency/compare.sh b/src/3rd_party/spdlog/bench/latency/compare.sh
index 0f0e4c97..0f0e4c97 100755..100644
--- a/src/3rd_party/spdlog/bench/latency/compare.sh
+++ b/src/3rd_party/spdlog/bench/latency/compare.sh
diff --git a/src/3rd_party/spdlog/details/format.h b/src/3rd_party/spdlog/details/format.h
index 4f2b5729..4f2b5729 100755..100644
--- a/src/3rd_party/spdlog/details/format.h
+++ b/src/3rd_party/spdlog/details/format.h
diff --git a/src/3rd_party/spdlog/details/logger_impl.h b/src/3rd_party/spdlog/details/logger_impl.h
index 428cd189..428cd189 100755..100644
--- a/src/3rd_party/spdlog/details/logger_impl.h
+++ b/src/3rd_party/spdlog/details/logger_impl.h
diff --git a/src/3rd_party/spdlog/logger.h b/src/3rd_party/spdlog/logger.h
index 41d51fbf..41d51fbf 100755..100644
--- a/src/3rd_party/spdlog/logger.h
+++ b/src/3rd_party/spdlog/logger.h
diff --git a/src/3rd_party/spdlog/tests/catch.hpp b/src/3rd_party/spdlog/tests/catch.hpp
index 3493d9cb..3493d9cb 100755..100644
--- a/src/3rd_party/spdlog/tests/catch.hpp
+++ b/src/3rd_party/spdlog/tests/catch.hpp
diff --git a/src/3rd_party/spdlog/tests/install_libcxx.sh b/src/3rd_party/spdlog/tests/install_libcxx.sh
index cee97692..cee97692 100755..100644
--- a/src/3rd_party/spdlog/tests/install_libcxx.sh
+++ b/src/3rd_party/spdlog/tests/install_libcxx.sh
diff --git a/src/3rd_party/sse_mathfun.h b/src/3rd_party/sse_mathfun.h
new file mode 100644
index 00000000..a3951805
--- /dev/null
+++ b/src/3rd_party/sse_mathfun.h
@@ -0,0 +1,711 @@
+/* SIMD (SSE1+MMX or SSE2) implementation of sin, cos, exp and log
+
+ Inspired by Intel Approximate Math library, and based on the
+ corresponding algorithms of the cephes math library
+
+ The default is to use the SSE1 version. If you define USE_SSE2 the
+ the SSE2 intrinsics will be used in place of the MMX intrinsics. Do
+ not expect any significant performance improvement with SSE2.
+*/
+
+/* Copyright (C) 2007 Julien Pommier
+
+ This software is provided 'as-is', without any express or implied
+ warranty. In no event will the authors be held liable for any damages
+ arising from the use of this software.
+
+ Permission is granted to anyone to use this software for any purpose,
+ including commercial applications, and to alter it and redistribute it
+ freely, subject to the following restrictions:
+
+ 1. The origin of this software must not be misrepresented; you must not
+ claim that you wrote the original software. If you use this software
+ in a product, an acknowledgment in the product documentation would be
+ appreciated but is not required.
+ 2. Altered source versions must be plainly marked as such, and must not be
+ misrepresented as being the original software.
+ 3. This notice may not be removed or altered from any source distribution.
+
+ (this is the zlib license)
+*/
+
+#include <xmmintrin.h>
+
+/* yes I know, the top of this file is quite ugly */
+
+#ifdef _MSC_VER /* visual c++ */
+# define ALIGN16_BEG __declspec(align(16))
+# define ALIGN16_END
+#else /* gcc or icc */
+# define ALIGN16_BEG
+# define ALIGN16_END __attribute__((aligned(16)))
+#endif
+
+/* __m128 is ugly to write */
+typedef __m128 v4sf; // vector of 4 float (sse1)
+
+#ifdef USE_SSE2
+# include <emmintrin.h>
+typedef __m128i v4si; // vector of 4 int (sse2)
+#else
+typedef __m64 v2si; // vector of 2 int (mmx)
+#endif
+
+/* declare some SSE constants -- why can't I figure a better way to do that? */
+#define _PS_CONST(Name, Val) \
+ static const ALIGN16_BEG float _ps_##Name[4] ALIGN16_END = { (float)Val, (float)Val, (float)Val, (float)Val }
+#define _PI32_CONST(Name, Val) \
+ static const ALIGN16_BEG int _pi32_##Name[4] ALIGN16_END = { Val, Val, Val, Val }
+#define _PS_CONST_TYPE(Name, Type, Val) \
+ static const ALIGN16_BEG Type _ps_##Name[4] ALIGN16_END = { Val, Val, Val, Val }
+
+_PS_CONST(1 , 1.0f);
+_PS_CONST(0p5, 0.5f);
+/* the smallest non denormalized float number */
+_PS_CONST_TYPE(min_norm_pos, int, 0x00800000);
+_PS_CONST_TYPE(mant_mask, int, 0x7f800000);
+_PS_CONST_TYPE(inv_mant_mask, int, ~0x7f800000);
+
+_PS_CONST_TYPE(sign_mask, int, (int)0x80000000);
+_PS_CONST_TYPE(inv_sign_mask, int, ~0x80000000);
+
+_PI32_CONST(1, 1);
+_PI32_CONST(inv1, ~1);
+_PI32_CONST(2, 2);
+_PI32_CONST(4, 4);
+_PI32_CONST(0x7f, 0x7f);
+
+_PS_CONST(cephes_SQRTHF, 0.707106781186547524);
+_PS_CONST(cephes_log_p0, 7.0376836292E-2);
+_PS_CONST(cephes_log_p1, - 1.1514610310E-1);
+_PS_CONST(cephes_log_p2, 1.1676998740E-1);
+_PS_CONST(cephes_log_p3, - 1.2420140846E-1);
+_PS_CONST(cephes_log_p4, + 1.4249322787E-1);
+_PS_CONST(cephes_log_p5, - 1.6668057665E-1);
+_PS_CONST(cephes_log_p6, + 2.0000714765E-1);
+_PS_CONST(cephes_log_p7, - 2.4999993993E-1);
+_PS_CONST(cephes_log_p8, + 3.3333331174E-1);
+_PS_CONST(cephes_log_q1, -2.12194440e-4);
+_PS_CONST(cephes_log_q2, 0.693359375);
+
+#ifndef USE_SSE2
+typedef union xmm_mm_union {
+ __m128 xmm;
+ __m64 mm[2];
+} xmm_mm_union;
+
+#define COPY_XMM_TO_MM(xmm_, mm0_, mm1_) { \
+ xmm_mm_union u; u.xmm = xmm_; \
+ mm0_ = u.mm[0]; \
+ mm1_ = u.mm[1]; \
+}
+
+#define COPY_MM_TO_XMM(mm0_, mm1_, xmm_) { \
+ xmm_mm_union u; u.mm[0]=mm0_; u.mm[1]=mm1_; xmm_ = u.xmm; \
+ }
+
+#endif // USE_SSE2
+
+/* natural logarithm computed for 4 simultaneous float
+ return NaN for x <= 0
+*/
+static inline v4sf log_ps(v4sf x) {
+#ifdef USE_SSE2
+ v4si emm0;
+#else
+ v2si mm0, mm1;
+#endif
+ v4sf one = *(v4sf*)_ps_1;
+
+ v4sf invalid_mask = _mm_cmple_ps(x, _mm_setzero_ps());
+
+ x = _mm_max_ps(x, *(v4sf*)_ps_min_norm_pos); /* cut off denormalized stuff */
+
+#ifndef USE_SSE2
+ /* part 1: x = frexpf(x, &e); */
+ COPY_XMM_TO_MM(x, mm0, mm1);
+ mm0 = _mm_srli_pi32(mm0, 23);
+ mm1 = _mm_srli_pi32(mm1, 23);
+#else
+ emm0 = _mm_srli_epi32(_mm_castps_si128(x), 23);
+#endif
+ /* keep only the fractional part */
+ x = _mm_and_ps(x, *(v4sf*)_ps_inv_mant_mask);
+ x = _mm_or_ps(x, *(v4sf*)_ps_0p5);
+
+#ifndef USE_SSE2
+ /* now e=mm0:mm1 contain the really base-2 exponent */
+ mm0 = _mm_sub_pi32(mm0, *(v2si*)_pi32_0x7f);
+ mm1 = _mm_sub_pi32(mm1, *(v2si*)_pi32_0x7f);
+ v4sf e = _mm_cvtpi32x2_ps(mm0, mm1);
+ _mm_empty(); /* bye bye mmx */
+#else
+ emm0 = _mm_sub_epi32(emm0, *(v4si*)_pi32_0x7f);
+ v4sf e = _mm_cvtepi32_ps(emm0);
+#endif
+
+ e = _mm_add_ps(e, one);
+
+ /* part2:
+ if( x < SQRTHF ) {
+ e -= 1;
+ x = x + x - 1.0;
+ } else { x = x - 1.0; }
+ */
+ v4sf mask = _mm_cmplt_ps(x, *(v4sf*)_ps_cephes_SQRTHF);
+ v4sf tmp = _mm_and_ps(x, mask);
+ x = _mm_sub_ps(x, one);
+ e = _mm_sub_ps(e, _mm_and_ps(one, mask));
+ x = _mm_add_ps(x, tmp);
+
+
+ v4sf z = _mm_mul_ps(x,x);
+
+ v4sf y = *(v4sf*)_ps_cephes_log_p0;
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p1);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p2);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p3);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p4);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p5);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p6);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p7);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_log_p8);
+ y = _mm_mul_ps(y, x);
+
+ y = _mm_mul_ps(y, z);
+
+
+ tmp = _mm_mul_ps(e, *(v4sf*)_ps_cephes_log_q1);
+ y = _mm_add_ps(y, tmp);
+
+
+ tmp = _mm_mul_ps(z, *(v4sf*)_ps_0p5);
+ y = _mm_sub_ps(y, tmp);
+
+ tmp = _mm_mul_ps(e, *(v4sf*)_ps_cephes_log_q2);
+ x = _mm_add_ps(x, y);
+ x = _mm_add_ps(x, tmp);
+ x = _mm_or_ps(x, invalid_mask); // negative arg will be NAN
+ return x;
+}
+
+_PS_CONST(exp_hi, 88.3762626647949f);
+_PS_CONST(exp_lo, -88.3762626647949f);
+
+_PS_CONST(cephes_LOG2EF, 1.44269504088896341);
+_PS_CONST(cephes_exp_C1, 0.693359375);
+_PS_CONST(cephes_exp_C2, -2.12194440e-4);
+
+_PS_CONST(cephes_exp_p0, 1.9875691500E-4);
+_PS_CONST(cephes_exp_p1, 1.3981999507E-3);
+_PS_CONST(cephes_exp_p2, 8.3334519073E-3);
+_PS_CONST(cephes_exp_p3, 4.1665795894E-2);
+_PS_CONST(cephes_exp_p4, 1.6666665459E-1);
+_PS_CONST(cephes_exp_p5, 5.0000001201E-1);
+
+static inline v4sf exp_ps(v4sf x) {
+ v4sf tmp = _mm_setzero_ps(), fx;
+#ifdef USE_SSE2
+ v4si emm0;
+#else
+ v2si mm0, mm1;
+#endif
+ v4sf one = *(v4sf*)_ps_1;
+
+ x = _mm_min_ps(x, *(v4sf*)_ps_exp_hi);
+ x = _mm_max_ps(x, *(v4sf*)_ps_exp_lo);
+
+ /* express exp(x) as exp(g + n*log(2)) */
+ fx = _mm_mul_ps(x, *(v4sf*)_ps_cephes_LOG2EF);
+ fx = _mm_add_ps(fx, *(v4sf*)_ps_0p5);
+
+ /* how to perform a floorf with SSE: just below */
+#ifndef USE_SSE2
+ /* step 1 : cast to int */
+ tmp = _mm_movehl_ps(tmp, fx);
+ mm0 = _mm_cvttps_pi32(fx);
+ mm1 = _mm_cvttps_pi32(tmp);
+ /* step 2 : cast back to float */
+ tmp = _mm_cvtpi32x2_ps(mm0, mm1);
+#else
+ emm0 = _mm_cvttps_epi32(fx);
+ tmp = _mm_cvtepi32_ps(emm0);
+#endif
+ /* if greater, substract 1 */
+ v4sf mask = _mm_cmpgt_ps(tmp, fx);
+ mask = _mm_and_ps(mask, one);
+ fx = _mm_sub_ps(tmp, mask);
+
+ tmp = _mm_mul_ps(fx, *(v4sf*)_ps_cephes_exp_C1);
+ v4sf z = _mm_mul_ps(fx, *(v4sf*)_ps_cephes_exp_C2);
+ x = _mm_sub_ps(x, tmp);
+ x = _mm_sub_ps(x, z);
+
+ z = _mm_mul_ps(x,x);
+
+ v4sf y = *(v4sf*)_ps_cephes_exp_p0;
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p1);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p2);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p3);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p4);
+ y = _mm_mul_ps(y, x);
+ y = _mm_add_ps(y, *(v4sf*)_ps_cephes_exp_p5);
+ y = _mm_mul_ps(y, z);
+ y = _mm_add_ps(y, x);
+ y = _mm_add_ps(y, one);
+
+ /* build 2^n */
+#ifndef USE_SSE2
+ z = _mm_movehl_ps(z, fx);
+ mm0 = _mm_cvttps_pi32(fx);
+ mm1 = _mm_cvttps_pi32(z);
+ mm0 = _mm_add_pi32(mm0, *(v2si*)_pi32_0x7f);
+ mm1 = _mm_add_pi32(mm1, *(v2si*)_pi32_0x7f);
+ mm0 = _mm_slli_pi32(mm0, 23);
+ mm1 = _mm_slli_pi32(mm1, 23);
+
+ v4sf pow2n;
+ COPY_MM_TO_XMM(mm0, mm1, pow2n);
+ _mm_empty();
+#else
+ emm0 = _mm_cvttps_epi32(fx);
+ emm0 = _mm_add_epi32(emm0, *(v4si*)_pi32_0x7f);
+ emm0 = _mm_slli_epi32(emm0, 23);
+ v4sf pow2n = _mm_castsi128_ps(emm0);
+#endif
+ y = _mm_mul_ps(y, pow2n);
+ return y;
+}
+
+_PS_CONST(minus_cephes_DP1, -0.78515625);
+_PS_CONST(minus_cephes_DP2, -2.4187564849853515625e-4);
+_PS_CONST(minus_cephes_DP3, -3.77489497744594108e-8);
+_PS_CONST(sincof_p0, -1.9515295891E-4);
+_PS_CONST(sincof_p1, 8.3321608736E-3);
+_PS_CONST(sincof_p2, -1.6666654611E-1);
+_PS_CONST(coscof_p0, 2.443315711809948E-005);
+_PS_CONST(coscof_p1, -1.388731625493765E-003);
+_PS_CONST(coscof_p2, 4.166664568298827E-002);
+_PS_CONST(cephes_FOPI, 1.27323954473516); // 4 / M_PI
+
+
+/* evaluation of 4 sines at onces, using only SSE1+MMX intrinsics so
+ it runs also on old athlons XPs and the pentium III of your grand
+ mother.
+
+ The code is the exact rewriting of the cephes sinf function.
+ Precision is excellent as long as x < 8192 (I did not bother to
+ take into account the special handling they have for greater values
+ -- it does not return garbage for arguments over 8192, though, but
+ the extra precision is missing).
+
+ Note that it is such that sinf((float)M_PI) = 8.74e-8, which is the
+ surprising but correct result.
+
+ Performance is also surprisingly good, 1.33 times faster than the
+ macos vsinf SSE2 function, and 1.5 times faster than the
+ __vrs4_sinf of amd's ACML (which is only available in 64 bits). Not
+ too bad for an SSE1 function (with no special tuning) !
+ However the latter libraries probably have a much better handling of NaN,
+ Inf, denormalized and other special arguments..
+
+ On my core 1 duo, the execution of this function takes approximately 95 cycles.
+
+ From what I have observed on the experiments with Intel AMath lib, switching to an
+ SSE2 version would improve the perf by only 10%.
+
+ Since it is based on SSE intrinsics, it has to be compiled at -O2 to
+ deliver full speed.
+*/
+static inline v4sf sin_ps(v4sf x) { // any x
+ v4sf xmm1, xmm2 = _mm_setzero_ps(), xmm3, sign_bit, y;
+
+#ifdef USE_SSE2
+ v4si emm0, emm2;
+#else
+ v2si mm0, mm1, mm2, mm3;
+#endif
+ sign_bit = x;
+ /* take the absolute value */
+ x = _mm_and_ps(x, *(v4sf*)_ps_inv_sign_mask);
+ /* extract the sign bit (upper one) */
+ sign_bit = _mm_and_ps(sign_bit, *(v4sf*)_ps_sign_mask);
+
+ /* scale by 4/Pi */
+ y = _mm_mul_ps(x, *(v4sf*)_ps_cephes_FOPI);
+
+#ifdef USE_SSE2
+ /* store the integer part of y in mm0 */
+ emm2 = _mm_cvttps_epi32(y);
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ emm2 = _mm_add_epi32(emm2, *(v4si*)_pi32_1);
+ emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_inv1);
+ y = _mm_cvtepi32_ps(emm2);
+
+ /* get the swap sign flag */
+ emm0 = _mm_and_si128(emm2, *(v4si*)_pi32_4);
+ emm0 = _mm_slli_epi32(emm0, 29);
+ /* get the polynom selection mask
+ there is one polynom for 0 <= x <= Pi/4
+ and another one for Pi/4<x<=Pi/2
+
+ Both branches will be computed.
+ */
+ emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_2);
+ emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
+
+ v4sf swap_sign_bit = _mm_castsi128_ps(emm0);
+ v4sf poly_mask = _mm_castsi128_ps(emm2);
+ sign_bit = _mm_xor_ps(sign_bit, swap_sign_bit);
+
+#else
+ /* store the integer part of y in mm0:mm1 */
+ xmm2 = _mm_movehl_ps(xmm2, y);
+ mm2 = _mm_cvttps_pi32(y);
+ mm3 = _mm_cvttps_pi32(xmm2);
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ mm2 = _mm_add_pi32(mm2, *(v2si*)_pi32_1);
+ mm3 = _mm_add_pi32(mm3, *(v2si*)_pi32_1);
+ mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_inv1);
+ mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_inv1);
+ y = _mm_cvtpi32x2_ps(mm2, mm3);
+ /* get the swap sign flag */
+ mm0 = _mm_and_si64(mm2, *(v2si*)_pi32_4);
+ mm1 = _mm_and_si64(mm3, *(v2si*)_pi32_4);
+ mm0 = _mm_slli_pi32(mm0, 29);
+ mm1 = _mm_slli_pi32(mm1, 29);
+ /* get the polynom selection mask */
+ mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_2);
+ mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_2);
+ mm2 = _mm_cmpeq_pi32(mm2, _mm_setzero_si64());
+ mm3 = _mm_cmpeq_pi32(mm3, _mm_setzero_si64());
+ v4sf swap_sign_bit, poly_mask;
+ COPY_MM_TO_XMM(mm0, mm1, swap_sign_bit);
+ COPY_MM_TO_XMM(mm2, mm3, poly_mask);
+ sign_bit = _mm_xor_ps(sign_bit, swap_sign_bit);
+ _mm_empty(); /* good-bye mmx */
+#endif
+
+ /* The magic pass: "Extended precision modular arithmetic"
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
+ xmm1 = *(v4sf*)_ps_minus_cephes_DP1;
+ xmm2 = *(v4sf*)_ps_minus_cephes_DP2;
+ xmm3 = *(v4sf*)_ps_minus_cephes_DP3;
+ xmm1 = _mm_mul_ps(y, xmm1);
+ xmm2 = _mm_mul_ps(y, xmm2);
+ xmm3 = _mm_mul_ps(y, xmm3);
+ x = _mm_add_ps(x, xmm1);
+ x = _mm_add_ps(x, xmm2);
+ x = _mm_add_ps(x, xmm3);
+
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
+ y = *(v4sf*)_ps_coscof_p0;
+ v4sf z = _mm_mul_ps(x,x);
+
+ y = _mm_mul_ps(y, z);
+ y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p1);
+ y = _mm_mul_ps(y, z);
+ y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p2);
+ y = _mm_mul_ps(y, z);
+ y = _mm_mul_ps(y, z);
+ v4sf tmp = _mm_mul_ps(z, *(v4sf*)_ps_0p5);
+ y = _mm_sub_ps(y, tmp);
+ y = _mm_add_ps(y, *(v4sf*)_ps_1);
+
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
+
+ v4sf y2 = *(v4sf*)_ps_sincof_p0;
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p1);
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p2);
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_mul_ps(y2, x);
+ y2 = _mm_add_ps(y2, x);
+
+ /* select the correct result from the two polynoms */
+ xmm3 = poly_mask;
+ y2 = _mm_and_ps(xmm3, y2); //, xmm3);
+ y = _mm_andnot_ps(xmm3, y);
+ y = _mm_add_ps(y,y2);
+ /* update the sign */
+ y = _mm_xor_ps(y, sign_bit);
+ return y;
+}
+
+/* almost the same as sin_ps */
+static inline v4sf cos_ps(v4sf x) { // any x
+ v4sf xmm1, xmm2 = _mm_setzero_ps(), xmm3, y;
+#ifdef USE_SSE2
+ v4si emm0, emm2;
+#else
+ v2si mm0, mm1, mm2, mm3;
+#endif
+ /* take the absolute value */
+ x = _mm_and_ps(x, *(v4sf*)_ps_inv_sign_mask);
+
+ /* scale by 4/Pi */
+ y = _mm_mul_ps(x, *(v4sf*)_ps_cephes_FOPI);
+
+#ifdef USE_SSE2
+ /* store the integer part of y in mm0 */
+ emm2 = _mm_cvttps_epi32(y);
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ emm2 = _mm_add_epi32(emm2, *(v4si*)_pi32_1);
+ emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_inv1);
+ y = _mm_cvtepi32_ps(emm2);
+
+ emm2 = _mm_sub_epi32(emm2, *(v4si*)_pi32_2);
+
+ /* get the swap sign flag */
+ emm0 = _mm_andnot_si128(emm2, *(v4si*)_pi32_4);
+ emm0 = _mm_slli_epi32(emm0, 29);
+ /* get the polynom selection mask */
+ emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_2);
+ emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
+
+ v4sf sign_bit = _mm_castsi128_ps(emm0);
+ v4sf poly_mask = _mm_castsi128_ps(emm2);
+#else
+ /* store the integer part of y in mm0:mm1 */
+ xmm2 = _mm_movehl_ps(xmm2, y);
+ mm2 = _mm_cvttps_pi32(y);
+ mm3 = _mm_cvttps_pi32(xmm2);
+
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ mm2 = _mm_add_pi32(mm2, *(v2si*)_pi32_1);
+ mm3 = _mm_add_pi32(mm3, *(v2si*)_pi32_1);
+ mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_inv1);
+ mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_inv1);
+
+ y = _mm_cvtpi32x2_ps(mm2, mm3);
+
+
+ mm2 = _mm_sub_pi32(mm2, *(v2si*)_pi32_2);
+ mm3 = _mm_sub_pi32(mm3, *(v2si*)_pi32_2);
+
+ /* get the swap sign flag in mm0:mm1 and the
+ polynom selection mask in mm2:mm3 */
+
+ mm0 = _mm_andnot_si64(mm2, *(v2si*)_pi32_4);
+ mm1 = _mm_andnot_si64(mm3, *(v2si*)_pi32_4);
+ mm0 = _mm_slli_pi32(mm0, 29);
+ mm1 = _mm_slli_pi32(mm1, 29);
+
+ mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_2);
+ mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_2);
+
+ mm2 = _mm_cmpeq_pi32(mm2, _mm_setzero_si64());
+ mm3 = _mm_cmpeq_pi32(mm3, _mm_setzero_si64());
+
+ v4sf sign_bit, poly_mask;
+ COPY_MM_TO_XMM(mm0, mm1, sign_bit);
+ COPY_MM_TO_XMM(mm2, mm3, poly_mask);
+ _mm_empty(); /* good-bye mmx */
+#endif
+ /* The magic pass: "Extended precision modular arithmetic"
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
+ xmm1 = *(v4sf*)_ps_minus_cephes_DP1;
+ xmm2 = *(v4sf*)_ps_minus_cephes_DP2;
+ xmm3 = *(v4sf*)_ps_minus_cephes_DP3;
+ xmm1 = _mm_mul_ps(y, xmm1);
+ xmm2 = _mm_mul_ps(y, xmm2);
+ xmm3 = _mm_mul_ps(y, xmm3);
+ x = _mm_add_ps(x, xmm1);
+ x = _mm_add_ps(x, xmm2);
+ x = _mm_add_ps(x, xmm3);
+
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
+ y = *(v4sf*)_ps_coscof_p0;
+ v4sf z = _mm_mul_ps(x,x);
+
+ y = _mm_mul_ps(y, z);
+ y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p1);
+ y = _mm_mul_ps(y, z);
+ y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p2);
+ y = _mm_mul_ps(y, z);
+ y = _mm_mul_ps(y, z);
+ v4sf tmp = _mm_mul_ps(z, *(v4sf*)_ps_0p5);
+ y = _mm_sub_ps(y, tmp);
+ y = _mm_add_ps(y, *(v4sf*)_ps_1);
+
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
+
+ v4sf y2 = *(v4sf*)_ps_sincof_p0;
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p1);
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p2);
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_mul_ps(y2, x);
+ y2 = _mm_add_ps(y2, x);
+
+ /* select the correct result from the two polynoms */
+ xmm3 = poly_mask;
+ y2 = _mm_and_ps(xmm3, y2); //, xmm3);
+ y = _mm_andnot_ps(xmm3, y);
+ y = _mm_add_ps(y,y2);
+ /* update the sign */
+ y = _mm_xor_ps(y, sign_bit);
+
+ return y;
+}
+
+/* since sin_ps and cos_ps are almost identical, sincos_ps could replace both of them..
+ it is almost as fast, and gives you a free cosine with your sine */
+static inline void sincos_ps(v4sf x, v4sf *s, v4sf *c) {
+ v4sf xmm1, xmm2, xmm3 = _mm_setzero_ps(), sign_bit_sin, y;
+#ifdef USE_SSE2
+ v4si emm0, emm2, emm4;
+#else
+ v2si mm0, mm1, mm2, mm3, mm4, mm5;
+#endif
+ sign_bit_sin = x;
+ /* take the absolute value */
+ x = _mm_and_ps(x, *(v4sf*)_ps_inv_sign_mask);
+ /* extract the sign bit (upper one) */
+ sign_bit_sin = _mm_and_ps(sign_bit_sin, *(v4sf*)_ps_sign_mask);
+
+ /* scale by 4/Pi */
+ y = _mm_mul_ps(x, *(v4sf*)_ps_cephes_FOPI);
+
+#ifdef USE_SSE2
+ /* store the integer part of y in emm2 */
+ emm2 = _mm_cvttps_epi32(y);
+
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ emm2 = _mm_add_epi32(emm2, *(v4si*)_pi32_1);
+ emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_inv1);
+ y = _mm_cvtepi32_ps(emm2);
+
+ emm4 = emm2;
+
+ /* get the swap sign flag for the sine */
+ emm0 = _mm_and_si128(emm2, *(v4si*)_pi32_4);
+ emm0 = _mm_slli_epi32(emm0, 29);
+ v4sf swap_sign_bit_sin = _mm_castsi128_ps(emm0);
+
+ /* get the polynom selection mask for the sine*/
+ emm2 = _mm_and_si128(emm2, *(v4si*)_pi32_2);
+ emm2 = _mm_cmpeq_epi32(emm2, _mm_setzero_si128());
+ v4sf poly_mask = _mm_castsi128_ps(emm2);
+#else
+ /* store the integer part of y in mm2:mm3 */
+ xmm3 = _mm_movehl_ps(xmm3, y);
+ mm2 = _mm_cvttps_pi32(y);
+ mm3 = _mm_cvttps_pi32(xmm3);
+
+ /* j=(j+1) & (~1) (see the cephes sources) */
+ mm2 = _mm_add_pi32(mm2, *(v2si*)_pi32_1);
+ mm3 = _mm_add_pi32(mm3, *(v2si*)_pi32_1);
+ mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_inv1);
+ mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_inv1);
+
+ y = _mm_cvtpi32x2_ps(mm2, mm3);
+
+ mm4 = mm2;
+ mm5 = mm3;
+
+ /* get the swap sign flag for the sine */
+ mm0 = _mm_and_si64(mm2, *(v2si*)_pi32_4);
+ mm1 = _mm_and_si64(mm3, *(v2si*)_pi32_4);
+ mm0 = _mm_slli_pi32(mm0, 29);
+ mm1 = _mm_slli_pi32(mm1, 29);
+ v4sf swap_sign_bit_sin;
+ COPY_MM_TO_XMM(mm0, mm1, swap_sign_bit_sin);
+
+ /* get the polynom selection mask for the sine */
+
+ mm2 = _mm_and_si64(mm2, *(v2si*)_pi32_2);
+ mm3 = _mm_and_si64(mm3, *(v2si*)_pi32_2);
+ mm2 = _mm_cmpeq_pi32(mm2, _mm_setzero_si64());
+ mm3 = _mm_cmpeq_pi32(mm3, _mm_setzero_si64());
+ v4sf poly_mask;
+ COPY_MM_TO_XMM(mm2, mm3, poly_mask);
+#endif
+
+ /* The magic pass: "Extended precision modular arithmetic"
+ x = ((x - y * DP1) - y * DP2) - y * DP3; */
+ xmm1 = *(v4sf*)_ps_minus_cephes_DP1;
+ xmm2 = *(v4sf*)_ps_minus_cephes_DP2;
+ xmm3 = *(v4sf*)_ps_minus_cephes_DP3;
+ xmm1 = _mm_mul_ps(y, xmm1);
+ xmm2 = _mm_mul_ps(y, xmm2);
+ xmm3 = _mm_mul_ps(y, xmm3);
+ x = _mm_add_ps(x, xmm1);
+ x = _mm_add_ps(x, xmm2);
+ x = _mm_add_ps(x, xmm3);
+
+#ifdef USE_SSE2
+ emm4 = _mm_sub_epi32(emm4, *(v4si*)_pi32_2);
+ emm4 = _mm_andnot_si128(emm4, *(v4si*)_pi32_4);
+ emm4 = _mm_slli_epi32(emm4, 29);
+ v4sf sign_bit_cos = _mm_castsi128_ps(emm4);
+#else
+ /* get the sign flag for the cosine */
+ mm4 = _mm_sub_pi32(mm4, *(v2si*)_pi32_2);
+ mm5 = _mm_sub_pi32(mm5, *(v2si*)_pi32_2);
+ mm4 = _mm_andnot_si64(mm4, *(v2si*)_pi32_4);
+ mm5 = _mm_andnot_si64(mm5, *(v2si*)_pi32_4);
+ mm4 = _mm_slli_pi32(mm4, 29);
+ mm5 = _mm_slli_pi32(mm5, 29);
+ v4sf sign_bit_cos;
+ COPY_MM_TO_XMM(mm4, mm5, sign_bit_cos);
+ _mm_empty(); /* good-bye mmx */
+#endif
+
+ sign_bit_sin = _mm_xor_ps(sign_bit_sin, swap_sign_bit_sin);
+
+
+ /* Evaluate the first polynom (0 <= x <= Pi/4) */
+ v4sf z = _mm_mul_ps(x,x);
+ y = *(v4sf*)_ps_coscof_p0;
+
+ y = _mm_mul_ps(y, z);
+ y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p1);
+ y = _mm_mul_ps(y, z);
+ y = _mm_add_ps(y, *(v4sf*)_ps_coscof_p2);
+ y = _mm_mul_ps(y, z);
+ y = _mm_mul_ps(y, z);
+ v4sf tmp = _mm_mul_ps(z, *(v4sf*)_ps_0p5);
+ y = _mm_sub_ps(y, tmp);
+ y = _mm_add_ps(y, *(v4sf*)_ps_1);
+
+ /* Evaluate the second polynom (Pi/4 <= x <= 0) */
+
+ v4sf y2 = *(v4sf*)_ps_sincof_p0;
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p1);
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_add_ps(y2, *(v4sf*)_ps_sincof_p2);
+ y2 = _mm_mul_ps(y2, z);
+ y2 = _mm_mul_ps(y2, x);
+ y2 = _mm_add_ps(y2, x);
+
+ /* select the correct result from the two polynoms */
+ xmm3 = poly_mask;
+ v4sf ysin2 = _mm_and_ps(xmm3, y2);
+ v4sf ysin1 = _mm_andnot_ps(xmm3, y);
+ y2 = _mm_sub_ps(y2,ysin2);
+ y = _mm_sub_ps(y, ysin1);
+
+ xmm1 = _mm_add_ps(ysin1,ysin2);
+ xmm2 = _mm_add_ps(y,y2);
+
+ /* update the sign */
+ *s = _mm_xor_ps(xmm1, sign_bit_sin);
+ *c = _mm_xor_ps(xmm2, sign_bit_cos);
+}
+
diff --git a/src/3rd_party/threadpool.h b/src/3rd_party/threadpool.h
index c97c8848..d77ce43c 100755..100644
--- a/src/3rd_party/threadpool.h
+++ b/src/3rd_party/threadpool.h
@@ -107,6 +107,7 @@ class ThreadPool {
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads, size_t in_bound)
: bound(in_bound), stop(false) {
+ ABORT_IF(getThrowExceptionOnAbort(), "Throwing of MarianRuntimeException not presently supported in threads");
reserve(threads);
}
diff --git a/src/3rd_party/yaml-cpp/binary_renamed.cpp b/src/3rd_party/yaml-cpp/binary_renamed.cpp
index 7541cbb7..7541cbb7 100755..100644
--- a/src/3rd_party/yaml-cpp/binary_renamed.cpp
+++ b/src/3rd_party/yaml-cpp/binary_renamed.cpp
diff --git a/src/3rd_party/yaml-cpp/collectionstack.h b/src/3rd_party/yaml-cpp/collectionstack.h
index c9f6b9f3..c9f6b9f3 100755..100644
--- a/src/3rd_party/yaml-cpp/collectionstack.h
+++ b/src/3rd_party/yaml-cpp/collectionstack.h
diff --git a/src/3rd_party/yaml-cpp/dll.h b/src/3rd_party/yaml-cpp/dll.h
index fbbee527..fbbee527 100755..100644
--- a/src/3rd_party/yaml-cpp/dll.h
+++ b/src/3rd_party/yaml-cpp/dll.h
diff --git a/src/3rd_party/yaml-cpp/emitterstate.cpp b/src/3rd_party/yaml-cpp/emitterstate.cpp
index 3342704e..3342704e 100755..100644
--- a/src/3rd_party/yaml-cpp/emitterstate.cpp
+++ b/src/3rd_party/yaml-cpp/emitterstate.cpp
diff --git a/src/3rd_party/yaml-cpp/emitterstate.h b/src/3rd_party/yaml-cpp/emitterstate.h
index 75611019..75611019 100755..100644
--- a/src/3rd_party/yaml-cpp/emitterstate.h
+++ b/src/3rd_party/yaml-cpp/emitterstate.h
diff --git a/src/3rd_party/yaml-cpp/node/convert.h b/src/3rd_party/yaml-cpp/node/convert.h
index c9d36551..c9d36551 100755..100644
--- a/src/3rd_party/yaml-cpp/node/convert.h
+++ b/src/3rd_party/yaml-cpp/node/convert.h
diff --git a/src/3rd_party/yaml-cpp/node/node.h b/src/3rd_party/yaml-cpp/node/node.h
index c0a6fc69..c0a6fc69 100755..100644
--- a/src/3rd_party/yaml-cpp/node/node.h
+++ b/src/3rd_party/yaml-cpp/node/node.h
diff --git a/src/3rd_party/yaml-cpp/node_data.cpp b/src/3rd_party/yaml-cpp/node_data.cpp
index f4236da8..f4236da8 100755..100644
--- a/src/3rd_party/yaml-cpp/node_data.cpp
+++ b/src/3rd_party/yaml-cpp/node_data.cpp
diff --git a/src/3rd_party/yaml-cpp/scanner.cpp b/src/3rd_party/yaml-cpp/scanner.cpp
index 72f70b7e..72f70b7e 100755..100644
--- a/src/3rd_party/yaml-cpp/scanner.cpp
+++ b/src/3rd_party/yaml-cpp/scanner.cpp
diff --git a/src/3rd_party/yaml-cpp/scantoken.cpp b/src/3rd_party/yaml-cpp/scantoken.cpp
index b96edbfd..b96edbfd 100755..100644
--- a/src/3rd_party/yaml-cpp/scantoken.cpp
+++ b/src/3rd_party/yaml-cpp/scantoken.cpp
diff --git a/src/3rd_party/yaml-cpp/singledocparser.cpp b/src/3rd_party/yaml-cpp/singledocparser.cpp
index 6e419c58..6e419c58 100755..100644
--- a/src/3rd_party/yaml-cpp/singledocparser.cpp
+++ b/src/3rd_party/yaml-cpp/singledocparser.cpp
diff --git a/src/3rd_party/zstr/strict_fstream.hpp b/src/3rd_party/zstr/strict_fstream.hpp
index 21173c73..7b117393 100644
--- a/src/3rd_party/zstr/strict_fstream.hpp
+++ b/src/3rd_party/zstr/strict_fstream.hpp
@@ -27,7 +27,7 @@ static std::string strerror()
{
buff = "Unknown error";
}
-#elif (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600) && ! _GNU_SOURCE
+#elif (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600 || __APPLE__) && ! _GNU_SOURCE
// XSI-compliant strerror_r()
if (strerror_r(errno, &buff[0], buff.size()) != 0)
{
@@ -125,7 +125,7 @@ struct static_method_holder
is_p->peek();
peek_failed = is_p->fail();
}
- catch (std::ios_base::failure e) {}
+ catch (const std::ios_base::failure &e) {}
if (peek_failed)
{
throw Exception(std::string("strict_fstream: open('")
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 03524117..6d5a0b1f 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -4,8 +4,12 @@ include_directories(.)
include_directories(3rd_party)
include_directories(3rd_party/SQLiteCpp/include)
include_directories(3rd_party/sentencepiece)
+include_directories(3rd_party/fbgemm/include)
+include_directories(${CMAKE_BINARY_DIR}/local/include)
add_library(marian STATIC
+ common/aliases.cpp
+ common/fastopt.cpp
common/version.cpp
common/utils.cpp
common/logging.cpp
@@ -14,13 +18,19 @@ add_library(marian STATIC
common/config.cpp
common/config_parser.cpp
common/config_validator.cpp
+ common/options.cpp
common/binary.cpp
+ common/build_info.cpp
common/io.cpp
-
+ common/filesystem.cpp
+ common/file_stream.cpp
+ common/types.cpp
+
data/alignment.cpp
data/vocab.cpp
data/default_vocab.cpp
data/sentencepiece_vocab.cpp
+ data/factored_vocab.cpp
data/corpus_base.cpp
data/corpus.cpp
data/corpus_sqlite.cpp
@@ -30,8 +40,11 @@ add_library(marian STATIC
3rd_party/cnpy/cnpy.cpp
3rd_party/ExceptionWithCallStack.cpp
+ 3rd_party/phf/phf.cc
+
tensors/backend.cpp
tensors/rand.cpp
+ tensors/tensor.cpp
tensors/cpu/device.cpp
tensors/cpu/prod.cpp
tensors/cpu/tensor_operators.cpp
@@ -39,6 +52,7 @@ add_library(marian STATIC
tensors/cpu/sharp/int_gemm.cpp
tensors/cpu/sharp/avx_gemm.cpp
tensors/cpu/sharp/sse_gemm.cpp
+ tensors/cpu/fbgemm/packed_gemm.cpp
graph/expression_graph.cpp
graph/expression_operators.cpp
@@ -47,6 +61,7 @@ add_library(marian STATIC
graph/node_initializers.cpp
layers/convolution.cpp
+ layers/generic.cpp
layers/loss.cpp
layers/weight.cpp
@@ -77,6 +92,7 @@ add_library(marian STATIC
training/graph_group_multinode_sync.cpp
training/validator.cpp
training/communicator.cpp
+ training/scheduler.cpp
# this is only compiled to catch build errors, but not linked
microsoft/quicksand.cpp
@@ -91,21 +107,40 @@ target_compile_options(marian PUBLIC ${ALL_WARNINGS})
# Generate git_revision.h to reflect current git revision information
# [https://stackoverflow.com/questions/1435953/how-can-i-pass-git-sha1-to-compiler-as-definition-using-cmake]
# Git updates .git/logs/HEAD file whenever you pull or commit something.
+
+# If Marian is checked out as a submodule in another repository,
+# there's no .git directory in ${CMAKE_SOURCE_DIR}. Instead .git is a
+# file that specifies the relative path from ${CMAKE_SOURCE_DIR} to
+# ./git/modules/<MARIAN_ROOT_DIR> in the root of the repository that
+# contains Marian as a submodule. We set MARIAN_GIT_DIR to the appropriate
+# path, depending on whether ${CMAKE_SOURCE_DIR}/.git is a directory or file.
+if(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git) # not a submodule
+ set(MARIAN_GIT_DIR ${CMAKE_SOURCE_DIR}/.git)
+else(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
+ file(READ ${CMAKE_SOURCE_DIR}/.git MARIAN_GIT_DIR)
+ string(REGEX REPLACE "gitdir: (.*)\n" "\\1" MARIAN_GIT_DIR ${MARIAN_GIT_DIR})
+ get_filename_component(MARIAN_GIT_DIR "${CMAKE_SOURCE_DIR}/${MARIAN_GIT_DIR}" ABSOLUTE)
+endif(IS_DIRECTORY ${CMAKE_SOURCE_DIR}/.git)
+
add_custom_command(OUTPUT ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
+ WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}
COMMAND git log -1 --pretty=format:\#define\ GIT_REVISION\ \"\%h\ \%ai\" > ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h
- DEPENDS ${CMAKE_SOURCE_DIR}/.git/logs/HEAD
+ DEPENDS ${MARIAN_GIT_DIR}/logs/HEAD
VERBATIM
)
add_custom_target(marian_version DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/common/git_revision.h)
add_dependencies(marian marian_version) # marian must depend on it so that it gets created first
+# make sure all local dependencies are installed first before this is built
+add_dependencies(marian 3rd_party_installs)
if(CUDA_FOUND)
cuda_add_library(marian_cuda
tensors/gpu/device.cu
tensors/gpu/algorithm.cu
- tensors/gpu/prod.cu
+ tensors/gpu/prod.cpp
tensors/gpu/element.cu
- tensors/gpu/add.cu
+ tensors/gpu/add.cu
+ tensors/gpu/add_all.cu
tensors/gpu/tensor_operators.cu
tensors/gpu/cudnn_wrappers.cu
translator/nth_element.cu
@@ -115,6 +150,8 @@ cuda_add_library(marian_cuda
STATIC)
target_compile_options(marian_cuda PUBLIC ${ALL_WARNINGS})
+ # make sure all local dependencies are installed first before this is built
+ add_dependencies(marian_cuda 3rd_party_installs)
endif(CUDA_FOUND)
set_target_properties(marian PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
@@ -179,6 +216,10 @@ if(COMPILE_SERVER)
set(EXECUTABLES ${EXECUTABLES} marian_server)
endif(COMPILE_SERVER)
+if(APPLE) # This is a dependency of pathie but I can't seem to link it into that CMakeLists because we're not compiling it as a library.
+ set(EXT_LIBS ${EXT_LIBS} iconv)
+endif()
+
foreach(exec ${EXECUTABLES})
target_link_libraries(${exec} marian ${EXT_LIBS} ${EXT_LIBS} ${CMAKE_THREAD_LIBS_INIT})
if(CUDA_FOUND)
@@ -187,13 +228,6 @@ foreach(exec ${EXECUTABLES})
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
endforeach(exec)
-#add_executable(
-# align2steps
-# tools/align2steps.cpp
-#)
-
-#set_target_properties(align2steps PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
-
if(COMPILE_TESTS)
set(CATCH_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/3rd_party)
add_library(Catch INTERFACE)
diff --git a/src/command/marian_conv.cpp b/src/command/marian_conv.cpp
index 23350311..97f83acf 100755..100644
--- a/src/command/marian_conv.cpp
+++ b/src/command/marian_conv.cpp
@@ -4,6 +4,8 @@
#include <sstream>
+#include "tensors/cpu/fbgemm/expression_graph_packable.h"
+
int main(int argc, char** argv) {
using namespace marian;
@@ -11,18 +13,35 @@ int main(int argc, char** argv) {
auto options = New<Options>();
{
+ YAML::Node config; // @TODO: get rid of YAML::Node here entirely to avoid the pattern. Currently not fixing as it requires more changes to the Options object.
auto cli = New<cli::CLIWrapper>(
- options,
- "Convert a model in the .npz format to a mmap-able binary model",
+ config,
+ "Convert a model in the .npz format and normal memory layout to a mmap-able binary model which could be in normal memory layout or packed memory layout",
"Allowed options",
"Examples:\n"
- " ./marian-conv -f model.npz -t model.bin");
+ " ./marian-conv -f model.npz -t model.bin --gemm-type packed16");
cli->add<std::string>("--from,-f", "Input model", "model.npz");
cli->add<std::string>("--to,-t", "Output model", "model.bin");
+ cli->add<std::string>("--gemm-type,-g", "GEMM Type to be used: float32, packed16, packed8avx2, packed8avx512", "float32");
cli->parse(argc, argv);
+ options->merge(config);
}
auto modelFrom = options->get<std::string>("from");
auto modelTo = options->get<std::string>("to");
+
+ auto saveGemmTypeStr = options->get<std::string>("gemm-type", "float32");
+ Type saveGemmType;
+ if(saveGemmTypeStr == "float32") {
+ saveGemmType = Type::float32;
+ } else if(saveGemmTypeStr == "packed16") { // packed16 only supports AVX2. AVX512 might be added later
+ saveGemmType = Type::packed16;
+ } else if(saveGemmTypeStr == "packed8avx2") { // packed8 for AVX2
+ saveGemmType = Type::packed8avx2;
+ } else if(saveGemmTypeStr == "packed8avx512") { // packed8 for AVX512
+ saveGemmType = Type::packed8avx512;
+ } else {
+ ABORT("Unknown gemm-type: {}", saveGemmTypeStr);
+ }
LOG(info, "Outputting {}", modelTo);
@@ -31,12 +50,14 @@ int main(int argc, char** argv) {
marian::io::getYamlFromModel(config, "special:model.yml", modelFrom);
configStr << config;
- auto graph = New<ExpressionGraph>(true, false);
+ auto graph = New<ExpressionGraphPackable>();
graph->setDevice(CPU0);
+ graph->getBackend()->setOptimized(false);
graph->load(modelFrom);
graph->forward();
- graph->save(modelTo, configStr.str());
+ // added a flag if the weights needs to be packed or not
+ graph->packAndSave(modelTo, configStr.str(), /* --gemm-type */ saveGemmType, Type::float32);
// graph->saveBinary(vm["bin"].as<std::string>());
diff --git a/src/command/marian_decoder.cpp b/src/command/marian_decoder.cpp
index 98efe49d..1621a7f3 100644
--- a/src/command/marian_decoder.cpp
+++ b/src/command/marian_decoder.cpp
@@ -8,7 +8,6 @@
int main(int argc, char** argv) {
using namespace marian;
-
auto options = parseOptions(argc, argv, cli::mode::translation);
auto task = New<Translate<BeamSearch>>(options);
diff --git a/src/command/marian_main.cpp b/src/command/marian_main.cpp
index a2a19145..a2a19145 100755..100644
--- a/src/command/marian_main.cpp
+++ b/src/command/marian_main.cpp
diff --git a/src/command/marian_server.cpp b/src/command/marian_server.cpp
index a9a1b0be..e4074bd1 100644
--- a/src/command/marian_server.cpp
+++ b/src/command/marian_server.cpp
@@ -12,7 +12,7 @@ int main(int argc, char **argv) {
using namespace marian;
// Initialize translation task
- auto options = parseOptions(argc, argv, cli::mode::translation, true);
+ auto options = parseOptions(argc, argv, cli::mode::server, true);
auto task = New<TranslateService<BeamSearch>>(options);
// Initialize web server
@@ -44,7 +44,7 @@ int main(int argc, char **argv) {
// Error Codes for error code meanings
// http://www.boost.org/doc/libs/1_55_0/doc/html/boost_asio/reference.html
- translate.on_error = [](Ptr<WSServer::Connection> connection,
+ translate.on_error = [](Ptr<WSServer::Connection> /*connection*/,
const SimpleWeb::error_code &ec) {
LOG(error, "Connection error: ({}) {}", ec.value(), ec.message());
};
diff --git a/src/command/marian_train.cpp b/src/command/marian_train.cpp
index 88980624..9d312bbb 100755..100644
--- a/src/command/marian_train.cpp
+++ b/src/command/marian_train.cpp
@@ -1,3 +1,4 @@
+#include <signal.h>
#include "marian.h"
#include "training/graph_group_async.h"
@@ -11,10 +12,12 @@
#include "training/graph_group_multinode.h"
#endif
+#include "3rd_party/ExceptionWithCallStack.h"
+
int main(int argc, char** argv) {
using namespace marian;
- auto options = parseOptions(argc, argv);
+ auto options = parseOptions(argc, argv, cli::mode::training);
// selects MultiNodeGraphGroup family
//
@@ -66,5 +69,13 @@ int main(int argc, char** argv) {
}
}
- return 0;
+ // If we exit due to SIGTERM, exit with 128 + the signal number, as suggested
+ // for bash in http://tldp.org/LDP/abs/html/exitcodes.html. This allows parent
+ // scripts to determine if training terminated naturally or via SIGTERM.
+ // Whith this approach we can accommodate additional signals in the future.
+ // An alternative would be to return 124, which is what the timeout command
+ // returns for timeout -s SIGTERM <seconds> ...., because exiting after SIGTERM
+ // is not technically a fatal error (which is what the 128+x convention usually
+ // stands for).
+ return getSigtermFlag() ? (128 + SIGTERM) : 0;
}
diff --git a/src/command/marian_vocab.cpp b/src/command/marian_vocab.cpp
index de8ef3c7..d9c83b5a 100755..100644
--- a/src/command/marian_vocab.cpp
+++ b/src/command/marian_vocab.cpp
@@ -9,10 +9,11 @@ int main(int argc, char** argv) {
createLoggers();
- auto options = New<Options>();
+ Ptr<Options> options = New<Options>();
{
+ YAML::Node config; // @TODO: get rid of YAML::Node here entirely to avoid the pattern. Currently not fixing as it requires more changes to the Options object.
auto cli = New<cli::CLIWrapper>(
- options,
+ config,
"Create a vocabulary from text corpora given on STDIN",
"Allowed options",
"Examples:\n"
@@ -20,6 +21,7 @@ int main(int argc, char** argv) {
" cat text.src text.trg | ./marian-vocab > vocab.yml");
cli->add<size_t>("--max-size,-m", "Generate only UINT most common vocabulary items", 0);
cli->parse(argc, argv);
+ options->merge(config);
}
LOG(info, "Creating vocabulary...");
diff --git a/src/common/aliases.cpp b/src/common/aliases.cpp
new file mode 100644
index 00000000..03cb60cf
--- /dev/null
+++ b/src/common/aliases.cpp
@@ -0,0 +1,150 @@
+#include "common/config_parser.h"
+#include "common/definitions.h"
+
+namespace marian {
+
+/**
+ * Add all aliases
+ *
+ * An alias is a command line option name and value pair that sets multiple other non-alias
+ * (standard) command line options. And example would be `--task transformer-big` which --
+ * as a whole -- is an alias for setting options and hyperparameters that would be reasonable
+ * for training a Google-style Transformer-Big model. Below key-value pairs
+ * ("task", "transformer-base") and ("task", "transformer-big") are different aliases that result
+ * in different option sets to be defined.
+ *
+ * The alias option has to be first defined using cli.add<T>(). Defining
+ * multiple aliases for the same option name but with different values is allowed.
+ *
+ * As aliases are key-value pairs by default, values are compared as std::string.
+ * If the command line option corresponding to the alias is a vector, the alias
+ * will be triggered if the requested value exists in that vector at least once.
+ *
+ * @see CLIWrapper::alias()
+ *
+ * The order of alias definitions *does* matter: options from later aliases override earlier
+ * regardless of its order in the command line or config file.
+ */
+void ConfigParser::addAliases(cli::CLIWrapper& cli) {
+ cli.alias("fp16", "true", [&](YAML::Node& config) {
+ if(mode_ == cli::mode::training) {
+ config["precision"] = std::vector<std::string>({"float16", "float32", "float32"}); // inference type, optimization type, save type
+ // @TODO: review this
+ // scaling factor (power of 2), frequency, multiplier at increase, tolerance, range, minium factor
+ config["cost-scaling"] = std::vector<std::string>({"7", "2000", "2", "0.05", "10", "1"});
+ } else {
+ config["precision"] = std::vector<std::string>({"float16"}); // for inference we do not need the other types
+ }
+ });
+
+ if(mode_ == cli::mode::training) {
+ // for backwards-compatibility with older version, "--no-shuffle" maps to "--shuffle none"
+ cli.alias("no-shuffle", "true", [](YAML::Node& config) {
+ config["shuffle"] = "none";
+ });
+
+ // Options setting the BiDeep architecture proposed in http://www.aclweb.org/anthology/W17-4710
+ cli.alias("best-deep", "true", [](YAML::Node& config) {
+ config["layer-normalization"] = true;
+ config["tied-embeddings"] = true;
+ config["enc-type"] = "alternating";
+ config["enc-cell-depth"] = 2;
+ config["enc-depth"] = 4;
+ config["dec-cell-base-depth"] = 4;
+ config["dec-cell-high-depth"] = 2;
+ config["dec-depth"] = 4;
+ config["skip"] = true;
+
+ // Training specific options
+ config["learn-rate"] = 0.0003;
+ config["cost-type"] = "ce-mean-words";
+ config["lr-decay-inv-sqrt"] = 16000;
+ config["label-smoothing"] = 0.1;
+ config["clip-norm"] = 0;
+ config["sync-sgd"] = true;
+ config["exponential-smoothing"] = 1e-4;
+ config["mini-batch-fit"] = true;
+ config["mini-batch"] = 1000;
+ config["maxi-batch"] = 1000;
+ // config["workspace"] = 6500;
+ });
+
+ // Architecture and proposed training settings for a Transformer "base" model introduced in
+ // https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
+ cli.alias("task", "transformer-base", [](YAML::Node& config) {
+ // Model options
+ config["type"] = "transformer";
+ config["enc-depth"] = 6;
+ config["dec-depth"] = 6;
+ config["dim-emb"] = 512;
+ config["tied-embeddings-all"] = true;
+ config["transformer-dim-ffn"] = 2048;
+ config["transformer-heads"] = 8;
+ config["transformer-postprocess"] = "dan";
+ config["transformer-preprocess"] = "";
+ config["transformer-ffn-activation"] = "relu";
+ config["transformer-dropout"] = 0.1;
+
+ // Training specific options
+ config["learn-rate"] = 0.0003;
+ config["cost-type"] = "ce-mean-words";
+ config["lr-warmup"] = 16000;
+ config["lr-decay-inv-sqrt"] = 16000;
+ config["label-smoothing"] = 0.1;
+ config["clip-norm"] = 0;
+ config["sync-sgd"] = true;
+ config["exponential-smoothing"] = 1e-4;
+ config["max-length"] = 100;
+ config["mini-batch-fit"] = true;
+ config["mini-batch"] = 1000;
+ config["maxi-batch"] = 1000;
+ config["workspace"] = 9500;
+ config["optimizer-params"] = std::vector<float>({0.9f, 0.98f, 1e-09f});
+
+ // Validation specific options
+ config["beam-size"] = 8;
+ config["valid-mini-batch"] = 16;
+ config["normalize"] = 1.0;
+ });
+
+ // Architecture and proposed training settings for a Transformer "big" model introduced in
+ // https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf
+ cli.alias("task", "transformer-big", [](YAML::Node& config) {
+ // Model options
+ config["type"] = "transformer";
+ config["enc-depth"] = 6;
+ config["dec-depth"] = 6;
+ config["dim-emb"] = 1024;
+ config["tied-embeddings-all"] = true;
+ config["transformer-dim-ffn"] = 4096;
+ config["transformer-heads"] = 16;
+ config["transformer-postprocess"] = "dan";
+ config["transformer-preprocess"] = "";
+ config["transformer-ffn-activation"] = "relu";
+ config["transformer-dropout"] = 0.1;
+
+ // Training specific options
+ config["learn-rate"] = 0.0002;
+ config["cost-type"] = "ce-mean-words";
+ config["lr-warmup"] = 8000;
+ config["lr-decay-inv-sqrt"] = 8000;
+ config["label-smoothing"] = 0.1;
+ config["clip-norm"] = 0;
+ config["sync-sgd"] = true;
+ config["exponential-smoothing"] = 1e-4;
+ config["max-length"] = 100;
+ config["mini-batch-fit"] = true;
+ config["mini-batch"] = 1000;
+ config["maxi-batch"] = 1000;
+ config["workspace"] = 13000;
+ config["optimizer-params"] = std::vector<float>({0.9f, 0.998f, 1e-09f});
+
+ // Validation specific options
+ config["beam-size"] = 8;
+ config["valid-mini-batch"] = 8;
+ config["normalize"] = 1.0;
+ });
+ }
+}
+
+} // namespace marian
diff --git a/src/common/authors.h b/src/common/authors.h
new file mode 100644
index 00000000..5dcdaf28
--- /dev/null
+++ b/src/common/authors.h
@@ -0,0 +1,65 @@
+#pragma once
+
+#include <string>
+
+namespace marian {
+
+std::string citation() {
+ return "Marian: Fast Neural Machine Translation in C++\n"
+ "\n"
+ "Please cite the following paper if you use Marian:\n"
+ "\n"
+ "@InProceedings{mariannmt,\n"
+ " title = {Marian: Fast Neural Machine Translation in {C++}},\n"
+ " author = {Junczys-Dowmunt, Marcin and Grundkiewicz, Roman and\n"
+ " Dwojak, Tomasz and Hoang, Hieu and Heafield, Kenneth and\n"
+ " Neckermann, Tom and Seide, Frank and Germann, Ulrich and\n"
+ " Fikri Aji, Alham and Bogoychev, Nikolay and\n"
+ " Martins, Andr\\'{e} F. T. and Birch, Alexandra},\n"
+ " booktitle = {Proceedings of ACL 2018, System Demonstrations},\n"
+ " pages = {116--121},\n"
+ " publisher = {Association for Computational Linguistics},\n"
+ " year = {2018},\n"
+ " month = {July},\n"
+ " address = {Melbourne, Australia},\n"
+ " url = {http://www.aclweb.org/anthology/P18-4020}\n"
+ "}\n";
+}
+
+// The list of contributors has been compiled semi-automatically from the
+// GitHub contributor list in default order. That list can be printed out with
+// `git shortlog -s -n`.
+std::string authors() {
+ return "Marian: Fast Neural Machine Translation in C++\n"
+ "\n"
+ "An inevitably non-exhaustive list of contributors:\n"
+ "\n"
+ "Marcin Junczys-Dowmunt <marcinjd@microsoft.com>\n"
+ "Roman Grundkiewicz <rgrundki@inf.ed.ac.uk>\n"
+ "Frank Seide <fseide@microsoft.com>\n"
+ "Hieu Hoang <hieuhoang@gmail.com>\n"
+ "Tomasz Dwojak <t.dwojak@amu.edu.pl>\n"
+ "Ulrich Germann <ugermann@inf.ed.ac.uk>\n"
+ "Alham Fikri Aji <afaji321@gmail.com>\n"
+ "Cédric Rousseau <cedrou@gmail.com>\n"
+ "Young Jin Kim <youki@microsoft.com>\n"
+ "Lane Schwartz <dowobeha@gmail.com>\n"
+ "Andre Martins <andre.t.martins@gmail.com>\n"
+ "Nikolay Bogoychev <n.bogoych@ed.ac.uk>\n"
+ "Kenneth Heafield <kheafiel@ed.ac.uk>\n"
+ "Maximiliana Behnke <mbehnke@inf.ed.ac.uk>\n"
+ "Tom Neckermann <tomneckermann@gmail.com>\n"
+ "Hany Hassan Awadalla <hanyh@microsoft.com>\n"
+ "Jim Geovedi <jim@geovedi.com>\n"
+ "Catarina Silva <catarina.cruz.csilva@gmail.com>\n"
+ "Jon Clark <jonathac@microsoft.com>\n"
+ "Rihards Krišlauks <rihards.krislauks@gmail.com>\n"
+ "Vishal Chowdhary <vishalc@microsoft.com>\n"
+ "Barry Haddow <bhaddow@inf.ed.ac.uk>\n"
+ "Dominik Stańczak <stanczakdominik@gmail.com>\n"
+ "Michael Hutt <Michael.Hutt@gmail.com>\n"
+ "Richard Wei <rxwei@users.noreply.github.com>\n"
+ "Wenyong Huang <weyo.huang@gmail.com>\n"
+ "alancucki <alancucki+github@gmail.com>\n";
+}
+} // namespace marian
diff --git a/src/common/binary.cpp b/src/common/binary.cpp
index 983c15b5..1531fed6 100644
--- a/src/common/binary.cpp
+++ b/src/common/binary.cpp
@@ -18,6 +18,7 @@ struct Header {
size_t dataLength;
};
+// cast current void pointer to T pointer and move forward by num elements
template <typename T>
const T* get(const void*& current, size_t num = 1) {
const T* ptr = (const T*)current;
@@ -32,9 +33,10 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {
binaryFileVersion,
BINARY_FILE_VERSION);
- size_t numHeaders = *get<size_t>(current);
- const Header* headers = get<Header>(current, numHeaders);
+ size_t numHeaders = *get<size_t>(current); // number of item headers that follow
+ const Header* headers = get<Header>(current, numHeaders); // read that many headers
+ // prepopulate items with meta data from headers
items.resize(numHeaders);
for(int i = 0; i < numHeaders; ++i) {
items[i].type = (Type)headers[i].type;
@@ -42,21 +44,22 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {
items[i].mapped = mapped;
}
+ // read in actual shape and data
for(int i = 0; i < numHeaders; ++i) {
size_t len = headers[i].shapeLength;
- items[i].shape.resize(len);
- const int* arr = get<int>(current, len);
- std::copy(arr, arr + len, items[i].shape.begin());
+ items[i].shape.resize(len);
+ const int* arr = get<int>(current, len); // read shape
+ std::copy(arr, arr + len, items[i].shape.begin()); // copy to Item::shape
}
- // move by offset bytes
+ // move by offset bytes, aligned to 256-bytes boundary
size_t offset = *get<size_t>(current);
get<char>(current, offset);
for(int i = 0; i < numHeaders; ++i) {
- if(items[i].mapped) {
+ if(items[i].mapped) { // memory-mapped, hence only set pointer
items[i].ptr = get<char>(current, headers[i].dataLength);
- } else {
+ } else { // reading into item data
size_t len = headers[i].dataLength;
items[i].bytes.resize(len);
const char* ptr = get<char>(current, len);
@@ -68,15 +71,21 @@ void loadItems(const void* current, std::vector<io::Item>& items, bool mapped) {
void loadItems(const std::string& fileName, std::vector<io::Item>& items) {
// Read file into buffer
size_t fileSize = filesystem::fileSize(fileName);
- char* ptr = new char[fileSize];
+ std::vector<char> buf(fileSize);
+// @TODO: check this again:
+#if 1 // for some reason, the #else branch fails with "file not found" in the *read* operation (open succeeds)
+ FILE *f = fopen(fileName.c_str(), "rb");
+ ABORT_IF(f == nullptr, "Error {} ('{}') opening file '{}'", errno, strerror(errno), fileName);
+ auto rc = fread(buf.data(), sizeof(*buf.data()), buf.size(), f);
+ ABORT_IF(rc != buf.size(), "Error {} ('{}') reading file '{}'", errno, strerror(errno), fileName);
+ fclose(f);
+#else
io::InputFileStream in(fileName);
- in.read(ptr, fileSize);
+ in.read(buf.data(), buf.size());
+#endif
// Load items from buffer without mapping
- loadItems(ptr, items, false);
-
- // Delete buffer
- delete[] ptr;
+ loadItems(buf.data(), items, false);
}
io::Item getItem(const void* current, const std::string& varName) {
@@ -114,7 +123,7 @@ void saveItems(const std::string& fileName,
headers.push_back(Header{item.name.size() + 1,
(size_t)item.type,
item.shape.size(),
- item.size()});
+ item.bytes.size()}); // binary item size with padding, will be 256-byte-aligned
}
size_t headerSize = headers.size();
@@ -141,9 +150,11 @@ void saveItems(const std::string& fileName,
}
// Write out all values
- for(const auto& item : items) {
- pos += out.write(item.data(), item.size());
- }
+ for(const auto& item : items)
+ pos += out.write(item.data(), item.bytes.size()); // writes out data with padding, keeps 256-byte boundary.
+ // Amazingly this is binary-compatible with V1 and aligned and
+ // non-aligned models can be read with the same procedure.
+ // No version-bump required. Gets 5-8% of speed back when mmapped.
}
} // namespace binary
diff --git a/src/common/binary.h b/src/common/binary.h
index a1503348..7d73834b 100644
--- a/src/common/binary.h
+++ b/src/common/binary.h
@@ -5,10 +5,10 @@
#include <string>
#include <vector>
-// Increase this if binary format changes
-#define BINARY_FILE_VERSION 1
-
namespace marian {
+
+const static int BINARY_FILE_VERSION = 1;
+
namespace io {
namespace binary {
diff --git a/src/common/build_info.cpp.in b/src/common/build_info.cpp.in
new file mode 100644
index 00000000..b8d12e8a
--- /dev/null
+++ b/src/common/build_info.cpp.in
@@ -0,0 +1,18 @@
+#include "common/build_info.h"
+
+/*
+ * File build_info.cpp is generated using CMake. Do NOT modify it manually! Edit
+ * build_info.cpp.in file instead.
+ */
+
+std::string marian::cmakeBuildOptions() {
+ return ""
+@PROJECT_CMAKE_CACHE@
+ ;
+}
+
+std::string marian::cmakeBuildOptionsAdvanced() {
+ return ""
+@PROJECT_CMAKE_CACHE_ADVANCED@
+ ;
+}
diff --git a/src/common/build_info.h b/src/common/build_info.h
new file mode 100644
index 00000000..f9e4a57c
--- /dev/null
+++ b/src/common/build_info.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <string>
+
+namespace marian {
+
+// Returns list of non-advanced cache variables used by CMake
+std::string cmakeBuildOptions();
+
+// Returns list of advanced cache variables used by CMake
+std::string cmakeBuildOptionsAdvanced();
+
+} // namespace marian
diff --git a/src/common/cli_helper.h b/src/common/cli_helper.h
index 27761372..4477f0c0 100755..100644
--- a/src/common/cli_helper.h
+++ b/src/common/cli_helper.h
@@ -15,7 +15,6 @@ static inline std::string InterpolateEnvVars(std::string str) {
// presently has the form /hdfs/VC instead of /{gfs,hdfs}/CLUSTER/VC
// Catch stdin/stdout and do not process
- std::cerr << str << std::endl;
if(str == "stdin" || str == "stdout") {
return str;
}
diff --git a/src/common/cli_wrapper.cpp b/src/common/cli_wrapper.cpp
index 28826bb2..52a24887 100755..100644
--- a/src/common/cli_wrapper.cpp
+++ b/src/common/cli_wrapper.cpp
@@ -3,11 +3,20 @@
#include "common/logging.h"
#include "common/options.h"
#include "common/timer.h"
+#include "common/utils.h"
#include "common/version.h"
namespace marian {
namespace cli {
+// clang-format off
+const std::unordered_set<std::string> DEPRECIATED_OPTIONS = {
+ "version",
+ "special-vocab"
+};
+// clang-format on
+
+
/*
static uint16_t guess_terminal_width(uint16_t max_width, uint16_t default_width) {
uint16_t cols = 0;
@@ -91,18 +100,14 @@ CLIWrapper::CLIWrapper(YAML::Node &config,
optVersion_->group(defaultGroup_);
}
-CLIWrapper::CLIWrapper(Ptr<marian::Options> options,
- const std::string &description,
- const std::string &header,
- const std::string &footer,
- size_t columnWidth,
- size_t screenWidth)
- : CLIWrapper(options->getYaml(), description, header, footer, columnWidth, screenWidth) {}
-
CLIWrapper::~CLIWrapper() {}
-void CLIWrapper::switchGroup(const std::string &name) {
- currentGroup_ = name.empty() ? defaultGroup_ : name;
+// set current group to name, return previous group
+std::string CLIWrapper::switchGroup(std::string name) {
+ currentGroup_.swap(name);
+ if (currentGroup_.empty())
+ currentGroup_ = defaultGroup_;
+ return name;
}
void CLIWrapper::parse(int argc, char **argv) {
@@ -119,43 +124,110 @@ void CLIWrapper::parse(int argc, char **argv) {
}
}
-std::string CLIWrapper::failureMessage(const CLI::App *app, const CLI::Error &e) {
- std::string header = "Error: " + std::string(e.what()) + "\n";
- if(app->get_help_ptr() != nullptr)
- header += "Run with " + app->get_help_ptr()->get_name() + " for more information.\n";
- return header;
+void CLIWrapper::parseAliases() {
+ // Exit if no aliases defined
+ if(aliases_.empty())
+ return;
+
+ // Iterate all known aliases, each alias has a key, value, and config
+ for(const auto &alias : aliases_) {
+ // Check if the alias option exists in the config (it may come from command line or a config
+ // file)
+ if(config_[alias.key]) {
+ // Check if the option in the config stores the value required to expand the alias. If so,
+ // expand the alias.
+ // Two cases:
+ // * the option is a sequence: extract it as a vector of strings and look for the value
+ // * otherwise: compare values as strings
+ bool expand = false;
+ if(config_[alias.key].IsSequence()) {
+ auto aliasOpts = config_[alias.key].as<std::vector<std::string>>();
+ expand = std::find(aliasOpts.begin(), aliasOpts.end(), alias.value) != aliasOpts.end();
+ } else {
+ expand = config_[alias.key].as<std::string>() == alias.value;
+ }
+
+ if(expand) {
+ // Update global config options with the config associated with the alias. Abort if the
+ // alias contains an undefined option.
+ updateConfig(alias.config,
+ // Priority of each expanded option is the same as the priority of the alias
+ options_[alias.key].priority,
+ "Unknown option(s) in alias '" + alias.key + ": " + alias.value + "'");
+ }
+ }
+ }
+
+ // Remove aliases from the global config to avoid redundancy when writing/reading config files
+ for(const auto &alias : aliases_) {
+ config_.remove(alias.key);
+ }
}
-bool CLIWrapper::updateConfig(const YAML::Node &config) {
- bool success = true;
+void CLIWrapper::updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg) {
auto cmdOptions = getParsedOptionNames();
+ // Keep track of unrecognized options from the provided config
+ std::vector<std::string> unknownOpts;
+
+ // Iterate incoming options: they need to be merged into the global config
for(auto it : config) {
auto key = it.first.as<std::string>();
- // skip options specified via command-line to allow overwriting them
+
+ // Skip options specified via command-line to allow overwriting them
if(cmdOptions.count(key))
continue;
+ // Skip options that might exist in config files generated by older versions of Marian
+ if(DEPRECIATED_OPTIONS.count(key))
+ continue;
+
+ // Check if an incoming option has been defined in CLI
if(options_.count(key)) {
- config_[key] = YAML::Clone(it.second);
- options_[key].modified = true;
- } else {
- success = false;
+ // Do not proceed if the priority of incoming option is not greater than the existing option
+ if(priority <= options_[key].priority) {
+ continue;
+ }
+ // Check if the option exists in the global config and types match
+ if(config_[key] && config_[key].Type() == it.second.Type()) {
+ config_[key] = YAML::Clone(it.second);
+ options_[key].priority = priority;
+ // If types doesn't match, try to convert
+ } else {
+ // Default value is a sequence and incoming node is a scalar, hence we can upcast to
+ // single element sequence
+ if(config_[key].IsSequence() && it.second.IsScalar()) {
+ // create single element sequence
+ YAML::Node sequence;
+ sequence.push_back(YAML::Clone(it.second));
+ config_[key] = sequence; // overwrite to replace default values
+ options_[key].priority = priority;
+ } else {
+ // Cannot convert other non-matching types, e.g. scalar <- list should fail
+ ABORT("Cannot convert values for the option: " + key);
+ }
+ }
+ } else { // an unknown option
+ unknownOpts.push_back(key);
}
}
- return success;
+
+ ABORT_IF(!unknownOpts.empty(), errorMsg + ": " + utils::join(unknownOpts, ", "));
}
-std::string CLIWrapper::dumpConfig(bool skipDefault /*= false*/) const {
+std::string CLIWrapper::dumpConfig(bool skipUnmodified /*= false*/) const {
YAML::Emitter out;
out << YAML::Comment("Marian configuration file generated at " + timer::currentDate()
+ " with version " + buildVersion());
out << YAML::BeginMap;
std::string comment;
+ // Iterate option names in the same order as they have been created
for(const auto &key : getOrderedOptionNames()) {
- // do not proceed keys that are removed from config_
+ // Do not dump options that were removed from config_
if(!config_[key])
continue;
- if(skipDefault && !options_.at(key).modified)
+ // Do not dump options that were not passed via the command line
+ if(skipUnmodified && options_.at(key).priority == cli::OptionPriority::DefaultValue)
continue;
+ // Put the group name as a comment before the first option in the group
auto group = options_.at(key).opt->get_group();
if(comment != group) {
if(!comment.empty())
@@ -192,5 +264,12 @@ std::vector<std::string> CLIWrapper::getOrderedOptionNames() const {
return keys;
}
+std::string CLIWrapper::failureMessage(const CLI::App *app, const CLI::Error &e) {
+ std::string header = "Error: " + std::string(e.what()) + "\n";
+ if(app->get_help_ptr() != nullptr)
+ header += "Run with " + app->get_help_ptr()->get_name() + " for more information.\n";
+ return header;
+}
+
} // namespace cli
} // namespace marian
diff --git a/src/common/cli_wrapper.h b/src/common/cli_wrapper.h
index cf47a310..349d353b 100755..100644
--- a/src/common/cli_wrapper.h
+++ b/src/common/cli_wrapper.h
@@ -9,6 +9,7 @@
#include <map>
#include <string>
#include <unordered_set>
+#include <vector>
namespace marian {
@@ -16,28 +17,30 @@ class Options;
namespace cli {
-// Try to determine the width of the terminal
-//
-// TODO: make use of it in the current CLI or remove. This is an old code used
-// for boost::program_options and might not be needed anymore.
-//static uint16_t guess_terminal_width(uint16_t max_width = 0,
-// uint16_t default_width = 180);
-
-// TODO: use validators in ConfigParser
-namespace validators {
-const CLI::detail::ExistingFileValidator file_exists;
-const CLI::detail::ExistingDirectoryValidator dir_exists;
-const CLI::detail::ExistingPathValidator path_exists;
-
-const CLI::detail::NonexistentPathValidator path_not_exists;
-
-typedef CLI::Range range;
-}
+// Option priority
+enum struct OptionPriority : int { DefaultValue = 0, ConfigFile = 1, CommandLine = 2 };
/**
- * The helper class for cli::CLIWrapper handling formatting of options and their
- * descriptions.
+ * Helper tuple storing an option object, the associated variable and creation index
+ *
+ * Note: bare pointers are used for CLI::Option objects as this comes from the CLI11 library.
+ * Removing it would require deep modifications in the 3rd party library, what we want to avoid.
*/
+struct CLIOptionTuple {
+ CLI::Option *opt; // a pointer to an option object from CLI11
+ Ptr<any_type> var; // value assigned to the option via command-line
+ size_t idx{0}; // order in which the option was created
+ OptionPriority priority{cli::OptionPriority::DefaultValue};
+};
+
+// Helper tuple for aliases storing the alias name, value, and options to be expanded
+struct CLIAliasTuple {
+ std::string key; // alias option name
+ std::string value; // value for the alias option indicating that it should be expanded
+ YAML::Node config; // config with options that the alias adds
+};
+
+// The helper class for cli::CLIWrapper handling formatting of options and their descriptions.
class CLIFormatter : public CLI::Formatter {
public:
CLIFormatter(size_t columnWidth, size_t screenWidth);
@@ -47,57 +50,39 @@ private:
size_t screenWidth_{0};
};
-// @TODO: in this file review the use of naked pointers. We use Ptr<Type> anywhere else,
-// what's up with that?
-
-/**
- * The helper structure storing an option object, the associated variable and creation index.
- */
-struct CLIOptionTuple {
- CLI::Option *opt;
- Ptr<any_type> var;
- size_t idx{0};
- bool modified{false};
-};
-
/**
* @brief The class used to define and parse command-line arguments.
*
- * It is a wrapper around https://github.com/CLIUtils/CLI11 that stores defined
- * command-line arguments in a YAML object.
+ * It is a wrapper around https://github.com/CLIUtils/CLI11 that stores defined command-line
+ * arguments in a YAML object.
*
- * Usage outline: first call add() methods to create all the options; then call
- * parse(argv, argc) to parse command line and get defined options and their
- * values in a YAML object. The object can be also obtained later by calling
+ * Usage outline: first call add() methods to create all the options; then call parse(argv, argc) to
+ * parse command line and get defined options and their values in a YAML object; finally call
+ * parseAliases() to expand alias options. The config object can be also obtained later by calling
* getConfig().
*
- * Options are organized in option groups. Each option group has a header that
- * preceeds all options in the group. The header for the default option group
- * can be set from the class constructor.
+ * Options are organized in option groups. Each option group has a header that preceeds all options
+ * in the group. The header for the default option group can be set from the class constructor.
*/
class CLIWrapper {
private:
// Map with option names and option tuples
std::unordered_map<std::string, CLIOptionTuple> options_;
- // Counter for created options
+ // Counter for created options to keep track of order in which options were created
size_t counter_{0};
- // Command-line argument parser
- Ptr<CLI::App> app_;
+ std::vector<CLIAliasTuple> aliases_; // List of alias tuples
+
+ Ptr<CLI::App> app_; // Command-line argument parser from CLI11
- // Name of the default option group
- std::string defaultGroup_{""};
- // Name of the current option group
- std::string currentGroup_{""};
+ std::string defaultGroup_{""}; // Name of the default option group
+ std::string currentGroup_{""}; // Name of the current option group
- // Reference to the main config object
- YAML::Node &config_;
+ YAML::Node &config_; // Reference to the main config object
// Option for --version flag. This is a special flag and similarly to --help,
// the key "version" will be not added into the YAML config
CLI::Option *optVersion_;
- static std::string failureMessage(const CLI::App *app, const CLI::Error &e);
-
// Extract option name from a comma-separated list of long and short options, e.g. 'help' from
// '--help,-h'
std::string keyName(const std::string &args) const {
@@ -107,7 +92,15 @@ private:
.front(); // get first long name
}
+ // Get names of options passed via command-line
+ std::unordered_set<std::string> getParsedOptionNames() const;
+ // Get option names in the same order as they are created
+ std::vector<std::string> getOrderedOptionNames() const;
+
+ static std::string failureMessage(const CLI::App *app, const CLI::Error &e);
+
public:
+
/**
* @brief Create an instance of the command-line argument parser
*
@@ -118,8 +111,7 @@ public:
* @param header Header text for the main option group
* @param footer Text displayed after the list of options
* @param columnWidth Width of the column with option names
- * @param screenWidth Maximum allowed width for help messages, 0 means no
- * limit
+ * @param screenWidth Maximum allowed width for help messages, 0 means no limit
*/
CLIWrapper(YAML::Node &config,
const std::string &description = "",
@@ -128,24 +120,13 @@ public:
size_t columnWidth = 40,
size_t screenWidth = 0);
- /**
- * @brief Create an instance of the command-line argument parser,
- * short-cuft for Options object.
- *
- * @see Other constructor
- */
- CLIWrapper(Ptr<Options> options,
- const std::string &description = "",
- const std::string &header = "General options",
- const std::string &footer = "",
- size_t columnWidth = 30,
- size_t screenWidth = 0);
-
virtual ~CLIWrapper();
/**
* @brief Define an option with a default value
*
+ * Explicit default values will appear in help messages.
+ *
* @param args Comma-separated list of short and long option names
* @param help Help message
* @param val Default value
@@ -154,110 +135,122 @@ public:
*/
template <typename T>
CLI::Option *add(const std::string &args, const std::string &help, T val) {
- return add_option<T>(keyName(args),
- args,
- help,
- val,
- /*defaulted =*/true,
- /*addToConfig =*/true);
+ return addOption<T>(keyName(args),
+ args,
+ help,
+ val,
+ /*defaulted =*/true);
}
/**
- * @brief Define an option without an explicit default value. The implicit
- * default value is T()
+ * @brief Define an option without an explicit default value. The implicit default value is T()
+ *
+ * The option will be defined in the config file even if not given as a command-line argument. The
+ * implicit default value for a boolean or numeric option is 0, for a string is an empty string,
+ * and for a vector is an empty vector.
*
- * The option will be defined in the config file even if not given as a
- * command-line argument. The implicit default value for a boolean or numeric
- * option is 0, for a string is an empty string, and for a vector is an empty
- * vector.
+ * Implicit default values will *NOT* appear in help messages.
*
* @param args Comma-separated list of short and long option names
* @param help Help message
*
* @return Option object
*
- * TODO: require to always state the default value creating the parser as this
- * will be clearer
+ * @TODO: require to always state the default value creating the parser as this will be clearer
*/
template <typename T>
CLI::Option *add(const std::string &args, const std::string &help) {
- return add_option<T>(keyName(args),
- args,
- help,
- T(),
- /*defaulted =*/false,
- /*addToConfig =*/true);
+ return addOption<T>(keyName(args),
+ args,
+ help,
+ T(),
+ /*defaulted =*/false);
}
/**
- * @brief Define a non-defaulted option
+ * @brief Transform a command line option into an alias. This alias will set other options later.
*
- * The option will not be present in the config file unless given as a
- * command-line argument.
+ * An alias sets one or more options to predefined values. The options expanded by the alias are
+ * provided as a function setting a temporary YAML config.
*
- * @param args Comma-separated list of short and long option names
- * @param help Help message
+ * The alias option has to be first defined using `add<T>()`. Otherwise, the program will abort.
*
- * @return Option object
+ * Defining more than one alias for the same `key` but different `value` is allowed.
+ *
+ * Option values are compared as std::string. If the alias option is a vector, the alias will be
+ * triggered if `value` exists in that vector at least once.
+ *
+ * Options set directly via command line have precedence over options defined in an alias, i.e. an
+ * option added via alias can be overwritten by setting a specific option via command line.
*
- * @TODO: consider removing this method during final refactorization of
- * command-line/config parsers in the future as all options should either
- * have a default value or be non-defaulted
+ * @param key Alias option name
+ * @param value Option value that trigger the alias
+ * @param fun Function setting a temporary YAML config with options expanded by alias
*/
- template <typename T>
- CLI::Option *add_nondefault(const std::string &args, const std::string &help) {
- return add_option<T>(keyName(args),
- args,
- help,
- T(),
- /*defaulted =*/false,
- /*addToConfig =*/false);
+ void alias(const std::string &key,
+ const std::string &value,
+ const std::function<void(YAML::Node &config)> &fun) {
+ ABORT_IF(!options_.count(key), "Option '{}' is not defined so alias can not be created", key);
+ aliases_.resize(aliases_.size() + 1);
+ aliases_.back().key = key;
+ aliases_.back().value = value;
+ fun(aliases_.back().config);
}
/**
* Switch to different option group or to the default group if argument is empty.
*
* @param name Header of the option group
+ * @return Previous group.
*/
- void switchGroup(const std::string &name = "");
+ std::string switchGroup(std::string name = "");
// Parse command-line arguments. Handles --help and --version options
void parse(int argc, char **argv);
- /*
- * @brief Overwrite values for unparsed options
+ /**
+ * @brief Expand aliases based on arguments parsed with parse(int, char**)
+ *
+ * Should be called after parse(int, char**) to take an effect. If any alias tries to expand an
+ * undefined option, the method will abort the program.
+ *
+ * All options defined as aliases are removed from the global config object to avoid redundancy
+ * when options are dumped (explicitly or implicitly) to a config file.
+ */
+ void parseAliases();
+
+ /**
+ * @brief Overwrite options with lower priority
+ *
+ * Values for options with lower priority than the provided priority remain unchanged. This allows
+ * for overwritting default options by options from config files, or both by options provided in
+ * the command line.
*
- * Default values are overwritten with the options from the config provided, while parsed
- * command-line options remain unchanged.
* This should be a preferred way of updating config options as the class keeps track of options,
* which values have changed.
*
- * @param node YAML config with new default values for options
+ * @param config YAML config with new default values for options
+ * @param priority priority of incoming options
+ * @param errorMsg error message printed if config contains undefined keys. The message is
+ * appended with ": <comma-separated list of invalid options>"
*/
- bool updateConfig(const YAML::Node &config);
+ void updateConfig(const YAML::Node &config, cli::OptionPriority priority, const std::string &errorMsg);
// Get textual YAML representation of the config
- std::string dumpConfig(bool skipDefault = false) const;
+ std::string dumpConfig(bool skipUnmodified = false) const;
private:
- // Get names of options passed via command-line
- std::unordered_set<std::string> getParsedOptionNames() const;
- // Get option names in the same order as they are created
- std::vector<std::string> getOrderedOptionNames() const;
-
template <typename T,
// options with numeric and string-like values
CLI::enable_if_t<!CLI::is_bool<T>::value && !CLI::is_vector<T>::value,
CLI::detail::enabler> = CLI::detail::dummy>
- CLI::Option *add_option(const std::string &key,
- const std::string &args,
- const std::string &help,
- T val,
- bool defaulted,
- bool addToConfig) {
- // define YAML entry if requested
- if(addToConfig)
- config_[key] = val;
+ CLI::Option *addOption(const std::string &key,
+ const std::string &args,
+ const std::string &help,
+ T val,
+ bool defaulted) {
+ // add key to YAML
+ config_[key] = val;
// create option tuple
CLIOptionTuple option;
@@ -266,7 +259,7 @@ private:
// callback function collecting a command-line argument
CLI::callback_t fun = [this, key](CLI::results_t res) {
- options_[key].modified = true;
+ options_[key].priority = cli::OptionPriority::CommandLine;
// get variable associated with the option
auto &var = options_[key].var->as<T>();
// store parser result in var
@@ -298,15 +291,13 @@ private:
template <typename T,
// options with vector values
CLI::enable_if_t<CLI::is_vector<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
- CLI::Option *add_option(const std::string &key,
- const std::string &args,
- const std::string &help,
- T val,
- bool defaulted,
- bool addToConfig) {
- // define YAML entry if requested
- if(addToConfig)
- config_[key] = val;
+ CLI::Option *addOption(const std::string &key,
+ const std::string &args,
+ const std::string &help,
+ T val,
+ bool defaulted) {
+ // add key to YAML
+ config_[key] = val;
// create option tuple
CLIOptionTuple option;
@@ -315,7 +306,7 @@ private:
// callback function collecting command-line arguments
CLI::callback_t fun = [this, key](CLI::results_t res) {
- options_[key].modified = true;
+ options_[key].priority = cli::OptionPriority::CommandLine;
// get vector variable associated with the option
auto &vec = options_[key].var->as<T>();
vec.clear();
@@ -357,15 +348,13 @@ private:
template <typename T,
// options with boolean values, called flags in CLI11
CLI::enable_if_t<CLI::is_bool<T>::value, CLI::detail::enabler> = CLI::detail::dummy>
- CLI::Option *add_option(const std::string &key,
- const std::string &args,
- const std::string &help,
- T val,
- bool defaulted,
- bool addToConfig) {
- // define YAML entry if requested
- if(addToConfig)
- config_[key] = val;
+ CLI::Option *addOption(const std::string &key,
+ const std::string &args,
+ const std::string &help,
+ T val,
+ bool defaulted) {
+ // add key to YAML
+ config_[key] = val;
// create option tuple
CLIOptionTuple option;
@@ -374,7 +363,7 @@ private:
// callback function setting the flag
CLI::callback_t fun = [this, key](CLI::results_t res) {
- options_[key].modified = true;
+ options_[key].priority = cli::OptionPriority::CommandLine;
// get parser result, it is safe as boolean options have an implicit value
auto val = res[0];
auto ret = true;
diff --git a/src/common/compile_time_crc32.h b/src/common/compile_time_crc32.h
index 78211388..78211388 100755..100644
--- a/src/common/compile_time_crc32.h
+++ b/src/common/compile_time_crc32.h
diff --git a/src/common/config.cpp b/src/common/config.cpp
index e5208b0d..a6ce44c4 100755..100644
--- a/src/common/config.cpp
+++ b/src/common/config.cpp
@@ -1,9 +1,11 @@
#include "common/config.h"
+#include "common/config_parser.h"
#include "common/file_stream.h"
#include "common/logging.h"
+#include "common/options.h"
+#include "common/regex.h"
#include "common/utils.h"
#include "common/version.h"
-#include "common/regex.h"
#include <algorithm>
#include <set>
@@ -14,35 +16,26 @@ namespace marian {
// @TODO: keep seed in a single place, now it is kept here and in Config/Options
size_t Config::seed = (size_t)time(0);
-Config::Config(int argc,
- char** argv,
- cli::mode mode /*= cli::mode::training*/,
- bool validate /*= true*/) {
- initialize(argc, argv, mode, validate);
+
+Config::Config(ConfigParser const& cp) {
+ initialize(cp);
}
+Config::Config(int argc, char** argv, cli::mode mode, bool validate /*= true*/)
+ : Config(ConfigParser(argc, argv, mode, validate)) {}
+
Config::Config(const Config& other) : config_(YAML::Clone(other.config_)) {}
-Config::Config(const Options& options) : config_(YAML::Clone(options.getYaml())) {}
+Config::Config(const Options& options) : config_(options.cloneToYamlNode()) {}
-void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
- auto parser = ConfigParser(argc, argv, mode, validate);
- config_ = parser.getConfig();
+void Config::initialize(ConfigParser const& cp) {
+ config_ = YAML::Clone(cp.getConfig());
+ cli::mode mode = cp.getMode();
createLoggers(this);
// echo version and command line
LOG(info, "[marian] Marian {}", buildVersion());
- std::string cmdLine;
- for (int i = 0; i < argc; i++) {
- std::string arg = argv[i];
- std::string quote; // attempt to quote special chars
- if (arg.empty() || arg.find_first_of(" #`\"'\\${}|&^?*!()%><") != std::string::npos)
- quote = "'";
- arg = regex::regex_replace(arg, regex::regex("'"), "'\\''");
- if (!cmdLine.empty())
- cmdLine.push_back(' ');
- cmdLine += quote + arg + quote;
- }
+ std::string cmdLine = cp.cmdLine();
std::string hostname; int pid; std::tie
(hostname, pid) = utils::hostnameAndProcessId();
LOG(info, "[marian] Running on {} as process {} with command line:", hostname, pid);
@@ -56,7 +49,17 @@ void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
}
// load model parameters
- if(mode != cli::mode::translation) {
+ if(mode == cli::mode::translation || mode == cli::mode::server) {
+ auto model = get<std::vector<std::string>>("models")[0];
+ try {
+ if(!get<bool>("ignore-model-config"))
+ loadModelParameters(model);
+ } catch(std::runtime_error& ) {
+ LOG(info, "[config] No model configuration found in model file");
+ }
+ }
+ // if cli::mode::training or cli::mode::scoring
+ else {
auto model = get<std::string>("model");
if(filesystem::exists(model) && !get<bool>("no-reload")) {
try {
@@ -67,16 +70,6 @@ void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
}
}
}
- // if cli::mode::translation
- else {
- auto model = get<std::vector<std::string>>("models")[0];
- try {
- if(!get<bool>("ignore-model-config"))
- loadModelParameters(model);
- } catch(std::runtime_error& ) {
- LOG(info, "[config] No model configuration found in model file");
- }
- }
// echo full configuration
log();
@@ -95,15 +88,14 @@ void Config::initialize(int argc, char** argv, cli::mode mode, bool validate) {
version,
buildVersion());
else
- LOG(info,
- "[config] Loaded model has been created with Marian {}",
- version);
+ LOG(info, "[config] Loaded model has been created with Marian {}", version);
+
+ // Remove "version" from config to make it consistent among different start-up scenarios
+ config_.remove("version");
}
// If this is a newly started training
else if(mode == cli::mode::training) {
- LOG(info,
- "[config] Model is being created with Marian {}",
- buildVersion());
+ LOG(info, "[config] Model is being created with Marian {}", buildVersion());
}
}
@@ -156,10 +148,9 @@ void Config::log() {
std::string configString = out.c_str();
// print YAML prepending each line with [config]
- std::vector<std::string> results;
- utils::split(configString, results, "\n");
- for(auto& r : results)
- LOG(info, "[config] {}", r);
+ auto lines = utils::split(configString, "\n");
+ for(auto& line : lines)
+ LOG(info, "[config] {}", line);
}
// Parse the device-spec parameters (--num-devices, --devices, --cpu-threads) into an array of
@@ -264,14 +255,17 @@ std::vector<DeviceId> Config::getDevices(Ptr<Options> options,
return devices;
}
-Ptr<Options> parseOptions(int argc,
- char** argv,
- cli::mode mode /*= cli::mode::training*/,
- bool validate /*= true*/) {
- auto config = New<Config>(argc, argv, mode, validate);
- auto options = New<Options>();
- options->merge(config->get());
- return options;
+Ptr<Options>
+parseOptions(int argc, char** argv, cli::mode mode, bool validate){
+ ConfigParser cp(mode);
+ return cp.parseOptions(argc, argv, validate);
+}
+
+std::ostream& operator<<(std::ostream& out, const Config& config) {
+ YAML::Emitter outYaml;
+ cli::OutputYaml(config.get(), outYaml);
+ out << outYaml.c_str();
+ return out;
}
} // namespace marian
diff --git a/src/common/config.h b/src/common/config.h
index ec462a53..d4784af7 100755..100644
--- a/src/common/config.h
+++ b/src/common/config.h
@@ -38,6 +38,7 @@ public:
typedef YAML::Node YamlNode;
+ Config(ConfigParser const& cp);
// TODO: remove mode from this class
Config(int argc,
char** argv,
@@ -47,7 +48,7 @@ public:
Config(const Config& other);
Config(const Options& options);
- void initialize(int argc, char** argv, cli::mode mode, bool validate);
+ void initialize(ConfigParser const& cp);
bool has(const std::string& key) const;
@@ -83,12 +84,7 @@ public:
void save(const std::string& name);
- friend std::ostream& operator<<(std::ostream& out, const Config& config) {
- YAML::Emitter outYaml;
- cli::OutputYaml(config.get(), outYaml);
- out << outYaml.c_str();
- return out;
- }
+ friend std::ostream& operator<<(std::ostream& out, const Config& config);
static std::vector<DeviceId> getDevices(Ptr<Options> options,
size_t myMPIRank = 0,
@@ -115,7 +111,7 @@ private:
*/
Ptr<Options> parseOptions(int argc,
char** argv,
- cli::mode mode = cli::mode::training,
+ cli::mode mode,
bool validate = true);
} // namespace marian
diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp
index d39f3b1a..9c711eaa 100755
--- a/src/common/config_parser.cpp
+++ b/src/common/config_parser.cpp
@@ -1,12 +1,15 @@
-#include "common/config_parser.h"
-
-#include "common/definitions.h"
+#include "common/authors.h"
+#include "common/build_info.h"
#include "common/cli_helper.h"
+#include "common/config.h"
+#include "common/config_parser.h"
#include "common/config_validator.h"
+#include "common/definitions.h"
#include "common/file_stream.h"
#include "common/logging.h"
+#include "common/options.h"
+#include "common/regex.h"
#include "common/utils.h"
-
#include <algorithm>
#include <set>
#include <stdexcept>
@@ -22,7 +25,8 @@
namespace marian {
-// TODO: move to CLIWrapper
+// TODO: Move this to CLIWrapper and allow to mark options as paths in the same place they are
+// defined
// clang-format off
const std::set<std::string> PATHS = {
"model",
@@ -32,6 +36,7 @@ const std::set<std::string> PATHS = {
"embedding-vectors",
"valid-sets",
"valid-script-path",
+ "valid-script-args",
"valid-log",
"valid-translation-output",
"input", // except: stdin
@@ -47,44 +52,111 @@ const std::set<std::string> PATHS = {
};
// clang-format on
+std::string escapeCmdLine(int argc, char** argv){
+ std::string cmdLine;
+ for(int i = 0; i < argc; i++) {
+ std::string arg = argv[i];
+ std::string quote; // attempt to quote special chars
+ if(arg.empty() || arg.find_first_of(" #`\"'\\${}|&^?*!()%><") != std::string::npos)
+ quote = "'";
+ arg = regex::regex_replace(arg, regex::regex("'"), "'\\''");
+ if(!cmdLine.empty())
+ cmdLine.push_back(' ');
+ cmdLine += quote + arg + quote;
+ }
+ return cmdLine;
+}
+
+std::string const& ConfigParser::cmdLine() const {
+ return cmdLine_;
+}
+
+ConfigParser::ConfigParser(cli::mode mode)
+ : cli_(config_,"Marian: Fast Neural Machine Translation in C++",
+ "General options", "", 40),
+ mode_(mode == cli::mode::server ? cli::mode::translation : mode) {
+
+ addOptionsGeneral(cli_);
+ if (mode == cli::mode::server)
+ addOptionsServer(cli_);
+ addOptionsModel(cli_);
+
+ // clang-format off
+ switch(mode_) {
+ case cli::mode::training:
+ addOptionsTraining(cli_);
+ addOptionsValidation(cli_);
+ break;
+ case cli::mode::translation:
+ addOptionsTranslation(cli_);
+ break;
+ case cli::mode::scoring:
+ addOptionsScoring(cli_);
+ break;
+ default:
+ ABORT("wrong CLI mode");
+ break;
+ }
+
+ addAliases(cli_);
+ // clang-format on
+}
+
void ConfigParser::addOptionsGeneral(cli::CLIWrapper& cli) {
int defaultWorkspace = (mode_ == cli::mode::translation) ? 512 : 2048;
cli.switchGroup("General options");
// clang-format off
+ cli.add<bool>("--authors",
+ "Print list of authors and exit");
+ cli.add<bool>("--cite",
+ "Print citation and exit");
+ cli.add<std::string>("--build-info",
+ "Print CMake build options and exit. Set to 'all' to print advanced options")
+ ->implicit_val("basic");
cli.add<std::vector<std::string>>("--config,-c",
- "Configuration file(s). If multiple, later overrides earlier");
+ "Configuration file(s). If multiple, later overrides earlier");
cli.add<size_t>("--workspace,-w",
- "Preallocate arg MB of work space",
- defaultWorkspace);
- cli.add_nondefault<std::string>("--log",
- "Log training process information to file given by arg");
+ "Preallocate arg MB of work space",
+ defaultWorkspace);
+ cli.add<std::string>("--log",
+ "Log training process information to file given by arg");
cli.add<std::string>("--log-level",
- "Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off",
- "info");
- cli.add_nondefault<std::string>("--log-time-zone",
- "Set time zone for the date shown on logging");
+ "Set verbosity level of logging: trace, debug, info, warn, err(or), critical, off",
+ "info");
+ cli.add<std::string>("--log-time-zone",
+ "Set time zone for the date shown on logging");
cli.add<bool>("--quiet",
- "Suppress all logging to stderr. Logging to files still works");
+ "Suppress all logging to stderr. Logging to files still works");
cli.add<bool>("--quiet-translation",
- "Suppress logging for translation");
+ "Suppress logging for translation");
cli.add<size_t>("--seed",
- "Seed for all random number generators. 0 means initialize randomly");
+ "Seed for all random number generators. 0 means initialize randomly");
cli.add<float>("--clip-gemm",
- "If not 0 clip GEMM input values to +/- arg");
+ "If not 0 clip GEMM input values to +/- arg");
cli.add<bool>("--interpolate-env-vars",
- "allow the use of environment variables in paths, of the form ${VAR_NAME}");
+ "allow the use of environment variables in paths, of the form ${VAR_NAME}");
cli.add<bool>("--relative-paths",
- "All paths are relative to the config file location");
- cli.add_nondefault<std::string>("--dump-config",
- "Dump current (modified) configuration to stdout and exit. Possible values: full, minimal")
+ "All paths are relative to the config file location");
+ cli.add<std::string>("--dump-config",
+ "Dump current (modified) configuration to stdout and exit. Possible values: full, minimal, expand")
->implicit_val("full");
// clang-format on
}
+void ConfigParser::addOptionsServer(cli::CLIWrapper& cli) {
+ // clang-format off
+ auto previous_group = cli.switchGroup("Server options");
+ cli.add<size_t>("--port,-p",
+ "Port number for web socket server",
+ 8080);
+ cli.switchGroup(previous_group);
+ // clang-format on
+}
+
void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
- cli.switchGroup("Model options");
+ auto previous_group = cli.switchGroup("Model options");
// clang-format off
if(mode_ == cli::mode::translation) {
@@ -96,7 +168,7 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"model.npz");
if(mode_ == cli::mode::training) {
- cli.add_nondefault<std::string>("--pretrained-model",
+ cli.add<std::string>("--pretrained-model",
"Path prefix for pre-trained model to initialize model weights");
}
}
@@ -108,10 +180,13 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"amun");
cli.add<std::vector<int>>("--dim-vocabs",
"Maximum items in vocabulary ordered by rank, 0 uses all items in the provided/created vocabulary file",
- std::vector<int>({0, 0}));
+ {0, 0});
cli.add<int>("--dim-emb",
"Size of embedding vector",
512);
+ cli.add<int>("--lemma-dim-emb",
+ "Re-embedding dimension of lemma in factors",
+ 0);
cli.add<int>("--dim-rnn",
"Size of rnn hidden state", 1024);
cli.add<std::string>("--enc-type",
@@ -143,10 +218,12 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Enable layer normalization");
cli.add<bool>("--right-left",
"Train right-to-left model");
+ cli.add<std::vector<std::string>>("--input-types",
+ "Provide type of input data if different than 'sequence'. "
+ "Possible values: sequence, class. You need to provide one type per input.",
+ {});
cli.add<bool>("--best-deep",
"Use Edinburgh deep RNN configuration (s2s)");
- cli.add_nondefault<std::vector<size_t>>("--special-vocab",
- "Model-specific special vocabulary ids");
cli.add<bool>("--tied-embeddings",
"Tie target embeddings and output embeddings in output layer");
cli.add<bool>("--tied-embeddings-src",
@@ -196,7 +273,17 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
cli.add<std::string>("--transformer-postprocess",
"Operation after each transformer layer: d = dropout, a = add, n = normalize",
"dan");
-
+ cli.add<bool>("--transformer-train-position-embeddings",
+ "Train positional embeddings instead of using static sinusoidal embeddings");
+ cli.add<bool>("--transformer-depth-scaling",
+ "Scale down weight initialization in transformer layers by 1 / sqrt(depth)");
+
+ cli.add<std::string>("--bert-mask-symbol", "Masking symbol for BERT masked-LM training", "[MASK]");
+ cli.add<std::string>("--bert-sep-symbol", "Sentence separator symbol for BERT next sentence prediction training", "[SEP]");
+ cli.add<std::string>("--bert-class-symbol", "Class symbol BERT classifier training", "[CLS]");
+ cli.add<float>("--bert-masking-fraction", "Fraction of masked out tokens during training", 0.15f);
+ cli.add<bool>("--bert-train-type-embeddings", "Train bert type embeddings, set to false to use static sinusoidal embeddings", true);
+ cli.add<int>("--bert-type-vocab-size", "Size of BERT type vocab (sentence A and B)", 2);
#ifdef CUDNN
cli.add<int>("--char-stride",
"Width of max-pooling layer after convolution layer in char-s2s model",
@@ -205,11 +292,11 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
"Number of highway network layers after max-pooling in char-s2s model",
4);
cli.add<std::vector<int>>("--char-conv-filters-num",
- "Numbers of convolution filters of correspoding width in char-s2s model",
- std::vector<int>({200, 200, 250, 250, 300, 300, 300, 300}));
+ "Numbers of convolution filters of corresponding width in char-s2s model",
+ {200, 200, 250, 250, 300, 300, 300, 300});
cli.add<std::vector<int>>("--char-conv-filters-widths",
"Convolution window widths in char-s2s model",
- std::vector<int>({1, 2, 3, 4, 5, 6, 7, 8}));
+ {1, 2, 3, 4, 5, 6, 7, 8});
#endif
if(mode_ == cli::mode::training) {
@@ -234,14 +321,19 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
cli.add<float>("--transformer-dropout-ffn",
"Dropout for transformer filter (0 = no dropout)");
}
+ cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
- cli.switchGroup("Training options");
+ auto previous_group = cli.switchGroup("Training options");
// clang-format off
- cli.add<std::string>("--cost-type",
+ cli.add<std::string>("--cost-type", // @TODO: rename to loss-type
"Optimization criterion: ce-mean, ce-mean-words, ce-sum, perplexity", "ce-mean");
+ cli.add<std::string>("--multi-loss-type",
+ "How to accumulate multi-objective losses: sum, scaled, mean", "sum");
+ cli.add<bool>("--unlikelihood-loss",
+ "Use word-level weights as indicators for sequence-level unlikelihood training");
cli.add<bool>("--overwrite",
"Do not create model checkpoints, only overwrite main model file with last checkpoint. "
"Reduces disk usage");
@@ -273,9 +365,11 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Display information every arg updates (append 't' for every arg target labels)",
"1000u");
cli.add<size_t>("--disp-first",
- "Display nformation for the first arg updates");
+ "Display information for the first arg updates");
cli.add<bool>("--disp-label-counts",
"Display label counts when logging loss progress");
+// cli.add<int>("--disp-label-index",
+// "Display label counts based on i-th input stream (-1 is last)", -1);
cli.add<std::string/*SchedulerPeriod*/>("--save-freq",
"Save model file every arg updates (append 't' for every arg target labels)",
"10000u");
@@ -283,8 +377,12 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
addSuboptionsInputLength(cli);
// data management options
+ cli.add<std::string>("--shuffle",
+ "How to shuffle input data (data: shuffles data and sorted batches; batches: "
+ "data is read in order into batches, but batches are shuffled; none: no shuffling). "
+ "Use with '--maxi-batch-sort none' in order to achieve exact reading order", "data");
cli.add<bool>("--no-shuffle",
- "Skip shuffling of training data before each epoch");
+ "Shortcut for backwards compatiblity, equivalent to --shuffle none (deprecated)");
cli.add<bool>("--no-restore-corpus",
"Skip restoring corpus state after training is restarted");
cli.add<std::string>("--tempdir,-T",
@@ -304,30 +402,33 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<std::string>("--optimizer,-o",
"Optimization algorithm: sgd, adagrad, adam",
"adam");
- cli.add_nondefault<std::vector<float>>("--optimizer-params",
- "Parameters for optimization algorithm, e.g. betas for adam");
- cli.add<size_t>("--optimizer-delay",
- "SGD update delay, 1 = no delay",
- 1);
+ cli.add<std::vector<float>>("--optimizer-params",
+ "Parameters for optimization algorithm, e.g. betas for Adam. "
+ "Auto-adjusted to --mini-batch-words-ref if given");
+ cli.add<float>("--optimizer-delay",
+ "SGD update delay (#batches between updates). 1 = no delay. "
+ "Can be fractional, e.g. 0.1 to use only 10% of each batch",
+ 1.f);
cli.add<bool>("--sync-sgd",
"Use synchronous SGD instead of asynchronous for multi-gpu training");
// learning rate options
- cli.add<double>("--learn-rate,-l",
- "Learning rate",
- 0.0001);
+ cli.add<float>("--learn-rate,-l",
+ "Learning rate. "
+ "Auto-adjusted to --mini-batch-words-ref if given",
+ 0.0001f);
cli.add<bool>("--lr-report",
"Report learning rate for each update");
- cli.add<double>("--lr-decay",
+ cli.add<float>("--lr-decay",
"Per-update decay factor for learning rate: lr <- lr * arg (0 to disable)");
cli.add<std::string>("--lr-decay-strategy",
"Strategy for learning rate decaying: epoch, batches, stalled, epoch+batches, epoch+stalled",
"epoch+stalled");
cli.add<std::vector<size_t>>("--lr-decay-start",
"The first number of (epoch, batches, stalled) validations to start learning rate decaying (tuple)",
- std::vector<size_t>({10,1}));
+ {10,1});
cli.add<size_t>("--lr-decay-freq",
"Learning rate decaying frequency for batches, requires --lr-decay-strategy to be batches",
50000);
@@ -335,9 +436,10 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
"Reset running statistics of optimizer whenever learning rate decays");
cli.add<bool>("--lr-decay-repeat-warmup",
"Repeat learning rate warmup when learning rate is decayed");
- cli.add<std::string/*SchedulerPeriod*/>("--lr-decay-inv-sqrt",
- "Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs))",
- "0");
+ cli.add<std::vector<std::string/*SchedulerPeriod*/>>("--lr-decay-inv-sqrt",
+ "Decrease learning rate at arg / sqrt(no. batches) starting at arg (append 't' or 'e' for sqrt(target labels or epochs)). "
+ "Add second argument to define the starting point (default: same as first value)",
+ {"0"});
cli.add<std::string/*SchedulerPeriod*/>("--lr-warmup",
"Increase learning rate linearly for arg first batches (append 't' for arg first target labels)",
@@ -351,12 +453,15 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<double>("--label-smoothing",
"Epsilon for label smoothing (0 to disable)");
- cli.add<double>("--clip-norm",
- "Clip gradient norm to argcli.add<int>(0 to disable)",
- 1.f);
+ cli.add<double>("--factor-weight",
+ "Weight for loss function for factors (factored vocab only) (1 to disable)", 1.0f);
+ cli.add<float>("--clip-norm",
+ "Clip gradient norm to arg (0 to disable)",
+ 1.f); // @TODO: this is currently wrong with ce-sum and should rather be disabled or fixed by multiplying with labels
cli.add<float>("--exponential-smoothing",
- "Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable",
- 0)->implicit_val("1e-4");
+ "Maintain smoothed version of parameters for validation and saving with smoothing factor. 0 to disable. "
+ "Auto-adjusted to --mini-batch-words-ref if given.",
+ 0.f)->implicit_val("1e-4");
cli.add<std::string>("--guided-alignment",
"Path to a file with word alignments. Use guided alignment to guide attention or 'none'",
"none");
@@ -366,14 +471,14 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<double>("--guided-alignment-weight",
"Weight for guided alignment cost",
0.1);
- cli.add_nondefault<std::string>("--data-weighting",
+ cli.add<std::string>("--data-weighting",
"Path to a file with sentence or word weights");
cli.add<std::string>("--data-weighting-type",
"Processing level for data weighting: sentence, word",
"sentence");
// embedding options
- cli.add_nondefault<std::vector<std::string>>("--embedding-vectors",
+ cli.add<std::vector<std::string>>("--embedding-vectors",
"Paths to files with custom source and target embedding vectors");
cli.add<bool>("--embedding-normalization",
"Normalize values from custom embedding vectors to [-1, 1]");
@@ -382,29 +487,50 @@ void ConfigParser::addOptionsTraining(cli::CLIWrapper& cli) {
cli.add<bool>("--embedding-fix-trg",
"Fix target embeddings. Affects all decoders");
+ // mixed precision training
+ cli.add<bool>("--fp16",
+ "Shortcut for mixed precision training with float16 and cost-scaling, "
+ "corresponds to: --precision float16 float32 float32 --cost-scaling 7 2000 2 0.05 10 1");
+ cli.add<std::vector<std::string>>("--precision",
+ "Mixed precision training for forward/backward pass and optimizaton. "
+ "Defines types for: forward/backward, optimization, saving.",
+ {"float32", "float32", "float32"});
+ cli.add<std::vector<std::string>>("--cost-scaling",
+ "Dynamic cost scaling for mixed precision training: "
+ "power of 2, scaling window, scaling factor, tolerance, range, minimum factor")->implicit_val("7.f 2000 2.f 0.05f 10 1.f");
+ cli.add<bool>("--normalize-gradient", "Normalize gradient by multiplying with no. devices / total labels");
+
+ // multi-node training
cli.add<bool>("--multi-node",
"Enable asynchronous multi-node training through MPI (and legacy sync if combined with --sync-sgd)");
cli.add<bool>("--multi-node-overlap",
"Overlap model computations with MPI communication",
true);
+
// add ULR settings
addSuboptionsULR(cli);
+
+ cli.add<std::vector<std::string>>("--task",
+ "Use predefined set of options. Possible values: transformer, transformer-big");
+ cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
- cli.switchGroup("Validation set options");
+ auto previous_group = cli.switchGroup("Validation set options");
// clang-format off
- cli.add_nondefault<std::vector<std::string>>("--valid-sets",
+ cli.add<std::vector<std::string>>("--valid-sets",
"Paths to validation corpora: source target");
cli.add<std::string/*SchedulerPeriod*/>("--valid-freq",
"Validate model every arg updates (append 't' for every arg target labels)",
"10000u");
cli.add<std::vector<std::string>>("--valid-metrics",
"Metric to use during validation: cross-entropy, ce-mean-words, perplexity, valid-script, "
- " translation, bleu, bleu-detok. Multiple metrics can be specified",
- std::vector<std::string>({"cross-entropy"}));
+ "translation, bleu, bleu-detok. Multiple metrics can be specified",
+ {"cross-entropy"});
+ cli.add<bool>("--valid-reset-stalled",
+ "Reset all stalled validation metrics when the training is restarted");
cli.add<size_t>("--early-stopping",
"Stop if the first validation metric does not improve for arg consecutive validation steps",
10);
@@ -425,38 +551,47 @@ void ConfigParser::addOptionsValidation(cli::CLIWrapper& cli) {
"Allow unknown words to appear in output");
cli.add<bool>("--n-best",
"Generate n-best list");
+ cli.add<bool>("--word-scores",
+ "Print word-level scores");
// efficiency options
cli.add<int>("--valid-mini-batch",
"Size of mini-batch used during validation",
32);
cli.add<size_t>("--valid-max-length",
- "Maximum length of a sentence in a validating sentence pair",
+ "Maximum length of a sentence in a validating sentence pair. "
+ "Sentences longer than valid-max-length are cropped to valid-max-length",
1000);
// options for validation script
- cli.add_nondefault<std::string>("--valid-script-path",
+ cli.add<std::string>("--valid-script-path",
"Path to external validation script."
" It should print a single score to stdout."
" If the option is used with validating translation, the output"
" translation file will be passed as a first argument");
- cli.add_nondefault<std::string>("--valid-translation-output",
- "Path to store the translation");
-
+ cli.add<std::vector<std::string>>("--valid-script-args",
+ "Additional args passed to --valid-script-path. These are inserted"
+ " between the script path and the output translation-file path");
+ cli.add<std::string>("--valid-translation-output",
+ "(Template for) path to store the translation. "
+ "E.g., validation-output-after-{U}-updates-{T}-tokens.txt. Template "
+ "parameters: {E} for epoch; {B} for No. of batches within epoch; "
+ "{U} for total No. of updates; {T} for total No. of tokens seen.");
cli.add<bool>("--keep-best",
"Keep best model for each validation metric");
- cli.add_nondefault<std::string>("--valid-log",
+ cli.add<std::string>("--valid-log",
"Log validation scores to file given by arg");
+ cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
- cli.switchGroup("Translator options");
+ auto previous_group = cli.switchGroup("Translator options");
// clang-format off
cli.add<std::vector<std::string>>("--input,-i",
"Paths to input file(s), stdin by default",
- std::vector<std::string>({"stdin"}));
+ {"stdin"});
cli.add<std::string>("--output,-o",
"Path to output file, stdout by default",
"stdout");
@@ -478,9 +613,15 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
"Allow unknown words to appear in output");
cli.add<bool>("--n-best",
"Generate n-best list");
- cli.add_nondefault<std::string>("--alignment",
+ cli.add<std::string>("--alignment",
"Return word alignment. Possible values: 0.0-1.0, hard, soft")
->implicit_val("1");
+ cli.add<bool>("--word-scores",
+ "Print word-level scores");
+#ifdef USE_SENTENCEPIECE
+ cli.add<bool>("--no-spm-decode",
+ "Keep the output segmented into SentencePiece subwords");
+#endif
addSuboptionsDevices(cli);
addSuboptionsInputLength(cli);
@@ -490,26 +631,31 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) {
"Optimize speed aggressively sacrificing memory or precision");
cli.add<bool>("--skip-cost",
"Ignore model cost during translation, not recommended for beam-size > 1");
+ cli.add<bool>("--fp16",
+ "Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
+ cli.add<std::vector<std::string>>("--precision",
+ "Mixed precision for inference, set parameter type in expression graph",
+ {"float32"});
- cli.add_nondefault<std::vector<std::string>>("--shortlist",
+ cli.add<std::vector<std::string>>("--shortlist",
"Use softmax shortlist: path first best prune");
- cli.add_nondefault<std::vector<float>>("--weights",
+ cli.add<std::vector<float>>("--weights",
"Scorer weights");
cli.add<bool>("--output-sampling",
- "Noise output layer with gumbel noise",
- false);
+ "Noise output layer with gumbel noise",
+ false);
- // TODO: the options should be available only in server
- cli.add_nondefault<size_t>("--port,-p",
- "Port number for web socket server");
+#if 0 // @TODO: Ask Hany if there are any decoding-time options
// add ULR settings
addSuboptionsULR(cli);
+#endif
+ cli.switchGroup(previous_group);
// clang-format on
}
void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
- cli.switchGroup("Scorer options");
+ auto previous_group = cli.switchGroup("Scorer options");
// clang-format off
cli.add<bool>("--no-reload",
@@ -521,19 +667,19 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
"Path to output file, stdout by default",
"stdout");
cli.add<std::vector<std::string>>("--vocabs,-v",
- "Paths to vocabulary files have to correspond to --train-sets."
- " If this parameter is not supplied we look for vocabulary files source.{yml,json} and target.{yml,json}."
- " If these files do not exists they are created");
+ "Paths to vocabulary files have to correspond to --train-sets. "
+ "If this parameter is not supplied we look for vocabulary files source.{yml,json} and target.{yml,json}. "
+ "If these files do not exists they are created");
cli.add<bool>("--n-best",
"Score n-best list instead of plain text corpus");
cli.add<std::string>("--n-best-feature",
"Feature name to be inserted into n-best list", "Score");
cli.add<bool>("--normalize,-n",
"Divide translation score by translation length");
- cli.add_nondefault<std::string>("--summary",
+ cli.add<std::string>("--summary",
"Only print total cost, possible values: cross-entropy (ce-mean), ce-mean-words, ce-sum, perplexity")
->implicit_val("cross-entropy");
- cli.add_nondefault<std::string>("--alignment",
+ cli.add<std::string>("--alignment",
"Return word alignments. Possible values: 0.0-1.0, hard, soft")
->implicit_val("1"),
@@ -543,6 +689,13 @@ void ConfigParser::addOptionsScoring(cli::CLIWrapper& cli) {
cli.add<bool>("--optimize",
"Optimize speed aggressively sacrificing memory or precision");
+ cli.add<bool>("--fp16",
+ "Shortcut for mixed precision inference with float16, corresponds to: --precision float16");
+ cli.add<std::vector<std::string>>("--precision",
+ "Mixed precision for inference, set parameter type in expression graph",
+ {"float32"});
+
+ cli.switchGroup(previous_group);
// clang-format on
}
@@ -550,8 +703,8 @@ void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) {
// clang-format off
cli.add<std::vector<std::string>>("--devices,-d",
"Specifies GPU ID(s) to use for training. Defaults to 0..num-devices-1",
- std::vector<std::string>({"0"}));
- cli.add_nondefault<size_t>("--num-devices",
+ {"0"});
+ cli.add<size_t>("--num-devices",
"Number of GPUs to use for this process. Defaults to length(devices) or 1");
#ifdef USE_NCCL
if(mode_ == cli::mode::training)
@@ -560,12 +713,13 @@ void ConfigParser::addSuboptionsDevices(cli::CLIWrapper& cli) {
#endif
#ifdef CUDA_FOUND
cli.add<size_t>("--cpu-threads",
- "Use CPU-based computation with this many independent threads, 0 means GPU-based computation")
- ->default_val("0")->implicit_val("1");
+ "Use CPU-based computation with this many independent threads, 0 means GPU-based computation",
+ 0)
+ ->implicit_val("1");
#else
cli.add<size_t>("--cpu-threads",
- "Use CPU-based computation with this many independent threads, 0 means GPU-based computation")
- ->default_val("1");
+ "Use CPU-based computation with this many independent threads, 0 means GPU-based computation",
+ 1);
#endif
// clang-format on
}
@@ -593,6 +747,8 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
cli.add<size_t>("--mini-batch-fit-step",
"Step size for mini-batch-fit statistics",
10);
+ cli.add<bool>("--gradient-checkpointing",
+ "Enable gradient-checkpointing to minimize memory usage");
}
cli.add<int>("--maxi-batch",
@@ -602,8 +758,25 @@ void ConfigParser::addSuboptionsBatching(cli::CLIWrapper& cli) {
"Sorting strategy for maxi-batch: none, src, trg (not available for decoder)",
defaultMaxiBatchSort);
- cli.add<bool>("--shuffle-in-ram",
- "Keep shuffled corpus in RAM, do not write to temp file");
+ if(mode_ == cli::mode::training) {
+ cli.add<bool>("--shuffle-in-ram",
+ "Keep shuffled corpus in RAM, do not write to temp file");
+ // @TODO: Consider making the next two options options of the vocab instead, to make it more local in scope.
+ cli.add<size_t>("--all-caps-every",
+ "When forming minibatches, preprocess every Nth line on the fly to all-caps. Assumes UTF-8");
+ cli.add<size_t>("--english-title-case-every",
+ "When forming minibatches, preprocess every Nth line on the fly to title-case. Assumes English (ASCII only)");
+
+ cli.add<int>("--mini-batch-words-ref",
+ "If given, the following hyper parameters are adjusted as-if we had this mini-batch size: "
+ "--learn-rate, --optimizer-params, --exponential-smoothing, --mini-batch-warmup");
+ cli.add<std::string/*SchedulerPeriod*/>("--mini-batch-warmup",
+ "Linear ramp-up of MB size, up to this #updates (append 't' for up to this #target labels). "
+ "Auto-adjusted to --mini-batch-words-ref if given",
+ {"0"});
+ cli.add<bool>("--mini-batch-track-lr",
+ "Dynamically track mini-batch size inverse to actual learning rate (not considering lr-warmup)");
+ }
// clang-format on
}
@@ -614,7 +787,7 @@ void ConfigParser::addSuboptionsInputLength(cli::CLIWrapper& cli) {
"Maximum length of a sentence in a training sentence pair",
defaultMaxLength);
cli.add<bool>("--max-length-crop",
- "Crop a sentence to max-length instead of ommitting it if longer than max-length");
+ "Crop a sentence to max-length instead of omitting it if longer than max-length");
// clang-format on
}
@@ -622,8 +795,7 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
// clang-format off
// support for universal encoder ULR https://arxiv.org/pdf/1802.05368.pdf
cli.add<bool>("--ulr",
- "Enable ULR (Universal Language Representation)",
- false);
+ "Enable ULR (Universal Language Representation)");
// reading pre-trained universal embeddings for multi-sources.
// Note that source and target here is relative to ULR not the translation langs
// queries: EQ in Fig2 : is the unified embeddings projected to one space.
@@ -636,8 +808,7 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
"Path to file with universal sources embeddings of traget keys from projection into universal space",
"");
cli.add<bool>("--ulr-trainable-transformation",
- "Make Query Transformation Matrix A trainable",
- false);
+ "Make Query Transformation Matrix A trainable");
cli.add<int>("--ulr-dim-emb",
"ULR monolingual embeddings dimension");
cli.add<float>("--ulr-dropout",
@@ -649,62 +820,41 @@ void ConfigParser::addSuboptionsULR(cli::CLIWrapper& cli) {
// clang-format on
}
-void ConfigParser::expandAliases(cli::CLIWrapper& cli) {
- YAML::Node config;
- // The order of aliases does matter as later options overwrite earlier
-
- if(config_["best-deep"].as<bool>()) {
- config["layer-normalization"] = true;
- config["tied-embeddings"] = true;
- config["enc-type"] = "alternating";
- config["enc-cell-depth"] = 2;
- config["enc-depth"] = 4;
- config["dec-cell-base-depth"] = 4;
- config["dec-cell-high-depth"] = 2;
- config["dec-depth"] = 4;
- config["skip"] = true;
- }
- if(config) {
- auto success = cli.updateConfig(config);
- ABORT_IF(!success, "Unknown option(s) in aliases, check if aliases consist of correct options");
- }
-}
+cli::mode ConfigParser::getMode() const { return mode_; }
-void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
- cli::CLIWrapper cli(config_,
- "Marian: Fast Neural Machine Translation in C++",
- "General options",
- "",
- 40);
+Ptr<Options> ConfigParser::parseOptions(int argc, char** argv, bool doValidate){
+ cmdLine_ = escapeCmdLine(argc,argv);
- addOptionsGeneral(cli);
- addOptionsModel(cli);
+ // parse command-line options and fill wrapped YAML config
+ cli_.parse(argc, argv);
- // clang-format off
- switch(mode_) {
- case cli::mode::training:
- addOptionsTraining(cli);
- addOptionsValidation(cli);
- break;
- case cli::mode::translation:
- addOptionsTranslation(cli);
- break;
- case cli::mode::scoring:
- addOptionsScoring(cli);
- break;
+ if(get<bool>("authors")) {
+ std::cerr << authors() << std::endl;
+ exit(0);
}
- // clang-format on
- // parse command-line options and fill wrapped YAML config
- cli.parse(argc, argv);
+ if(get<bool>("cite")) {
+ std::cerr << citation() << std::endl;
+ exit(0);
+ }
+
+ auto buildInfo = get<std::string>("build-info");
+ if(!buildInfo.empty() && buildInfo != "false") {
+ if(buildInfo == "all")
+ std::cerr << cmakeBuildOptionsAdvanced() << std::endl;
+ else
+ std::cerr << cmakeBuildOptions() << std::endl;
+ exit(0);
+ }
// get paths to extra config files
auto configPaths = findConfigPaths();
if(!configPaths.empty()) {
auto config = loadConfigFiles(configPaths);
- auto success = cli.updateConfig(config);
- ABORT_IF(!success, "There are option(s) in a config file that are not expected");
+ cli_.updateConfig(config,
+ cli::OptionPriority::ConfigFile,
+ "There are option(s) in a config file that are not expected");
}
if(get<bool>("interpolate-env-vars")) {
@@ -712,21 +862,29 @@ void ConfigParser::parseOptions(int argc, char** argv, bool doValidate) {
}
if(doValidate) {
- // this aborts the program on first validation error
ConfigValidator(config_).validateOptions(mode_);
}
// remove extra config files from the config to avoid redundancy
config_.remove("config");
- if(has("dump-config") && get<std::string>("dump-config") != "false") {
- bool skipDefault = get<std::string>("dump-config") == "minimal";
+ if(!get<std::string>("dump-config").empty() && get<std::string>("dump-config") != "false") {
+ auto dumpMode = get<std::string>("dump-config");
config_.remove("dump-config");
- std::cout << cli.dumpConfig(skipDefault) << std::endl;
+
+ if(dumpMode == "expand") {
+ cli_.parseAliases();
+ }
+
+ bool minimal = (dumpMode == "minimal" || dumpMode == "expand");
+ std::cout << cli_.dumpConfig(minimal) << std::endl;
exit(0);
}
- expandAliases(cli);
+ cli_.parseAliases();
+ auto opts = New<Options>();
+ opts->merge(Config(*this).get());
+ return opts;
}
std::vector<std::string> ConfigParser::findConfigPaths() {
@@ -760,7 +918,8 @@ YAML::Node ConfigParser::loadConfigFiles(const std::vector<std::string>& paths)
for(auto& path : paths) {
// load single config file
- YAML::Node config = YAML::Load(io::InputFileStream(path));
+ io::InputFileStream strm(path);
+ YAML::Node config = YAML::Load(strm);
// expand relative paths if requested
if(config["relative-paths"] && config["relative-paths"].as<bool>()) {
@@ -787,7 +946,7 @@ YAML::Node ConfigParser::loadConfigFiles(const std::vector<std::string>& paths)
return configAll;
}
-YAML::Node ConfigParser::getConfig() const {
+const YAML::Node& ConfigParser::getConfig() const {
return config_;
}
} // namespace marian
diff --git a/src/common/config_parser.h b/src/common/config_parser.h
index de1cb70e..652a1d24 100755..100644
--- a/src/common/config_parser.h
+++ b/src/common/config_parser.h
@@ -14,21 +14,72 @@
namespace marian {
namespace cli {
-enum struct mode { training, translation, scoring };
+enum struct mode { training, translation, scoring, server };
} // namespace cli
/**
* @brief Command-line options parser
*
* New options and aliases should be defined within `addOptions*` methods.
+ * ... unless they are specific to certain executables.
+ * In that case, use a pattern like this (e.g., for a server):
+ * int main(int argc, char* argv[]) {
+ * ConfigParser cp(cli::mode::translation);
+ * cp.addOption<int>("--port", // option name
+ * "Server Options", // option group name
+ * "Port for server.", // help string
+ * 5678); // default value
+ * auto opts = cp.parseOptions(argc,argv,true); // 'true' for validation
+ * ...
+ *
+ *
*/
class ConfigParser {
public:
+
+ ConfigParser(cli::mode mode);
+
ConfigParser(int argc, char** argv, cli::mode mode, bool validate = false)
- : mode_(mode) {
+ : ConfigParser(mode) {
parseOptions(argc, argv, validate);
}
+ template<typename T>
+ ConfigParser&
+ addOption(const std::string& args,
+ const std::string& group,
+ const std::string& help,
+ const T val) {
+ std::string previous_group = cli_.switchGroup(group);
+ cli_.add<T>(args,help,val);
+ cli_.switchGroup(previous_group);
+ return *this;
+ }
+
+ template<typename T>
+ ConfigParser&
+ addOption(const std::string& args,
+ const std::string& group,
+ const std::string& help,
+ const T val,
+ const T implicit_val) {
+ std::string previous_group = cli_.switchGroup(group);
+ cli_.add<T>(args,help,val)->implicit_val(implicit_val);
+ cli_.switchGroup(previous_group);
+ return *this;
+ }
+
+ template<typename T>
+ ConfigParser&
+ addOption(const std::string& args,
+ const std::string& group,
+ const std::string& help) {
+ std::string previous_group = cli_.switchGroup(group);
+ cli_.add<T>(args,help);
+ cli_.switchGroup(previous_group);
+ return *this;
+ }
+
/**
* @brief Parse command-line options
*
@@ -45,14 +96,18 @@ public:
* @param argc
* @param argv
* @param validate Do or do not validate parsed options
+ * @return (YAML::Node const&)config_
*/
- void parseOptions(int argc, char** argv, bool validate);
-
- YAML::Node getConfig() const;
+ Ptr<Options> parseOptions(int argc, char** argv, bool validate);
+ YAML::Node const& getConfig() const;
+ cli::mode getMode() const;
+ std::string const& cmdLine() const;
private:
+ cli::CLIWrapper cli_;
cli::mode mode_;
YAML::Node config_;
+ std::string cmdLine_;
// Check if the config contains value for option key
bool has(const std::string& key) const {
@@ -68,17 +123,19 @@ private:
}
void addOptionsGeneral(cli::CLIWrapper&);
+ void addOptionsServer(cli::CLIWrapper&);
void addOptionsModel(cli::CLIWrapper&);
void addOptionsTraining(cli::CLIWrapper&);
void addOptionsValidation(cli::CLIWrapper&);
void addOptionsTranslation(cli::CLIWrapper&);
void addOptionsScoring(cli::CLIWrapper&);
+ void addAliases(cli::CLIWrapper&);
+
void addSuboptionsDevices(cli::CLIWrapper&);
void addSuboptionsBatching(cli::CLIWrapper&);
void addSuboptionsInputLength(cli::CLIWrapper&);
void addSuboptionsULR(cli::CLIWrapper&);
- void expandAliases(cli::CLIWrapper&);
// Extract paths to all config files found in the config object.
// Look at --config option and model.npz.yml files.
diff --git a/src/common/config_validator.cpp b/src/common/config_validator.cpp
index 5086c726..60917248 100755..100644
--- a/src/common/config_validator.cpp
+++ b/src/common/config_validator.cpp
@@ -10,7 +10,10 @@ bool ConfigValidator::has(const std::string& key) const {
return config_[key];
}
-ConfigValidator::ConfigValidator(const YAML::Node& config) : config_(config) {}
+ConfigValidator::ConfigValidator(const YAML::Node& config)
+ : config_(config),
+ dumpConfigOnly_(config["dump-config"] && !config["dump-config"].as<std::string>().empty()
+ && config["dump-config"].as<std::string>() != "false") {}
ConfigValidator::~ConfigValidator() {}
@@ -28,6 +31,9 @@ void ConfigValidator::validateOptions(cli::mode mode) const {
validateOptionsParallelData();
validateOptionsTraining();
break;
+ default:
+ ABORT("wrong CLI mode");
+ break;
}
// clang-format on
@@ -42,16 +48,25 @@ void ConfigValidator::validateOptionsTranslation() const {
ABORT_IF(models.empty() && configs.empty(),
"You need to provide at least one model file or a config file");
- auto vocabs = get<std::vector<std::string>>("vocabs");
- ABORT_IF(vocabs.empty(), "Translating, but vocabularies are not given!");
-
for(const auto& modelFile : models) {
filesystem::Path modelPath(modelFile);
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelFile);
}
+
+ auto vocabs = get<std::vector<std::string>>("vocabs");
+ ABORT_IF(vocabs.empty(), "Translating, but vocabularies are not given");
+
+ for(const auto& vocabFile : vocabs) {
+ filesystem::Path vocabPath(vocabFile);
+ ABORT_IF(!filesystem::exists(vocabPath), "Vocabulary file does not exist: " + vocabFile);
+ }
}
void ConfigValidator::validateOptionsParallelData() const {
+ // Do not check these constraints if only goal is to dump config
+ if(dumpConfigOnly_)
+ return;
+
auto trainSets = get<std::vector<std::string>>("train-sets");
ABORT_IF(trainSets.empty(), "No train sets given in config file or on command line");
@@ -62,17 +77,23 @@ void ConfigValidator::validateOptionsParallelData() const {
void ConfigValidator::validateOptionsScoring() const {
filesystem::Path modelPath(get<std::string>("model"));
-
ABORT_IF(!filesystem::exists(modelPath), "Model file does not exist: " + modelPath.string());
- ABORT_IF(get<std::vector<std::string>>("vocabs").empty(),
- "Scoring, but vocabularies are not given!");
+
+ auto vocabs = get<std::vector<std::string>>("vocabs");
+ ABORT_IF(vocabs.empty(), "Scoring, but vocabularies are not given");
+
+ for(const auto& vocabFile : vocabs) {
+ filesystem::Path vocabPath(vocabFile);
+ ABORT_IF(!filesystem::exists(vocabPath), "Vocabulary file does not exist: " + vocabFile);
+ }
}
void ConfigValidator::validateOptionsTraining() const {
auto trainSets = get<std::vector<std::string>>("train-sets");
ABORT_IF(has("embedding-vectors")
- && get<std::vector<std::string>>("embedding-vectors").size() != trainSets.size(),
+ && get<std::vector<std::string>>("embedding-vectors").size() != trainSets.size()
+ && !get<std::vector<std::string>>("embedding-vectors").empty(),
"There should be as many embedding vector files as training sets");
filesystem::Path modelPath(get<std::string>("model"));
@@ -84,12 +105,13 @@ void ConfigValidator::validateOptionsTraining() const {
ABORT_IF(!modelDir.empty() && !filesystem::isDirectory(modelDir),
"Model directory does not exist");
- ABORT_IF(
- has("valid-sets") && get<std::vector<std::string>>("valid-sets").size() != trainSets.size(),
- "There should be as many validation sets as training sets");
+ ABORT_IF(has("valid-sets")
+ && get<std::vector<std::string>>("valid-sets").size() != trainSets.size()
+ && !get<std::vector<std::string>>("valid-sets").empty(),
+ "There should be as many validation sets as training sets");
// validations for learning rate decaying
- ABORT_IF(get<double>("lr-decay") > 1.0, "Learning rate decay factor greater than 1.0 is unusual");
+ ABORT_IF(get<float>("lr-decay") > 1.f, "Learning rate decay factor greater than 1.0 is unusual");
auto strategy = get<std::string>("lr-decay-strategy");
diff --git a/src/common/config_validator.h b/src/common/config_validator.h
index 59cd1186..0e73a9e3 100644
--- a/src/common/config_validator.h
+++ b/src/common/config_validator.h
@@ -5,19 +5,20 @@
namespace marian {
-// TODO: Finally refactorize Config, Options, ConfigParser and ConfigValidator
-// classes.
class ConfigValidator {
private:
const YAML::Node& config_;
bool has(const std::string& key) const;
-
template <typename T>
T get(const std::string& key) const {
return config_[key].as<T>();
}
+ // The option --dump-config is used, so alleviate some constraints, e.g. we don't want to require
+ // --train-sets or --vocabs
+ bool dumpConfigOnly_{false};
+
void validateOptionsTranslation() const;
void validateOptionsParallelData() const;
void validateOptionsScoring() const;
diff --git a/src/common/definitions.h b/src/common/definitions.h
index 3b4a6edd..96eb6ed1 100755
--- a/src/common/definitions.h
+++ b/src/common/definitions.h
@@ -1,7 +1,8 @@
#pragma once
#include "common/logging.h"
-#include "shape.h"
+#include "common/shape.h"
+#include "common/intrusive_ptr.h"
#include <functional>
#include <iostream>
@@ -9,10 +10,33 @@
#include <string>
#include <vector>
-//#define THREAD_GUARD(body) std::thread([&]() { body; }).join()
+// The macro MAYBE_UNUSED is used to selectively disable
+// unused-variable warnings. C++17 defines the attribute
+// [[maybe_unused]], but I don't think we're at C++17 yet. We can add it when we reach C++17.
+// The compilers gcc and clang (and maybe others) define
+// __has_attribute and support __attribute__(unused) in C++11,
+#if defined __has_attribute
+# if __has_attribute(unused)
+# define MAYBE_UNUSED __attribute__((unused))
+# else
+# define MAYBE_UNUSED
+# endif
+#else
+# define MAYBE_UNUSED
+#endif
+
#define THREAD_GUARD(body) [&]() { body; }() // test if THREAD_GUARD is neccessary, remove if no problems occur.
#define NodeOp(op) [=]() { op; }
+// helper macro to disable optimization (gcc only)
+// To use this, just insert DONT_OPTIMIZE right before the function definition
+// (e.g. where the "static" keyword would go).
+#ifdef __GNUC__
+#define DONT_OPTIMIZE __attribute__((optimize("O0")))
+#else
+#define DONT_OPTIMIZE // silently ignore on Visual Studio, where this is less of a problem
+#endif
+
namespace marian {
// Type to be used for all index types, e.g. for integer tensors for rows operator.
@@ -21,12 +45,21 @@ namespace marian {
// This minimizes bandwith at little cost.
typedef uint32_t IndexType;
+// @TODO: come up with better short name. "I..." stands for interface now. Here it stands
+// for "intrusive". Not a good overlap.
template <class T>
-using Ptr = std::shared_ptr<T>;
+using IPtr = IntrusivePtr<T>;
template <class T>
using UPtr = std::unique_ptr<T>;
+// @TODO: come up with better short name. "I..." stands for interface now.
+template <class T>
+using IWeak = T*;
+
+template <class T>
+using Ptr = std::shared_ptr<T>;
+
template <class T>
using Weak = std::weak_ptr<T>;
@@ -42,6 +75,18 @@ Ptr<T> New(Ptr<T> p) {
return Ptr<T>(p);
}
+/** @brief Creates InstrusivePtr of any type, passes all arguments to any available
+ * constructor */
+template <class T, typename... Args>
+IPtr<T> INew(Args&&... args) {
+ return IPtr<T>(new T(std::forward<Args>(args)...));
+}
+
+template <class T>
+IPtr<T> INew(Ptr<T> p) {
+ return IPtr<T>(p);
+}
+
enum class DeviceType : size_t { gpu = 0, cpu = 1 };
struct DeviceId {
@@ -51,8 +96,16 @@ struct DeviceId {
DeviceId() : no{0}, type{DeviceType::gpu} {}
DeviceId(size_t no_, DeviceType type_) : no(no_), type(type_) {}
+ std::string typeAsString() const {
+ return (type == DeviceType::gpu ? "gpu" : "cpu");
+ }
+
+ operator std::string() const {
+ return typeAsString() + std::to_string(no);
+ }
+
friend std::ostream& operator<<(std::ostream& out, DeviceId deviceId) {
- out << (deviceId.type == DeviceType::gpu ? "gpu" : "cpu") << deviceId.no;
+ out << std::string(deviceId);
return out;
}
@@ -81,12 +134,14 @@ const DeviceId GPU5{5, DeviceType::gpu};
const DeviceId GPU6{6, DeviceType::gpu};
const DeviceId GPU7{7, DeviceType::gpu};
+// These are many small objects, hence use IntrusivePtr
class TensorBase;
-typedef Ptr<TensorBase> Tensor;
+typedef IPtr<TensorBase> Tensor;
+// These are many small objects, hence use IntrusivePtr
template <class DataType>
class Chainable;
-typedef Ptr<Chainable<Tensor>> Expr;
+typedef IPtr<Chainable<Tensor>> Expr;
class OptimizerBase;
typedef Ptr<OptimizerBase> OptimizerBasePtr;
diff --git a/src/common/fastopt.cpp b/src/common/fastopt.cpp
new file mode 100644
index 00000000..9af8e844
--- /dev/null
+++ b/src/common/fastopt.cpp
@@ -0,0 +1,112 @@
+#include "common/fastopt.h"
+
+#include <utility>
+
+namespace marian {
+
+const std::unique_ptr<const FastOpt> FastOpt::uniqueNullPtr{nullptr};
+
+// see fastopt.h for comments
+namespace fastopt_helpers {
+
+// helper structs for dynamic type conversion and specializations
+// for different conversion scenarios.
+
+// general template, mostly for numerical and logical types
+template <typename To, typename From>
+struct Convert {
+ static inline To apply(const From& from) {
+ return (To)from;
+ }
+};
+
+// specialization for translating from string, @TODO check if this is required at all, mostly for compilation now.
+template <typename To>
+struct Convert<To, std::string> {
+ static inline To apply(const std::string& /* from */) {
+ ABORT("Not implemented");
+ }
+};
+
+// convert anything to string, checked at compile-time
+template <typename From>
+struct Convert<std::string, From> {
+ static inline std::string apply(const From& from) {
+ return std::to_string(from);
+ }
+};
+
+// do nothing conversion for std::string
+template <>
+struct Convert<std::string, std::string> {
+ static inline std::string apply(const std::string& from) {
+ return from;
+ }
+};
+
+// helper class for FastOpt::as<T>() used for specializations
+template <typename T>
+T As<T>::apply(const FastOpt& node) {
+ ABORT_IF(!node.isScalar(), "Node is not a scalar node");
+
+ if(node.isBool())
+ return Convert<T, bool>::apply(node.value_.as<bool>());
+ else if(node.isInt())
+ return Convert<T, int64_t>::apply(node.value_.as<int64_t>());
+ else if(node.isFloat())
+ return Convert<T, double>::apply(node.value_.as<double>());
+ else if(node.isString())
+ return Convert<T, std::string>::apply(node.value_.as<std::string>());
+ else {
+ ABORT("Casting of value failed");
+ }
+}
+
+// specializations for simple types
+template struct As<bool>;
+template struct As<int>;
+template struct As<unsigned long>;
+template struct As<float>;
+template struct As<double>;
+template struct As<std::string>;
+
+// specialization of above class for std::vector<T>
+template <typename T>
+std::vector<T> As<std::vector<T>>::apply(const FastOpt& node) {
+ ABORT_IF(!node.isSequence(), "Node is not a sequence node");
+
+ std::vector<T> seq;
+ for(const auto& elem : node.array_)
+ seq.push_back(elem->as<T>());
+ return seq;
+}
+
+// specializations for simple vector types
+template struct As<std::vector<bool>>;
+template struct As<std::vector<int>>;
+// Windows, Linux based OS and Mac have different type definitions for 'unsigned long'.
+// So, we need an explicit definitions for uint64_t, that cover different platforms.
+// Otherwise, there's a linking error on windows or Linux or Mac.
+// https://software.intel.com/en-us/articles/size-of-long-integer-type-on-different-architecture-and-os/
+// https://stackoverflow.com/questions/32021860/c-should-you-size-t-with-a-regular-array
+// MacOS: size_t = unsigned long (8 bytes), uint64_t = unsigned long long (8 bytes)
+// Linux: size_t = unsigned long (8 bytes), uint64_t = unsigned long (8 bytes)
+// Windows: size_t = unsigned long long (8 bytes), uint64_t = unsigned long long (8 bytes)
+template struct As<std::vector<unsigned long long>>;
+template struct As<std::vector<unsigned long>>;
+template struct As<std::vector<float>>;
+template struct As<std::vector<double>>;
+template struct As<std::vector<std::string>>;
+
+// specialization of above class for std::pair<T>
+template <typename T1, typename T2>
+std::pair<T1, T2> As<std::pair<T1, T2>>::apply(const FastOpt& node) {
+ ABORT_IF(!node.isSequence(), "Node is not a sequence node");
+ ABORT_IF(node.size() != 2, "Sequence must contain two elements in order to convert to pair");
+ return std::make_pair(node[0].as<T1>(), node[1].as<T2>());
+}
+
+template struct As<std::pair<int, int>>;
+
+}
+}
diff --git a/src/common/fastopt.h b/src/common/fastopt.h
new file mode 100644
index 00000000..3f735660
--- /dev/null
+++ b/src/common/fastopt.h
@@ -0,0 +1,379 @@
+#pragma once
+
+#include "common/definitions.h"
+#include "3rd_party/any_type.h"
+#include "3rd_party/phf/phf.h"
+#include "3rd_party/yaml-cpp/yaml.h"
+
+// This file contains code to create a fast access option class,
+// meant as a replacment/supplement to YAML::Node.
+
+namespace marian {
+
+namespace crc {
+// has to stay in header due to constexpr
+
+// This code comes from https://notes.underscorediscovery.com/constexpr-fnv1a/
+// and is distributed as public domain as stated by the author under that link
+
+// constants for hash computations
+constexpr uint64_t val_64_const = 0xcbf29ce484222325;
+constexpr uint64_t prime_64_const = 0x100000001b3;
+
+// recursive compile-time hash, looking for stack-overflow source
+inline constexpr uint64_t
+hash_64_fnv1a_const(const char* const str,
+ const uint64_t value = val_64_const) noexcept {
+ return (str[0] == '\0') ? value :
+ hash_64_fnv1a_const(&str[1], (value ^ uint64_t(str[0])) * prime_64_const);
+}
+
+// Compile time string hashing. Should work particularly well for option look up with explicitly used keys like options->get("dim-input");
+inline constexpr uint64_t crc(const char* const str) noexcept {
+ return hash_64_fnv1a_const(str);
+}
+
+}
+
+/*****************************************************************************/
+
+// PerfectHash constructs a perfect hash for a set K of n numeric keys. The size of
+// the hash is m > n (not much larger) and n << max(K) (much smaller). If I am not wrong m
+// is the next power of 2 larger than n? We then build an array of size m with n fields defined.
+// m - n fields stay undefined (a bit of waste).
+class PerfectHash {
+private:
+ phf phf_;
+
+ PerfectHash(const uint64_t keys[], size_t num) {
+ int error = PHF::init<uint64_t, true>(&phf_, keys, num,
+ /* bucket size */ 4,
+ /* loading factor */ 90,
+ /* seed */ 123456);
+ ABORT_IF(error != 0, "PHF error {}", error);
+ }
+
+public:
+
+ PerfectHash(const std::vector<uint64_t>& v)
+ : PerfectHash(v.data(), v.size()) { }
+
+ ~PerfectHash() {
+ PHF::destroy(&phf_);
+ }
+
+ uint32_t operator[](const uint64_t& key) const {
+ return PHF::hash<uint64_t>(const_cast<phf*>(&phf_), key);
+ }
+
+ uint32_t operator[](const char* const keyStr) const {
+ return (*this)[crc::crc(keyStr)];
+ }
+
+ size_t size() const {
+ return phf_.m;
+ }
+};
+
+/*****************************************************************************/
+
+class FastOpt;
+
+// helper class for conversion, see fastopt.cpp
+namespace fastopt_helpers {
+ template <typename T>
+ struct As {
+ static T apply(const FastOpt&);
+ };
+
+ template <typename T>
+ struct As<std::vector<T>> {
+ static std::vector<T> apply(const FastOpt&);
+ };
+
+ template <typename T1, typename T2>
+ struct As<std::pair<T1, T2>> {
+ static std::pair<T1, T2> apply(const FastOpt&);
+ };
+}
+
+// Fast access option class, meant as a replacment/supplement to YAML::Node.
+// Relatively expensive to construct, fast to access (not visible in profiler)
+// via std::vector or perfect hash. The perfect hash only requires a few look-ups
+// and arithmentic operations, still O(1).
+// Still requires YAML::Node support for parsing and modification via rebuilding.
+class FastOpt {
+private:
+ template <typename T>
+ friend struct fastopt_helpers::As;
+
+public:
+ // Node types for FastOpt, seem to be enough to cover YAML:NodeType
+ enum struct NodeType {
+ Null, Bool, Int64, Float64, String, Sequence, Map
+ };
+
+private:
+ any_type value_;
+ std::unique_ptr<const PerfectHash> ph_;
+ std::vector<std::unique_ptr<const FastOpt>> array_;
+ NodeType type_{NodeType::Null};
+
+ static const std::unique_ptr<const FastOpt> uniqueNullPtr; // return this unique_ptr if key not found, equivalent to nullptr
+
+ uint64_t fingerprint_{0}; // When node is used as a value in a map, used to check if the perfect hash
+ // returned the right value (they can produce false positives)
+ size_t elements_{0}; // Number of elements if isMap or isSequence is true, 0 otherwise.
+
+ // Used to find elements if isSequence() is true.
+ inline const std::unique_ptr<const FastOpt>& arrayLookup(size_t keyId) const {
+ if(keyId < array_.size())
+ return array_[keyId];
+ else
+ return uniqueNullPtr;
+ }
+
+ // Used to find elements if isMap() is true.
+ inline const std::unique_ptr<const FastOpt>& phLookup(size_t keyId) const {
+ if(ph_)
+ return array_[(*ph_)[keyId]];
+ else
+ return uniqueNullPtr;
+ }
+
+ // Build Null node.
+ void makeNull() {
+ elements_ = 0;
+ type_ = NodeType::Null;
+
+ ABORT_IF(ph_, "ph_ should be undefined");
+ ABORT_IF(!array_.empty(), "array_ should be empty");
+ }
+
+ // Build Scalar node via controlled failure to convert from a YAML::Node object.
+ void makeScalar(const YAML::Node& v) {
+ elements_ = 0;
+ try {
+ // Cast node to text first, that works for any scalar node and test that it does not contain single characters
+ // that according to YAML could be boolean values. Unfortunately, we do not have any type information at this point.
+ // This means we are disabling support for boolean values in YAML that are expressed with these characters.
+ auto asText = v.as<std::string>();
+ if(asText.size() == 1 && asText.find_first_of("nyNYtfTF") == 0) // @TODO: should we disallow other strings too?
+ throw YAML::BadConversion(YAML::Mark()); // get's picked up by next catch block
+
+ value_ = v.as<bool>();
+ type_ = NodeType::Bool;
+ } catch(const YAML::BadConversion& /*e*/) {
+ try {
+ value_ = v.as<int64_t>();
+ type_ = NodeType::Int64;
+ } catch(const YAML::BadConversion& /*e*/) {
+ try {
+ value_ = v.as<double>();
+ type_ = NodeType::Float64;
+ } catch(const YAML::BadConversion& /*e*/) {
+ try {
+ value_ = v.as<std::string>();
+ type_ = NodeType::String;
+ } catch (const YAML::BadConversion& /*e*/) {
+ ABORT("Cannot convert YAML node {}", v);
+ }
+ }
+ }
+ }
+
+ ABORT_IF(ph_, "ph_ should be undefined");
+ ABORT_IF(!array_.empty(), "array_ should be empty");
+ }
+
+ // Build a Sequence node, can by converted to std::vector<T> if elements can be converted to T.
+ void makeSequence(const std::vector<YAML::Node>& v) {
+ elements_ = v.size();
+ ABORT_IF(!array_.empty(), "array_ is not empty??");
+ for(size_t pos = 0; pos < v.size(); ++pos) {
+ array_.emplace_back(new FastOpt(v[pos], pos));
+ }
+ type_ = NodeType::Sequence;
+
+ ABORT_IF(ph_, "ph_ should be undefined");
+ }
+
+ // Build a Map node.
+ void makeMap(const std::map<uint64_t, YAML::Node>& m) {
+ std::vector<uint64_t> keys;
+ for(const auto& it : m)
+ keys.push_back(it.first);
+
+ ABORT_IF(ph_, "ph_ is already defined??");
+ ph_.reset(new PerfectHash(keys));
+
+ ABORT_IF(!array_.empty(), "array_ is not empty??");
+
+ // for lack of resize_emplace
+ for(int i = 0; i < ph_->size(); ++i)
+ array_.emplace_back(nullptr);
+ elements_ = keys.size();
+
+ for(const auto& it : m) {
+ uint64_t key = it.first;
+ size_t pos = (*ph_)[key];
+ array_[pos].reset(new FastOpt(it.second, key));
+ }
+
+ type_ = NodeType::Map;
+ }
+
+ // Build a Map node, uses std::string as key, which gets hashed to size_t and used in the function above.
+ void makeMap(const std::map<std::string, YAML::Node>& m) {
+ std::map<uint64_t, YAML::Node> mi;
+ for(const auto& it : m) {
+ auto key = it.first.c_str();
+ mi[crc::crc(key)] = it.second;
+ }
+
+ makeMap(mi);
+ }
+
+ // Only build from YAML::Node
+ FastOpt(const FastOpt&) = delete;
+ FastOpt() = delete;
+
+ void construct(const YAML::Node& node) {
+ switch(node.Type()) {
+ case YAML::NodeType::Scalar:
+ makeScalar(node);
+ break;
+ case YAML::NodeType::Sequence: {
+ std::vector<YAML::Node> nodesVec;
+ for(auto&& n : node)
+ nodesVec.push_back(n);
+ makeSequence(nodesVec);
+ } break;
+ case YAML::NodeType::Map: {
+ std::map<std::string, YAML::Node> nodesMap;
+ for(auto& n : node) {
+ auto key = n.first.as<std::string>();
+ nodesMap[key] = n.second;
+ }
+ makeMap(nodesMap);
+ } break;
+ case YAML::NodeType::Undefined:
+ case YAML::NodeType::Null:
+ makeNull();
+ }
+ }
+
+public:
+ // Constructor to recursively create a FastOpt object from a YAML::Node following the yaml structure.
+ FastOpt(const YAML::Node& node)
+ { construct(node); }
+
+ FastOpt(const YAML::Node& node, uint64_t fingerprint)
+ : fingerprint_{fingerprint}
+ { construct(node); }
+
+ bool isSequence() const {
+ return type_ == NodeType::Sequence;
+ }
+
+ bool isMap() const {
+ return type_ == NodeType::Map;
+ }
+
+ bool isScalar() const {
+ return type_ == NodeType::Bool
+ || type_ == NodeType::Float64
+ || type_ == NodeType::Int64
+ || type_ == NodeType::String;
+ }
+
+ bool isNull() const {
+ return type_ == NodeType::Null;
+ }
+
+ bool isInt() const {
+ return type_ == NodeType::Int64;
+ }
+
+ bool isBool() const {
+ return type_ == NodeType::Bool;
+ }
+
+ bool isFloat() const {
+ return type_ == NodeType::Float64;
+ }
+
+ bool isString() const {
+ return type_ == NodeType::String;
+ }
+
+ // actual number of elements in a sequence or map, 0 (not 1) for scalar nodes.
+ // 0 here means rather "not applicable".
+ size_t size() const {
+ return elements_;
+ }
+
+ // replace current node with an externally built FastOpt object
+ void swap(FastOpt& other) {
+ std::swap(value_, other.value_);
+ std::swap(ph_, other.ph_);
+ std::swap(array_, other.array_);
+ std::swap(type_, other.type_);
+ std::swap(elements_, other.elements_);
+ // leave fingerprint alone as it needed by parent node.
+ }
+
+ // Is the hashed key in a map?
+ bool has(size_t keyId) const {
+ if(isMap() && elements_ > 0) {
+ const auto& ptr = phLookup(keyId);
+ return ptr ? ptr->fingerprint_ == keyId : false;
+ } else {
+ return false;
+ }
+ }
+
+ bool has(const char* const key) const {
+ return has(crc::crc(key));
+ }
+
+ bool has(const std::string& key) const {
+ return has(key.c_str());
+ }
+
+ // convert to requested type
+ template <typename T>
+ inline T as() const {
+ return fastopt_helpers::As<T>::apply(*this);
+ }
+
+ // access sequence or map element
+ const FastOpt& operator[](size_t keyId) const {
+ if(isSequence()) {
+ const auto& ptr = arrayLookup(keyId);
+ ABORT_IF(!ptr, "Unseen key {}" , keyId);
+ return *ptr;
+ } else if(isMap()) {
+ const auto& ptr = phLookup(keyId);
+ ABORT_IF(!ptr || ptr->fingerprint_ != keyId, "Unseen key {}", keyId);
+ return *ptr;
+ } else {
+ ABORT("Not a sequence or map node");
+ }
+ }
+
+ const FastOpt& operator[](int key) const {
+ return operator[]((size_t)key);
+ }
+
+ const FastOpt& operator[](const char* const key) const {
+ // MacOS requires explicit cast to size_t before we can use it.
+ return operator[]((size_t)crc::crc(key));
+ }
+
+ const FastOpt& operator[](const std::string& key) const {
+ return operator[](key.c_str());
+ }
+};
+
+}
diff --git a/src/common/file_stream.cpp b/src/common/file_stream.cpp
new file mode 100755
index 00000000..81504830
--- /dev/null
+++ b/src/common/file_stream.cpp
@@ -0,0 +1,172 @@
+#include "common/file_stream.h"
+
+#include <streambuf>
+#include <string>
+#include <vector>
+#ifdef _MSC_VER
+#include <io.h>
+#include <windows.h>
+#include <fcntl.h>
+#include <stdlib.h>
+#else
+#include <sys/types.h>
+#include <unistd.h>
+#endif
+
+namespace marian {
+namespace io {
+
+///////////////////////////////////////////////////////////////////////////////////////////////
+InputFileStream::InputFileStream(const std::string &file)
+ : std::istream(NULL), file_(file) {
+ ABORT_IF(!marian::filesystem::exists(file_), "File '{}' does not exist", file);
+
+ streamBuf1_.reset(new std::filebuf());
+ auto ret = static_cast<std::filebuf*>(streamBuf1_.get())->open(file.c_str(), std::ios::in | std::ios::binary);
+ ABORT_IF(!ret, "File cannot be opened", file);
+ ABORT_IF(ret != streamBuf1_.get(), "Return value is not equal to streambuf pointer, that is weird");
+
+ if(file_.extension() == marian::filesystem::Path(".gz")) {
+ streamBuf2_.reset(new zstr::istreambuf(streamBuf1_.get()));
+ this->init(streamBuf2_.get());
+ } else {
+ this->init(streamBuf1_.get());
+ }
+}
+
+InputFileStream::~InputFileStream() {}
+
+bool InputFileStream::empty() {
+ return this->peek() == std::ifstream::traits_type::eof();
+}
+
+void InputFileStream::setbufsize(size_t size) {
+ rdbuf()->pubsetbuf(0, 0);
+ readBuf_.resize(size);
+ rdbuf()->pubsetbuf(readBuf_.data(), readBuf_.size());
+}
+
+std::string InputFileStream::getFileName() const {
+ return file_.string();
+}
+
+// wrapper around std::getline() that handles Windows input files with extra CR
+// chars at the line end
+std::istream &getline(std::istream &in, std::string &line) {
+ std::getline(in, line);
+ // bad() seems to be correct here. Should not abort on EOF.
+ ABORT_IF(in.bad(), "Error reading from stream");
+ // strip terminal CR if present
+ if(in && !line.empty() && line.back() == in.widen('\r'))
+ line.pop_back();
+ return in;
+}
+///////////////////////////////////////////////////////////////////////////////////////////////
+OutputFileStream::OutputFileStream(const std::string &file)
+ : std::ostream(NULL), file_(file) {
+ streamBuf1_.reset(new std::filebuf());
+ auto ret = static_cast<std::filebuf*>(streamBuf1_.get())->open(file.c_str(), std::ios::out | std::ios_base::binary);
+ ABORT_IF(!ret, "File cannot be opened", file);
+ ABORT_IF(ret != streamBuf1_.get(), "Return value is not equal to streambuf pointer, that is weird");
+
+ if(file_.extension() == marian::filesystem::Path(".gz")) {
+ streamBuf2_.reset(new zstr::ostreambuf(streamBuf1_.get()));
+ this->init(streamBuf2_.get());
+ } else {
+ this->init(streamBuf1_.get());
+ }
+}
+
+OutputFileStream::OutputFileStream()
+ : std::ostream(NULL) {}
+
+OutputFileStream::~OutputFileStream() {
+ this->flush();
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////
+TemporaryFile::TemporaryFile(const std::string &base, bool earlyUnlink)
+ : OutputFileStream(), unlink_(earlyUnlink) {
+ std::string baseTemp(base);
+ NormalizeTempPrefix(baseTemp);
+ MakeTemp(baseTemp);
+
+ inSteam_ = UPtr<io::InputFileStream>(new io::InputFileStream(file_.string()));
+ if(unlink_) {
+ ABORT_IF(remove(file_.string().c_str()), "Error while deleting '{}'", file_.string());
+ }
+}
+
+TemporaryFile::~TemporaryFile() {
+ if(!unlink_)
+ // We do not check for errors here as this is the destructor and we cannot really fix an error anyway.
+ remove(file_.string().c_str()), "Error while deleting '{}'", file_.string();
+}
+
+void TemporaryFile::NormalizeTempPrefix(std::string &base) const {
+ if(base.empty())
+ return;
+
+#ifdef _MSC_VER
+ if(base.substr(0, 4) == "/tmp")
+ base = getenv("TMP");
+#else
+ if(base[base.size() - 1] == '/')
+ return;
+ struct stat sb;
+ // It's fine for it to not exist.
+ if(stat(base.c_str(), &sb) == -1)
+ return;
+ if(S_ISDIR(sb.st_mode))
+ base += '/';
+#endif
+}
+void TemporaryFile::MakeTemp(const std::string &base) {
+#ifdef _MSC_VER
+ char *name = tempnam(base.c_str(), "marian.");
+ ABORT_IF(name == NULL, "Error while making a temporary based on '{}'", base);
+
+ int oflag = _O_RDWR | _O_CREAT | _O_EXCL;
+ if(unlink_)
+ oflag |= _O_TEMPORARY;
+
+ int fd = open(name, oflag, _S_IREAD | _S_IWRITE);
+ ABORT_IF(fd == -1, "Error while making a temporary based on '{}'", base);
+
+#else
+ // create temp file
+ std::string name(base);
+ name += "marian.XXXXXX";
+ name.push_back(0);
+ int fd = mkstemp(&name[0]);
+ ABORT_IF(fd == -1, "Error creating temp file {}", name);
+
+ file_ = name;
+#endif
+
+ // open again with c++
+ streamBuf1_.reset(new std::filebuf());
+ auto ret = static_cast<std::filebuf*>(streamBuf1_.get())->open(name, std::ios::out | std::ios_base::binary);
+ ABORT_IF(!streamBuf1_, "File cannot be temp opened", name);
+ ABORT_IF(ret != streamBuf1_.get(), "Return value is not equal to streambuf pointer, that is weird");
+
+ this->init(streamBuf1_.get());
+
+ // close original file descriptor
+ ABORT_IF(close(fd), "Can't close file descriptor", name);
+
+#ifdef _MSC_VER
+ free(name);
+#endif
+}
+
+UPtr<InputFileStream> TemporaryFile::getInputStream() {
+ return std::move(inSteam_);
+}
+
+std::string TemporaryFile::getFileName() const {
+ return file_.string();
+}
+
+} // namespace io
+} // namespace marian
diff --git a/src/common/file_stream.h b/src/common/file_stream.h
index 7f5236f7..9bbf3359 100755..100644
--- a/src/common/file_stream.h
+++ b/src/common/file_stream.h
@@ -1,309 +1,100 @@
#pragma once
+#include <fstream>
+#include <iostream>
+#include <memory>
+#include "common/definitions.h"
#include "common/filesystem.h"
#include "common/logging.h"
-#include "common/definitions.h"
+// Even when compiling with clang, __GNUC__ may be defined, so
+// we need to add some extra checks to avoid compile errors with
+// respect to -Wsuggest-override.
#ifdef __GNUC__
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wsuggest-override"
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wunused-value"
+# if defined(__has_warning)
+# if __has_warning("-Wsuggest-override")
+# pragma GCC diagnostic ignored "-Wsuggest-override"
+# endif
+# else
+# pragma GCC diagnostic ignored "-Wsuggest-override"
+# endif
#endif
-#include "3rd_party/zstr/zstr.hpp"
-#include <boost/iostreams/device/file_descriptor.hpp>
-#include <boost/iostreams/stream_buffer.hpp>
-#ifdef __GNUC__
-#pragma GCC diagnostic pop
-#endif
-
-#include <iostream>
-#include <memory>
#ifdef _MSC_VER
-#include <fcntl.h>
-#include <io.h>
-#include <stdlib.h>
-#endif
-
-namespace marian {
-namespace io {
-
-class TemporaryFile {
-private:
- int fd_{-1};
- bool unlink_;
- std::string name_;
-
-#ifndef _MSC_VER
- int mkstemp_and_unlink(char* tmpl) {
- int ret = mkstemp(tmpl);
- if(unlink_ && ret != -1) {
- ABORT_IF(unlink(tmpl), "Error while deleting '{}'", tmpl);
- }
- return ret;
- }
-#endif
-
-
- int MakeTemp(const std::string& base) {
-#ifdef _MSC_VER
- char* name = tempnam(base.c_str(), "marian.");
- ABORT_IF(name == NULL,
- "Error while making a temporary based on '{}'",
- base);
-
- int oflag = _O_RDWR | _O_CREAT | _O_EXCL;
- if (unlink_) oflag |= _O_TEMPORARY;
-
- int ret = open(name, oflag, _S_IREAD | _S_IWRITE);
- ABORT_IF(ret == -1,
- "Error while making a temporary based on '{}'",
- base);
-
- name_ = name;
- free(name);
-
- return ret;
-#else
- std::string name(base);
- name += "marian.XXXXXX";
- name.push_back(0);
- int ret;
- ABORT_IF(-1 == (ret = mkstemp_and_unlink(&name[0])),
- "Error while making a temporary based on '{}'",
- base);
- name_ = name;
- return ret;
+#pragma warning(push) // 4101: 'identifier' : unreferenced local variable. One parameter variable in zstr.hpp is not used.
+#pragma warning(disable : 4101)
#endif
- }
-
- void NormalizeTempPrefix(std::string& base) {
- if(base.empty())
- return;
-
+#include "3rd_party/zstr/zstr.hpp"
#ifdef _MSC_VER
- if(base.substr(0,4) == "/tmp")
- base = getenv("TMP");
-#else
- if(base[base.size() - 1] == '/')
- return;
- struct stat sb;
- // It's fine for it to not exist.
- if(stat(base.c_str(), &sb) == - 1)
- return;
- if(S_ISDIR(sb.st_mode))
- base += '/';
+#pragma warning(pop)
#endif
- }
-
-public:
- TemporaryFile(const std::string base = "/tmp/", bool earlyUnlink = true)
- : unlink_(earlyUnlink) {
- std::string baseTemp(base);
- NormalizeTempPrefix(baseTemp);
- fd_ = MakeTemp(baseTemp);
- }
-
- ~TemporaryFile() {
-#ifdef _MSC_VER
- if (fd_ == -1)
- return;
-
- if(close(fd_)) {
- std::cerr << "Could not close file " << fd_ << std::endl;
- std::abort();
- }
-
- if(!unlink_) {
- ABORT_IF(remove(name_.c_str()), "Error while deleting '{}'", name_);
- }
-#else
- if(fd_ != -1 && !unlink_) {
- ABORT_IF(unlink(name_.c_str()), "Error while deleting '{}'", name_);
- }
- if(fd_ != -1 && close(fd_)) {
- std::cerr << "Could not close file " << fd_ << std::endl;
- std::abort();
- }
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
#endif
- }
-
- int getFileDescriptor() { return fd_; }
- std::string getFileName() { return name_; }
-};
+namespace marian {
+namespace io {
-class InputFileStream {
+//////////////////////////////////////////////////////////////////////////////////////////////
+class InputFileStream : public std::istream {
public:
- InputFileStream(const std::string& file)
- : file_(file) {
- ABORT_IF(!marian::filesystem::exists(file_), "File '{}' does not exist", file);
-
- if(file_.extension() == marian::filesystem::Path(".gz"))
- // @TODO: consider make_unique for next refactoring
- istream_.reset(new zstr::ifstream(file_.string()));
- else
- istream_.reset(new std::ifstream(file_.string()));
- }
-
- InputFileStream(TemporaryFile& tempfile)
- : fds_(tempfile.getFileDescriptor(), boost::iostreams::never_close_handle) {
- lseek(tempfile.getFileDescriptor(), 0, SEEK_SET);
-
- namespace bio = boost::iostreams;
- fdsBuffer_.reset(new bio::stream_buffer<bio::file_descriptor_source>(fds_));
- istream_.reset(new std::istream(fdsBuffer_.get()));
- }
-
- InputFileStream(std::istream& strm)
- : istream_(new std::istream(strm.rdbuf())) {}
-
- operator std::istream&() { return *istream_; }
-
- operator bool() { return (bool)*istream_; }
+ explicit InputFileStream(const std::string& file);
+ virtual ~InputFileStream();
- bool bad() const {
- return istream_->bad();
- }
-
- bool fail() const {
- return istream_->fail();
- }
-
- char widen(char c) {
- return istream_->widen(c);
- }
-
- std::string path() { return file_.string(); }
-
- bool empty() { return istream_->peek() == std::ifstream::traits_type::eof(); }
-
- void setbufsize(size_t size) const {
- istream_->rdbuf()->pubsetbuf(0, 0);
- readBuf_.resize(size);
- istream_->rdbuf()->pubsetbuf(readBuf_.data(), readBuf_.size());
- }
+ bool empty();
+ void setbufsize(size_t size);
+ std::string getFileName() const;
- template <typename T>
- friend InputFileStream& operator>>(InputFileStream& stream, T& t) {
- *stream.istream_ >> t;
- // bad() seems to be correct here. Should not abort on EOF.
- ABORT_IF(stream.bad(), "Error reading from file '{}'", stream.path());
- return stream;
- }
-
- template <typename T>
- size_t read(T* ptr, size_t num = 1) {
- istream_->read((char*)ptr, num * sizeof(T));
- // fail() seems to be correct here. Failure to read should abort.
- ABORT_IF(fail(), "Error reading from file '{}'", path());
- return num * sizeof(T);
- }
-
-private:
+protected:
marian::filesystem::Path file_;
- std::unique_ptr<std::istream> istream_;
-
- boost::iostreams::file_descriptor_source fds_;
- mutable std::vector<char> readBuf_; // for setbuf()
- std::unique_ptr<boost::iostreams::stream_buffer<boost::iostreams::file_descriptor_source>> fdsBuffer_;
-
-
+ std::unique_ptr<std::streambuf> streamBuf1_;
+ std::unique_ptr<std::streambuf> streamBuf2_;
+ std::vector<char> readBuf_;
};
-// wrapper around std::getline() that handles Windows input files with extra CR
-// chars at the line end
-static inline InputFileStream& getline(InputFileStream& in, std::string& line) {
- std::getline((std::istream&)in, line);
- // bad() seems to be correct here. Should not abort on EOF.
- ABORT_IF(in.bad(), "Error reading from file '{}'", in.path());
- // strip terminal CR if present
- if(in && !line.empty() && line.back() == in.widen('\r'))
- line.pop_back();
- return in;
-}
-
-// wrapper around std::getline() that handles Windows input files with extra CR
-// chars at the line end
-static inline InputFileStream& getline(InputFileStream& in, std::string& line, char delim) {
- std::getline((std::istream&)in, line, delim);
- // bad() seems to be correct here. Should not abort on EOF.
- ABORT_IF(in.bad(), "Error reading from file '{}'", in.path());
- // strip terminal CR if present
- if(in && !line.empty() && line.back() == in.widen('\r'))
- line.pop_back();
- return in;
-}
+std::istream& getline(std::istream& in, std::string& line);
-class OutputFileStream {
+//////////////////////////////////////////////////////////////////////////////////////////////
+class OutputFileStream : public std::ostream {
public:
- OutputFileStream(const std::string& file) : file_(file) {
- if(file_.extension() == marian::filesystem::Path(".gz"))
- ostream_.reset(new zstr::ofstream(file_.string()));
- else
- ostream_.reset(new std::ofstream(file_.string()));
-
- ABORT_IF(!marian::filesystem::exists(file_), "File '{}' could not be opened", file);
- }
-
- OutputFileStream(TemporaryFile& tempfile)
- : fds_(tempfile.getFileDescriptor(), boost::iostreams::never_close_handle) {
- lseek(tempfile.getFileDescriptor(), 0, SEEK_SET);
-
- namespace bio = boost::iostreams;
- fdsBuffer_.reset(new bio::stream_buffer<bio::file_descriptor_sink>(fds_));
- ostream_.reset(new std::ostream(fdsBuffer_.get()));
- }
-
- OutputFileStream(std::ostream& strm) {
- ostream_.reset(new std::ostream(strm.rdbuf()));
- }
-
- operator std::ostream&() { return *ostream_; }
-
- operator bool() { return (bool)*ostream_; }
-
- bool bad() const {
- return ostream_->bad();
- }
-
- bool fail() const {
- return ostream_->fail();
- }
-
- template <typename T>
- friend OutputFileStream& operator<<(OutputFileStream& stream, const T& t) {
- *stream.ostream_ << t;
- // fail() seems to be correct here. Failure to write should abort.
- ABORT_IF(stream.fail(), "Error writing to file '{}'", stream.path());
- return stream;
- }
-
- // handle things like std::endl which is actually a function not a value
- friend OutputFileStream& operator<<(OutputFileStream& stream, std::ostream& (*var)(std::ostream&)) {
- *stream.ostream_ << var;
- // fail() seems to be correct here. Failure to write should abort.
- ABORT_IF(stream.fail(), "Error writing to file '{}'", stream.path());
- return stream;
- }
+ explicit OutputFileStream(const std::string& file);
+ virtual ~OutputFileStream();
template <typename T>
size_t write(const T* ptr, size_t num = 1) {
- ostream_->write((char*)ptr, num * sizeof(T));
+ std::ostream::write((char*)ptr, num * sizeof(T));
// fail() seems to be correct here. Failure to write should abort.
- ABORT_IF(fail(), "Error writing to file '{}'", path());
+ ABORT_IF(fail(), "Error writing to file '{}'", file_.string());
return num * sizeof(T);
}
- std::string path() { return file_.string(); }
+protected:
+ explicit OutputFileStream(); // for temp file
-private:
marian::filesystem::Path file_;
- std::unique_ptr<std::ostream> ostream_;
+ std::unique_ptr<std::streambuf> streamBuf1_;
+ std::unique_ptr<std::streambuf> streamBuf2_;
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////
+class TemporaryFile : public OutputFileStream {
+public:
+ TemporaryFile(const std::string& base = "/tmp/", bool earlyUnlink = true);
+ virtual ~TemporaryFile();
+
+ UPtr<InputFileStream> getInputStream();
+ std::string getFileName() const;
+
+protected:
+ bool unlink_;
+ UPtr<InputFileStream> inSteam_;
+ void NormalizeTempPrefix(std::string& base) const;
+ void MakeTemp(const std::string& base);
- boost::iostreams::file_descriptor_sink fds_;
- std::unique_ptr<boost::iostreams::stream_buffer<boost::iostreams::file_descriptor_sink>> fdsBuffer_;
};
-}
-}
+} // namespace io
+} // namespace marian
diff --git a/src/common/filesystem.cpp b/src/common/filesystem.cpp
new file mode 100644
index 00000000..1abaeae4
--- /dev/null
+++ b/src/common/filesystem.cpp
@@ -0,0 +1,31 @@
+#include "filesystem.h"
+
+#ifndef _MSC_VER
+// don't include these on Windows:
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <unistd.h>
+#endif
+
+namespace marian {
+namespace filesystem {
+
+#ifdef _MSC_VER
+// Pretend that Windows knows no named pipes. It does, by the way, but
+// they seem to be different from pipes on Unix / Linux. See
+// https://docs.microsoft.com/en-us/windows/win32/ipc/named-pipes
+bool is_fifo(char const*) { return false; }
+#else
+bool is_fifo(char const* path) {
+ struct stat buf;
+ stat(path, &buf);
+ return S_ISFIFO(buf.st_mode);
+}
+#endif
+
+bool is_fifo(std::string const& path) {
+ return is_fifo(path.c_str());
+}
+
+} // end of namespace marian::filesystem
+} // end of namespace marian
diff --git a/src/common/filesystem.h b/src/common/filesystem.h
index f9c06104..d7cb3da6 100755..100644
--- a/src/common/filesystem.h
+++ b/src/common/filesystem.h
@@ -7,12 +7,22 @@
// @TODO: go back to canonical names for functions and objects
// as specified in C++17 so it becomes easy to move in the future
+// Even when compiling with clang, __GNUC__ may be defined, so
+// we need to add some extra checks to avoid compile errors with
+// respect to -Wsuggest-override.
#ifdef __GNUC__
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wsuggest-override"
+# pragma GCC diagnostic push
+# pragma GCC diagnostic ignored "-Wunused-value"
+# if defined(__has_warning)
+# if __has_warning("-Wsuggest-override")
+# pragma GCC diagnostic ignored "-Wsuggest-override"
+# endif
+# else
+# pragma GCC diagnostic ignored "-Wsuggest-override"
+# endif
#endif
-#include "3rd_party/pathie-cpp/include/path.hpp"
+#include "3rd_party/pathie-cpp/include/path.hpp" // @TODO: update to latest Pathie
#include "3rd_party/pathie-cpp/include/errors.hpp"
#ifdef __GNUC__
@@ -22,6 +32,9 @@
namespace marian {
namespace filesystem {
+ bool is_fifo(char const* path);
+ bool is_fifo(std::string const& path);
+
class Path {
private:
Pathie::Path path;
@@ -29,6 +42,7 @@ namespace filesystem {
public:
Path() {}
Path(const Path& p) : path{p.path} {}
+ Path& operator=(const Path& p) = default;
Path(const std::string& s) : path{s} {}
Path(const Pathie::Path& p) : path{p} {}
@@ -78,6 +92,11 @@ namespace filesystem {
return p.getImpl().absolute(base.getImpl()).expand();
}
+ static inline Path relative(const Path& p, const Path& base) {
+ // create a path relative to the base path
+ return p.getImpl().absolute().expand().relative(base.getImpl().absolute().expand());
+ }
+
static inline bool exists(const Path& p) {
return p.getImpl().exists();
}
@@ -97,4 +116,4 @@ namespace filesystem {
using FilesystemError = Pathie::PathieError;
}
-} \ No newline at end of file
+}
diff --git a/src/common/intrusive_ptr.h b/src/common/intrusive_ptr.h
new file mode 100644
index 00000000..4fe501ba
--- /dev/null
+++ b/src/common/intrusive_ptr.h
@@ -0,0 +1,225 @@
+#pragma once
+
+#include <cassert>
+#include <iostream>
+#include "common/logging.h"
+
+// Smart pointer class for small objects with reference counting but no thread-safety.
+// Inspired by boost::intrusive_ptr<T>.
+
+// Compared to std::shared_ptr this is small and cheap to construct and destroy.
+// Does not hold the counter, the pointed to class `T` needs to add
+// ENABLE_INTRUSIVE_PTR(T) into the body of the class (private section). This adds
+// the reference counters and count manipulation functions to the class.
+
+#define ENABLE_INTRUSIVE_PTR(type) \
+ size_t references_{0}; \
+ \
+ inline friend void intrusivePtrAddRef(type* x) { \
+ if(x != 0) \
+ ++x->references_; \
+ } \
+ \
+ inline friend void intrusivePtrRelease(type* x) { \
+ if(x != 0 && --x->references_ == 0) { \
+ delete x; \
+ x = 0; \
+ } \
+ } \
+ \
+ inline friend size_t references(type* x) { \
+ return x->references_; \
+ } \
+
+
+template<class T>
+class IntrusivePtr {
+private:
+ typedef IntrusivePtr this_type;
+
+public:
+ typedef T element_type;
+
+ IntrusivePtr() : ptr_(0) {};
+
+ IntrusivePtr(T* p)
+ : ptr_(p) {
+ if(ptr_ != 0)
+ intrusivePtrAddRef(ptr_);
+ }
+
+ template<class Y>
+ IntrusivePtr(const IntrusivePtr<Y>& rhs)
+ : ptr_(rhs.get()) {
+ if(ptr_ != 0)
+ intrusivePtrAddRef(ptr_);
+ }
+
+ IntrusivePtr(const IntrusivePtr& rhs)
+ : ptr_(rhs.ptr_) {
+ if(ptr_ != 0)
+ intrusivePtrAddRef(ptr_);
+ }
+
+ ~IntrusivePtr() {
+ if(ptr_ != 0)
+ intrusivePtrRelease(ptr_);
+ }
+
+ IntrusivePtr(IntrusivePtr&& rhs)
+ : ptr_(rhs.ptr_) {
+ rhs.ptr_ = 0;
+ }
+
+ inline size_t useCount() {
+ return references(ptr_);
+ }
+
+ inline IntrusivePtr& operator=(IntrusivePtr&& rhs) {
+ this_type(static_cast<IntrusivePtr&&>(rhs)).swap(*this);
+ return *this;
+ }
+
+ inline IntrusivePtr& operator=(const IntrusivePtr& rhs) {
+ this_type(rhs).swap(*this);
+ return *this;
+ }
+
+ template<class Y>
+ inline IntrusivePtr& operator=(const IntrusivePtr<Y>& rhs) {
+ this_type(rhs).swap(*this);
+ return *this;
+ }
+
+ inline void reset() {
+ this_type().swap(*this);
+ }
+
+ inline void reset(T* rhs) {
+ this_type(rhs).swap(*this);
+ }
+
+ inline T* get() const {
+ return ptr_;
+ }
+
+ inline T* detach() {
+ T* ret = ptr_;
+ ptr_ = 0;
+ return ret;
+ }
+
+ inline T& operator*() const {
+ //ABORT_IF(ptr_ == 0, "Null pointer in IntrusivePtr");
+ return *ptr_;
+ }
+
+ inline T* operator->() const {
+ //ABORT_IF(ptr_ == 0, "Null pointer in IntrusivePtr");
+ return ptr_;
+ }
+
+ inline explicit operator bool() const {
+ return ptr_ != 0;
+ }
+
+ inline bool operator!() const {
+ return ptr_ == 0;
+ }
+
+ inline void swap(IntrusivePtr& rhs) {
+ T* tmp = ptr_;
+ ptr_ = rhs.ptr_;
+ rhs.ptr_ = tmp;
+ }
+
+private:
+ T* ptr_;
+};
+
+template<class T, class U>
+inline bool operator==(const IntrusivePtr<T>& a, const IntrusivePtr<U>& b) {
+ return a.get() == b.get();
+}
+
+template<class T, class U>
+inline bool operator!=(const IntrusivePtr<T>& a, const IntrusivePtr<U>& b) {
+ return a.get() != b.get();
+}
+
+template<class T>
+inline bool operator==(const IntrusivePtr<T>& a, std::nullptr_t) {
+ return a.get() == 0;
+}
+
+template<class T>
+inline bool operator!=(const IntrusivePtr<T>& a, std::nullptr_t) {
+ return a.get() != 0;
+}
+
+template<class T>
+inline bool operator==(const IntrusivePtr<T>& a, T* b) {
+ return a.get() == b;
+}
+
+template<class T>
+inline bool operator!=(const IntrusivePtr<T>& a, T* b) {
+ return a.get() != b;
+}
+
+template<class T>
+inline bool operator==(T* a, const IntrusivePtr<T>& b) {
+ return a == b.get();
+}
+
+template<class T>
+inline bool operator!=(T* a, const IntrusivePtr<T>& b) {
+ return a != b.get();
+}
+
+template<class T, class U>
+inline bool operator<(const IntrusivePtr<T>& a, const IntrusivePtr<U>& b) {
+ return std::less<T*>()(a.get(), b.get());
+}
+
+template<class T>
+inline void swap(IntrusivePtr<T> & a, IntrusivePtr<T> & b) {
+ a.swap(b);
+}
+
+template<class E, class T, class Y>
+std::basic_ostream<E, T>& operator<<(std::basic_ostream<E, T>& os, const IntrusivePtr<Y>& p) {
+ os << p.get();
+ return os;
+}
+
+// compatibility functions to make std::*_pointer_cast<T> work, also for automatic hashing
+namespace std {
+ template<class T>
+ T* get_pointer(const IntrusivePtr<T>& p) {
+ return p.get();
+ }
+
+ template<class T, class U>
+ IntrusivePtr<T> static_pointer_cast(const IntrusivePtr<U>& p) {
+ return static_cast<T*>(p.get());
+ }
+
+ template<class T, class U>
+ IntrusivePtr<T> const_pointer_cast(const IntrusivePtr<U>& p) {
+ return const_cast<T*>(p.get());
+ }
+
+ template<class T, class U>
+ IntrusivePtr<T> dynamic_pointer_cast(const IntrusivePtr<U>& p) {
+ return dynamic_cast<T*>(p.get());
+ }
+
+ // IntrusivePtr<T> can be used as hash map key
+ template <class T> struct hash<IntrusivePtr<T>> {
+ size_t operator()(const IntrusivePtr<T>& x) const {
+ std::hash<size_t> hasher;
+ return hasher((size_t)x.get());
+ }
+ };
+}
diff --git a/src/common/io.cpp b/src/common/io.cpp
index ae768e18..decc4aca 100755..100644
--- a/src/common/io.cpp
+++ b/src/common/io.cpp
@@ -128,17 +128,18 @@ void saveItemsNpz(const std::string& fileName, const std::vector<Item>& items) {
std::vector<cnpy::NpzItem> npzItems;
for(auto& item : items) {
std::vector<unsigned int> shape(item.shape.begin(), item.shape.end());
- char type = 'f';
+ char type;
if(item.type == Type::float32)
type = cnpy::map_type(typeid(float));
+ else if(item.type == Type::float64)
+ type = cnpy::map_type(typeid(double));
else if(item.type == Type::int8)
type = cnpy::map_type(typeid(char));
else
ABORT("Other types not supported yet");
- npzItems.emplace_back(
- item.name, item.bytes, shape, type, sizeOf(item.type));
+ npzItems.emplace_back(item.name, item.bytes, shape, type, sizeOf(item.type));
}
cnpy::npz_save(fileName, npzItems);
}
diff --git a/src/common/io_item.h b/src/common/io_item.h
index 2ef967f9..d86c01ac 100644..100755
--- a/src/common/io_item.h
+++ b/src/common/io_item.h
@@ -24,11 +24,70 @@ struct Item {
return bytes.data();
}
- size_t size() const {
- if(mapped)
- return shape.elements() * sizeOf(type);
+ size_t size() const { // @TODO: review this again for 256-bytes boundary alignment
+ return requiredBytes(shape, type);
+ }
+
+ // Extend this item with data and shape from the input item, creating a flattened concatenation.
+ void append(const Item& other) {
+ ABORT_IF(mapped, "Memory-mapped items cannot be appended");
+ ABORT_IF(type != other.type, "Only item of same type can be appended");
+
+ // abort if any of the shapes is not a flat array, i.e. the number of elements in the
+ // last dimension has to correspond to the number of bytes.
+ ABORT_IF(shape[-1] != shape.elements(), "1 - Only flat items can be appended : {}", shape);
+ ABORT_IF(other.shape[-1] != other.shape.elements(), "2 - Only flat items can be appended: {}", other.shape);
+
+ // cut to size (get rid of padding if any) to make append operation work correctly
+ size_t bytesWithoutPadding = shape.elements() * sizeOf(type);
+ bytes.resize(bytesWithoutPadding);
+
+ shape.set(-1, shape.elements() + other.shape.elements());
+
+ size_t addbytesWithoutPadding = other.shape.elements() * sizeOf(other.type); // ignore padding if any
+ bytes.insert(bytes.end(), other.bytes.begin(), other.bytes.begin() + addbytesWithoutPadding);
+
+ // grow to align to 256 bytes boundary (will be undone when more pieces are appended)
+ size_t multiplier = (size_t)ceil((float)bytes.size() / (float)256);
+ bytes.resize(multiplier * 256);
+ }
+
+ template <typename From, typename To>
+ void convertFromTo() {
+ size_t elements = size() / sizeof(From);
+ size_t newSize = elements * sizeof(To);
+ std::vector<char> newBytes(newSize);
+
+ From* in = (From*)bytes.data();
+ To* out = (To*)newBytes.data();
+ for(int i = 0; i < elements; ++i)
+ out[i] = (To)in[i];
+
+ bytes.swap(newBytes);
+ }
+
+ template <typename T>
+ void convertTo() {
+ if(type == Type::float32)
+ convertFromTo<float, T>();
+ else if(type == Type::float16)
+ convertFromTo<HalfFloat, T>();
+ else
+ ABORT("convert from type {} not implemented", type);
+ }
+
+ void convert(Type toType) {
+ if(type == toType)
+ return;
+
+ if(toType == Type::float32)
+ convertTo<float>();
+ else if(toType == Type::float16)
+ convertTo<float16>();
else
- return bytes.size();
+ ABORT("convert to type {} not implemented", toType);
+
+ type = toType;
}
};
diff --git a/src/common/logging.cpp b/src/common/logging.cpp
index 0170d633..62d76fee 100755..100644
--- a/src/common/logging.cpp
+++ b/src/common/logging.cpp
@@ -14,26 +14,28 @@
#define noinline __attribute__((noinline))
#endif
-std::shared_ptr<spdlog::logger> createStderrLogger(
- const std::string& name,
- const std::string& pattern,
- const std::vector<std::string>& files,
- bool quiet) {
+namespace marian {
+ static bool throwExceptionOnAbort = false;
+ bool getThrowExceptionOnAbort() { return throwExceptionOnAbort; }
+ void setThrowExceptionOnAbort(bool doThrowExceptionOnAbort) { throwExceptionOnAbort = doThrowExceptionOnAbort; };
+}
+
+std::shared_ptr<spdlog::logger> createStderrLogger(const std::string& name,
+ const std::string& pattern,
+ const std::vector<std::string>& files,
+ bool quiet) {
std::vector<spdlog::sink_ptr> sinks;
auto stderr_sink = spdlog::sinks::stderr_sink_mt::instance();
-
if(!quiet)
sinks.push_back(stderr_sink);
for(auto&& file : files) {
- auto file_sink
- = std::make_shared<spdlog::sinks::simple_file_sink_st>(file, true);
+ auto file_sink = std::make_shared<spdlog::sinks::simple_file_sink_st>(file, true);
sinks.push_back(file_sink);
}
- auto logger
- = std::make_shared<spdlog::logger>(name, begin(sinks), end(sinks));
+ auto logger = std::make_shared<spdlog::logger>(name, begin(sinks), end(sinks));
spdlog::register_logger(logger);
logger->set_pattern(pattern);
@@ -56,73 +58,65 @@ bool setLoggingLevel(spdlog::logger& logger, std::string const level) {
else if(level == "off")
logger.set_level(spdlog::level::off);
else {
- logger.warn("Unknown log level '{}' for logger '{}'",
- level.c_str(),
- logger.name().c_str());
+ logger.warn("Unknown log level '{}' for logger '{}'", level.c_str(), logger.name().c_str());
return false;
}
return true;
}
static void setErrorHandlers();
-void createLoggers(const marian::Config* options) {
+void createLoggers(const marian::Config* config) {
std::vector<std::string> generalLogs;
std::vector<std::string> validLogs;
- if(options && options->has("log")) {
- generalLogs.push_back(options->get<std::string>("log"));
+ if(config && !config->get<std::string>("log").empty()) {
+ generalLogs.push_back(config->get<std::string>("log"));
#ifndef _WIN32
// can't open the same file twice in Windows for some reason
- validLogs.push_back(options->get<std::string>("log"));
+ validLogs.push_back(config->get<std::string>("log"));
#endif
}
- if(options && options->has("valid-log")
- && !options->get<std::string>("valid-log").empty()) {
- validLogs.push_back(options->get<std::string>("valid-log"));
+ // valid-log is available only for training
+ if(config && config->has("valid-log") && !config->get<std::string>("valid-log").empty()) {
+ validLogs.push_back(config->get<std::string>("valid-log"));
}
- bool quiet = options && options->get<bool>("quiet");
- Logger general{
- createStderrLogger("general", "[%Y-%m-%d %T] %v", generalLogs, quiet)};
- Logger valid{
- createStderrLogger("valid", "[%Y-%m-%d %T] [valid] %v", validLogs, quiet)};
+ bool quiet = config && config->get<bool>("quiet");
+ Logger general{createStderrLogger("general", "[%Y-%m-%d %T] %v", generalLogs, quiet)};
+ Logger valid{createStderrLogger("valid", "[%Y-%m-%d %T] [valid] %v", validLogs, quiet)};
- if(options && options->has("log-level")) {
- std::string loglevel = options->get<std::string>("log-level");
+ if(config && config->has("log-level")) {
+ std::string loglevel = config->get<std::string>("log-level");
if(!setLoggingLevel(*general, loglevel))
return;
setLoggingLevel(*valid, loglevel);
}
- if (options && options->has("log-time-zone")) {
- std::string timezone = options->get<std::string>("log-time-zone");
- if (timezone != "") {
+ if(config && !config->get<std::string>("log-time-zone").empty()) {
+ std::string timezone = config->get<std::string>("log-time-zone");
#ifdef _WIN32
#define setenv(var, val, over) SetEnvironmentVariableA(var, val) // ignoring over flag
#endif
- setenv("TZ", timezone.c_str(), true);
- tzset();
- }
+ setenv("TZ", timezone.c_str(), true);
+ tzset();
}
setErrorHandlers();
}
static void unhandledException() {
- if (std::current_exception()) {
+ if(std::current_exception()) {
try {
- throw; // rethrow so that we can get access to what()
- }
- catch (const std::exception& e) {
- ABORT("Unhandled {}: {}", typeid(e).name(), e.what());
- }
- catch (...) {
+ throw; // rethrow so that we can get access to what()
+ } catch(const std::exception& e) {
+ ABORT("Unhandled exception of type '{}': {}", typeid(e).name(), e.what());
+ } catch(...) {
ABORT("Unhandled exception");
}
- }
- else
+ } else {
std::abort();
+ }
}
static void setErrorHandlers() {
@@ -130,7 +124,7 @@ static void setErrorHandlers() {
std::set_terminate(unhandledException);
#ifdef __unix__
// catch segfaults
- struct sigaction sa = { 0 };
+ struct sigaction sa = { {0} };
sigemptyset(&sa.sa_mask);
sa.sa_flags = SA_SIGINFO;
sa.sa_sigaction = [](int /*signal*/, siginfo_t*, void*) { ABORT("Segmentation fault"); };
@@ -144,8 +138,8 @@ static void setErrorHandlers() {
// This is called upon initializing MPI. It is needed to associated error messages to ranks.
void switchtoMultinodeLogging(std::string nodeIdStr) {
Logger log = spdlog::get("general");
- if (log)
- log->set_pattern("[%Y-%m-%d %T " + nodeIdStr + "] %v");
+ if(log)
+ log->set_pattern("[%Y-%m-%d %T " + nodeIdStr + ":%t] %v");
}
diff --git a/src/common/logging.h b/src/common/logging.h
index cdaa806c..6f292f61 100755..100644
--- a/src/common/logging.h
+++ b/src/common/logging.h
@@ -4,9 +4,37 @@
#include "spdlog/spdlog.h"
+
namespace marian {
void logCallStack(size_t skipLevels);
std::string getCallStack(size_t skipLevels);
+
+ // Marian gives a basic exception guarantee. If you catch a
+ // MarianRuntimeError you must assume that the object can be
+ // safely destructed, but cannot be used otherwise.
+
+ // Internal multi-threading in exception-throwing mode is not
+ // allowed; and constructing a thread-pool will cause an exception.
+
+ class MarianRuntimeException : public std::runtime_error {
+ private:
+ std::string callStack_;
+
+ public:
+ MarianRuntimeException(const std::string& message, const std::string& callStack)
+ : std::runtime_error(message),
+ callStack_(callStack) {}
+
+ const char* getCallStack() const throw() {
+ return callStack_.c_str();
+ }
+ };
+
+ // Get the state of throwExceptionOnAbort (see logging.cpp), by default false
+ bool getThrowExceptionOnAbort();
+
+ // Set the state of throwExceptionOnAbort (see logging.cpp)
+ void setThrowExceptionOnAbort(bool);
}
/**
@@ -21,6 +49,16 @@ namespace marian {
*/
#define LOG(level, ...) checkedLog("general", #level, __VA_ARGS__)
+// variant that prints the log message only upon the first time the call site is executed
+#define LOG_ONCE(level, ...) do { \
+ static bool logged = false; \
+ if (!logged) \
+ { \
+ logged = true; \
+ LOG(level, __VA_ARGS__); \
+ } \
+} while(0)
+
/**
* Prints logging message regarding validation into stderr and a file specified
* with `--valid-log` option.
@@ -47,19 +85,23 @@ namespace marian {
*
* @param ... Message text and variables
*/
-#define ABORT(...) \
- do { \
- auto logger = spdlog::get("general"); \
- if(logger == nullptr) \
- logger = createStderrLogger("general", "[%Y-%m-%d %T] Error: %v"); \
- else \
- logger->set_pattern("[%Y-%m-%d %T] Error: %v"); \
- checkedLog("general", "critical", __VA_ARGS__); \
- checkedLog("general", "critical", "Aborted from {} in {}:{}", \
- FUNCTION_NAME, __FILE__, __LINE__); \
- logger->set_pattern("%v"); \
- checkedLog("general", "critical", marian::getCallStack(/*skipLevels=*/0)); \
- std::abort(); \
+#define ABORT(...) \
+ do { \
+ auto logger = spdlog::get("general"); \
+ if(logger == nullptr) \
+ logger = createStderrLogger("general", "[%Y-%m-%d %T] Error: %v"); \
+ else \
+ logger->set_pattern("[%Y-%m-%d %T] Error: %v"); \
+ checkedLog("general", "critical", __VA_ARGS__); \
+ checkedLog("general", "critical", "Aborted from {} in {}:{}", \
+ FUNCTION_NAME, __FILE__, __LINE__); \
+ logger->set_pattern("%v"); \
+ auto callStack = marian::getCallStack(/*skipLevels=*/0); \
+ checkedLog("general", "critical", callStack); \
+ if(marian::getThrowExceptionOnAbort()) \
+ throw marian::MarianRuntimeException(fmt::format(__VA_ARGS__), callStack); \
+ else \
+ std::abort(); \
} while(0)
/**
diff --git a/src/common/options.cpp b/src/common/options.cpp
new file mode 100644
index 00000000..e39b5522
--- /dev/null
+++ b/src/common/options.cpp
@@ -0,0 +1,101 @@
+#include "options.h"
+
+namespace marian {
+
+Options::Options()
+#if FASTOPT
+ : fastOptions_(options_)
+#endif
+{}
+
+Options::Options(const Options& other)
+#if FASTOPT
+ : options_(YAML::Clone(other.options_)),
+ fastOptions_(options_) {}
+#else
+ : options_(YAML::Clone(other.options_)) {}
+#endif
+
+Options Options::clone() const {
+ return Options(*this); // fastOptions_ get set in constructor above
+}
+
+YAML::Node Options::cloneToYamlNode() const {
+ return YAML::Clone(options_); // Do not give access to internal YAML object
+}
+
+void Options::parse(const std::string& yaml) {
+ auto node = YAML::Load(yaml);
+ for(auto it : node)
+ options_[it.first.as<std::string>()] = YAML::Clone(it.second);
+#if FASTOPT
+ setLazyRebuild();
+#endif
+}
+
+void Options::merge(const YAML::Node& node, bool overwrite) {
+ for(auto it : node)
+ if(overwrite || !options_[it.first.as<std::string>()])
+ options_[it.first.as<std::string>()] = YAML::Clone(it.second);
+#if FASTOPT
+ setLazyRebuild();
+#endif
+}
+
+void Options::merge(Ptr<Options> options) {
+ merge(options->options_);
+}
+
+std::string Options::asYamlString() {
+ std::stringstream ss;
+ ss << options_;
+ return ss.str();
+}
+
+bool Options::hasAndNotEmpty(const char* const key) const {
+#if FASTOPT
+ lazyRebuild();
+ if(!fastOptions_.has(key)) {
+ return false;
+ } else {
+ auto& node = fastOptions_[key];
+ if(node.isSequence())
+ return node.size() != 0;
+ else if(node.isScalar()) // numerical values count as non-empty
+ return !node.as<std::string>().empty();
+ else
+ ABORT("Wrong node type");
+ }
+#else
+ if(!options_[key]) {
+ return false;
+ } else {
+ auto& node = options_[key];
+ if(node.IsSequence())
+ return node.size() != 0;
+ else if(node.IsScalar()) // numerical values count as non-empty
+ return !node.as<std::string>().empty();
+ else
+ ABORT("Wrong node type");
+ }
+#endif
+}
+
+bool Options::hasAndNotEmpty(const std::string& key) const {
+ return hasAndNotEmpty(key.c_str());
+}
+
+bool Options::has(const char* const key) const {
+#if FASTOPT
+ lazyRebuild();
+ return fastOptions_.has(key);
+#else
+ return options_[key];
+#endif
+}
+
+bool Options::has(const std::string& key) const {
+ return has(key.c_str());
+}
+
+}
diff --git a/src/common/options.h b/src/common/options.h
index 9f1bb8db..d288f0a3 100755
--- a/src/common/options.h
+++ b/src/common/options.h
@@ -1,13 +1,19 @@
#pragma once
+// @TODO: to be removed when sure it works
+#define FASTOPT 1 // for diagnostics, 0 reverts to old behavior
+
#include <sstream>
#include <string>
#include "common/definitions.h"
-
#include "3rd_party/yaml-cpp/yaml.h"
+#ifdef FASTOPT
+#include "common/fastopt.h"
+#endif
+
#define YAML_REGISTER_TYPE(registered, type) \
- namespace YAML { \
+namespace YAML { \
template <> \
struct convert<registered> { \
static Node encode(const registered& rhs) { \
@@ -20,34 +26,69 @@
return true; \
} \
}; \
- }
+}
namespace marian {
/**
* Container for options stored as key-value pairs. Keys are unique strings.
+ * This is not thread-safe and locking is the responsibility of the caller.
*/
class Options {
protected:
- YAML::Node options_;
+ YAML::Node options_; // YAML options use for parsing, modification and printing
+
+#if FASTOPT
+ // Only to be modified in lazyRebuild and setLazyRebuild
+ mutable FastOpt fastOptions_; // FastOpt used for fast lookup, lazily rebuilt from YYAML whenever required
+ mutable bool lazyRebuildPending_{false}; // flag if need to lazily rebuild
+
+ // set flag that a rebuild is required
+ void setLazyRebuild() const {
+ lazyRebuildPending_ = true;
+ }
+
+ // check if rebuild is required, rebuild, unset flag.
+ void lazyRebuild() const {
+ if(lazyRebuildPending_) {
+ FastOpt temp(options_);
+ fastOptions_.swap(temp);
+ lazyRebuildPending_ = false;
+ }
+ }
+#endif
public:
- Options() {}
- Options(const Options& other) : options_(YAML::Clone(other.options_)) {}
+ Options();
+ Options(const Options& other);
+
+ // constructor with one or more key-value pairs
+ // New<Options>("var1", val1, "var2", val2, ...)
+ template <typename T, typename... Args>
+ Options(const std::string& key, T value, Args&&... moreArgs) : Options() {
+ set(key, value, std::forward<Args>(moreArgs)...);
+ }
+
+ // constructor that clones and zero or more updates
+ // options->with("var1", val1, "var2", val2, ...)
+ template <typename... Args>
+ Ptr<Options> with(Args&&... args) const {
+ auto options = New<Options>(*this);
+ options->set(std::forward<Args>(args)...);
+ return options;
+ }
/**
* @brief Return a copy of the object that can be safely modified.
*/
- Options clone() const { return Options(*this); }
+ Options clone() const;
- YAML::Node& getYaml() { return options_; }
- const YAML::Node& getYaml() const { return options_; }
+ // Do not allow access to internal YAML object as changes on the outside are difficult to track
+ // and mess with the rebuilding of the fast options lookup. Hence only return a clone which guarentees
+ // full encapsulation.
+ YAML::Node cloneToYamlNode() const;
- void parse(const std::string& yaml) {
- auto node = YAML::Load(yaml);
- for(auto it : node)
- options_[it.first.as<std::string>()] = YAML::Clone(it.second);
- }
+ void parse(const std::string& yaml);
/**
* @brief Splice options from a YAML node
@@ -58,41 +99,81 @@ public:
* @param node a YAML node to transfer the options from
* @param overwrite overwrite all options
*/
- void merge(YAML::Node& node, bool overwrite = false) {
- for(auto it : node)
- if(overwrite || !options_[it.first.as<std::string>()])
- options_[it.first.as<std::string>()] = YAML::Clone(it.second);
- }
+ void merge(const YAML::Node& node, bool overwrite = false);
+ void merge(Ptr<Options> options);
- void merge(const YAML::Node& node, bool overwrite = false) { merge(node, overwrite); }
- void merge(Ptr<Options> options) { merge(options->getYaml()); }
-
- std::string str() {
- std::stringstream ss;
- ss << options_;
- return ss.str();
- }
+ std::string asYamlString();
template <typename T>
void set(const std::string& key, T value) {
options_[key] = value;
+#if FASTOPT
+ setLazyRebuild();
+#endif
+ }
+
+ // set multiple
+ // options->set("var1", val1, "var2", val2, ...)
+ template <typename T, typename... Args>
+ void set(const std::string& key, T value, Args&&... moreArgs) {
+ set(key, value);
+ set(std::forward<Args>(moreArgs)...);
+#if FASTOPT
+ setLazyRebuild();
+#endif
}
template <typename T>
- T get(const std::string& key) {
+ T get(const char* const key) const {
+#if FASTOPT
+ lazyRebuild();
+ ABORT_IF(!has(key), "Required option '{}' has not been set", key);
+ return fastOptions_[key].as<T>();
+#else
ABORT_IF(!has(key), "Required option '{}' has not been set", key);
return options_[key].as<T>();
+#endif
+ }
+
+ template <typename T>
+ T get(const std::string& key) const {
+ return get<T>(key.c_str());
}
template <typename T>
- T get(const std::string& key, T defaultValue) {
+ T get(const char* const key, T defaultValue) const {
+#if FASTOPT
+ lazyRebuild();
+ if(has(key))
+ return fastOptions_[key].as<T>();
+#else
if(has(key))
return options_[key].as<T>();
+#endif
else
return defaultValue;
}
- bool has(const std::string& key) const { return options_[key]; }
+ template <typename T>
+ T get(const std::string& key, T defaultValue) const {
+ return get<T>(key.c_str(), defaultValue);
+ }
+
+ /**
+ * @brief Check if a sequence or string option is defined and nonempty
+ *
+ * Aborts if the option does not store a sequence or string value. Returns false if an option with
+ * the given key does not exist.
+ *
+ * @param key option name
+ *
+ * @return true if the option is defined and is a nonempty sequence or string
+ */
+ bool hasAndNotEmpty(const char* const key) const;
+ bool hasAndNotEmpty(const std::string& key) const;
+
+ bool has(const char* const key) const;
+ bool has(const std::string& key) const;
};
} // namespace marian
diff --git a/src/common/project_version.h.in b/src/common/project_version.h.in
index 756f2e56..815aff13 100755..100644
--- a/src/common/project_version.h.in
+++ b/src/common/project_version.h.in
@@ -1,8 +1,8 @@
#pragma once
/*
- * File project-version.h is generated using CMake. Do NOT modify it manually! Edit
- * project-version.h.in file instead.
+ * File project_version.h is generated using CMake. Do NOT modify it manually! Edit
+ * project_version.h.in file instead.
*/
// e.g. v1.2.3-beta+1.abc123d
diff --git a/src/common/regex.h b/src/common/regex.h
index eb7482a8..f2ef95c1 100644
--- a/src/common/regex.h
+++ b/src/common/regex.h
@@ -1,8 +1,4 @@
+#pragma once
-#ifdef USE_BOOST_REGEX
-#include <boost/regex.hpp>
-namespace regex = boost;
-#else
#include <regex>
namespace regex = std;
-#endif
diff --git a/src/common/shape.h b/src/common/shape.h
index 89a120fe..fd86ef51 100755..100644
--- a/src/common/shape.h
+++ b/src/common/shape.h
@@ -12,6 +12,22 @@
namespace marian {
+struct Slice // Python-like slice/index descriptor
+{
+ Slice(int b, int e, int s) : begin(b), end(e), stride(s) {}
+ Slice(int b, int e) : Slice(b, e, 1) {}
+ Slice() : Slice(0, END) {}
+ explicit Slice(int i) : Slice(i, i + 1) {}
+ Slice(const Slice& other) : Slice(other.begin, other.end, other.stride) {}
+ const Slice& operator=(const Slice& other) { begin = other.begin; end = other.end; stride = other.stride; return *this; }
+ const Slice& operator=(int i) { begin = i; end = i + 1; stride = 1; return *this; }
+ bool operator==(const Slice& other) const { return begin == other.begin && end == other.end && stride == other.stride; }
+ bool operator!=(const Slice& other) const { return !(*this == other); }
+ /*const*/ int begin, end, stride;
+ static const int END = INT_MAX;
+};
+typedef std::vector<Slice> Slices;
+
struct Shape {
private:
std::vector<int> shape_;
@@ -31,6 +47,8 @@ public:
std::copy(shape.begin(), shape.end(), begin());
}
+ Shape& operator=(const Shape& p) = default;
+
inline size_t size() const { return shape_.size(); }
void resize(size_t n) { shape_.resize(n, 1); }
@@ -84,10 +102,11 @@ public:
return stride[size() + i];
}
- inline int elements() const {
- int el = 1;
+ template<typename T = int> // using a template so that FactoredSegmenter, which uses this as well, can pass size_t
+ inline T elements() const {
+ T el = 1;
for(auto s : shape_)
- el *= s;
+ el *= (T)s;
return el;
}
@@ -147,6 +166,17 @@ public:
return ax;
}
+ Slice slice(Slice slice, int ax) const { // interpret negative and special values in Slice
+ int n = dim(ax);
+ if (slice.begin < 0)
+ slice.begin += n;
+ if (slice.end < 0)
+ slice.end += n;
+ else if (slice.end == Slice::END)
+ slice.end = n;
+ return slice;
+ }
+
static Shape broadcast(const std::vector<Shape>& shapes) {
size_t maxDims = 0;
for(auto& s : shapes)
diff --git a/src/common/timer.h b/src/common/timer.h
index dfb91bf0..ac83b363 100755..100644
--- a/src/common/timer.h
+++ b/src/common/timer.h
@@ -1,21 +1,8 @@
#pragma once
-#ifdef _GNUC_
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wsuggest-override"
-#endif
-#include <boost/timer/timer.hpp>
-#ifdef _GNUC_
-#pragma GCC diagnostic pop
-#endif
-
-#ifdef _MSC_VER
-// (needed on Windows only to resolve a link error, but causes a warning on Linux)
-#include <boost/chrono.hpp>
-#endif
-
-#include <chrono>
+#include <iostream>
#include <sstream>
+#include <chrono>
namespace marian {
namespace timer {
@@ -29,7 +16,6 @@ static inline std::string currentDate() {
}
// Timer measures elapsed time.
-// This is a wrapper around std::chrono providing wall time only
class Timer {
protected:
using clock = std::chrono::steady_clock;
@@ -89,12 +75,9 @@ public:
}
};
-// @TODO: replace with a timer providing CPU/thread time on both Linux and Windows. This is required
-// for auto-tuner.
-// Check get_clocktime on Linux: https://linux.die.net/man/3/clock_gettime
-// Check GetThreadTimes on Windows:
-// https://docs.microsoft.com/en-gb/windows/desktop/api/processthreadsapi/nf-processthreadsapi-getthreadtimes
-using CPUTimer = boost::timer::cpu_timer;
+// std::chrono::steady_clock seems to be the right choice here.
+using CPUTimer = Timer;
+
} // namespace timer
} // namespace marian
diff --git a/src/common/types.cpp b/src/common/types.cpp
new file mode 100644
index 00000000..f358cdb6
--- /dev/null
+++ b/src/common/types.cpp
@@ -0,0 +1,38 @@
+#include "common/types.h"
+#include "tensors/cpu/fbgemm/packed_gemm.h"
+
+namespace marian {
+
+// this function calculates the amount of bytes needed to contain a tensor of given shape and type.
+// For most situation that is trivial (just number of elements time size of single element).
+// But for instance, for intransparent types like packed tensors, it cannot easily be inferred by
+// multiplying. All cases are handed here and can later be passed to allocators etc.
+size_t requiredBytes(const Shape& shape, Type type) {
+#if USE_FBGEMM
+ if (isPacked(type)) {
+ if (sizeOf(type) == 1) {
+ // Type::packed8avx2 || type == Type::packed8avx512
+ // AVX2 and AVX512 CPUs have different cache and vector lanes,
+ // so the optimal memory layouts for them are different.
+ int nrow, ncol;
+ uint64_t packsize;
+ cpu::variant::fbgemmPacked8PackInfo(shape, type, false, /*out=*/nrow, /*out=*/ncol, /*out=*/packsize);
+ return (size_t)packsize;
+ } else if (type == Type::packed16) {
+ uint64_t packsize;
+ cpu::variant::fbgemmPacked16PackInfo(shape, false, /*out=*/packsize);
+ return (size_t)packsize;
+ } else {
+ ABORT("Not a supported data type: {}", type);
+ return 0;
+ }
+ } else {
+ return shape.elements() * sizeOf(type);
+ }
+#else
+ return shape.elements() * sizeOf(type);
+#endif // USE_FBGEMM
+
+}
+
+} \ No newline at end of file
diff --git a/src/common/types.h b/src/common/types.h
index f25520f0..4bc4f9ad 100755..100644
--- a/src/common/types.h
+++ b/src/common/types.h
@@ -1,40 +1,270 @@
#pragma once
+#include "common/logging.h" // for ABORT and ABORT_IF
+#include "common/shape.h"
+
+#if __GNUC__ >= 7
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wint-in-bool-context" // gcc-7 introduces this warning, triggered in 3rd-party code
+#endif
+#include "half_float/umHalf.h"
+#if __GNUC__ >= 7
+#pragma GCC diagnostic pop
+#endif
#include <iostream>
#include <string>
+#include <functional>
+#include <type_traits>
+
+#ifndef __CUDACC__ // NVCC is very unreliable when it comes to CPU intrinsics, we hide them completely from NVCC-compiled code
+#include <immintrin.h>
+#endif
+
+#ifdef __CUDACC__ // nvcc is compiling this code
+#include <cuda.h> // required to see CUDA_VERSION
+#if (CUDA_VERSION > 9000 && (__CUDA_ARCH__ >= 600 || !defined(__CUDA_ARCH__)))
+#define COMPILE_FP16 1 // we are in GPU code and we know what to do with FP16 code
+#else
+#define COMPILE_FP16 0 // we are in GPU code, but compute capability is too low to use FP16
+#endif
+#elif CUDA_FOUND // other compiler, likely host code. Should be fine with seeing the correct includes with host code
+#include <cuda.h> // required to see CUDA_VERSION
+#if (CUDA_VERSION > 9000)
+#define COMPILE_FP16 1
+#else
+#define COMPILE_FP16 0
+#endif
+#else
+#define COMPILE_FP16 0
+#endif
+
+#ifdef _MSC_VER
+// @BUGBUG: Visual Studio somehow fails on template expansions for float16.
+// To be able to build on Windows, we temporarily disable this, until the greater merge has happened.
+#define DISPATCH_BY_TYPE0(type, func) \
+do { \
+ switch(type) { \
+ case Type::int8: return func<int8_t >(); \
+ case Type::int16: return func<int16_t >(); \
+ case Type::int32: return func<int32_t >(); \
+ case Type::int64: return func<int64_t >(); \
+ case Type::uint8: return func<uint8_t >(); \
+ case Type::uint16: return func<uint16_t>(); \
+ case Type::uint32: return func<uint32_t>(); \
+ case Type::uint64: return func<uint64_t>(); \
+ case Type::float16: ABORT("Broken type {}", type);/*return func<float16 >();*/ \
+ case Type::float32: return func<float >(); \
+ case Type::float64: return func<double >(); \
+ default: ABORT("Unknown type {}", type); \
+ } \
+} while(0)
+
+#define DISPATCH_BY_TYPE1(type, func, arg1) \
+do { \
+ switch(type) { \
+ case Type::int8: return func<int8_t >(arg1); \
+ case Type::int16: return func<int16_t >(arg1); \
+ case Type::int32: return func<int32_t >(arg1); \
+ case Type::int64: return func<int64_t >(arg1); \
+ case Type::uint8: return func<uint8_t >(arg1); \
+ case Type::uint16: return func<uint16_t>(arg1); \
+ case Type::uint32: return func<uint32_t>(arg1); \
+ case Type::uint64: return func<uint64_t>(arg1); \
+ case Type::float16: ABORT("Broken type {}", type);/*return func<float16 >(arg1);*/ \
+ case Type::float32: return func<float >(arg1); \
+ case Type::float64: return func<double >(arg1); \
+ default: ABORT("Unknown type {}", type); \
+ } \
+} while(0)
+#else
+#define DISPATCH_BY_TYPE0(type, func) \
+do { \
+ switch(type) { \
+ case Type::int8: return func<int8_t >(); \
+ case Type::int16: return func<int16_t >(); \
+ case Type::int32: return func<int32_t >(); \
+ case Type::int64: return func<int64_t >(); \
+ case Type::uint8: return func<uint8_t >(); \
+ case Type::uint16: return func<uint16_t>(); \
+ case Type::uint32: return func<uint32_t>(); \
+ case Type::uint64: return func<uint64_t>(); \
+ case Type::float16: return func<float16 >(); \
+ case Type::float32: return func<float >(); \
+ case Type::float64: return func<double >(); \
+ default: ABORT("Unknown type {}", type); \
+ } \
+} while(0)
+
+#define DISPATCH_BY_TYPE1(type, func, arg1) \
+do { \
+ switch(type) { \
+ case Type::int8: return func<int8_t >(arg1); \
+ case Type::int16: return func<int16_t >(arg1); \
+ case Type::int32: return func<int32_t >(arg1); \
+ case Type::int64: return func<int64_t >(arg1); \
+ case Type::uint8: return func<uint8_t >(arg1); \
+ case Type::uint16: return func<uint16_t>(arg1); \
+ case Type::uint32: return func<uint32_t>(arg1); \
+ case Type::uint64: return func<uint64_t>(arg1); \
+ case Type::float16: return func<float16 >(arg1); \
+ case Type::float32: return func<float >(arg1); \
+ case Type::float64: return func<double >(arg1); \
+ default: ABORT("Unknown type {}", type); \
+ } \
+} while(0)
+#endif
+
+#define DISPATCH_BY_TYPE2(type, func, arg1, arg2) \
+do { \
+ switch(type) { \
+ case Type::int8 : return func<int8_t >(arg1, arg2); \
+ case Type::int16 : return func<int16_t >(arg1, arg2); \
+ case Type::int32 : return func<int32_t >(arg1, arg2); \
+ case Type::int64 : return func<int64_t >(arg1, arg2); \
+ case Type::uint8 : return func<uint8_t >(arg1, arg2); \
+ case Type::uint16 : return func<uint16_t>(arg1, arg2); \
+ case Type::uint32 : return func<uint32_t>(arg1, arg2); \
+ case Type::uint64 : return func<uint64_t>(arg1, arg2); \
+ case Type::float16 : return func<float16 >(arg1, arg2); \
+ case Type::float32 : return func<float >(arg1, arg2); \
+ case Type::float64 : return func<double >(arg1, arg2); \
+ default: ABORT("Unknown type {}", type); \
+ } \
+} while(0)
namespace marian {
+// small struct to enable templating based on types use for packing
+struct packed16 {
+ uint16_t x;
+};
+
+// small struct to enable templating based on types use for packing. This is a memory holder.
+// There's no difference between packed8avx2 and packed8avx512. But, they are separately defined to be distinguished.
+struct packed8avx2 {
+ uint8_t x;
+};
+
+// small struct to enable templating based on types use for packing. This is a memory holder.
+struct packed8avx512 {
+ uint8_t x;
+};
+
+#ifndef __CUDACC__ // vectorized types not available from .cu files
+
+// @TODO: check what intrinsics are actually available.
+struct float32x4 {
+private:
+ __m128 f_;
+
+public:
+ float32x4() {}
+ float32x4(const __m128& f) : f_(f) {}
+ float32x4(const float& f) : f_(_mm_set1_ps(f)) {} // __m128 _mm_set1_ps(float) copies value into all slots
+
+ operator const __m128&() const { return f_; }
+ operator __m128&() { return f_; }
+
+ float operator[] (size_t i) const {
+ return *(((float*)&f_) + i); // potentially undefined, but efficient. In practice __m128 is an array of floats
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, float32x4 f4) {
+ float* a = (float*)&f4;
+ out << "[" << a[0];
+ for(int i = 1; i < 4; i++)
+ out << " " << a[i];
+ out << "]";
+ return out;
+ }
+};
+
+// @TODO: consider how code can be shared via templating
+#ifdef __AVX__
+struct float32x8 {
+private:
+ __m256 f_;
+
+public:
+ float32x8() {}
+ float32x8(const __m256& f) : f_(f) {}
+ float32x8(const float& f) : f_(_mm256_set1_ps(f)) {} // __m256 _mm_set1_ps(float) copies value into all slots
+
+ operator const __m256&() const { return f_; }
+ operator __m256&() { return f_; }
+
+ float operator[] (size_t i) const {
+ return *(((float*)&f_) + i); // potentially undefined, but efficient. In practice __m128 is an array of floats
+ }
+
+ friend std::ostream& operator<<(std::ostream& out, float32x8 f8) {
+ float* a = (float*)&f8;
+ out << "[" << a[0];
+ for(int i = 1; i < 8; i++)
+ out << " " << a[i];
+ out << "]";
+ return out;
+ }
+};
+#else
+//Dummy version to get things to compile on older CPUs
+struct float32x8 {
+};
+#endif
+#endif
+
+// Internal to types.h, don't use. Use test functions below.
enum class TypeClass : size_t {
- signed_type = 0x100,
- unsigned_type = 0x200,
- float_type = 0x400,
- size_mask = 0x0FF
+ signed_type = 0x0100,
+ unsigned_type = 0x0200,
+ float_type = 0x0400,
+
+ packed_type = 0x0800, // special packed (CPU cache friendly) type class, used in FBGEMM, not meant to be used anywhere else
+ avx2_type = 0x1000, // processor-specific layout for avx2, currently used for FBGEMM only
+ avx512_type = 0x2000, // processor-specific layout for avx512, currently used for FBGEMM only
+
+ size_mask = 0x00FF,
+ class_mask = 0xFF00
};
constexpr inline size_t operator+(TypeClass typeClass, size_t val) {
return (size_t)typeClass + val;
}
+constexpr inline size_t operator+(size_t val, TypeClass typeClass) {
+ return val + (size_t)typeClass;
+}
+
+// @TODO: rename to ElementType when things become stable, so it's easier to review
enum class Type : size_t {
- int8 = TypeClass::signed_type + 1u,
- int16 = TypeClass::signed_type + 2u,
- int32 = TypeClass::signed_type + 4u,
- int64 = TypeClass::signed_type + 8u,
-
- uint8 = TypeClass::unsigned_type + 1u,
- uint16 = TypeClass::unsigned_type + 2u,
- uint32 = TypeClass::unsigned_type + 4u,
- uint64 = TypeClass::unsigned_type + 8u,
-
- float32 = TypeClass::float_type + 4u,
- float64 = TypeClass::float_type + 8u
+ int8 = TypeClass::signed_type + 1u,
+ int16 = TypeClass::signed_type + 2u,
+ int32 = TypeClass::signed_type + 4u,
+ int64 = TypeClass::signed_type + 8u,
+
+ uint8 = TypeClass::unsigned_type + 1u,
+ uint16 = TypeClass::unsigned_type + 2u,
+ uint32 = TypeClass::unsigned_type + 4u,
+ uint64 = TypeClass::unsigned_type + 8u,
+
+ float16 = TypeClass::float_type + 2u,
+ float32 = TypeClass::float_type + 4u,
+ float64 = TypeClass::float_type + 8u,
+
+ packed16 = TypeClass::packed_type + 2u, // special type for FBGEMM, not meant to be used anywhere else, not meant to be accessed invidually. Internal actual type (uint16) is meaningless.
+ packed8avx2 = TypeClass::packed_type + 1u + TypeClass::avx2_type, // special type for FBGEMM with AVX2, not meant to be used anywhere else, not meant to be accessed invidually. Internal actual type (uint8) is meaningless.
+ packed8avx512 = TypeClass::packed_type + 1u + TypeClass::avx512_type, // special type for FBGEMM with AVX512, not meant to be used anywhere else, not meant to be accessed invidually. Internal actual type (uint8) is meaningless.
+
};
static inline size_t operator&(TypeClass typeClass, Type type) {
return (size_t)typeClass & (size_t)type;
}
+static inline bool isSameTypeClass(Type type1, Type type2) {
+ return (TypeClass::class_mask & type1) == (TypeClass::class_mask & type2);
+}
+
static inline size_t sizeOf(Type type) {
return TypeClass::size_mask & type;
}
@@ -55,38 +285,63 @@ static inline bool isFloat(Type type) {
return (TypeClass::float_type & type) != 0;
}
+static inline bool isPacked(Type type) {
+ return (TypeClass::packed_type & type) != 0;
+}
+
+static inline bool isAvx2(Type type) {
+ return (TypeClass::avx2_type & type) != 0;
+}
+
+static inline bool isAvx512(Type type) {
+ return (TypeClass::avx512_type & type) != 0;
+}
+
+size_t requiredBytes(const Shape& shape, Type type); // towards Frank's vision of joint Shape/Type
+
template <typename T>
inline bool matchType(Type type);
// clang-format off
-template <> inline bool matchType<int8_t>(Type type) { return type == Type::int8; }
-template <> inline bool matchType<int16_t>(Type type) { return type == Type::int16; }
-template <> inline bool matchType<int32_t>(Type type) { return type == Type::int32; }
-template <> inline bool matchType<int64_t>(Type type) { return type == Type::int64; }
-
-template <> inline bool matchType<uint8_t>(Type type) { return type == Type::uint8; }
-template <> inline bool matchType<uint16_t>(Type type) { return type == Type::uint16; }
-template <> inline bool matchType<uint32_t>(Type type) { return type == Type::uint32; }
-template <> inline bool matchType<uint64_t>(Type type) { return type == Type::uint64; }
-
-template <> inline bool matchType<float>(Type type) { return type == Type::float32; }
-template <> inline bool matchType<double>(Type type) { return type == Type::float64; }
+template <> inline bool matchType<int8_t>(Type type) { return type == Type::int8; }
+template <> inline bool matchType<int16_t>(Type type) { return type == Type::int16; }
+template <> inline bool matchType<int32_t>(Type type) { return type == Type::int32; }
+template <> inline bool matchType<int64_t>(Type type) { return type == Type::int64; }
+
+// In case of packed type, it uses uint8 as underlying memory type
+template <> inline bool matchType<uint8_t>(Type type) { return type == Type::uint8; }
+template <> inline bool matchType<uint16_t>(Type type) { return type == Type::uint16; }
+template <> inline bool matchType<uint32_t>(Type type) { return type == Type::uint32; }
+template <> inline bool matchType<uint64_t>(Type type) { return type == Type::uint64; }
+
+template <> inline bool matchType<float16>(Type type) { return type == Type::float16; }
+template <> inline bool matchType<float>(Type type) { return type == Type::float32; }
+template <> inline bool matchType<double>(Type type) { return type == Type::float64; }
+
+template <> inline bool matchType<packed16>(Type type) { return type == Type::packed16; }
+template <> inline bool matchType<packed8avx2>(Type type) { return type == Type::packed8avx2; }
+template <> inline bool matchType<packed8avx512>(Type type) { return type == Type::packed8avx512; }
// clang-format on
static inline std::ostream& operator<<(std::ostream& out, Type type) {
switch(type) {
- case Type::int8: out << "int8"; break;
- case Type::int16: out << "int16"; break;
- case Type::int32: out << "int32"; break;
- case Type::int64: out << "int64"; break;
-
- case Type::uint8: out << "uint8"; break;
- case Type::uint16: out << "uint16"; break;
- case Type::uint32: out << "uint32"; break;
- case Type::uint64: out << "uint64"; break;
-
- case Type::float32: out << "float32"; break;
- case Type::float64: out << "float64"; break;
+ case Type::int8 : out << "int8"; break;
+ case Type::int16 : out << "int16"; break;
+ case Type::int32 : out << "int32"; break;
+ case Type::int64 : out << "int64"; break;
+
+ case Type::uint8 : out << "uint8"; break;
+ case Type::uint16 : out << "uint16"; break;
+ case Type::uint32 : out << "uint32"; break;
+ case Type::uint64 : out << "uint64"; break;
+
+ case Type::float16 : out << "float16"; break;
+ case Type::float32 : out << "float32"; break;
+ case Type::float64 : out << "float64"; break;
+
+ case Type::packed16 : out << "packed16"; break;
+ case Type::packed8avx2 : out << "packed8avx2"; break;
+ case Type::packed8avx512 : out << "packed8avx512"; break;
}
return out;
}
@@ -105,10 +360,71 @@ template <> inline std::string request<uint16_t>() { return "uint16"; }
template <> inline std::string request<uint32_t>() { return "uint32"; }
template <> inline std::string request<uint64_t>() { return "uint64"; }
-template <> inline std::string request<float>() { return "float32"; }
-template <> inline std::string request<double>() { return "float64"; }
+template <> inline std::string request<float16>() { return "float16"; }
+template <> inline std::string request<float>() { return "float32"; }
+template <> inline std::string request<double>() { return "float64"; }
+
+template <> inline std::string request<packed16>() { return "packed16"; }
+template <> inline std::string request<packed8avx2>() { return "packed8avx2"; }
+template <> inline std::string request<packed8avx512>() { return "packed8avx512"; }
// clang-format on
+static Type inline typeFromString(const std::string& str) {
+ if(str == "int8")
+ return Type::int8;
+ if(str == "int16")
+ return Type::int16;
+ if(str == "int32")
+ return Type::int32;
+ if(str == "int64")
+ return Type::int64;
+
+ if(str == "uint8")
+ return Type::uint8;
+ if(str == "uint16")
+ return Type::uint16;
+ if(str == "uint32")
+ return Type::uint32;
+ if(str == "uint64")
+ return Type::uint64;
+
+ if(str == "float16")
+ return Type::float16;
+ if(str == "float32")
+ return Type::float32;
+ if(str == "float64")
+ return Type::float64;
+
+ if(str == "packed16")
+ return Type::packed16;
+ if(str == "packed8avx2")
+ return Type::packed8avx2;
+ if(str == "packed8avx512")
+ return Type::packed8avx512;
+
+ ABORT("Unknown type {}", str);
+}
+
+template <typename T>
+inline Type typeId();
+
+template <> inline Type typeId<int8_t>() { return Type::int8; }
+template <> inline Type typeId<int16_t>() { return Type::int16; }
+template <> inline Type typeId<int32_t>() { return Type::int32; }
+template <> inline Type typeId<int64_t>() { return Type::int64; }
+
+template <> inline Type typeId<uint8_t>() { return Type::uint8; }
+template <> inline Type typeId<uint16_t>() { return Type::uint16; }
+template <> inline Type typeId<uint32_t>() { return Type::uint32; }
+template <> inline Type typeId<uint64_t>() { return Type::uint64; }
+
+template <> inline Type typeId<float16>() { return Type::float16; }
+template <> inline Type typeId<float>() { return Type::float32; }
+template <> inline Type typeId<double>() { return Type::float64; }
+
+template <> inline Type typeId<packed16>() { return Type::packed16; }
+template <> inline Type typeId<packed8avx2>() { return Type::packed8avx2; }
+template <> inline Type typeId<packed8avx512>() { return Type::packed8avx512; }
// Abort if given C++ does not correspond to runtime type
template <typename T>
@@ -119,4 +435,77 @@ void matchOrAbort(Type type) {
type);
}
+namespace typeFitting { // own namespace instead of in class, otherwise we get error "explicit specialization in non-namespace scope"
+
+ // Helper function for fitsIntoMax() below
+ // Returns the 'capacity' of a type: number of digits for integers,
+ // max_exponent for floats. We ignore the mantissa for floats.
+ template<typename X> constexpr int capacity() {
+ static_assert(std::is_arithmetic<X>::value || std::is_same<X,HalfFloat>::value,
+ "Wrong type for this template");
+ return (std::is_integral<X>::value
+ ? std::numeric_limits<X>::digits
+ : std::numeric_limits<X>::max_exponent);
+ }
+
+
+ // Compare max for different types as constexpr, so can be used at compile-time to determine if RequestType type max fits into ReturnType max, see std::conditional below.
+ template <typename RequestType, typename ReturnType>
+ constexpr bool fitsIntoMax() {
+ // We can't just compare std::numeric_limits<>::max(), because Clang-10
+ // complains about rounding errors when implicitly converting int to float
+ return ((!std::is_integral<RequestType>::value // RequestType is a float
+ && std::is_integral<ReturnType>::value) // ReturnType an integer
+ ? capacity<RequestType>() < capacity<ReturnType>() // special case
+ : capacity<RequestType>() <= capacity<ReturnType>()); // normal case
+ } // for built-in types everything is constexpr
+
+}
+
+template <typename ReturnType>
+class NumericLimits {
+private:
+
+ template <typename MaxType> void setLimitsMax() {
+ max = (ReturnType)std::numeric_limits<MaxType>::max();
+ lowest = (ReturnType)std::numeric_limits<MaxType>::lowest();
+ }
+
+ template <typename RequestType>
+ void setLimits() {
+ // check if the maximum of type RequestType fits into ReturnType
+ constexpr bool fits = typeFitting::fitsIntoMax<RequestType, ReturnType>();
+ // sanity check:
+ static_assert(fits || typeFitting::fitsIntoMax<ReturnType, RequestType>(),
+ "RequestType doesn't fit into ReturnType, and ReturnType doesn't "
+ "fit into RequestType. fitsIntoMax is broken!");
+ // and then use the smaller of each types to determine max, min, lowest.
+ using MaxType = typename std::conditional<fits, RequestType, ReturnType>::type;
+ setLimitsMax<MaxType>();
+ // @TODO: should we rather abort if the RequestType does not fit into ReturnType instead of clipping to smaller type?
+ // ABORT_IF(!fits, "Type {} is too small to contain max of type {}", typeId<ReturnType>(), typeId<RequestType>());
+ }
+
+ void setLimits(Type type) {
+ DISPATCH_BY_TYPE0(type, setLimits);
+ }
+
+public:
+ ReturnType max;
+ ReturnType lowest;
+
+ NumericLimits(Type type) {
+ setLimits(type);
+ }
+};
+
} // namespace marian
+
+// custom specialization of std::hash can be injected in namespace std
+namespace std {
+ template<> struct hash<::marian::Type> {
+ size_t operator()(const ::marian::Type& type) const noexcept {
+ return (size_t)type; // type is already a unique value of type size_t
+ }
+ };
+}
diff --git a/src/common/utils.cpp b/src/common/utils.cpp
index b47cd1f4..3acb756d 100755
--- a/src/common/utils.cpp
+++ b/src/common/utils.cpp
@@ -7,9 +7,22 @@
#include <iostream>
#include <sstream>
#include <string>
-#ifdef __unix__
+#include <set>
+#if defined(__unix__) || defined(__APPLE__)
#include <unistd.h>
#endif
+#include <codecvt>
+#include <cwctype>
+
+// MACOS lacks HOST_NAME_MAX
+#ifndef HOST_NAME_MAX
+# if defined(_POSIX_HOST_NAME_MAX)
+# define HOST_NAME_MAX _POSIX_HOST_NAME_MAX
+# elif defined(MAXHOSTNAMELEN)
+# define HOST_NAME_MAX MAXHOSTNAMELEN
+# endif
+#endif
+
namespace marian {
namespace utils {
@@ -26,21 +39,26 @@ void trimLeft(std::string& s) {
CLI::detail::ltrim(s, " \t\n");
}
-// @TODO: use more functions from CLI instead of own implementations
void split(const std::string& line,
- std::vector<std::string>& pieces,
- const std::string del /*= " "*/,
- bool keepEmpty) {
+ /*out*/ std::vector<std::string>& pieces,
+ const std::string& del /*= " "*/,
+ bool keepEmpty /*= false*/,
+ bool anyOf /*= false*/) {
+ pieces.clear();
size_t begin = 0;
size_t pos = 0;
std::string token;
- while((pos = line.find(del, begin)) != std::string::npos) {
+ size_t delSize = anyOf ? 1 : del.size();
+ while(true) {
+ pos = anyOf ? line.find_first_of(del, begin) : line.find(del, begin);
+ if(pos == std::string::npos)
+ break;
if(pos >= begin) {
token = line.substr(begin, pos - begin);
if(token.size() > 0 || keepEmpty)
pieces.push_back(token);
}
- begin = pos + del.size();
+ begin = pos + delSize;
}
if(pos >= begin) {
token = line.substr(begin, pos - begin);
@@ -50,46 +68,28 @@ void split(const std::string& line,
}
std::vector<std::string> split(const std::string& line,
- const std::string del /*= " "*/,
- bool keepEmpty) {
+ const std::string& del /*= " "*/,
+ bool keepEmpty /*= false*/,
+ bool anyOf /*= false*/) {
std::vector<std::string> pieces;
- split(line, pieces, del, keepEmpty);
+ split(line, pieces, del, keepEmpty, anyOf);
return pieces;
}
-// @TODO: splitAny() shares all but 2 expressions with split(). Merge them.
void splitAny(const std::string& line,
- std::vector<std::string>& pieces,
- const std::string del /*= " "*/,
- bool keepEmpty) {
- size_t begin = 0;
- size_t pos = 0;
- std::string token;
- while((pos = line.find_first_of(del, begin)) != std::string::npos) {
- if(pos >= begin) {
- token = line.substr(begin, pos - begin);
- if(token.size() > 0 || keepEmpty)
- pieces.push_back(token);
- }
- begin = pos + 1;
- }
- if(pos >= begin) {
- token = line.substr(begin, pos - begin);
- if(token.size() > 0 || keepEmpty)
- pieces.push_back(token);
- }
+ /*out*/ std::vector<std::string>& pieces,
+ const std::string& del /*= " "*/,
+ bool keepEmpty /*= false*/) {
+ split(line, pieces, del, keepEmpty, /*anyOf =*/true);
}
std::vector<std::string> splitAny(const std::string& line,
- const std::string del /*= " "*/,
- bool keepEmpty) {
- std::vector<std::string> pieces;
- splitAny(line, pieces, del, keepEmpty);
- return pieces;
+ const std::string& del /*= " "*/,
+ bool keepEmpty /*= false*/) {
+ return split(line, del, keepEmpty, /*anyOf =*/true);
}
-std::string join(const std::vector<std::string>& words,
- const std::string& del /*= " "*/) {
+std::string join(const std::vector<std::string>& words, const std::string& del /*= " "*/) {
std::stringstream ss;
if(words.empty()) {
return "";
@@ -103,51 +103,292 @@ std::string join(const std::vector<std::string>& words,
return ss.str();
}
-std::string exec(const std::string& cmd) {
+// escapes a string for passing to popen, which uses /bin/sh to parse its argument string
+static std::string escapeForPOpen(const std::string& arg) {
+ // e.g. abc -> 'abc'; my file.txt -> 'my file.txt'; $10 -> '$10'; it's -> 'it'\''s'
+ return arg;
+ // @BUGBUG: This sometimes fails with "sh: 1: Syntax error: Unterminated quoted string",
+ // so since this is not super-critical, we will disable it for now.
+ //return "'" + findReplace(arg, "'", "'\\''", /*all=*/ true) + "'";
+}
+
+// execute an external command
+// The command is composed of three pieces:
+// - the executable path, e.g. --valid-script-path
+// - an optional array of arguments. Meant for options. E.g. --valid-script-args. Options with leading - can only be passed via Yaml/Json.
+// - one more optional single argument. Meant as the main filename argument.
+// Each item will be escaped for shell syntax.
+std::string exec(const std::string& cmd, const std::vector<std::string>& args /*= {}*/, const std::string& arg /*= ""*/) {
std::array<char, 128> buffer;
std::string result;
#ifdef _WIN32
#define popen _popen
#define pclose _pclose
#endif
- std::shared_ptr<std::FILE> pipe(popen(cmd.c_str(), "r"), pclose);
+ auto cmdLine = escapeForPOpen(cmd);
+ for (const auto& a : args) // @TODO: proper escaping
+ cmdLine += " " + escapeForPOpen(a);
+ if (!arg.empty())
+ cmdLine += " " + escapeForPOpen(arg);
+ //std::cerr << "###" << cmdLine << "###" << std::endl;
+ std::shared_ptr<std::FILE> pipe(popen(cmdLine.c_str(), "r"), pclose);
if(!pipe)
ABORT("popen() failed!");
while(!std::feof(pipe.get())) {
- if(std::fgets(buffer.data(), 128, pipe.get()) != NULL)
+ if(std::fgets(buffer.data(), (int)buffer.size(), pipe.get()) != NULL)
result += buffer.data();
}
return result;
}
-std::pair<std::string, int> hostnameAndProcessId() { // helper to get hostname:pid
+std::pair<std::string, int> hostnameAndProcessId() { // helper to get hostname:pid
#ifdef _WIN32
std::string hostname = getenv("COMPUTERNAME");
auto processId = (int)GetCurrentProcessId();
#else
- static std::string hostname = [](){ // not sure if gethostname() is expensive. This way we call it only once.
- char hostnamebuf[HOST_NAME_MAX + 1] = { 0 };
+ static std::string hostname = []() { // not sure if gethostname() is expensive. This way we call it only once.
+ char hostnamebuf[HOST_NAME_MAX + 1] = {0};
gethostname(hostnamebuf, sizeof(hostnamebuf));
return std::string(hostnamebuf);
}();
auto processId = (int)getpid();
#endif
- return{ hostname, processId };
+ return {hostname, processId};
}
// format a long number with comma separators
std::string withCommas(size_t n) {
std::string res = std::to_string(n);
- for (int i = (int)res.size() - 3; i > 0; i -= 3)
+ for(int i = (int)res.size() - 3; i > 0; i -= 3)
res.insert(i, ",");
return res;
}
+bool beginsWith(const std::string& text, const std::string& prefix) {
+ return text.size() >= prefix.size()
+ && !text.compare(0, prefix.size(), prefix);
+}
+
bool endsWith(const std::string& text, const std::string& suffix) {
return text.size() >= suffix.size()
&& !text.compare(text.size() - suffix.size(), suffix.size(), suffix);
}
+// @TODO: sort these functions into a separate header.
+std::u32string utf8ToUnicodeString(std::string const& s) {
+#ifdef _MSC_VER // workaround for a known bug in VS CRT
+ std::wstring_convert<std::codecvt_utf8<unsigned int/*char32_t*/>, unsigned int/*char32_t*/> converter;
+ auto res = converter.from_bytes(s);
+ return std::u32string(res.begin(), res.end());
+#else
+ std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
+ return converter.from_bytes(s);
+#endif
+}
+
+std::string utf8FromUnicodeString(const std::u32string& s) {
+#ifdef _MSC_VER // workaround for a known bug in VS CRT
+ std::wstring_convert<std::codecvt_utf8<unsigned int/*char32_t*/>, unsigned int/*char32_t*/> converter;
+ std::basic_string<unsigned int> si(s.begin(), s.end());
+ return converter.to_bytes(si);
+#else
+ std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
+ return converter.to_bytes(s);
+#endif
+}
+
+std::u16string utf8ToUtf16String(std::string const& s) {
+#ifdef _MSC_VER // workaround for a known bug in VS CRT
+ std::wstring_convert<std::codecvt_utf8<wchar_t/*char16_t*/>, wchar_t/*char16_t*/> converter;
+ auto res = converter.from_bytes(s);
+ return std::u16string(res.begin(), res.end());
+#else
+ std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t> converter;
+ return converter.from_bytes(s);
+#endif
+}
+
+std::string utf8FromUtf16String(const std::u16string& s) {
+#ifdef _MSC_VER // workaround for a known bug in VS CRT
+ std::wstring_convert<std::codecvt_utf8<wchar_t/*char16_t*/>, wchar_t/*char16_t*/> converter;
+ std::basic_string<wchar_t> si(s.begin(), s.end());
+ return converter.to_bytes(si);
+#else
+ std::wstring_convert<std::codecvt_utf8<char16_t>, char16_t> converter;
+ return converter.to_bytes(s);
+#endif
+}
+
+// test whether a Unicode code point is in continuous script (e.g. Chinese or Thai)
+// This is used for detok bleu scoring where we have CJT characters.
+bool isContinuousScript(char32_t c) {
+ // currently, this table is hand-coded, and may need to be extended when the standard grows
+ auto in = [c](char32_t minVal, char32_t maxVal) { return c >= minVal && c <= maxVal; };
+ bool isHan = in(0x2E80, 0x2E99) || in(0x2E9B, 0x2EF3) || in(0x2F00, 0x2FD5) || in(0x3005, 0x3005) ||
+ in(0x3007, 0x3007) || in(0x3021, 0x3029) || in(0x3038, 0x303A) || in(0x303B, 0x303b) ||
+ in(0x3200, 0x32FF) || // Enclosed CJK Letters and Months, https://en.wikipedia.org/wiki/Enclosed_CJK_Letters_and_Months
+ in(0x3400, 0x4DB5) || // up to here, we have a few gaps compared to sacrebleu
+ in(0x4E00, 0x9FEF) || // sacrebleu: only up to 0x9fbb
+ in(0xF900, 0xFA6D) || in(0xFA70, 0xFAD9) || // similar to sacrebleu
+ in(0x20000, 0x2A6D6) ||
+ in(0x2A700, 0x2B734) || in(0x2B740, 0x2B81D) || in(0x2B820, 0x2CEA1) || in(0x2CEB0, 0x2EBE0) || // not in sacrebleu
+ in(0x2F800, 0x2FA1D);
+ bool isKana = in(0x3040, 0x30FF) || // Hiragana, Katakana
+ in(0x1B000, 0x1B0FF) || // Kana supplement, https://en.wikipedia.org/wiki/Kana_Supplement
+ in(0x1B130, 0x1B16F); // small Kana, https://en.wikipedia.org/wiki/Small_Kana_Extension
+ bool isThai = in(0x0E00, 0x0E7F); // https://en.wikipedia.org/wiki/Thai_(Unicode_block)
+ return isHan || isKana || isThai;
+ // Korean characters (Hangul syllables): 0xac00..0xd7a3
+ // Korean subwords (Hangul Jamo): 0x1100..0x11ff [https://en.wikipedia.org/wiki/Hangul_Jamo_(Unicode_block)]
+ // Sacrebleu uses characters units for Chinese characters; specifically, these ranges:
+ /* (ranges as used in sacrebleuy.py)
+ uchar >= u'\u2600' and uchar <= u'\u27bf' ## missing above
+ uchar >= u'\u2e80' # CJK Radicals Supplement
+ uchar >= u'\u2f00' and uchar <= u'\u2fdf' # Kangxi Radicals
+ uchar >= u'\u2ff0' # Chinese character structure
+ uchar >= u'\u3000' and uchar <= u'\u303f' # CJK punctuation mark ## 3040..30ff = Kana
+ uchar >= u'\u3100' and uchar <= u'\u312f' # Phonetic symbols
+ uchar >= u'\u31a0' # Phonetic symbols (Taiwanese and Hakka expansion)
+ uchar >= u'\u31c0' and uchar <= u'\u31ef' # CJK stroke
+ uchar >= u'\u3200' and uchar <= u'\u4db5' # CJK Unified Ideographs Extension A, release 3.0
+ uchar >= u'\u4e00' # CJK Unified Ideographs, release 1.1
+ uchar >= u'\u9fa6' and uchar <= u'\u9fbb' # CJK Unified Ideographs, release 4.1
+ uchar >= u'\uf900' and uchar <= u'\ufa2d' # CJK Compatibility Ideographs, release 1.1
+ uchar >= u'\ufa30' and uchar <= u'\ufa6a' # CJK Compatibility Ideographs, release 3.2
+ uchar >= u'\ufa70' and uchar <= u'\ufad9' # CJK Compatibility Ideographs, release 4.1
+ uchar >= u'\ufe10' and uchar <= u'\ufe1f' ## missing above
+ uchar >= u'\ufe30' and uchar <= u'\ufe4f' ## missing above
+ uchar >= u'\uff00' and uchar <= u'\uffef' # Full width ASCII, full width of English punctuation, half width Katakana, half wide half width kana, Korean alphabet
+ uchar >= u'\u20000' and uchar <= u'\u2a6d6' # CJK Unified Ideographs Extension B, release 3.1
+ uchar >= u'\u2f800' and uchar <= u'\u2fa1d' # CJK Compatibility Supplement, release 3.1
+ */
+}
+
+// convert UTF-8 characters to lower or upper case
+struct UTF8Mapper { // can't use the standard lib functions because MS-internal Philly servers do not have UTF-8 locale installed
+ std::map<char32_t, char32_t> toUpperMap, toLowerMap;
+ UTF8Mapper() {
+ /*
+ env LC_ALL=en_US.UTF-8 sed 's/\(.\)/\1\n/g' TEXT_FILE_CONTAINING_ALL_CHARS > l
+ env LC_ALL=en_US.UTF-8 sed 's/\(.\)/\U\1\n/g' TEXT_FILE_CONTAINING_ALL_CHARS > u
+ paste l u | env LC_ALL=en_US.UTF-8 sort -u > x
+ cat x | awk '{if($1 != $2){print}}' > y
+ cat y | tr -d '\r' \
+ | od -w10000 -t x1 \
+ | head -1 \
+ | sed -e 's/^0000000 /{{".x/g' -e 's/ 09 /",".x/g' -e 's/ 0a /"},{".x/g' -e 's/ 0a$/"}/' -e 's/ /.x/g' \
+ | tr '.' '\\' \
+ | xclip
+ */
+ std::vector<std::pair<std::string, std::string>> map8{ {"\xc9\x92","\xe2\xb1\xb0"},{"\x61","\x41"},{"\xc3\xa1","\xc3\x81"},{"\xc3\xa0","\xc3\x80"},{"\xe1\xba\xaf","\xe1\xba\xae"},{"\xe1\xba\xb1","\xe1\xba\xb0"},{"\xe1\xba\xb5","\xe1\xba\xb4"},{"\xe1\xba\xb3","\xe1\xba\xb2"},{"\xe1\xba\xb7","\xe1\xba\xb6"},{"\xc4\x83","\xc4\x82"},{"\xe1\xba\xa5","\xe1\xba\xa4"},{"\xe1\xba\xa7","\xe1\xba\xa6"},{"\xe1\xba\xab","\xe1\xba\xaa"},{"\xe1\xba\xa9","\xe1\xba\xa8"},{"\xe1\xba\xad","\xe1\xba\xac"},{"\xc3\xa2","\xc3\x82"},{"\xc7\x8e","\xc7\x8d"},{"\xc7\xbb","\xc7\xba"},{"\xc3\xa5","\xc3\x85"},{"\xc7\x9f","\xc7\x9e"},{"\xc3\xa4","\xc3\x84"},{"\xc3\xa3","\xc3\x83"},{"\xc4\x85","\xc4\x84"},{"\xc4\x81","\xc4\x80"},{"\xe1\xba\xa3","\xe1\xba\xa2"},{"\xc8\x83","\xc8\x82"},{"\xe1\xba\xa1","\xe1\xba\xa0"},{"\xc7\xa3","\xc7\xa2"},{"\xc3\xa6","\xc3\x86"},{"\x62","\x42"},{"\xe1\xb8\x87","\xe1\xb8\x86"},{"\x63","\x43"},{"\xc4\x87","\xc4\x86"},{"\xc4\x89","\xc4\x88"},{"\xc4\x8d","\xc4\x8c"},{"\xc4\x8b","\xc4\x8a"},{"\xc3\xa7","\xc3\x87"},{"\x64","\x44"},{"\xc4\x8f","\xc4\x8e"},{"\xc4\x91","\xc4\x90"},{"\xe1\xb8\x91","\xe1\xb8\x90"},{"\xe1\xb8\x8d","\xe1\xb8\x8c"},{"\xe1\xb8\x8f","\xe1\xb8\x8e"},{"\xc3\xb0","\xc3\x90"},{"\x65","\x45"},{"\xc3\xa9","\xc3\x89"},{"\xc3\xa8","\xc3\x88"},{"\xc4\x95","\xc4\x94"},{"\xe1\xba\xbf","\xe1\xba\xbe"},{"\xe1\xbb\x81","\xe1\xbb\x80"},{"\xe1\xbb\x85","\xe1\xbb\x84"},{"\xe1\xbb\x83","\xe1\xbb\x82"},{"\xe1\xbb\x87","\xe1\xbb\x86"},{"\xc3\xaa","\xc3\x8a"},{"\xc4\x9b","\xc4\x9a"},{"\xc3\xab","\xc3\x8b"},{"\xe1\xba\xbd","\xe1\xba\xbc"},{"\xc4\x97","\xc4\x96"},{"\xc4\x99","\xc4\x98"},{"\xe1\xb8\x97","\xe1\xb8\x96"},{"\xc4\x93","\xc4\x92"},{"\xe1\xba\xbb","\xe1\xba\xba"},{"\xc8\x87","\xc8\x86"},{"\xe1\xba\xb9","\xe1\xba\xb8"},{"\xc7\x9d","\xc6\x8e"},{"\x66","\x46"},{"\x67","\x47"},{"\xc7\xb5","\xc7\xb4"},{"\xc4\x9f","\xc4\x9e"},{"\xc4\x9d","\xc4\x9c"},{"\xc7\xa7","\xc7\xa6"},{"\xc4\xa1","\xc4\xa0"},{"\xc4\xa3","\xc4\xa2"},{"\xc9\xa0","\xc6\x93"},{"\x68","\x48"},{"\xc4\xa5","\xc4\xa4"},{"\xc4\xa7","\xc4\xa6"},{"\xe1\xb8\xa9","\xe1\xb8\xa8"},{"\xe1\xb8\xa5","\xe1\xb8\xa4"},{"\xe1\xb8\xab","\xe1\xb8\xaa"},{"\x69","\x49"},{"\xc4\xb1","\x49"},{"\xc3\xad","\xc3\x8d"},{"\xc3\xac","\xc3\x8c"},{"\xc4\xad","\xc4\xac"},{"\xc3\xae","\xc3\x8e"},{"\xc7\x90","\xc7\x8f"},{"\xc3\xaf","\xc3\x8f"},{"\xc4\xa9","\xc4\xa8"},{"\xc4\xaf","\xc4\xae"},{"\xc4\xab","\xc4\xaa"},{"\xe1\xbb\x89","\xe1\xbb\x88"},{"\xc8\x8b","\xc8\x8a"},{"\xe1\xbb\x8b","\xe1\xbb\x8a"},{"\x6a","\x4a"},{"\xc4\xb5","\xc4\xb4"},{"\x6b","\x4b"},{"\xe1\xb8\xb1","\xe1\xb8\xb0"},{"\xc4\xb7","\xc4\xb6"},{"\xe1\xb8\xb3","\xe1\xb8\xb2"},{"\xc6\x99","\xc6\x98"},{"\x6c","\x4c"},{"\xc4\xba","\xc4\xb9"},{"\xc4\xbe","\xc4\xbd"},{"\xc5\x82","\xc5\x81"},{"\xc4\xbc","\xc4\xbb"},{"\xe1\xb8\xb7","\xe1\xb8\xb6"},{"\x6d","\x4d"},{"\xe1\xb8\xbf","\xe1\xb8\xbe"},{"\xe1\xb9\x83","\xe1\xb9\x82"},{"\xc5\x8b","\xc5\x8a"},{"\x6e","\x4e"},{"\xc5\x84","\xc5\x83"},{"\xc5\x88","\xc5\x87"},{"\xc3\xb1","\xc3\x91"},{"\xe1\xb9\x85","\xe1\xb9\x84"},{"\xc5\x86","\xc5\x85"},{"\xe1\xb9\x87","\xe1\xb9\x86"},{"\xe1\xb9\x89","\xe1\xb9\x88"},{"\xc5\x93","\xc5\x92"},{"\x6f","\x4f"},{"\xc3\xb3","\xc3\x93"},{"\xc3\xb2","\xc3\x92"},{"\xc5\x8f","\xc5\x8e"},{"\xe1\xbb\x91","\xe1\xbb\x90"},{"\xe1\xbb\x93","\xe1\xbb\x92"},{"\xe1\xbb\x95","\xe1\xbb\x94"},{"\xe1\xbb\x99","\xe1\xbb\x98"},{"\xc3\xb4","\xc3\x94"},{"\xc7\x92","\xc7\x91"},{"\xc3\xb6","\xc3\x96"},{"\xc5\x91","\xc5\x90"},{"\xc3\xb5","\xc3\x95"},{"\xc3\xb8","\xc3\x98"},{"\xc7\xab","\xc7\xaa"},{"\xc5\x8d","\xc5\x8c"},{"\xe1\xbb\x8f","\xe1\xbb\x8e"},{"\xc8\x8f","\xc8\x8e"},{"\xe1\xbb\x8d","\xe1\xbb\x8c"},{"\xe1\xbb\x9b","\xe1\xbb\x9a"},{"\xe1\xbb\x9d","\xe1\xbb\x9c"},{"\xe1\xbb\xa1","\xe1\xbb\xa0"},{"\xe1\xbb\x9f","\xe1\xbb\x9e"},{"\xe1\xbb\xa3","\xe1\xbb\xa2"},{"\xc6\xa1","\xc6\xa0"},{"\xc9\x94","\xc6\x86"},{"\x70","\x50"},{"\xe1\xb9\x95","\xe1\xb9\x94"},{"\x71","\x51"},{"\x72","\x52"},{"\xc5\x95","\xc5\x94"},{"\xc5\x99","\xc5\x98"},{"\xc5\x97","\xc5\x96"},{"\xe1\xb9\x9b","\xe1\xb9\x9a"},{"\xe1\xb9\x9f","\xe1\xb9\x9e"},{"\x73","\x53"},{"\xc5\x9b","\xc5\x9a"},{"\xc5\x9d","\xc5\x9c"},{"\xc5\xa1","\xc5\xa0"},{"\xc5\x9f","\xc5\x9e"},{"\xe1\xb9\xa3","\xe1\xb9\xa2"},{"\x74","\x54"},{"\xc5\xa5","\xc5\xa4"},{"\xc5\xa3","\xc5\xa2"},{"\xe1\xb9\xad","\xe1\xb9\xac"},{"\xe1\xb9\xaf","\xe1\xb9\xae"},{"\xc8\x95","\xc8\x94"},{"\x75","\x55"},{"\xc3\xba","\xc3\x9a"},{"\xc3\xb9","\xc3\x99"},{"\xc5\xad","\xc5\xac"},{"\xc3\xbb","\xc3\x9b"},{"\xc7\x94","\xc7\x93"},{"\xc5\xaf","\xc5\xae"},{"\xc7\x98","\xc7\x97"},{"\xc7\x9c","\xc7\x9b"},{"\xc3\xbc","\xc3\x9c"},{"\xc5\xb1","\xc5\xb0"},{"\xc5\xa9","\xc5\xa8"},{"\xc5\xb3","\xc5\xb2"},{"\xc5\xab","\xc5\xaa"},{"\xe1\xbb\xa7","\xe1\xbb\xa6"},{"\xe1\xbb\xa5","\xe1\xbb\xa4"},{"\xe1\xb9\xb3","\xe1\xb9\xb2"},{"\xe1\xbb\xa9","\xe1\xbb\xa8"},{"\xe1\xbb\xab","\xe1\xbb\xaa"},{"\xe1\xbb\xaf","\xe1\xbb\xae"},{"\xe1\xbb\xad","\xe1\xbb\xac"},{"\xe1\xbb\xb1","\xe1\xbb\xb0"},{"\xc6\xb0","\xc6\xaf"},{"\x76","\x56"},{"\x77","\x57"},{"\xc5\xb5","\xc5\xb4"},{"\x78","\x58"},{"\xe1\xba\x8b","\xe1\xba\x8a"},{"\x79","\x59"},{"\xc3\xbd","\xc3\x9d"},{"\xe1\xbb\xb3","\xe1\xbb\xb2"},{"\xc5\xb7","\xc5\xb6"},{"\xc3\xbf","\xc5\xb8"},{"\xe1\xbb\xb9","\xe1\xbb\xb8"},{"\x7a","\x5a"},{"\xc5\xba","\xc5\xb9"},{"\xc5\xbe","\xc5\xbd"},{"\xc5\xbc","\xc5\xbb"},{"\xc6\xb6","\xc6\xb5"},{"\xe1\xba\x93","\xe1\xba\x92"},{"\xe1\xba\x95","\xe1\xba\x94"},{"\xc8\xa5","\xc8\xa4"},{"\xc3\xbe","\xc3\x9e"},{"\xca\x92","\xc6\xb7"},{"\xce\xb1","\xce\x91"},{"\xce\xac","\xce\x86"},{"\xce\xb2","\xce\x92"},{"\xce\xb3","\xce\x93"},{"\xce\xb4","\xce\x94"},{"\xce\xb5","\xce\x95"},{"\xce\xad","\xce\x88"},{"\xce\xb6","\xce\x96"},{"\xce\xb7","\xce\x97"},{"\xce\xae","\xce\x89"},{"\xce\xb8","\xce\x98"},{"\xce\xb9","\xce\x99"},{"\xce\xaf","\xce\x8a"},{"\xcf\x8a","\xce\xaa"},{"\xce\xba","\xce\x9a"},{"\xce\xbb","\xce\x9b"},{"\xce\xbc","\xce\x9c"},{"\xce\xbd","\xce\x9d"},{"\xce\xbe","\xce\x9e"},{"\xce\xbf","\xce\x9f"},{"\xcf\x8c","\xce\x8c"},{"\xcf\x80","\xce\xa0"},{"\xcf\x83","\xce\xa3"},{"\xcf\x82","\xce\xa3"},{"\xcf\x84","\xce\xa4"},{"\xcf\x85","\xce\xa5"},{"\xcf\x8d","\xce\x8e"},{"\xcf\x8b","\xce\xab"},{"\xcf\x86","\xce\xa6"},{"\xcf\x87","\xce\xa7"},{"\xcf\x88","\xce\xa8"},{"\xcf\x89","\xce\xa9"},{"\xcf\x8e","\xce\x8f"},{"\xd0\xb0","\xd0\x90"},{"\xd3\x93","\xd3\x92"},{"\xd3\x95","\xd3\x94"},{"\xd0\xb1","\xd0\x91"},{"\xd0\xb2","\xd0\x92"},{"\xd0\xb3","\xd0\x93"},{"\xd2\x93","\xd2\x92"},{"\xd2\x91","\xd2\x90"},{"\xd0\xb4","\xd0\x94"},{"\xd1\x93","\xd0\x83"},{"\xd1\x92","\xd0\x82"},{"\xd0\xb5","\xd0\x95"},{"\xd1\x90","\xd0\x80"},{"\xd3\x99","\xd3\x98"},{"\xd1\x94","\xd0\x84"},{"\xd1\x91","\xd0\x81"},{"\xd0\xb6","\xd0\x96"},{"\xd0\xb7","\xd0\x97"},{"\xd2\x99","\xd2\x98"},{"\xd1\x95","\xd0\x85"},{"\xd0\xb8","\xd0\x98"},{"\xd3\xa3","\xd3\xa2"},{"\xd1\x96","\xd0\x86"},{"\xd1\x97","\xd0\x87"},{"\xd0\xb9","\xd0\x99"},{"\xd1\x98","\xd0\x88"},{"\xd0\xba","\xd0\x9a"},{"\xd2\x9b","\xd2\x9a"},{"\xd3\x84","\xd3\x83"},{"\xd2\xa1","\xd2\xa0"},{"\xd0\xbb","\xd0\x9b"},{"\xd1\x99","\xd0\x89"},{"\xd0\xbc","\xd0\x9c"},{"\xd0\xbd","\xd0\x9d"},{"\xd2\xa3","\xd2\xa2"},{"\xd1\x9a","\xd0\x8a"},{"\xd0\xbe","\xd0\x9e"},{"\xd3\xa7","\xd3\xa6"},{"\xd3\xa9","\xd3\xa8"},{"\xd0\xbf","\xd0\x9f"},{"\xd1\x80","\xd0\xa0"},{"\xd1\x81","\xd0\xa1"},{"\xd2\xab","\xd2\xaa"},{"\xd1\x82","\xd0\xa2"},{"\xd1\x9c","\xd0\x8c"},{"\xd1\x9b","\xd0\x8b"},{"\xd1\x83","\xd0\xa3"},{"\xd3\xb1","\xd3\xb0"},{"\xd2\xb1","\xd2\xb0"},{"\xd2\xaf","\xd2\xae"},{"\xd1\x9e","\xd0\x8e"},{"\xd1\x84","\xd0\xa4"},{"\xd1\x85","\xd0\xa5"},{"\xd2\xb3","\xd2\xb2"},{"\xd2\xbb","\xd2\xba"},{"\xd1\x86","\xd0\xa6"},{"\xd1\x87","\xd0\xa7"},{"\xd1\x9f","\xd0\x8f"},{"\xd1\x88","\xd0\xa8"},{"\xd1\x89","\xd0\xa9"},{"\xd1\x8a","\xd0\xaa"},{"\xd1\x8b","\xd0\xab"},{"\xd1\x8c","\xd0\xac"},{"\xd1\x8d","\xd0\xad"},{"\xd1\x8e","\xd0\xae"},{"\xd1\x8f","\xd0\xaf"},{"\xd5\xa1","\xd4\xb1"},{"\xd5\xa3","\xd4\xb3"},{"\xd5\xa5","\xd4\xb5"},{"\xd5\xab","\xd4\xbb"},{"\xd5\xac","\xd4\xbc"},{"\xd5\xb2","\xd5\x82"},{"\xd5\xb8","\xd5\x88"},{"\xd5\xbd","\xd5\x8d"},{"\xd5\xbe","\xd5\x8e"},{"\xd5\xbf","\xd5\x8f"},{"\xd6\x80","\xd5\x90"},{"\xd6\x81","\xd5\x91"} };
+ for (auto p8 : map8) {
+ auto from = utf8ToUnicodeString(p8.first);
+ auto to = utf8ToUnicodeString(p8.second);
+ ABORT_IF(from.size() != 1 || to.size() != 1, "Incorrect character encoding??");
+ toUpperMap.insert(std::make_pair(from.front(), to.front()));
+ toLowerMap.insert(std::make_pair(to.front(), from.front()));
+ }
+ }
+ char32_t toUpperOrLower(char32_t c, bool toLower) const { return mapChar(toLower ? toLowerMap : toUpperMap, c); }
+private:
+ static char32_t mapChar(const std::map<char32_t, char32_t>& map, char32_t c) {
+ auto iter = map.find(c);
+ if (iter == map.end())
+ return c;
+ else
+ return iter->second;
+ }
+};
+
+// shared implementation of toUpper, toLower, and toCapitalized
+static std::string utf8ToUpperOrLower(const std::string& s, bool toLower, bool toInitCap) {
+ static UTF8Mapper utf8Mapper;
+ auto ws = utf8ToUnicodeString(s);
+ for (auto& c : ws) {
+ c = utf8Mapper.toUpperOrLower(c, toLower);
+ if (toInitCap)
+ toLower = true;
+ }
+ return utf8FromUnicodeString(ws);
+}
+
+std::string utf8ToUpper(const std::string& s) { return utf8ToUpperOrLower(s, /*toLower=*/false, /*toInitCap=*/false); }
+std::string utf8ToLower(const std::string& s) { return utf8ToUpperOrLower(s, /*toLower=*/true , /*toInitCap=*/false); }
+std::string utf8Capitalized(const std::string& s) { return utf8ToUpperOrLower(s, /*toLower=*/false, /*toInitCap=*/true ); }
+
+// convert an English sentence to title case
+// Since title case is an English thing, we only consider ASCII characters.
+std::string toEnglishTitleCase(const std::string& s) {
+ auto res = s;
+ // process token by token
+ const std::string wordStartChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
+ const std::string wordInternalChars = wordStartChars + "'"; // don't title-case letters after word-internal apostrophe
+ const std::set<std::string> exceptions = { // from moses-scripts/scripts/recaser/detruecase.perl
+ "a","after","against","al-.+","and","any","as","at","be","because","between","by","during","el-.+","for","from","his","in","is","its","last","not","of","off","on","than","the","their","this","to","was","were","which","will","with"
+ };
+ const std::set<char> wordPredChars = {' ', '"', '\'', '-'}; // only capitalize words if following these characters (to avoid upper-casing word-internal SPM units)
+ // These are tokenization heuristics, which may be incomplete.
+ size_t epos = 0;
+ for(size_t pos = epos; pos < res.size(); pos = epos) {
+ // locate the next word
+ pos = res.find_first_of(wordStartChars, pos); // find first letter
+ if (pos == std::string::npos)
+ break;
+ epos = res.find_first_not_of(wordInternalChars, pos + 1); // find first non-letter
+ if (epos == std::string::npos)
+ epos = res.size();
+ auto word = res.substr(pos, epos - pos);
+ // further checks of the word
+ if (res[pos] < 'a' || res[pos] > 'z') // skip if already upper-case
+ continue;
+ if (pos > 0 && wordPredChars.find(res[pos-1]) == wordPredChars.end()) // skip if unexpected char before the word
+ continue;
+ if (exceptions.find(word) != exceptions.end()) // skip if in the exception list
+ continue;
+ // upper-case it
+ res[pos] -= 'a' - 'A';
+ }
+ return res;
+}
+
+std::string findReplace(const std::string& in, const std::string& what, const std::string& withWhat, bool all /*= false*/) {
+ std::string res = in;
+ for(size_t pos = res.find(what); pos != std::string::npos; pos = res.find(what, pos + withWhat.length())) {
+ res.replace(pos, what.length(), withWhat);
+ if (!all)
+ break;
+ }
+ return res;
+}
+
+double parseDouble(std::string s) {
+ double res;
+ char c; // dummy char -- if we succeed to parse this, then there were extraneous characters after the number
+ auto rc = sscanf(s.c_str(), "%lf%c", &res, &c);
+ ABORT_IF(rc != 1, "Mal-formed number: {}", s);
+ return res;
+}
+
+// parses a user-friendly number that can have commas and (some) units
+double parseNumber(std::string param) {
+ // get unit prefix
+ double factor = 1.;
+ if(!param.empty() && param.back() >= 'A') {
+ switch(param.back()) {
+ case 'k': factor = 1.e3; break;
+ case 'M': factor = 1.e6; break;
+ case 'G': factor = 1.e9; break;
+ case 'T': factor = 1.e12; break;
+ default: ABORT("Invalid or unsupported unit prefix '{}' in {}", param.back(), param);
+ }
+ param.pop_back();
+ }
+ // we allow users to place commas in numbers (note: we are not actually verifying that they are in
+ // the right place)
+ std::remove_if(param.begin(), param.end(), [](char c) { return c == ','; });
+ return factor * parseDouble(param);
+}
+
} // namespace utils
} // namespace marian
diff --git a/src/common/utils.h b/src/common/utils.h
index cd4fc6de..c3266bbf 100755..100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -12,31 +12,47 @@ void trimRight(std::string& s);
void split(const std::string& line,
std::vector<std::string>& pieces,
- const std::string del = " ",
- bool keepEmpty = false);
-
-std::vector<std::string> split(const std::string& line,
- const std::string del = " ",
- bool keepEmpty = false);
-
+ const std::string& del = " ",
+ bool keepEmpty = false,
+ bool anyOf = false);
void splitAny(const std::string& line,
std::vector<std::string>& pieces,
- const std::string del = " ",
+ const std::string& del = " ",
bool keepEmpty = false);
+std::vector<std::string> split(const std::string& line,
+ const std::string& del = " ",
+ bool keepEmpty = false,
+ bool anyOf = false);
std::vector<std::string> splitAny(const std::string& line,
- const std::string del = " ",
+ const std::string& del = " ",
bool keepEmpty = false);
-std::string join(const std::vector<std::string>& words,
- const std::string& del = " ");
+std::string join(const std::vector<std::string>& words, const std::string& del = " ");
-std::string exec(const std::string& cmd);
+std::string exec(const std::string& cmd, const std::vector<std::string>& args = {}, const std::string& arg = "");
std::pair<std::string, int> hostnameAndProcessId();
std::string withCommas(size_t n);
+bool beginsWith(const std::string& text, const std::string& prefix);
bool endsWith(const std::string& text, const std::string& suffix);
+std::string utf8ToUpper(const std::string& s);
+std::string utf8ToLower(const std::string& s);
+std::string utf8Capitalized(const std::string& word); // capitalize the first character only
+std::string toEnglishTitleCase(const std::string& s);
+
+std::u32string utf8ToUnicodeString(const std::string& s);
+std::string utf8FromUnicodeString(const std::u32string& s);
+std::u16string utf8ToUtf16String(const std::string& s);
+std::string utf8FromUtf16String(const std::u16string& s);
+bool isContinuousScript(char32_t c);
+
+std::string findReplace(const std::string& in, const std::string& what, const std::string& withWhat, bool all = false);
+
+double parseDouble(std::string s);
+double parseNumber(std::string s);
+
} // namespace utils
} // namespace marian
diff --git a/src/common/version.cpp b/src/common/version.cpp
index 75814d92..75814d92 100755..100644
--- a/src/common/version.cpp
+++ b/src/common/version.cpp
diff --git a/src/common/version.h b/src/common/version.h
index a425af93..a425af93 100755..100644
--- a/src/common/version.h
+++ b/src/common/version.h
diff --git a/src/data/alignment.cpp b/src/data/alignment.cpp
index 7a2a0003..928beb21 100755..100644
--- a/src/data/alignment.cpp
+++ b/src/data/alignment.cpp
@@ -8,9 +8,7 @@ namespace data {
WordAlignment::WordAlignment() {}
-WordAlignment::WordAlignment(
- const std::vector<Point>& align)
- : data_(align) {}
+WordAlignment::WordAlignment(const std::vector<Point>& align) : data_(align) {}
WordAlignment::WordAlignment(const std::string& line) {
std::vector<std::string> atok = utils::splitAny(line, " -");
diff --git a/src/data/alignment.h b/src/data/alignment.h
index 49fbde76..1c68bb39 100755..100644
--- a/src/data/alignment.h
+++ b/src/data/alignment.h
@@ -52,7 +52,9 @@ public:
std::string toString() const;
};
-typedef std::vector<std::vector<float>> SoftAlignment;
+// soft alignment = P(src pos|trg pos) for each beam and batch index, stored in a flattened CPU-side array
+// Also used on QuickSAND boundary where beam and batch size is 1. Then it is simply [t][s] -> P(s|t)
+typedef std::vector<std::vector<float>> SoftAlignment; // [trg pos][beam depth * max src length * batch size]
WordAlignment ConvertSoftAlignToHardAlign(SoftAlignment alignSoft,
float threshold = 1.f);
diff --git a/src/data/batch.h b/src/data/batch.h
index a7832585..3c592b31 100755
--- a/src/data/batch.h
+++ b/src/data/batch.h
@@ -17,16 +17,16 @@ public:
virtual size_t wordsTrg() const { return 0; };
virtual size_t widthTrg() const { return 0; };
- virtual void debug(){};
+ virtual void debug(bool /*printIndices*/ = false) {};
- virtual std::vector<Ptr<Batch>> split(size_t n) = 0;
+ virtual std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit = SIZE_MAX) = 0;
const std::vector<size_t>& getSentenceIds() const { return sentenceIds_; }
void setSentenceIds(const std::vector<size_t>& ids) { sentenceIds_ = ids; }
virtual void setGuidedAlignment(std::vector<float>&&) = 0;
virtual void setDataWeights(const std::vector<float>&) = 0;
-
+ virtual ~Batch() {};
protected:
std::vector<size_t> sentenceIds_;
};
diff --git a/src/data/batch_generator.h b/src/data/batch_generator.h
index 4578c3bb..f16a7a81 100755..100644
--- a/src/data/batch_generator.h
+++ b/src/data/batch_generator.h
@@ -56,7 +56,7 @@ public:
typedef typename DataSet::batch_ptr BatchPtr;
typedef typename DataSet::Sample Sample;
- typedef std::vector<Sample> Samples; // @TODO: type names should be capitalized
+ typedef std::vector<Sample> Samples;
typedef BatchIterator<BatchGenerator> iterator;
friend iterator;
@@ -65,7 +65,13 @@ protected:
Ptr<DataSet> data_;
Ptr<Options> options_;
bool restored_{false};
- bool shuffle_;
+
+ // replacing old shuffle_ with two variants that determine more fine-grained shuffling behavior.
+ // Both set to false is equivalent to old shuffle_ == false.
+ // Now we can not shuffle the data, but shuffle batches. Useful for linear reading of very large data sets with pre-reading.
+ // Parameters like maxi-batch determine how much data is pre-read and sorted by length or other criteria.
+ bool shuffleData_{false}; // determine if full data should be shuffled before reading and batching.
+ bool shuffleBatches_{false}; // determine if batches should be shuffled after batching.
private:
Ptr<BatchStats> stats_;
@@ -83,7 +89,6 @@ private:
// this runs on a bg thread; sequencing is handled by caller, but locking is done in here
std::deque<BatchPtr> fetchBatches() {
- //LOG(info, "fillBatches entered");
typedef typename Sample::value_type Item;
auto itemCmp = [](const Item& sa, const Item& sb) { return sa.size() < sb.size(); }; // sort by element length, not content
@@ -97,7 +102,7 @@ private:
a.rbegin(), a.rend(), b.rbegin(), b.rend(), itemCmp);
};
- auto cmpNone = [](const Sample& a, const Sample& b) { return &a < &b; }; // instead sort by address, so we have something to work with
+ auto cmpNone = [](const Sample& a, const Sample& b) { return a.getId() < b.getId(); }; // sort in order of original ids = original data order unless shuffling
typedef std::function<bool(const Sample&, const Sample&)> cmp_type;
typedef std::priority_queue<Sample, Samples, cmp_type> sample_queue;
@@ -118,8 +123,6 @@ private:
size_t maxBatchSize = options_->get<int>("mini-batch");
size_t maxSize = maxBatchSize * options_->get<int>("maxi-batch");
- // LOG(info, "Preloading batches");
-
// consume data from corpus into maxi-batch (single sentences)
// sorted into specified order (due to queue)
if(newlyPrepared_) {
@@ -133,16 +136,14 @@ private:
while(current_ != data_->end() && maxiBatch->size() < maxSize) { // loop over data
maxiBatch->push(*current_);
sets = current_->size();
- // do not consume more than required for the maxi batch as this causes
- // that line-by-line translation is delayed by one sentence
- bool last = maxiBatch->size() == maxSize;
+ // do not consume more than required for the maxi batch as this causes
+ // that line-by-line translation is delayed by one sentence
+ bool last = maxiBatch->size() == maxSize;
if(!last)
++current_; // this actually reads the next line and pre-processes it
}
size_t numSentencesRead = maxiBatch->size();
- // LOG(info, "Turning samples into batches");
-
// construct the actual batches and place them in the queue
Samples batchVector;
size_t currentWords = 0;
@@ -152,7 +153,6 @@ private:
// process all loaded sentences in order of increasing length
// @TODO: we could just use a vector and do a sort() here; would make the cost more explicit
- //LOG(info, "begin form batches, #lines = {}", maxiBatch->size());
const size_t mbWords = options_->get<size_t>("mini-batch-words", 0);
const bool useDynamicBatching = options_->has("mini-batch-fit");
BatchStats::const_iterator cachedStatsIter;
@@ -205,21 +205,31 @@ private:
}
// turn rest into batch
+ // @BUGBUG: This can create a very small batch, which with ce-mean-words can artificially
+ // inflate the contribution of the sames in the batch, causing instability.
+ // I think a good alternative would be to carry over the left-over sentences into the next round.
if(!batchVector.empty())
tempBatches.push_back(data_->toBatch(batchVector));
- //LOG(info, "end form batches, #tempBatches = {}", tempBatches.size());
// Shuffle the batches
- if(shuffle_) {
+ if(shuffleBatches_) {
std::shuffle(tempBatches.begin(), tempBatches.end(), eng_);
}
- LOG(debug, "[data] fetched {} batches with {} sentences.", tempBatches.size(), numSentencesRead);
+ double totalSent{}, totalLabels{};
+ for (auto& b : tempBatches) {
+ totalSent += (double)b->size();
+ totalLabels += (double)b->words(-1);
+ }
+ auto totalDenom = tempBatches.empty() ? 1 : tempBatches.size(); // (make 0/0 = 0)
+ LOG(debug, "[data] fetched {} batches with {} sentences. Per batch: {} sentences, {} labels.",
+ tempBatches.size(), numSentencesRead,
+ (double)totalSent / (double)totalDenom, (double)totalLabels / (double)totalDenom);
return tempBatches;
}
// this starts fillBatches() as a background operation
void fetchBatchesAsync() {
- ABORT_IF(futureBufferedBatches_.valid(), "attempted to restart futureBufferedBatches_ while still running");
+ ABORT_IF(futureBufferedBatches_.valid(), "Attempted to restart futureBufferedBatches_ while still running");
futureBufferedBatches_ = threadPool_.enqueue([this]() {
return fetchBatches();
});
@@ -229,7 +239,9 @@ private:
if(bufferedBatches_.empty()) {
// out of data: need to get next batch from background thread
// We only get here if the future has been scheduled to run; it must be valid.
- ABORT_IF(!futureBufferedBatches_.valid(), "attempted to wait for futureBufferedBatches_ when none pending");
+ ABORT_IF(!futureBufferedBatches_.valid(), "Attempted to wait for futureBufferedBatches_ when none pending.\n"
+ "This error often occurs when Marian tries to restore the training data iterator, but the corpus has been changed or replaced.\n"
+ "If you have changed the training corpus, add --no-restore-corpus to the training command and run it again.");
bufferedBatches_ = std::move(futureBufferedBatches_.get());
// if bg thread returns an empty swath, we hit the end of the epoch
if (bufferedBatches_.empty()) {
@@ -248,7 +260,11 @@ public:
BatchGenerator(Ptr<DataSet> data,
Ptr<Options> options,
Ptr<BatchStats> stats = nullptr)
- : data_(data), options_(options), stats_(stats), threadPool_(1) {}
+ : data_(data), options_(options), stats_(stats), threadPool_(1) {
+ auto shuffle = options_->get<std::string>("shuffle");
+ shuffleData_ = shuffle == "data";
+ shuffleBatches_ = shuffleData_ || shuffle == "batches";
+ }
~BatchGenerator() {
if (futureBufferedBatches_.valid()) // bg thread holds a reference to 'this',
@@ -264,23 +280,20 @@ public:
}
// @TODO: get rid of this function, begin() or constructor should figure this out
- void prepare(bool shuffle = true) {
- if(shuffle)
+ void prepare() {
+ if(shuffleData_)
data_->shuffle();
else
data_->reset();
newlyPrepared_ = true;
- // @TODO: solve this better, maybe use options
- shuffle_ = shuffle;
-
// start the background pre-fetch operation
fetchBatchesAsync();
}
// Used to restore the state of a BatchGenerator after
// an interrupted and resumed training.
- bool restore(Ptr<TrainingState> state, bool shuffle) {
+ bool restore(Ptr<TrainingState> state) {
if(state->epochs == 1 && state->batchesEpoch == 0)
return false;
@@ -294,12 +307,24 @@ public:
setRNGState(state->seedBatch);
}
- prepare(shuffle);
+ prepare();
for(size_t i = 0; i < state->batchesEpoch; ++i)
next();
return true;
}
+
+ // this is needed for dynamic MB scaling. Returns 0 if size is not known in words.
+ size_t estimateTypicalTrgBatchWords() const {
+ const size_t mbWords = options_->get<size_t>("mini-batch-words", 0);
+ const bool useDynamicBatching = options_->has("mini-batch-fit");
+ if (useDynamicBatching && stats_)
+ return stats_->estimateTypicalTrgWords();
+ else if (mbWords)
+ return mbWords;
+ else
+ return 0;
+ }
};
class CorpusBatchGenerator : public BatchGenerator<CorpusBase>,
diff --git a/src/data/batch_stats.h b/src/data/batch_stats.h
index 50791ed0..bba055d5 100755..100644
--- a/src/data/batch_stats.h
+++ b/src/data/batch_stats.h
@@ -39,16 +39,28 @@ public:
return it->second;
}
- void add(Ptr<data::CorpusBatch> batch, size_t multiplier = 1) {
+ void add(Ptr<data::CorpusBatch> batch, double multiplier = 1.) {
std::vector<size_t> lengths;
for(size_t i = 0; i < batch->sets(); ++i)
lengths.push_back((*batch)[i]->batchWidth());
- size_t batchSize = batch->size() * multiplier;
+ size_t batchSize = (size_t)ceil((double)batch->size() * multiplier);
if(map_[lengths] < batchSize)
map_[lengths] = batchSize;
}
+ // return a rough minibatch size in labels
+ // We average over all (batch sizes * max trg length).
+ size_t estimateTypicalTrgWords() const {
+ size_t sum = 0;
+ for (const auto& entry : map_) {
+ auto maxTrgLength = entry.first.back();
+ auto numSentences = entry.second;
+ sum += numSentences * maxTrgLength;
+ }
+ return sum / map_.size();
+ }
+
// helpers for multi-node --note: presently unused, but keeping them around for later use
// serialize into a flat vector, for MPI data exchange
std::vector<size_t> flatten() const {
diff --git a/src/data/corpus.cpp b/src/data/corpus.cpp
index 7a7a846e..a861e663 100755
--- a/src/data/corpus.cpp
+++ b/src/data/corpus.cpp
@@ -4,21 +4,49 @@
#include <random>
#include "common/utils.h"
+#include "common/filesystem.h"
+
#include "data/corpus.h"
namespace marian {
namespace data {
Corpus::Corpus(Ptr<Options> options, bool translate /*= false*/)
- : CorpusBase(options, translate), shuffleInRAM_(options_->get<bool>("shuffle-in-ram")) {}
+ : CorpusBase(options, translate),
+ shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)),
+ allCapsEvery_(options_->get<size_t>("all-caps-every", 0)),
+ titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {}
Corpus::Corpus(std::vector<std::string> paths,
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> options)
- : CorpusBase(paths, vocabs, options), shuffleInRAM_(options_->get<bool>("shuffle-in-ram")) {}
+ : CorpusBase(paths, vocabs, options),
+ shuffleInRAM_(options_->get<bool>("shuffle-in-ram", false)),
+ allCapsEvery_(options_->get<size_t>("all-caps-every", 0)),
+ titleCaseEvery_(options_->get<size_t>("english-title-case-every", 0)) {}
+
+void Corpus::preprocessLine(std::string& line, size_t streamId) {
+ if (allCapsEvery_ != 0 && pos_ % allCapsEvery_ == 0 && !inference_) {
+ line = vocabs_[streamId]->toUpper(line);
+ if (streamId == 0)
+ LOG_ONCE(info, "[data] Source all-caps'ed line to: {}", line);
+ else
+ LOG_ONCE(info, "[data] Target all-caps'ed line to: {}", line);
+ }
+ else if (titleCaseEvery_ != 0 && pos_ % titleCaseEvery_ == 1 && !inference_ && streamId == 0) {
+ // Only applied to stream 0 (source) since this feature is aimed at robustness against
+ // title case in the source (and not at translating into title case).
+ // Note: It is user's responsibility to not enable this if the source language is not English.
+ line = vocabs_[streamId]->toEnglishTitleCase(line);
+ if (streamId == 0)
+ LOG_ONCE(info, "[data] Source English-title-case'd line to: {}", line);
+ else
+ LOG_ONCE(info, "[data] Target English-title-case'd line to: {}", line);
+ }
+}
SentenceTuple Corpus::next() {
- for (;;) { // (this is a retry loop for skipping invalid sentences)
+ for(;;) { // (this is a retry loop for skipping invalid sentences)
// get index of the current sentence
size_t curId = pos_; // note: at end, pos_ == total size
// if corpus has been shuffled, ids_ contains sentence indexes
@@ -43,7 +71,7 @@ SentenceTuple Corpus::next() {
}
}
else {
- bool gotLine = io::getline(*files_[i], line);
+ bool gotLine = io::getline(*files_[i], line).good();
if(!gotLine) {
eofsHit++;
continue;
@@ -55,6 +83,7 @@ SentenceTuple Corpus::next() {
} else if(i > 0 && i == weightFileIdx_) {
addWeightsToSentenceTuple(line, tup);
} else {
+ preprocessLine(line, i);
addWordsToSentenceTuple(line, i, tup);
}
}
@@ -85,16 +114,26 @@ void Corpus::shuffle() {
// Call either reset() or shuffle().
// @TODO: make shuffle() private, instad pass a shuffle() flag to reset(), to clarify mutual exclusiveness with shuffle()
void Corpus::reset() {
- files_.clear();
corpusInRAM_.clear();
ids_.clear();
+ if (pos_ == 0) // no data read yet
+ return;
pos_ = 0;
- for(auto& path : paths_) {
- if(path == "stdin")
- files_.emplace_back(new io::InputFileStream(std::cin));
- else
- files_.emplace_back(new io::InputFileStream(path));
- }
+ for (size_t i = 0; i < paths_.size(); ++i) {
+ if(paths_[i] == "stdin") {
+ files_[i].reset(new std::istream(std::cin.rdbuf()));
+ // Probably not necessary, unless there are some buffers
+ // that we want flushed.
+ }
+ else {
+ ABORT_IF(files_[i] && filesystem::is_fifo(paths_[i]),
+ "File '", paths_[i], "' is a pipe and cannot be re-opened.");
+ // Do NOT reset named pipes; that closes them and triggers a SIGPIPE
+ // (lost pipe) at the writing end, which may do whatever it wants
+ // in this situation.
+ files_[i].reset(new io::InputFileStream(paths_[i]));
+ }
+ }
}
void Corpus::restore(Ptr<TrainingState> ts) {
@@ -102,7 +141,7 @@ void Corpus::restore(Ptr<TrainingState> ts) {
}
void Corpus::shuffleData(const std::vector<std::string>& paths) {
- LOG(info, "[data] Shuffling files");
+ LOG(info, "[data] Shuffling data");
size_t numStreams = paths.size();
@@ -115,8 +154,9 @@ void Corpus::shuffleData(const std::vector<std::string>& paths) {
else {
files_.resize(numStreams);
for(size_t i = 0; i < numStreams; ++i) {
- files_[i].reset(new io::InputFileStream(paths[i]));
- files_[i]->setbufsize(10000000); // huge read-ahead buffer to avoid network round-trips
+ UPtr<io::InputFileStream> strm(new io::InputFileStream(paths[i]));
+ strm->setbufsize(10000000); // huge read-ahead buffer to avoid network round-trips
+ files_[i] = std::move(strm);
}
// read entire corpus into RAM
@@ -124,7 +164,7 @@ void Corpus::shuffleData(const std::vector<std::string>& paths) {
for (;;) {
size_t eofsHit = 0;
for(size_t i = 0; i < numStreams; ++i) {
- bool gotLine = io::getline(*files_[i], lineBuf);
+ bool gotLine = io::getline(*files_[i], lineBuf).good();
if (gotLine)
corpus[i].push_back(lineBuf);
else
@@ -154,7 +194,7 @@ void Corpus::shuffleData(const std::vector<std::string>& paths) {
tempFiles_.resize(numStreams);
for(size_t i = 0; i < numStreams; ++i) {
tempFiles_[i].reset(new io::TemporaryFile(options_->get<std::string>("tempdir")));
- io::OutputFileStream out(*tempFiles_[i]);
+ io::TemporaryFile &out = *tempFiles_[i];
const auto& corpusStream = corpus[i];
for(auto id : ids_) {
out << corpusStream[id] << std::endl;
@@ -164,12 +204,61 @@ void Corpus::shuffleData(const std::vector<std::string>& paths) {
// replace files_[] by the tempfiles we just created
files_.resize(numStreams);
for(size_t i = 0; i < numStreams; ++i) {
- files_[i].reset(new io::InputFileStream(*tempFiles_[i]));
- files_[i]->setbufsize(10000000);
+ auto inputStream = tempFiles_[i]->getInputStream();
+ inputStream->setbufsize(10000000);
+ files_[i] = std::move(inputStream);
}
LOG(info, "[data] Done shuffling {} sentences to temp files", numSentences);
}
pos_ = 0;
}
+
+CorpusBase::batch_ptr Corpus::toBatch(const std::vector<Sample>& batchVector) {
+ size_t batchSize = batchVector.size();
+
+ std::vector<size_t> sentenceIds;
+
+ std::vector<int> maxDims; // @TODO: What's this? widths? maxLengths?
+ for(auto& ex : batchVector) { // @TODO: rename 'ex' to 'sample' or 'sentenceTuple'
+ if(maxDims.size() < ex.size())
+ maxDims.resize(ex.size(), 0);
+ for(size_t i = 0; i < ex.size(); ++i) {
+ if(ex[i].size() > (size_t)maxDims[i])
+ maxDims[i] = (int)ex[i].size();
+ }
+ sentenceIds.push_back(ex.getId());
+ }
+
+ std::vector<Ptr<SubBatch>> subBatches;
+ for(size_t j = 0; j < maxDims.size(); ++j) {
+ subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
+ }
+
+ std::vector<size_t> words(maxDims.size(), 0);
+ for(size_t b = 0; b < batchSize; ++b) { // loop over batch entries
+ for(size_t j = 0; j < maxDims.size(); ++j) { // loop over streams
+ auto subBatch = subBatches[j];
+ for(size_t s = 0; s < batchVector[b][j].size(); ++s) { // loop over word positions
+ subBatch->data()[subBatch->locate(/*batchIdx=*/b, /*wordPos=*/s)/*s * batchSize + b*/] = batchVector[b][j][s];
+ subBatch->mask()[subBatch->locate(/*batchIdx=*/b, /*wordPos=*/s)/*s * batchSize + b*/] = 1.f;
+ words[j]++;
+ }
+ }
+ }
+
+ for(size_t j = 0; j < maxDims.size(); ++j)
+ subBatches[j]->setWords(words[j]);
+
+ auto batch = batch_ptr(new batch_type(subBatches));
+ batch->setSentenceIds(sentenceIds);
+
+ if(options_->get("guided-alignment", std::string("none")) != "none" && alignFileIdx_)
+ addAlignmentsToBatch(batch, batchVector);
+ if(options_->hasAndNotEmpty("data-weighting") && weightFileIdx_)
+ addWeightsToBatch(batch, batchVector);
+
+ return batch;
+}
+
} // namespace data
} // namespace marian
diff --git a/src/data/corpus.h b/src/data/corpus.h
index 119f3aab..4ac43724 100755..100644
--- a/src/data/corpus.h
+++ b/src/data/corpus.h
@@ -27,6 +27,11 @@ private:
void shuffleData(const std::vector<std::string>& paths);
+ // for pre-processing
+ size_t allCapsEvery_{0}; // if set, convert every N-th input sentence (after randomization) to all-caps (source and target)
+ size_t titleCaseEvery_{0}; // ditto for title case (source only)
+ void preprocessLine(std::string& line, size_t streamId);
+
public:
// @TODO: check if translate can be replaced by an option in options
Corpus(Ptr<Options> options, bool translate = false);
@@ -58,51 +63,7 @@ public:
std::vector<Ptr<Vocab>>& getVocabs() override { return vocabs_; }
- batch_ptr toBatch(const std::vector<Sample>& batchVector) override {
- size_t batchSize = batchVector.size();
-
- std::vector<size_t> sentenceIds;
-
- std::vector<int> maxDims;
- for(auto& ex : batchVector) {
- if(maxDims.size() < ex.size())
- maxDims.resize(ex.size(), 0);
- for(size_t i = 0; i < ex.size(); ++i) {
- if(ex[i].size() > (size_t)maxDims[i])
- maxDims[i] = (int)ex[i].size();
- }
- sentenceIds.push_back(ex.getId());
- }
-
- std::vector<Ptr<SubBatch>> subBatches;
- for(size_t j = 0; j < maxDims.size(); ++j) {
- subBatches.emplace_back(New<SubBatch>(batchSize, maxDims[j], vocabs_[j]));
- }
-
- std::vector<size_t> words(maxDims.size(), 0);
- for(size_t i = 0; i < batchSize; ++i) {
- for(size_t j = 0; j < maxDims.size(); ++j) {
- for(size_t k = 0; k < batchVector[i][j].size(); ++k) {
- subBatches[j]->data()[k * batchSize + i] = batchVector[i][j][k];
- subBatches[j]->mask()[k * batchSize + i] = 1.f;
- words[j]++;
- }
- }
- }
-
- for(size_t j = 0; j < maxDims.size(); ++j)
- subBatches[j]->setWords(words[j]);
-
- auto batch = batch_ptr(new batch_type(subBatches));
- batch->setSentenceIds(sentenceIds);
-
- if(options_->get("guided-alignment", std::string("none")) != "none" && alignFileIdx_)
- addAlignmentsToBatch(batch, batchVector);
- if(options_->has("data-weighting") && weightFileIdx_)
- addWeightsToBatch(batch, batchVector);
-
- return batch;
- }
+ batch_ptr toBatch(const std::vector<Sample>& batchVector) override;
};
} // namespace data
} // namespace marian
diff --git a/src/data/corpus_base.cpp b/src/data/corpus_base.cpp
index c9df1c42..bb5fe735 100755
--- a/src/data/corpus_base.cpp
+++ b/src/data/corpus_base.cpp
@@ -1,6 +1,7 @@
#include <random>
#include "data/corpus.h"
+#include "data/factored_vocab.h"
namespace marian {
namespace data {
@@ -40,9 +41,12 @@ CorpusBase::CorpusBase(const std::vector<std::string>& paths,
"Number of corpus files and vocab files does not agree");
for(auto path : paths_) {
- files_.emplace_back(new io::InputFileStream(path));
- ABORT_IF(files_.back()->empty(), "File '{}' is empty", path);
+ UPtr<io::InputFileStream> strm(new io::InputFileStream(path));
+ ABORT_IF(strm->empty(), "File '{}' is empty", path);
+ files_.emplace_back(std::move(strm));
}
+
+ initEOS(/*training=*/true);
}
CorpusBase::CorpusBase(Ptr<Options> options, bool translate)
@@ -57,6 +61,8 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate)
else
paths_ = options_->get<std::vector<std::string>>("input");
+ initEOS(training);
+
std::vector<std::string> vocabPaths;
if(!options_->get<std::vector<std::string>>("vocabs").empty())
vocabPaths = options_->get<std::vector<std::string>>("vocabs");
@@ -78,22 +84,24 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate)
LOG(info, "No vocabulary files given, trying to find or build based on training data. "
"Vocabularies will be built separately for each file.");
+ std::vector<int> vocabDims(paths_.size(), 0);
+ std::vector<std::string> vocabPaths1(paths_.size());
// Create vocabs if not provided
for(size_t i = 0; i < paths_.size(); ++i) {
Ptr<Vocab> vocab = New<Vocab>(options_, i);
std::vector<std::string> trainPaths = { paths_[i] };
- size_t vocSize = vocab->loadOrCreate("", trainPaths, maxVocabs[i]);
- // TODO: this is not nice as it modifies the option object and needs to expose the changes
- // outside the corpus as models need to know about the vocabulary size; extract the vocab
- // creation functionality from the class.
- options_->getYaml()["dim-vocabs"][i] = vocSize;
-
- options_->getYaml()["vocabs"].push_back(paths_[i] + ".yml");
+ vocabDims[i] = (int) vocab->loadOrCreate("", trainPaths, maxVocabs[i]);
+ vocabPaths1[i] = paths_[i] + ".yml";
vocabs_.emplace_back(vocab);
}
+ // TODO: this is not nice as it modifies the option object and needs to expose the changes
+ // outside the corpus as models need to know about the vocabulary size; extract the vocab
+ // creation functionality from the class.
+ options_->set("dim-vocabs", vocabDims, "vocabs", vocabPaths1);
} else {
// Load all vocabs
- if(maxVocabs.size() < vocabPaths.size())
+ size_t numVocs = vocabPaths.size();
+ if(maxVocabs.size() < numVocs)
maxVocabs.resize(paths_.size(), 0);
// Helper object to for grouping training data based on vocabulary file name
@@ -101,33 +109,33 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate)
std::set<std::string> paths; // contains all paths that are used for training the vocabulary
size_t size; // contains the maximum vocabulary size
};
-
+
// Group training files based on vocabulary path. If the same
// vocab path corresponds to different training files, this means
// that a single vocab should combine tokens from all files.
std::map<std::string, PathsAndSize> groupVocab;
- for(size_t i = 0; i < vocabPaths.size(); ++i) {
+ for(size_t i = 0; i < numVocs; ++i) {
groupVocab[vocabPaths[i]].paths.insert(paths_[i]);
if(groupVocab[vocabPaths[i]].size < maxVocabs[i])
groupVocab[vocabPaths[i]].size = maxVocabs[i];
}
- for(size_t i = 0; i < vocabPaths.size(); ++i) {
+ auto vocabDims = options_->get<std::vector<int>>("dim-vocabs");
+ vocabDims.resize(numVocs, 0);
+ for(size_t i = 0; i < numVocs; ++i) {
Ptr<Vocab> vocab = New<Vocab>(options_, i);
// Get the set of files that corresponds to the vocab. If the next file is the same vocab,
// it wild not be created again, but just correctly loaded.
auto pathsAndSize = groupVocab[vocabPaths[i]];
std::vector<std::string> groupedPaths(pathsAndSize.paths.begin(), pathsAndSize.paths.end());
- size_t vocSize = vocab->loadOrCreate(vocabPaths[i], groupedPaths, pathsAndSize.size);
-
- // TODO: this is not nice as it modifies the option object and needs to expose the changes
- // outside the corpus as models need to know about the vocabulary size; extract the vocab
- // creation functionality from the class.
- options_->getYaml()["dim-vocabs"][i] = vocSize;
-
+ vocabDims[i] = (int) vocab->loadOrCreate(vocabPaths[i], groupedPaths, pathsAndSize.size);
vocabs_.emplace_back(vocab);
}
+ // TODO: this is not nice as it modifies the option object and needs to expose the changes
+ // outside the corpus as models need to know about the vocabulary size; extract the vocab
+ // creation functionality from the class.
+ options_->set("dim-vocabs", vocabDims);
}
}
@@ -135,24 +143,30 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate)
ABORT_IF(vocabPaths.empty(),
"Translating, but vocabularies are not given!");
- if(maxVocabs.size() < vocabPaths.size())
+ size_t numVocs = vocabPaths.size();
+ if(maxVocabs.size() < numVocs)
maxVocabs.resize(paths_.size(), 0);
- for(size_t i = 0; i + 1 < vocabPaths.size(); ++i) {
+ auto vocabDims = options_->get<std::vector<int>>("dim-vocabs");
+ vocabDims.resize(numVocs, 0);
+ for(size_t i = 0; i + 1 < numVocs; ++i) {
Ptr<Vocab> vocab = New<Vocab>(options_, i);
- size_t vocSize = vocab->load(vocabPaths[i], maxVocabs[i]);
- options_->getYaml()["dim-vocabs"][i] = vocSize;
-
+ vocabDims[i] = (int) vocab->load(vocabPaths[i], maxVocabs[i]);
vocabs_.emplace_back(vocab);
}
+ // TODO: As above, this is not nice as it modifies the option object and needs to expose the changes
+ // outside the corpus as models need to know about the vocabulary size; extract the vocab
+ // creation functionality from the class.
+ options_->set("dim-vocabs", vocabDims);
}
for(auto path : paths_) {
if(path == "stdin")
- files_.emplace_back(new io::InputFileStream(std::cin));
+ files_.emplace_back(new std::istream(std::cin.rdbuf()));
else {
- files_.emplace_back(new io::InputFileStream(path));
- ABORT_IF(files_.back()->empty(), "File '{}' is empty", path);
+ io::InputFileStream *strm = new io::InputFileStream(path);
+ ABORT_IF(strm->empty(), "File '{}' is empty", path);
+ files_.emplace_back(strm);
}
}
@@ -170,11 +184,12 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate)
alignFileIdx_ = paths_.size();
paths_.emplace_back(path);
- files_.emplace_back(new io::InputFileStream(path));
- ABORT_IF(files_.back()->empty(), "File with alignments '{}' is empty", path);
+ io::InputFileStream* strm = new io::InputFileStream(path);
+ ABORT_IF(strm->empty(), "File with alignments '{}' is empty", path);
+ files_.emplace_back(strm);
}
- if(training && options_->has("data-weighting")) {
+ if(training && options_->hasAndNotEmpty("data-weighting")) {
auto path = options_->get<std::string>("data-weighting");
ABORT_IF(!filesystem::exists(path), "Weight file does not exist");
@@ -182,26 +197,27 @@ CorpusBase::CorpusBase(Ptr<Options> options, bool translate)
weightFileIdx_ = paths_.size();
paths_.emplace_back(path);
- files_.emplace_back(new io::InputFileStream(path));
- ABORT_IF(files_.back()->empty(), "File with weights '{}' is empty", path);
+ io::InputFileStream* strm = new io::InputFileStream(path);
+ ABORT_IF(strm->empty(), "File with weights '{}' is empty", path);
+ files_.emplace_back(strm);
}
}
void CorpusBase::addWordsToSentenceTuple(const std::string& line,
- size_t i,
+ size_t batchIndex,
SentenceTuple& tup) const {
// This turns a string in to a sequence of numerical word ids. Depending
// on the vocabulary type, this can be non-trivial, e.g. when SentencePiece
// is used.
- Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, inference_);
+ Words words = vocabs_[batchIndex]->encode(line, /*addEOS =*/ addEOS_[batchIndex], inference_);
- if(words.empty())
- words.push_back(0);
+ ABORT_IF(words.empty(), "Empty input sequences are presently untested");
if(maxLengthCrop_ && words.size() > maxLength_) {
words.resize(maxLength_);
- words.back() = 0;
+ if(addEOS_[batchIndex])
+ words.back() = vocabs_[batchIndex]->getEosId();
}
if(rightLeft_)
@@ -225,10 +241,10 @@ void CorpusBase::addWeightsToSentenceTuple(const std::string& line, SentenceTupl
if(!elements.empty()) {
std::vector<float> weights;
- for(auto& e : elements) {
- if(maxLengthCrop_ && weights.size() > maxLength_)
+ for(auto& e : elements) { // Iterate weights as strings
+ if(maxLengthCrop_ && weights.size() >= maxLength_) // Cut if the input is going to be cut
break;
- weights.emplace_back(std::stof(e));
+ weights.emplace_back(std::stof(e)); // Add a weight converted into float
}
if(rightLeft_)
@@ -279,5 +295,90 @@ void CorpusBase::addWeightsToBatch(Ptr<CorpusBatch> batch,
batch->setDataWeights(weights);
}
+
+void CorpusBase::initEOS(bool training = true) {
+ // Labels fed into sub-batches that are just class-labels, not sequence labels do not require to
+ // add a EOS symbol. Hence decision to add EOS is now based on input stream positions and correspoding
+ // input type.
+
+ addEOS_.resize(paths_.size(), true);
+ // @TODO: think if this should be checked and processed here or in a validation step in config?
+ auto inputTypes = options_->get<std::vector<std::string>>("input-types", {}); // empty list by default
+
+ // make sure there is an input type for each path
+ ABORT_IF(inputTypes.size() > 0 && inputTypes.size() < paths_.size(),
+ "Input types have been specified ({}), you need to specify one per input ({})",
+ inputTypes.size(),
+ paths_.size());
+
+ // make sure there is an equal number of input types and paths when training
+ ABORT_IF(training && inputTypes.size() > 0 && inputTypes.size() != paths_.size(),
+ "Input types have been specified ({}), you need to specify one per input ({})",
+ inputTypes.size(),
+ paths_.size());
+
+ for(int i = 0; i < paths_.size(); ++i)
+ if(inputTypes.size() > i) {
+ if(inputTypes[i] == "class")
+ addEOS_[i] = false;
+ else if(inputTypes[i] == "sequence")
+ addEOS_[i] = true;
+ else
+ ABORT("Unknown input type {}: {}", i, inputTypes[i]);
+ } else {
+ // No input type specified, assuming "sequence"
+ addEOS_[i] = true;
+ }
+}
+
+// experimental: hide inline-fix source tokens from cross attention
+std::vector<float> SubBatch::crossMaskWithInlineFixSourceSuppressed() const
+{
+ const auto& srcVocab = *vocab();
+
+ auto factoredVocab = vocab()->tryAs<FactoredVocab>();
+ size_t inlineFixGroupIndex = 0, inlineFixSrc = 0;
+ auto hasInlineFixFactors = factoredVocab && factoredVocab->tryGetFactor(FactoredVocab_INLINE_FIX_WHAT_serialized, /*out*/ inlineFixGroupIndex, /*out*/ inlineFixSrc);
+
+ auto fixSrcId = srcVocab[FactoredVocab_FIX_SRC_ID_TAG];
+ auto fixTgtId = srcVocab[FactoredVocab_FIX_TGT_ID_TAG];
+ auto fixEndId = srcVocab[FactoredVocab_FIX_END_ID_TAG];
+ auto unkId = srcVocab.getUnkId();
+ auto hasInlineFixTags = fixSrcId != unkId && fixTgtId != unkId && fixEndId != unkId;
+
+ auto m = mask(); // default return value, which we will modify in-place below in case we need to
+ if (hasInlineFixFactors || hasInlineFixTags) {
+ LOG_ONCE(info, "[data] Suppressing cross-attention into inline-fix source tokens");
+
+ // example: force French translation of name "frank" to always be "franck"
+ // - hasInlineFixFactors: "frank|is franck|it", "frank|is" cannot be cross-attended to
+ // - hasInlineFixTags: "<IOPEN> frank <IDELIM> franck <ICLOSE>", "frank" and all tags cannot be cross-attended to
+ auto dimBatch = batchSize(); // number of sentences in the batch
+ auto dimWidth = batchWidth(); // number of words in the longest sentence in the batch
+ const auto& d = data();
+ size_t numWords = 0;
+ for (size_t b = 0; b < dimBatch; b++) { // loop over batch entries
+ bool inside = false;
+ for (size_t s = 0; s < dimWidth; s++) { // loop over source positions
+ auto i = locate(/*batchIdx=*/b, /*wordPos=*/s);
+ if (!m[i])
+ break;
+ numWords++;
+ // keep track of entering/exiting the inline-fix source tags
+ auto w = d[i];
+ if (w == fixSrcId)
+ inside = true;
+ else if (w == fixTgtId)
+ inside = false;
+ bool wHasSrcIdFactor = hasInlineFixFactors && factoredVocab->getFactor(w, inlineFixGroupIndex) == inlineFixSrc;
+ if (inside || w == fixSrcId || w == fixTgtId || w == fixEndId || wHasSrcIdFactor)
+ m[i] = 0.0f; // decoder must not look at embedded source, nor the markup tokens
+ }
+ }
+ ABORT_IF(batchWords() != 0/*n/a*/ && numWords != batchWords(), "batchWords() inconsistency??");
+ }
+ return m;
+}
+
} // namespace data
} // namespace marian
diff --git a/src/data/corpus_base.h b/src/data/corpus_base.h
index 8ecdf233..e85e378a 100755
--- a/src/data/corpus_base.h
+++ b/src/data/corpus_base.h
@@ -47,7 +47,7 @@ public:
/**
* @brief Adds a new sentence at the end of the tuple.
*
- * @param words A vector of word indexes.
+ * @param words A vector of word indices.
*/
void push_back(const Words& words) { tuple_.push_back(words); }
@@ -110,14 +110,14 @@ public:
*/
class SubBatch {
private:
- std::vector<Word> indices_;
+ Words indices_;
std::vector<float> mask_;
size_t size_;
size_t width_;
size_t words_;
- Ptr<Vocab> vocab_;
+ Ptr<const Vocab> vocab_;
// ... TODO: add the length information (remember it)
public:
@@ -127,8 +127,8 @@ public:
* @param size Number of sentences
* @param width Number of words in the longest sentence
*/
- SubBatch(size_t size, size_t width, const Ptr<Vocab>& vocab)
- : indices_(size * width, 0),
+ SubBatch(size_t size, size_t width, const Ptr<const Vocab>& vocab)
+ : indices_(size * width, vocab ? vocab->getEosId() : Word::ZERO), // note: for gaps, we must use a valid index
mask_(size * width, 0),
size_(size),
width_(width),
@@ -142,72 +142,80 @@ public:
* idx_{w,0},idx_{w,1},\dots,idx_{w,s}\f$, where \f$w\f$ is the number of
* words (width) and \f$s\f$ is the number of sentences (size).
*/
- std::vector<Word>& data() { return indices_; }
+ Words& data() { return indices_; }
+ const Words& data() const { return indices_; }
+ /**
+ * @brief compute flat index into data() and mask() vectors for given batch index and word index in sentence
+ */
+ size_t locate(size_t batchIdx, size_t wordPos) const { return locate(batchIdx, wordPos, size_); }
+ static size_t locate(size_t batchIdx, size_t wordPos, size_t batchSize) { return wordPos * batchSize + batchIdx; }
/**
* @brief Flat masking vector; 0 is used for masked words.
*
* @see data()
*/
std::vector<float>& mask() { return mask_; }
+ const std::vector<float>& mask() const { return mask_; }
/**
* @brief Accessors to the vocab_ field.
*/
- const Ptr<Vocab>& vocab() const { return vocab_; }
+ const Ptr<const Vocab>& vocab() const { return vocab_; }
/**
* @brief The number of sentences in the batch.
*/
- size_t batchSize() { return size_; }
+ size_t batchSize() const { return size_; }
/**
* @brief The number of words in the longest sentence in the batch.
*/
- size_t batchWidth() { return width_; };
+ size_t batchWidth() const { return width_; };
/**
- * @brief The total number of words in the batch, considering the mask.
+ * @brief The total number of words in the batch (not counting masked-out words).
*/
- size_t batchWords() { return words_; }
+ size_t batchWords() const { return words_; }
/**
- * @brief Splits the subbatch into subbatches of equal size.
+ * @brief Splits the stream into sub-batches of equal size (except for last).
+ *
+ * @param n number of sub-batches to split into
*
- * @param n Number of splits
+ * @param sizeLimit Pretend the batch only has this many sentences. Used for MB-size ramp-up.
*
- * @return Vector of pointers to new subbatches.
+ * @return Vector of pointers to new sub-batches (or nullptrs where run out of sub-batches)
*
* @see marian::data::Batch::split(size_t n)
*/
- std::vector<Ptr<SubBatch>> split(size_t n) {
- ABORT_IF(size_ == 0, "Encoutered sub-batch size of 0");
+ std::vector<Ptr<SubBatch>> split(size_t n, size_t sizeLimit /*or SIZE_MAX*/) const {
+ ABORT_IF(size_ == 0, "Encountered sub-batch size of 0");
- size_t subSize = (size_t)(std::ceil(size_ / (float)n));
+ auto size = std::min(size_, sizeLimit); // if limit is given then pretend the batch only has that many sentences
+ size_t targetSubSize = (size_t)(std::ceil(size / (float)n)); // aim at forming sub-batches of this #sentences
std::vector<Ptr<SubBatch>> splits;
- for(size_t pos = 0; pos < size_; pos += subSize) {
- size_t size = std::min(subSize, size_ - pos);
+ for(size_t pos = 0; pos < size; pos += targetSubSize) { // loop over ranges of size targetSubSize to form sub-batches of this size
+ size_t subSize = std::min(targetSubSize, size - pos); // actual number of sentences can be smaller at the end
- // determine actual width
+ // determine actual width (=max length) of this sub-batch, which may be smaller than the overall max length
size_t subWidth = 0;
- for(size_t j = 0; j < width_; ++j) {
- for(size_t i = 0; i < size; ++i) {
- if(mask_[j * size_ + (pos + i)] != 0)
- if (subWidth < j + 1)
- subWidth = j + 1;
+ for(size_t s = 0; s < width_; ++s) {
+ for(size_t b = 0; b < subSize; ++b) {
+ if(mask_[locate(/*batchIdx=*/pos + b, /*wordPos=*/s)] != 0) // s * size_ + (pos + b)
+ if (subWidth < s + 1)
+ subWidth = s + 1;
}
}
- //if (subWidth < width_)
- // LOG(info, "[data] sub-batch {} of {} wide batch has effective width of {}", pos / subSize, width_, subWidth);
// create sub-batch
- auto sb = New<SubBatch>(size, subWidth, vocab_);
+ auto sb = New<SubBatch>(subSize, subWidth, vocab_);
size_t words = 0;
- for(size_t j = 0; j < subWidth; ++j) {
- for(size_t i = 0; i < size; ++i) {
- sb->data()[j * size + i] = indices_[j * size_ + (pos + i)];
- sb->mask()[j * size + i] = mask_[j * size_ + (pos + i)];
+ for(size_t s = 0; s < subWidth; ++s) {
+ for(size_t b = 0; b < subSize; ++b) {
+ sb->data()[locate(/*batchIdx=*/b, /*wordPos=*/s, /*batchSize=*/subSize)/*s * subSize + b*/] = indices_[locate(/*batchIdx=*/pos + b, /*wordPos=*/s)]; // s * size_ + (pos + b)
+ sb->mask()[locate(/*batchIdx=*/b, /*wordPos=*/s, /*batchSize=*/subSize)/*s * subSize + b*/] = mask_[locate(/*batchIdx=*/pos + b, /*wordPos=*/s)]; // s * size_ + (pos + b)
- if(mask_[j * size_ + (pos + i)] != 0)
+ if(mask_[locate(/*batchIdx=*/pos + b, /*wordPos=*/s)/*s * size_ + (pos + b)*/] != 0)
words++;
}
}
@@ -219,6 +227,9 @@ public:
}
void setWords(size_t words) { words_ = words; }
+
+ // experimental: hide inline-fix source tokens from cross attention
+ std::vector<float> crossMaskWithInlineFixSourceSuppressed() const;
};
/**
@@ -226,9 +237,9 @@ public:
* such as guided alignments and sentence or word-leve weighting.
*/
class CorpusBatch : public Batch {
-private:
+protected:
std::vector<Ptr<SubBatch>> subBatches_;
- std::vector<float> guidedAlignment_;
+ std::vector<float> guidedAlignment_; // [max source len, batch size, max target len] flattened
std::vector<float> dataWeights_;
public:
@@ -263,8 +274,8 @@ public:
size_t size() const override { return subBatches_[0]->batchSize(); }
/**
- * @brief The total number of words for the longest sentence in the batch plus
- * one. Pass which=0 for source and -1 for target.
+ * @brief The total number of words in the batch (not counting masked-out words).
+ * Pass which=0 for source words and -1 for target words.
*/
size_t words(int which = 0) const override {
return subBatches_[which >= 0 ? which
@@ -283,7 +294,7 @@ public:
size_t sizeTrg() const override { return subBatches_.back()->batchSize(); }
/**
- * @brief The number of words for the longest sentence in the batch plus one.
+ * @brief The total number of words in the batch (not counting masked-out words).
*/
size_t wordsTrg() const override { return subBatches_.back()->batchWords(); };
@@ -299,7 +310,8 @@ public:
/**
* @brief Creates a batch filled with fake data. Used to determine the size of
- * the batch object.
+ * the batch object. With guided-alignments and multiple encoders, those
+ * multiple source streams are expected to have the same lengths.
*
* @param lengths List of subbatch sizes.
* @param batchSize Number of sentences in the batch.
@@ -307,21 +319,21 @@ public:
*
* @return Fake batch of the same size as the real batch.
*/
- static Ptr<CorpusBatch> fakeBatch(std::vector<size_t>& lengths,
+ static Ptr<CorpusBatch> fakeBatch(const std::vector<size_t>& lengths,
+ const std::vector<Ptr<Vocab>>& vocabs,
size_t batchSize,
Ptr<Options> options) {
std::vector<Ptr<SubBatch>> batches;
- size_t idx = 0;
+ size_t batchIndex = 0;
for(auto len : lengths) {
- auto vocab = New<Vocab>(options, 0);
- vocab->createFake();
- // data: gets initialized to 0. No EOS symbol is distinguished.
- auto sb = New<SubBatch>(batchSize, len, vocab);
- // set word indices to different values to avoid same hashes
- std::fill(sb->data().begin(), sb->data().end(), (unsigned int)idx++);
+ auto sb = New<SubBatch>(batchSize, len, vocabs[batchIndex]);
+ // set word indices to random values (not actually needed with current version --@marcinjd: please confirm)
+ std::transform(sb->data().begin(), sb->data().end(), sb->data().begin(),
+ [&](Word) -> Word { return vocabs[batchIndex]->randWord(); });
// mask: no items ask being masked out
std::fill(sb->mask().begin(), sb->mask().end(), 1.f);
+ batchIndex++;
batches.push_back(sb);
}
@@ -332,12 +344,13 @@ public:
return batch;
if(options->get("guided-alignment", std::string("none")) != "none") {
+ // @TODO: if > 1 encoder, verify that all encoders have the same sentence lengths
std::vector<float> alignment(batchSize * lengths.front() * lengths.back(),
0.f);
batch->setGuidedAlignment(std::move(alignment));
}
- if(options->has("data-weighting")) {
+ if(options->hasAndNotEmpty("data-weighting")) {
auto weightsSize = batchSize;
if(options->get<std::string>("data-weighting-type") != "sentence")
weightsSize *= lengths.back();
@@ -349,25 +362,27 @@ public:
}
/**
- * @brief Splits the batch into batches of equal size.
+ * @brief Splits the batch into batches of equal size (except for last).
+ *
+ * @param n number of sub-batches to split into
*
- * @param n number of splits
+ * @param sizeLimit Clip batch content to the first sizeLimit sentences in the batch
*
- * @return Vector of pointers to new batches.
+ * @return Vector of pointers to new sub-batches (or nullptrs where run out of sub-batches)
*
* @see marian::data::SubBatch::split(size_t n)
*/
- std::vector<Ptr<Batch>> split(size_t n) override {
+ std::vector<Ptr<Batch>> split(size_t n, size_t sizeLimit /*=SIZE_MAX*/) override {
ABORT_IF(size() == 0, "Encoutered batch size of 0");
- std::vector<std::vector<Ptr<SubBatch>>> subs;
- // split each subbatch separately
- for(auto subBatch : subBatches_) {
- size_t i = 0;
- for(auto splitSubBatch : subBatch->split(n)) {
+ std::vector<std::vector<Ptr<SubBatch>>> subs; // [subBatchIndex][streamIndex]
+ // split each stream separately
+ for(auto batchStream : subBatches_) {
+ size_t i = 0; // index into split batch
+ for(auto splitSubBatch : batchStream->split(n, sizeLimit)) { // splits a batch into pieces, can also change width
if(subs.size() <= i)
subs.resize(i + 1);
- subs[i++].push_back(splitSubBatch);
+ subs[i++].push_back(splitSubBatch); // this forms tuples across streams
}
}
@@ -403,7 +418,7 @@ public:
size_t bi = i + pos;
for(size_t sid = 0; sid < srcWords; ++sid) {
for(size_t tid = 0; tid < trgWords; ++tid) {
- size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid;
+ size_t bidx = sid * oldSize * oldTrgWords + bi * oldTrgWords + tid; // [sid, bi, tid]
size_t idx = sid * dimBatch * trgWords + i * trgWords + tid;
aligns[idx] = guidedAlignment_[bidx];
}
@@ -419,20 +434,19 @@ public:
if(!dataWeights_.empty()) {
size_t oldSize = size();
- size_t width = 1;
- // There are more weights than sentences, i.e. these are word weights.
- if(dataWeights_.size() != oldSize)
- width = subBatches_.back()->batchWidth();
-
for(auto split : splits) {
+ auto cb = std::static_pointer_cast<CorpusBatch>(split);
+ size_t width = 1; // One weight per sentence in case of sentence-level weights
+ if(dataWeights_.size() != oldSize) // if number of weights does not correspond to number of sentences we have word-level weights
+ width = cb->back()->batchWidth(); // splitting also affects width, hence we need to accomodate this here
std::vector<float> ws(width * split->size(), 1.0f);
// this needs to be split along the batch dimension
// which is here the innermost dimension.
// Should work for sentence-based weights, too.
- for(size_t j = 0; j < width; ++j) {
- for(size_t i = 0; i < split->size(); ++i) {
- ws[j * split->size() + i] = dataWeights_[j * oldSize + i + pos];
+ for(size_t s = 0; s < width; ++s) {
+ for(size_t b = 0; b < split->size(); ++b) {
+ ws[s * split->size() + b] = dataWeights_[s * oldSize + b + pos]; // @TODO: use locate() as well
}
}
split->setDataWeights(ws);
@@ -443,9 +457,13 @@ public:
return splits;
}
- std::vector<float>& getGuidedAlignment() { return guidedAlignment_; }
+ const std::vector<float>& getGuidedAlignment() const { return guidedAlignment_; } // [dimSrcWords, dimBatch, dimTrgWords] flattened
void setGuidedAlignment(std::vector<float>&& aln) override {
- guidedAlignment_ = std::move(aln);
+ guidedAlignment_ = std::move(aln);
+ }
+
+ size_t locateInGuidedAlignments(size_t b, size_t s, size_t t) {
+ return ((s * size()) + b) * widthTrg() + t;
}
std::vector<float>& getDataWeights() { return dataWeights_; }
@@ -456,29 +474,29 @@ public:
/**
* @brief Prints the batch in a readable form on stderr for debugging.
*/
- void debug() override {
+ void debug(bool printIndices = false) override { // prints word string if subbatch has vocab and
+ // printIndices == false otherwise only numeric indices
std::cerr << "batches: " << sets() << std::endl;
if(!sentenceIds_.empty()) {
- std::cerr << "indexes: ";
+ std::cerr << "indices: ";
for(auto id : sentenceIds_)
std::cerr << id << " ";
std::cerr << std::endl;
}
- size_t b = 0;
+ size_t subBatchIndex = 0;
for(auto sb : subBatches_) {
- std::cerr << "batch " << b++ << ": " << std::endl;
+ std::cerr << "stream " << subBatchIndex++ << ": " << std::endl;
const auto& vocab = sb->vocab();
- for(size_t i = 0; i < sb->batchWidth(); i++) {
+ for(size_t s = 0; s < sb->batchWidth(); s++) {
std::cerr << "\t w: ";
- for(size_t j = 0; j < sb->batchSize(); j++) {
- size_t idx = i * sb->batchSize() + j;
- Word w = sb->data()[idx];
- if (vocab)
+ for(size_t b = 0; b < sb->batchSize(); b++) {
+ Word w = sb->data()[sb->locate(/*batchIdx=*/b, /*wordPos=*/s)]; // s * sb->batchSize() + b;
+ if (vocab && !printIndices)
std::cerr << (*vocab)[w] << " ";
else
- std::cerr << w << " "; // if not loaded then print numeric id instead
+ std::cerr << w.toString() << " "; // if not loaded then print numeric id instead
}
std::cerr << std::endl;
}
@@ -507,12 +525,20 @@ public:
const std::vector<Ptr<Vocab>>& vocabs,
Ptr<Options> options);
+ virtual ~CorpusBase() {}
virtual std::vector<Ptr<Vocab>>& getVocabs() = 0;
protected:
- std::vector<UPtr<io::InputFileStream>> files_;
+ std::vector<UPtr<std::istream>> files_;
std::vector<Ptr<Vocab>> vocabs_;
+ /**
+ * brief Determines if a EOS symbol should be added. By default this is true for any sequence,
+ * but should be false for instance for classifier labels. This is set per input stream, hence a
+ * vector.
+ */
+ std::vector<bool> addEOS_;
+
size_t pos_{0};
size_t maxLength_{0};
@@ -532,11 +558,16 @@ protected:
size_t alignFileIdx_{0};
/**
+ * @brief Determine if EOS symbol should be added to input
+ */
+ void initEOS(bool training);
+
+ /**
* @brief Helper function converting a line of text into words using the i-th
* vocabulary and adding them to the sentence tuple.
*/
void addWordsToSentenceTuple(const std::string& line,
- size_t i,
+ size_t batchIndex,
SentenceTuple& tup) const;
/**
* @brief Helper function parsing a line with word alignments and adding them
diff --git a/src/data/corpus_nbest.cpp b/src/data/corpus_nbest.cpp
index 328c3c0d..d5a48d8d 100644
--- a/src/data/corpus_nbest.cpp
+++ b/src/data/corpus_nbest.cpp
@@ -15,8 +15,7 @@ CorpusNBest::CorpusNBest(std::vector<std::string> paths,
: CorpusBase(paths, vocabs, options) {}
int numFromNbest(const std::string& line) {
- std::vector<std::string> fields;
- utils::split(line, fields, " ||| ", true);
+ auto fields = utils::split(line, " ||| ", true);
ABORT_IF(fields.size() < 4,
"Too few fields ({}) in line \"{}\", is this a correct n-best list?",
fields.size(),
@@ -25,8 +24,7 @@ int numFromNbest(const std::string& line) {
}
std::string lineFromNbest(const std::string& line) {
- std::vector<std::string> fields;
- utils::split(line, fields, " ||| ", true);
+ auto fields = utils::split(line, " ||| ", true);
ABORT_IF(fields.size() < 4,
"Too few fields ({}) in line \"{}\", is this a correct n-best list?",
fields.size(),
@@ -57,7 +55,7 @@ SentenceTuple CorpusNBest::next() {
for(size_t i = 0; i < last; ++i) {
if(curr_num > lastNum_) {
- ABORT_IF(!io::getline(*files_[i], lastLines_[i]),
+ ABORT_IF(!std::getline(*files_[i], lastLines_[i]),
"Too few lines in input {}",
i);
}
@@ -88,7 +86,7 @@ void CorpusNBest::reset() {
lastNum_ = -1;
for(auto& path : paths_) {
if(path == "stdin")
- files_.emplace_back(new io::InputFileStream(std::cin));
+ files_.emplace_back(new std::istream(std::cin.rdbuf()));
else
files_.emplace_back(new io::InputFileStream(path));
}
diff --git a/src/data/corpus_nbest.h b/src/data/corpus_nbest.h
index 2c7958b1..2da55ccb 100755..100644
--- a/src/data/corpus_nbest.h
+++ b/src/data/corpus_nbest.h
@@ -18,7 +18,6 @@ namespace data {
class CorpusNBest : public CorpusBase {
private:
- std::vector<UPtr<io::TemporaryFile>> tempFiles_;
std::vector<size_t> ids_;
int lastNum_{-1};
std::vector<std::string> lastLines_;
diff --git a/src/data/corpus_sqlite.cpp b/src/data/corpus_sqlite.cpp
index cbab750e..714a70f0 100644
--- a/src/data/corpus_sqlite.cpp
+++ b/src/data/corpus_sqlite.cpp
@@ -25,8 +25,7 @@ void CorpusSQLite::fillSQLite() {
if(options_->get<std::string>("sqlite") == "temporary") {
LOG(info, "[sqlite] Creating temporary database in {}", tempDir);
- db_.reset(
- new SQLite::Database("", SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE));
+ db_.reset(new SQLite::Database("", SQLite::OPEN_READWRITE | SQLite::OPEN_CREATE));
db_->exec("PRAGMA temp_store_directory = '" + tempDir + "';");
fill = true;
diff --git a/src/data/corpus_sqlite.h b/src/data/corpus_sqlite.h
index 641da2a8..0da2a864 100755..100644
--- a/src/data/corpus_sqlite.h
+++ b/src/data/corpus_sqlite.h
@@ -104,7 +104,7 @@ public:
if(options_->has("guided-alignment") && alignFileIdx_)
addAlignmentsToBatch(batch, batchVector);
- if(options_->has("data-weighting") && weightFileIdx_)
+ if(options_->hasAndNotEmpty("data-weighting") && weightFileIdx_)
addWeightsToBatch(batch, batchVector);
return batch;
diff --git a/src/data/dataset.h b/src/data/dataset.h
index 34378954..881126a3 100755..100644
--- a/src/data/dataset.h
+++ b/src/data/dataset.h
@@ -15,12 +15,13 @@ class DatasetBase {
protected:
std::vector<std::string> paths_;
Ptr<Options> options_;
+
// Data processing may differ in training/inference settings
bool inference_{false};
public:
typedef Batch batch_type;
- typedef Ptr<Batch> batch_ptr;
+ typedef Ptr<Batch> batch_ptr; // @TODO: rename to camel case
typedef Iterator iterator;
typedef SampleType Sample;
diff --git a/src/data/default_vocab.cpp b/src/data/default_vocab.cpp
index e344eadf..590e9931 100755..100644
--- a/src/data/default_vocab.cpp
+++ b/src/data/default_vocab.cpp
@@ -15,16 +15,16 @@
namespace marian {
-class DefaultVocab : public VocabBase {
-private:
+class DefaultVocab : public IVocab {
+protected:
typedef std::map<std::string, Word> Str2Id;
Str2Id str2id_;
typedef std::vector<std::string> Id2Str;
Id2Str id2str_;
- Word eosId_ = (Word)-1;
- Word unkId_ = (Word)-1;
+ Word eosId_ = Word::NONE;
+ Word unkId_ = Word::NONE;
std::vector<std::string> suffixes_ = { ".yml", ".yaml", ".json" };
@@ -36,7 +36,7 @@ private:
VocabFreqOrderer(const std::unordered_map<std::string, size_t>& counter)
: counter_(counter) {}
- // order first by decreasing frequency,
+ // order first by decreasing frequency,
// if frequencies are the same order lexicographically by vocabulary string
bool operator()(const std::string& a, const std::string& b) const {
return counter_.at(a) > counter_.at(b) || (counter_.at(a) == counter_.at(b) && a < b);
@@ -44,6 +44,8 @@ private:
};
public:
+ // @TODO: choose between 'virtual' and 'final'. Can we derive from this class?
+ virtual ~DefaultVocab() {};
virtual const std::string& canonicalExtension() const override { return suffixes_[0]; }
virtual const std::vector<std::string>& suffixes() const override { return suffixes_; }
@@ -56,25 +58,28 @@ public:
}
Words encode(const std::string& line, bool addEOS, bool /*inference*/) const override {
- std::vector<std::string> lineTokens;
- utils::split(line, lineTokens, " ");
+ auto lineTokens = utils::split(line, " ");
return (*this)(lineTokens, addEOS);
}
std::string decode(const Words& sentence, bool ignoreEOS) const override {
- std::string line;
auto tokens = (*this)(sentence, ignoreEOS);
return utils::join(tokens, " ");
}
+ std::string surfaceForm(const Words& sentence) const override {
+ sentence;
+ ABORT("surfaceForm() not supported by this vocabulary type");
+ }
+
virtual std::string type() const override { return "DefaultVocab"; }
virtual Word getEosId() const override { return eosId_; }
virtual Word getUnkId() const override { return unkId_; }
-
- const std::string& operator[](Word id) const override {
- ABORT_IF(id >= id2str_.size(), "Unknown word id: ", id);
+ const std::string& operator[](Word word) const override {
+ auto id = word.toWordIndex();
+ ABORT_IF(id >= id2str_.size(), "Unknown word id: {}", id);
return id2str_[id];
}
@@ -89,15 +94,16 @@ public:
isJson ? "JSON/Yaml" : "text",
vocabPath);
ABORT_IF(!filesystem::exists(vocabPath),
- "DefaultVocabulary file {} does not exits",
+ "DefaultVocabulary file {} does not exist",
vocabPath);
std::map<std::string, Word> vocab;
// read from JSON (or Yaml) file
if(isJson) {
- YAML::Node vocabNode = YAML::Load(io::InputFileStream(vocabPath));
+ io::InputFileStream strm(vocabPath);
+ YAML::Node vocabNode = YAML::Load(strm);
for(auto&& pair : vocabNode)
- vocab.insert({pair.first.as<std::string>(), pair.second.as<Word>()});
+ vocab.insert({pair.first.as<std::string>(), Word::fromWordIndex(pair.second.as<IndexType>())});
}
// read from flat text file
else {
@@ -107,80 +113,33 @@ public:
ABORT_IF(line.empty(),
"DefaultVocabulary file {} must not contain empty lines",
vocabPath);
- vocab.insert({line, (Word)vocab.size()});
+ auto wasInserted = vocab.insert({line, Word::fromWordIndex(vocab.size())}).second;
+ ABORT_IF(!wasInserted, "Duplicate vocabulary entry {}", line);
}
ABORT_IF(in.bad(), "DefaultVocabulary file {} could not be read", vocabPath);
}
- std::unordered_set<Word> seenSpecial;
-
id2str_.reserve(vocab.size());
for(auto&& pair : vocab) {
auto str = pair.first;
auto id = pair.second;
// note: this requires ids to be sorted by frequency
- if(!maxSize || id < (Word)maxSize) {
+ if(!maxSize || id.toWordIndex() < maxSize) {
insertWord(id, str);
}
}
ABORT_IF(id2str_.empty(), "Empty vocabulary: ", vocabPath);
- // look up ids for </s> and <unk>, which are required
- // The name backCompatStr is alternatively accepted for Yaml vocabs if id
- // equals backCompatId.
- auto getRequiredWordId = [&](const std::string& str,
- const std::string& backCompatStr,
- Word backCompatId) {
- // back compat with Nematus Yaml dicts
- if(isJson) {
- // if word id 0 or 1 is either empty or has the Nematus-convention string,
- // then use it
- if(backCompatId < id2str_.size()
- && (id2str_[backCompatId].empty()
- || id2str_[backCompatId] == backCompatStr)) {
- LOG(info,
- "[data] Using unused word id {} for {}",
- backCompatStr,
- backCompatId,
- str);
- return backCompatId;
- }
- }
- auto iter = str2id_.find(str);
- ABORT_IF(iter == str2id_.end(),
- "DefaultVocabulary file {} is expected to contain an entry for {}",
- vocabPath,
- str);
- return iter->second;
- };
- eosId_ = getRequiredWordId(DEFAULT_EOS_STR, NEMATUS_EOS_STR, DEFAULT_EOS_ID);
- unkId_ = getRequiredWordId(DEFAULT_UNK_STR, NEMATUS_UNK_STR, DEFAULT_UNK_ID);
-
- // some special symbols for hard attention
- if(!seenSpecial.empty()) {
- auto requireWord = [&](Word id, const std::string& str) {
- auto iter = str2id_.find(str);
- // word already in vocab: must be at right index, else fail
- if(iter != str2id_.end())
- ABORT_IF(iter->second != id,
- "special vocabulary entry '{}' is expected to have id {}",
- str,
- id);
- else
- insertWord(id, str);
- };
- // @TODO: the hard-att code has not yet been updated to accept EOS at any id
- requireWord(DEFAULT_EOS_ID, DEFAULT_EOS_STR);
- }
+ addRequiredVocabulary(vocabPath, isJson);
return std::max(id2str_.size(), maxSize);
}
// for fakeBatch()
- void createFake() override {
- eosId_ = insertWord(DEFAULT_EOS_ID, DEFAULT_EOS_STR);
- unkId_ = insertWord(DEFAULT_UNK_ID, DEFAULT_UNK_STR);
+ virtual void createFake() override {
+ eosId_ = insertWord(Word::DEFAULT_EOS_ID, DEFAULT_EOS_STR);
+ unkId_ = insertWord(Word::DEFAULT_UNK_ID, DEFAULT_UNK_STR);
}
virtual void create(const std::string& vocabPath,
@@ -205,7 +164,7 @@ public:
"Vocabulary file '{}' exists. Not overwriting",
path.string());
}
-
+
std::unordered_map<std::string, size_t> counter;
for(const auto& trainPath : trainPaths)
addCounts(counter, trainPath);
@@ -214,17 +173,50 @@ public:
private:
+ virtual void addRequiredVocabulary(const std::string& vocabPath, bool isJson) {
+ // look up ids for </s> and <unk>, which are required
+ // The name backCompatStr is alternatively accepted for Yaml vocabs if id
+ // equals backCompatId.
+ auto getRequiredWordId = [&](const std::string& str,
+ const std::string& backCompatStr,
+ Word backCompatWord) -> Word {
+ // back compat with Nematus Yaml dicts
+ if(isJson) {
+ // if word id 0 or 1 is either empty or has the Nematus-convention string,
+ // then use it
+ auto backCompatId = backCompatWord.toWordIndex();
+ if(backCompatId < id2str_.size()
+ && (id2str_[backCompatId].empty()
+ || id2str_[backCompatId] == backCompatStr)) {
+ LOG(info,
+ "[data] Using unused word id {} for {}",
+ backCompatStr,
+ backCompatId,
+ str);
+ return backCompatWord;
+ }
+ }
+ auto iter = str2id_.find(str);
+ ABORT_IF(iter == str2id_.end(),
+ "DefaultVocabulary file {} is expected to contain an entry for {}",
+ vocabPath,
+ str);
+ return iter->second;
+ };
+ eosId_ = getRequiredWordId(DEFAULT_EOS_STR, NEMATUS_EOS_STR, Word::DEFAULT_EOS_ID);
+ unkId_ = getRequiredWordId(DEFAULT_UNK_STR, NEMATUS_UNK_STR, Word::DEFAULT_UNK_ID);
+ }
+
void addCounts(std::unordered_map<std::string, size_t>& counter,
const std::string& trainPath) {
- std::unique_ptr<io::InputFileStream> trainStrm(
- trainPath == "stdin" ? new io::InputFileStream(std::cin)
+ std::unique_ptr<std::istream> trainStrm(
+ trainPath == "stdin" ? new std::istream(std::cin.rdbuf())
: new io::InputFileStream(trainPath)
);
std::string line;
while(getline(*trainStrm, line)) {
- std::vector<std::string> toks;
- utils::split(line, toks, " ");
+ auto toks = utils::split(line, " ");
for(const std::string& tok : toks) {
auto iter = counter.find(tok);
@@ -236,9 +228,9 @@ private:
}
}
- void create(const std::string& vocabPath,
- const std::unordered_map<std::string, size_t>& counter,
- size_t maxSize = 0) {
+ virtual void create(const std::string& vocabPath,
+ const std::unordered_map<std::string, size_t>& counter,
+ size_t maxSize = 0) {
std::vector<std::string> vocabVec;
for(auto& p : counter)
@@ -247,10 +239,10 @@ private:
std::sort(vocabVec.begin(), vocabVec.end(), VocabFreqOrderer(counter));
YAML::Node vocabYaml;
- vocabYaml.force_insert(DEFAULT_EOS_STR, DEFAULT_EOS_ID);
- vocabYaml.force_insert(DEFAULT_UNK_STR, DEFAULT_UNK_ID);
+ vocabYaml.force_insert(DEFAULT_EOS_STR, Word::DEFAULT_EOS_ID.toWordIndex());
+ vocabYaml.force_insert(DEFAULT_UNK_STR, Word::DEFAULT_UNK_ID.toWordIndex());
- Word maxSpec = 1;
+ WordIndex maxSpec = 1;
auto vocabSize = vocabVec.size();
if(maxSize > maxSpec)
vocabSize = std::min(maxSize - maxSpec - 1, vocabVec.size());
@@ -258,8 +250,8 @@ private:
for(size_t i = 0; i < vocabSize; ++i)
vocabYaml.force_insert(vocabVec[i], i + maxSpec + 1);
- std::unique_ptr<io::OutputFileStream> vocabStrm(
- vocabPath == "stdout" ? new io::OutputFileStream(std::cout)
+ std::unique_ptr<std::ostream> vocabStrm(
+ vocabPath == "stdout" ? new std::ostream(std::cout.rdbuf())
: new io::OutputFileStream(vocabPath)
);
*vocabStrm << vocabYaml;
@@ -289,17 +281,55 @@ private:
}
// helper to insert a word into str2id_[] and id2str_[]
- Word insertWord(Word id, const std::string& str) {
- str2id_[str] = id;
+ Word insertWord(Word word, const std::string& str) {
+ str2id_[str] = word;
+ auto id = word.toWordIndex();
if(id >= id2str_.size())
id2str_.resize(id + 1);
id2str_[id] = str;
- return id;
+ return word;
};
};
-Ptr<VocabBase> createDefaultVocab() {
+// This is a vocabulary class that does not enforce </s> or <unk>.
+// This is used for class lists in a classifier.
+class ClassVocab : public DefaultVocab {
+private:
+ // Do nothing.
+ virtual void addRequiredVocabulary(const std::string& /*vocabPath*/, bool /*isJson*/) override {}
+
+ // Not adding special class labels, only seen classes.
+ virtual void create(const std::string& vocabPath,
+ const std::unordered_map<std::string, size_t>& counter,
+ size_t maxSize = 0) override {
+
+ std::vector<std::string> vocabVec;
+ for(auto& p : counter)
+ vocabVec.push_back(p.first);
+ std::sort(vocabVec.begin(), vocabVec.end(), VocabFreqOrderer(counter));
+
+ ABORT_IF(maxSize != 0 && vocabVec.size() != maxSize,
+ "Class vocab maxSize given ({}) has to match class vocab size ({})",
+ maxSize, vocabVec.size());
+
+ YAML::Node vocabYaml;
+ for(size_t i = 0; i < vocabVec.size(); ++i)
+ vocabYaml.force_insert(vocabVec[i], i);
+
+ std::unique_ptr<std::ostream> vocabStrm(
+ vocabPath == "stdout" ? new std::ostream(std::cout.rdbuf())
+ : new io::OutputFileStream(vocabPath)
+ );
+ *vocabStrm << vocabYaml;
+ }
+};
+
+Ptr<IVocab> createDefaultVocab() {
return New<DefaultVocab>();
}
+Ptr<IVocab> createClassVocab() {
+ return New<ClassVocab>();
+}
+
}
diff --git a/src/data/factored_vocab.cpp b/src/data/factored_vocab.cpp
new file mode 100755
index 00000000..971bf626
--- /dev/null
+++ b/src/data/factored_vocab.cpp
@@ -0,0 +1,778 @@
+// This is the main implementation of factored models, which are driven by the vocabulary.
+// Decoding, embedding, and output layer call into the vocab to drive their behavior.
+
+#include "data/vocab_base.h"
+#include "common/definitions.h"
+#include "data/types.h"
+#include "common/regex.h"
+#include "data/factored_vocab.h"
+#include <set>
+
+// @TODO: review all comments and clarify nomenclature:
+// * factor type (e.g. caps: |c* ); currently called a "group"
+// * factor name (e.g. all-caps: |ca )
+// * factor index (e.g. |ca is index 0 inside |ca |ci |cn)
+// * factor unit index (|ca is unit 41324 in joint factor vocab)
+// Also remove references to older outdated versions.
+
+namespace marian {
+
+/*virtual*/ size_t FactoredVocab::load(const std::string& modelPath, size_t maxSizeUnused /*= 0*/) /*override final*/ {
+ maxSizeUnused;
+ // If model has already been loaded, then assume this is a shared object, and skip loading it again.
+ // This can be multi-threaded, so must run under lock.
+ static std::mutex s_mtx;
+ std::lock_guard<std::mutex> criticalSection(s_mtx);
+ if (size() != 0) {
+ //LOG(info, "[vocab] Attempting to load model a second time; skipping (assuming shared vocab)");
+ return size();
+ }
+ LOG(info, "[vocab] Loading vocab spec file {}", modelPath);
+
+ // load factor-vocab file and parse it
+ std::vector<std::vector<std::string>> factorMapTokenized;
+ std::string line;
+ std::vector<std::string> tokBuf;
+ if (utils::endsWith(modelPath, ".fsv")) { // @TODO: this extension check is only for backcompat; can be removed once we no longer support the old format
+ // this is a fake parser for the generic factor spec, which makes a few hard assumptions:
+ // - all types starting with _ except _has_* are factor names
+ // - X : _x makes surface form X part of prob distribution _x except for _has_*
+ // - X : _has_x adds factor "x" to lemma X
+ // - _x <-> form only allows "_x <->" or "_x <-> _has_x" (same x), and is otherwise unused
+ // - _lemma is special
+ // The current version of the code just converts it internally to the legacy form.
+ // @TODO: Once the legacy form is no longer needed, simplify this.
+ io::InputFileStream in(modelPath);
+ WordIndex v = 0;
+ std::map<std::string,std::set<std::string>> factorTypeMap; // [type name] -> {factor-type names}
+ std::vector<std::string> deferredFactorVocab; // factor surface forms are presently expected to be at the end of factorVocab_, so collect them here first
+ while(io::getline(in, line)) {
+#if 1 // workaround for a bug fix in FactoredSegmenter that made old .fsv files incompatible
+ if (line == "\xef\xb8\x8f : _lemma _has_wb") // old vocabs have a wrong factor in here
+ line = "\xef\xb8\x8f : _lemma _has_gl _has_gr"; // patch it to the correct one
+ else if (line == "\xef\xb8\x8e : _lemma _has_wb")
+ line = "\xef\xb8\x8e : _lemma _has_gl _has_gr";
+#endif
+ utils::splitAny(line, tokBuf, " \t");
+ if (tokBuf.empty() || tokBuf[0][0] == '#') // skip comments and blank lines
+ continue;
+ const auto& lhs = tokBuf[0];
+ const auto& op = tokBuf.size() > 1 ? tokBuf[1] : "";
+ if (lhs[0] == '_') { // factor name
+ if (utils::beginsWith(lhs, "_has_")) {
+ const auto fName = lhs.substr(5); // skip _has_
+ ABORT_IF(factorTypeMap.find(fName) == factorTypeMap.end(), "Factor trait '{}' requires a factor named '{}' to exist", lhs, fName);
+ ABORT_IF(tokBuf.size() != 1, "Extraneous characters after factor trait: '{}'", line);
+ continue;
+ }
+ else if (op == "<->") {
+ ABORT_IF(lhs == "_lemma" && tokBuf.size() != 2, "Lemma factor distribution cannot be conditioned: '{}'", line);
+ ABORT_IF(lhs != "_lemma" && (tokBuf.size() != 3 || tokBuf[2] != "_has" + lhs), "Factor distribution can only be conditioned on nothing or on _has{}: '{}'", lhs, line);
+ continue;
+ }
+ else { // this declares a new factor
+ ABORT_IF(tokBuf.size() != 1, "Extraneous characters after factor declaration: '{}'", line);
+ const auto& fName = lhs.substr(1); // skip _
+ ABORT_IF(factorTypeMap.empty() && fName != "lemma", "First factor must be _lemma");
+ auto rv = factorTypeMap.insert(std::make_pair(fName, std::set<std::string>())); // create new factor
+ ABORT_IF(!rv.second, "Factor declared twice: '{}'", line);
+ groupPrefixes_.push_back(fName == "lemma" ? "(lemma)" : ("|" + fName));
+ continue;
+ }
+ }
+ else { // if not _ then it is a surface form
+ ABORT_IF(op != ":" || 2 >= tokBuf.size(), "Factor-lemma declaration should have the form LEMMA : _FACTOR, _has_FACTOR, _has_FACTOR... in '{}'", line);
+ ABORT_IF(tokBuf[2][0] != '_', "Factor name should begin with _ in '{}'", line);
+ ABORT_IF(utils::beginsWith(tokBuf[2], "_has_"), "The first factor after : must not begin with _has_ in '{}'", line);
+ // add to surface-form dictionary
+ const auto& fName = tokBuf[2].substr(1); // skip _
+ auto isLemma = fName == "lemma";
+ if (isLemma)
+ factorVocab_.add(lhs, v++); // note: each item can only be declared once
+ else
+ deferredFactorVocab.push_back(lhs); // add surface form to its declared factor type
+ auto surfaceFormSet = factorTypeMap.find(fName); // set of surface forms for this factor
+ ABORT_IF(surfaceFormSet == factorTypeMap.end(), "Unknown factor name in '{}'", line);
+ auto rv = surfaceFormSet->second.insert(lhs); // insert surface form into its declared factor type
+ ABORT_IF(!rv.second, "Factor declared twice: '{}'", line);
+ auto tokenizedMapLine = isLemma ? std::vector<std::string>{ lhs, lhs } : std::vector<std::string>();
+ // associated factors
+ for (size_t i = 3; i < tokBuf.size(); i++) {
+ const auto& has = tokBuf[i];
+ ABORT_IF(!utils::beginsWith(has, "_has_"), "Factor associations must use the form _has_X in '{}'", line);
+ ABORT_IF(!isLemma, "Factor associations are only allowed when factor type is _lemma: '{}', line");
+ const auto& faName = has.substr(5); // skip _has_ and prepend |
+ // for tokenized map, we pick one example of the factor names
+ auto iter = factorTypeMap.find(faName);
+ ABORT_IF(iter == factorTypeMap.end(), "Invalid factor association {}, no such factor: '{}'", has, line);
+ const auto& factorNames = iter->second;
+ ABORT_IF(factorNames.empty(), "Factor association {} refers to empty factor type: '{}'", has, line);
+ const auto& oneFactorName = "|" + *factorNames.begin(); // pick the first entry as one example
+ tokenizedMapLine[0] += oneFactorName;
+ tokenizedMapLine.push_back(oneFactorName);
+ }
+ if (isLemma)
+ factorMapTokenized.push_back(std::move(tokenizedMapLine));
+ continue;
+ }
+ ABORT("Malformed .fsv input line {}", line); // we only get here for lines we could not process
+ }
+ for (auto factorTypeName : deferredFactorVocab)
+ factorVocab_.add("|" + factorTypeName, v++);
+ } else { // legacy for old configs
+ // legacy format: one factor map, one flat list of factor surface forms
+ // load factor vocabulary
+ factorSeparator_ = '@';
+ auto factorVocabPath = modelPath;
+ factorVocabPath.back() = 'l'; // map .fm to .fl
+ factorVocab_.load(factorVocabPath);
+ groupPrefixes_ = { "(lemma)", "@C", "@GL", "@GR", "@WB"/*, "@WE"*/, "@CB"/*, "@CE"*/ }; // @TODO: hard-coded for these initial experiments
+ // @TODO: add checks for empty factor groups until it stops crashing (training already works; decoder still crashes)
+
+ io::InputFileStream in(modelPath);
+ for (WordIndex v = 0; io::getline(in, line); v++) {
+ utils::splitAny(line, tokBuf, " \t");
+ factorMapTokenized.push_back(tokBuf);
+ }
+ }
+
+ // construct mapping tables for factors
+ constructGroupInfoFromFactorVocab();
+ constructFactorIndexConversion();
+
+ // parse factorMap
+ // modelPath = path to file with entries in order of vocab entries of the form
+ // WORD FACTOR1 FACTOR2 FACTOR3...
+ // Factors are grouped
+ // - user specifies list-factor prefixes; all factors beginning with that prefix are in the same group
+ // - factors within a group as multi-class and normalized that way
+ // - groups of size 1 are interpreted as sigmoids, multiply with P(u) / P(u-1)
+ // - one prefix must not contain another
+ // - all factors not matching a prefix get lumped into yet another class (the lemmas)
+ // - factor vocab must be sorted such that all groups are consecutive
+ // - result of Output layer is nevertheless logits, not a normalized probability, due to the sigmoid entries
+ // For every lemma, the factor map contains one example. At the end of this loop, we have a vocabulary
+ // vocab_ that contains those examples, but not all possible combinations
+ lemmaHasFactorGroup_.resize(groupRanges_[0].second - groupRanges_[0].first); // group 0 is the lemmas; this difference is the number of lemma symbols
+ size_t numTotalFactors = 0;
+ for (WordIndex v = 0; v < factorMapTokenized.size(); v++) {
+ const auto& tokens = factorMapTokenized[v];
+ // parse the line, of the form WORD FACTOR1 FACTOR2 FACTOR1 ...
+ // where FACTOR1 is the lemma, a factor that all words have.
+ // Not every word has all other factors, so the n-th item is not always in the same factor group.
+ // @TODO: change to just use the .wl file, and manually split at @
+ ABORT_IF(tokens.size() < 2, "Factor map must have at least one factor per word", modelPath);
+ std::vector<WordIndex> factorUnits; // units in the joint factor vocab that belong to a specific factor type
+ for (size_t i = 1/*first factor*/; i < tokens.size(); i++) {
+ auto u = factorVocab_[tokens[i]];
+ factorUnits.push_back(u);
+ }
+ // convert to fully unrolled factors representation
+ auto na = FACTOR_NOT_APPLICABLE; // (gcc compiler bug: sometimes it cannot find this if passed directly)
+ std::vector<size_t> factorIndices(groupRanges_.size(), na); // default for unused factors
+ std::vector<bool> hasFactorGroupFlags(groupRanges_.size(), false);
+ for (auto u : factorUnits) {
+ factorIndices[factorGroups_[u]] = factorUnit2FactorIndex(u);
+ hasFactorGroupFlags[factorGroups_[u]] = true;
+ }
+ // record which lemma has what factor groups
+ ABORT_IF(!hasFactorGroupFlags[0], "Factor map does not specify a lemma (factor of first group) for word {}", tokens.front());
+ auto& lemmaFlags = lemmaHasFactorGroup_[factorIndices[0]];
+ if (lemmaFlags.empty())
+ lemmaFlags = std::move(hasFactorGroupFlags);
+ else
+ ABORT_IF(lemmaFlags != hasFactorGroupFlags, "Inconsistent factor groups used for word {}", tokens.front());
+ // map factors to non-dense integer
+ auto word = factors2word(factorIndices);
+ // add to vocab (the wordIndex are not dense, so the vocab will have holes)
+ // for now add what we get, and then expand more below
+ auto wordString = word2string(word);
+ if (tokens.front() != wordString) // order may differ, since we formed the input based on the factors in the user file, which may be in any order
+ LOG_ONCE(info, "[vocab] Word name in vocab file {} differs from canonical form {} (this warning is only shown once)", tokens.front(), wordString);
+ vocab_.add(wordString, word.toWordIndex());
+ numTotalFactors += tokens.size() - 1;
+ }
+ LOG(info, "[vocab] Factored-embedding map read with total/unique of {}/{} factors from {} example words (in space of {})",
+ numTotalFactors, factorVocabSize(), vocab_.size()/*numValid()*/, utils::withCommas(virtualVocabSize()));
+ //vocab_.dumpToFile(modelPath + "_examples");
+
+ // enumerate all valid combinations of factors for each lemma and add them to vocab_
+ // Having vocab_ makes life easier, although it is not strictly needed. Typical expanded valid vocabs
+ // are on the order of 200k entries. If we ever go much larger, we'd want to elimimate vocab_
+ // and fully virtualize its function.
+ LOG(info, "[vocab] Expanding all valid vocab entries out of {}...", utils::withCommas(virtualVocabSize()));
+ std::vector<size_t> factorIndices(getNumGroups());
+ rCompleteVocab(factorIndices, /*g=*/0);
+ LOG(info, "[vocab] Completed, total {} valid combinations", vocab_.size()/*numValid()*/);
+ //vocab_.dumpToFile(modelPath + "_expanded");
+
+#ifdef FACTOR_FULL_EXPANSION
+ // create mappings needed for normalization in factored outputs
+ constructNormalizationInfoForVocab();
+#endif
+
+ // </s> and <unk> must exist in the vocabulary
+ eosId_ = Word::fromWordIndex(vocab_[DEFAULT_EOS_STR]);
+ unkId_ = Word::fromWordIndex(vocab_[DEFAULT_UNK_STR]);
+ //LOG(info, "eos: {}; unk: {}", word2string(eosId_), word2string(unkId_));
+
+ return size();
+}
+
+// helper to add missing words to vocab_
+// factorIndices has been formed up to *ex*cluding position [g].
+void FactoredVocab::rCompleteVocab(std::vector<size_t>& factorIndices, size_t g) {
+ // reached the end
+ if (g == getNumGroups()) {
+ auto word = factors2word(factorIndices);
+ auto v = word.toWordIndex();
+ if (!vocab_.contains(v)) // add if missing
+ vocab_.add(word2string(word), v);
+ return;
+ }
+ // try next factor
+ if (g == 0 || lemmaHasFactorGroup(factorIndices[0], g)) {
+ for (size_t g1 = 0; g1 < factorShape_[g] - 1; g1++) {
+ factorIndices[g] = g1;
+ rCompleteVocab(factorIndices, g + 1);
+ }
+ }
+ else {
+ factorIndices[g] = FACTOR_NOT_APPLICABLE;
+ rCompleteVocab(factorIndices, g + 1);
+ }
+}
+
+void FactoredVocab::constructGroupInfoFromFactorVocab() {
+ // form groups
+ size_t numGroups = groupPrefixes_.size();
+ size_t factorVocabSize = this->factorVocabSize();
+ factorGroups_.resize(factorVocabSize, 0);
+ for (size_t g = 1; g < groupPrefixes_.size(); g++) { // set group labels; what does not match any prefix will stay in group 0
+ const auto& groupPrefix = groupPrefixes_[g];
+ for (WordIndex u = 0; u < factorVocabSize; u++)
+ if (utils::beginsWith(factorVocab_[u], groupPrefix)) {
+ //ABORT_IF(factorGroups_[u] != 0, "Factor {} matches multiple groups, incl. {}", factorVocab_[u], groupPrefix);
+ if(factorGroups_[u] != 0)
+ LOG(info, "Factor {} matches multiple groups, incl. {}, using {}", factorVocab_[u], groupPrefixes_[factorGroups_[u]], groupPrefix);
+ factorGroups_[u] = g;
+ }
+ }
+ // determine group index ranges
+ groupRanges_.resize(numGroups, { SIZE_MAX, (size_t)0 });
+ std::vector<int> groupCounts(numGroups); // number of group members
+ for (WordIndex u = 0; u < factorVocabSize; u++) { // determine ranges; these must be non-overlapping, verified via groupCounts
+ auto g = factorGroups_[u];
+ if (groupRanges_[g].first > u)
+ groupRanges_[g].first = u;
+ if (groupRanges_[g].second < u + 1)
+ groupRanges_[g].second = u + 1;
+ groupCounts[g]++;
+ }
+ for (size_t g = 0; g < numGroups; g++) { // detect non-overlapping groups
+ LOG(info, "[vocab] Factor group '{}' has {} members", groupPrefixes_[g], groupCounts[g]);
+ if (groupCounts[g] == 0) { // factor group is unused --@TODO: once this is not hard-coded, this is an error condition
+ groupRanges_[g].first = g > 0 ? groupRanges_[g-1].second : 0; // fix up the entry
+ groupRanges_[g].second = groupRanges_[g].first;
+ continue;
+ }
+ ABORT_IF(groupRanges_[g].second - groupRanges_[g].first != groupCounts[g],
+ "Factor group '{}' members should be consecutive in the factor vocabulary", groupPrefixes_[g]);
+ }
+ // we map between factors and flat WordIndex like indexing a tensor
+ constructFactorIndexConversion();
+}
+
+// create factorShape_ and factorStrides_, for mapping between flat (non-dense) ids and factor arrays
+void FactoredVocab::constructFactorIndexConversion() {
+ std::vector<int> shape;
+ for (const auto& r : groupRanges_)
+ shape.push_back((int)(r.second - r.first + 1)); // +1 to reserve the last value for either "factor not used" or "factor not present"
+ factorShape_ = Shape(std::move(shape));
+ factorStrides_.resize(factorShape_.size(), 1);
+ for (size_t g = factorStrides_.size() - 1; g --> 0; )
+ factorStrides_[g] = factorStrides_[g + 1] * (size_t)factorShape_[g + 1];
+ ABORT_IF((WordIndex)virtualVocabSize() != virtualVocabSize(),
+ "Too many factors, virtual index space {} exceeds the bit limit of WordIndex type", utils::withCommas(virtualVocabSize()));
+}
+
+// encode factors into a Word struct
+// inputs:
+// - factorIndices[factorType] = factorIndex (e.g. 0 for |ca )
+// output:
+// - representation as 'Word' (which is, in fact, a single big integer)
+Word FactoredVocab::factors2word(const std::vector<size_t>& factorIndices /* [numGroups] */) const {
+ size_t index = 0;
+ size_t numGroups = getNumGroups();
+ ABORT_IF(factorIndices.size() != numGroups, "Factor indices array size must be same as number of factor groups");
+ for (size_t g = 0; g < numGroups; g++) {
+ auto factorIndex = factorIndices[g];
+ if (factorIndex != FACTOR_NOT_SPECIFIED) { // check validity
+ auto factor0Index = factorIndices[0]; // lemma
+ ABORT_IF(factor0Index == FACTOR_NOT_SPECIFIED, "Without lemma, no other factor may be specified");
+ ABORT_IF(lemmaHasFactorGroup(factor0Index, g) == (factorIndex == FACTOR_NOT_APPLICABLE),
+ "Lemma '{}' {} factor group '{}'",
+ factorVocab_[WordIndex(factor0Index + groupRanges_[0].first)],
+ lemmaHasFactorGroup(factor0Index, g) ? "needs" : "does not have",
+ groupPrefixes_[g]);
+ }
+ if (factorIndex == FACTOR_NOT_APPLICABLE || factorIndex == FACTOR_NOT_SPECIFIED)
+ factorIndex = (size_t)factorShape_[g] - 1; // sentinel for "unused" or "not specified"
+ else
+ ABORT_IF(factorIndex >= (size_t)factorShape_[g] - 1, "Factor index out of range");
+ index += factorIndex * factorStrides_[g];
+ }
+ return Word::fromWordIndex(index);
+}
+
+// encode only a lemma into a 'Word'
+// The result is incomplete, in that the lemma likely has additional factors that are not yet specified.
+// Those are encoded as the value FACTOR_NOT_SPECIFIED. This function is used during beam search,
+// which starts with lemma scores, and then adds factors one by one to the path score.
+Word FactoredVocab::lemma2Word(size_t factor0Index) const {
+ size_t numGroups = getNumGroups();
+ std::vector<size_t> factorIndices;
+ factorIndices.reserve(numGroups);
+ factorIndices.push_back(factor0Index);
+ for (size_t g = 1; g < numGroups; g++) {
+ auto index = lemmaHasFactorGroup(factor0Index, g) ? FACTOR_NOT_SPECIFIED : FACTOR_NOT_APPLICABLE;
+ factorIndices.push_back(index);
+ }
+ return factors2word(factorIndices);
+}
+
+// replace a factor that is FACTOR_NOT_SPECIFIED by a specified one
+// This is used in beam search, where factors are searched one after another.
+Word FactoredVocab::expandFactoredWord(Word word, size_t groupIndex, size_t factorIndex) const {
+ //LOG(info, "expand {} + [{}]={}", word2string(word), groupIndex, factorIndex);
+ ABORT_IF(groupIndex == 0, "Cannot add or change lemma in a partial Word");
+ ABORT_IF(!isFactorValid(factorIndex), "Cannot add unspecified or n/a factor to a partial Word");
+ std::vector<size_t> factorIndices;
+ word2factors(word, factorIndices);
+ auto factor0Index = factorIndices[0];
+ ABORT_IF(!isFactorValid(factor0Index), "Cannot add factor to a partial Word without lemma");
+ ABORT_IF(factorIndices[groupIndex] == FACTOR_NOT_APPLICABLE, "Cannot add a factor that the lemma does not have");
+ ABORT_IF(factorIndices[groupIndex] != FACTOR_NOT_SPECIFIED, "Cannot modify a specified factor in a partial Word");
+ factorIndices[groupIndex] = factorIndex;
+ word = factors2word(factorIndices);
+ //LOG(info, "to {}", word2string(word));
+ return word;
+}
+
+// factor unit: index of factor name in the joint factor vocabulary
+// factor index: relative index within factor type, e.g. 0 for |ca
+size_t FactoredVocab::factorUnit2FactorIndex(WordIndex u) const {
+ auto g = factorGroups_[u]; // convert u to relative u within factor group range
+ ABORT_IF(u < groupRanges_[g].first || u >= groupRanges_[g].second, "Invalid factorGroups_ entry??");
+ return u - groupRanges_[g].first;
+}
+
+// split the 'Word' representation, which is really a single big integer, into the individual
+// factor indices for all factor types
+void FactoredVocab::word2factors(Word word, std::vector<size_t>& factorIndices /* [numGroups] */) const {
+ size_t numGroups = getNumGroups();
+ factorIndices.resize(numGroups);
+ for (size_t g = 0; g < numGroups; g++) {
+ auto factorIndex = getFactor(word, g);
+ factorIndices[g] = factorIndex;
+ }
+#if 1
+ auto test = factors2word(factorIndices);
+ ABORT_IF(test != word, "Word <-> factor conversion broken?? {} vs{}, '{}' vs. '{}'",
+ test.toWordIndex(), word.toWordIndex(), word2string(test), word2string(word));
+#endif
+}
+
+// serialize 'Word' representation into its string form
+std::string FactoredVocab::word2string(Word word) const {
+ // this function has some code dup, so that we can bypass some checks for debugging
+ size_t numGroups = getNumGroups();
+ size_t factor0Index = word.toWordIndex() / factorStrides_[0];
+ std::string res;
+ for (size_t g = 0; g < numGroups; g++) {
+ size_t index = word.toWordIndex();
+ index = index / factorStrides_[g];
+ index = index % (size_t)factorShape_[g];
+ if (index == (size_t)factorShape_[g] - 1) { // special sentinel value for unspecified or not-applicable
+ if (factor0Index >= (size_t)factorShape_[0])
+ res.append("(lemma oob)");
+ else if (lemmaHasFactorGroup(factor0Index, g))
+ res.append("?");
+ }
+ else
+ res.append(getFactorName(g, index));
+ }
+ return res;
+}
+
+// deserialize factored string form (e.g. HELLO|ci|wb) into its internal binary 'Word' representation
+Word FactoredVocab::string2word(const std::string& w) const {
+ auto sep = std::string(1, factorSeparator_);
+ auto parts = utils::splitAny(w, sep);
+ auto na = FACTOR_NOT_APPLICABLE; // (gcc compiler bug: sometimes it cannot find this if passed directly)
+ std::vector<size_t> factorIndices(groupRanges_.size(), na); // default for unused factors
+ for (size_t i = 0; i < parts.size(); i++) {
+ WordIndex u;
+ bool found = factorVocab_.tryFind(i == 0 ? parts[i] : sep + parts[i], u);
+ if (!found) {
+ static int logs = 100;
+ if (logs > 0) {
+ logs--;
+ LOG(info, "WARNING: Unknown factor '{}' in '{}'; mapping to '{}'", parts[i], w, word2string(getUnkId()));
+ }
+ return getUnkId();
+ }
+ // convert u to relative u within factor group range
+ auto g = factorGroups_[u];
+ ABORT_IF(u < groupRanges_[g].first || u >= groupRanges_[g].second, "Invalid factorGroups_ entry??");
+ factorIndices[g] = u - groupRanges_[g].first;
+ }
+ auto word = factors2word(factorIndices);
+ return word;
+}
+
+// does a specific factor exist in the vocabulary
+// Factor name must be given without separator. This function cannot be used for lemmas.
+bool FactoredVocab::tryGetFactor(const std::string& factorName, size_t& groupIndex, size_t& factorIndex) const {
+ WordIndex u;
+ if (factorVocab_.tryFind(factorSeparator_ + factorName, u))
+ {
+ groupIndex = factorGroups_[u];
+ ABORT_IF(u < groupRanges_[groupIndex].first || u >= groupRanges_[groupIndex].second, "Invalid factorGroups_ entry??");
+ factorIndex = u - groupRanges_[groupIndex].first;
+ return true;
+ }
+ else
+ return false;
+}
+
+// extract the factor index of a given factor type from the 'Word' representation
+size_t FactoredVocab::getFactor(Word word, size_t groupIndex) const {
+ size_t index = word.toWordIndex();
+ size_t factor0Index = index / factorStrides_[0];
+ index = index / factorStrides_[groupIndex];
+ index = index % (size_t)factorShape_[groupIndex];
+ if (index == (size_t)factorShape_[groupIndex] - 1) { // special sentinel value for unspecified or not-applicable
+ if (groupIndex == 0) // lemma itself is always applicable, hence 'not specified'
+ index = FACTOR_NOT_SPECIFIED;
+ else { // not lemma: check whether lemma of word has this factor group
+ if (lemmaHasFactorGroup(factor0Index, groupIndex))
+ index = FACTOR_NOT_SPECIFIED;
+ else
+ index = FACTOR_NOT_APPLICABLE;
+ }
+ }
+ else { // regular value: consistency check if lemma really has this factor group
+ ABORT_IF(factor0Index == (size_t)factorShape_[0] - 1, "Word has specified factor but no lemma??");
+ //ABORT_IF(!lemmaHasFactorGroup(factor0Index, groupIndex), "Word has a specified factor for a lemma that does not have that factor group??");
+ if (!lemmaHasFactorGroup(factor0Index, groupIndex))
+ index = FACTOR_NOT_SPECIFIED;
+ // @TODO: ^^ needed for determining all valid vocab entries; can we pass a flag in to allow this?
+ }
+ return index;
+}
+
+#ifdef FACTOR_FULL_EXPANSION
+void FactoredVocab::constructNormalizationInfoForVocab() {
+ // create mappings needed for normalization in factored outputs
+ //size_t numGroups = groupPrefixes_.size();
+ size_t vocabSize = virtualVocabSize();
+ //factorMasks_ .resize(numGroups, std::vector<float>(vocabSize, 0)); // [g][v] 1.0 if word v has factor g
+ //factorIndices_.resize(numGroups, std::vector<IndexType>(vocabSize, 0)); // [g][v] index of factor (or any valid index if it does not have it; we use 0)
+ gapLogMask_.resize(vocabSize, -1e8f);
+ for (WordIndex v = 0; v < vocabSize; v++) {
+#if 1 // @TODO: TEST THIS again by disabling factored decoding in beam_search.h
+ if (vocab_.contains(v))
+ gapLogMask_[v] = 0.0f; // valid entry
+#else
+ for (auto u : factorMap_[v]) {
+ auto g = factorGroups_[u]; // convert u to relative u within factor group range
+ ABORT_IF(u < groupRanges_[g].first || u >= groupRanges_[g].second, "Invalid factorGroups_ entry??");
+ //factorIndices_[g][v] = (IndexType)(u - groupRanges_[g].first);
+ //factorMasks_[g][v] = 1.0f;
+ gapLogMask_[v] = 0.0f; // valid entry
+ }
+#endif
+ }
+ //for (Word v = 0; v < vocabSize; v++) {
+ // LOG(info, "'{}': {}*{} {}*{} {}*{} {}*{}", vocab[v],
+ // factorMasks_[0][v], factorIndices_[0][v],
+ // factorMasks_[1][v], factorIndices_[1][v],
+ // factorMasks_[2][v], factorIndices_[2][v],
+ // factorMasks_[3][v], factorIndices_[3][v]);
+ //}
+
+ // create the global factor matrix, which is used for getLogits() only
+ // For invalid words, this leaves empty matrix rows, which are later masked by adding gapLogMask.
+ Words data;
+ for (size_t v = 0; v < vocabSize; v++) // note: this loops over the entire vocab space, incl. gaps
+ data.push_back(Word::fromWordIndex(v));
+ globalFactorMatrix_ = csr_rows(data); // [V x U]
+}
+#endif
+
+/*virtual*/ Word FactoredVocab::operator[](const std::string& word) const /*override final*/ {
+ // @TODO: do away with vocab_ altogether, and just always parse.
+ WordIndex index;
+ bool found = vocab_.tryFind(word, index);
+ if (found)
+ return Word::fromWordIndex(index);
+ else
+ return string2word(word);
+}
+
+/*virtual*/ const std::string& FactoredVocab::operator[](Word word) const /*override final*/ {
+ //LOG(info, "Looking up Word {}={}", word.toWordIndex(), word2string(word));
+ ABORT_IF(!vocab_.contains(word.toWordIndex()), "Invalid factor combination {}", word2string(word));
+ return vocab_[word.toWordIndex()];
+}
+
+// convert a string representation of a token sequence to all-caps by changing all capitalization factors to |ca
+/*virtual*/ std::string FactoredVocab::toUpper(const std::string& line) const /*override final*/ {
+ return utils::findReplace(utils::findReplace(utils::findReplace(utils::findReplace(utils::findReplace(line, "|scl", "|scu", /*all=*/true), "|ci", "|ca", /*all=*/true), "|cn", "|ca", /*all=*/true), "@CI", "@CA", /*all=*/true), "@CN", "@CA", /*all=*/true);
+}
+
+// convert a string representation of a token sequence to English title case by changing the capitalization factors to |ci
+/*virtual*/ std::string FactoredVocab::toEnglishTitleCase(const std::string& line) const /*override final*/ {
+ // @BUGBUG: does not handle the special words that should remain lower-case
+ // note: this presently supports both @WB and @GL- (legacy)
+ return utils::findReplace(utils::findReplace(utils::findReplace(utils::findReplace(utils::findReplace(line, "|scl", "|scu", /*all=*/true), "|cn|wb", "|ci|wb", /*all=*/true), "|cn|gl-", "|ci|gl-", /*all=*/true), "@CN@WB", "@CI@WB", /*all=*/true), "@CN@GL-", "@CI@GL-", /*all=*/true);
+}
+
+// convert word indices to indices of shortlist items
+// We only shortlist the lemmas, hence return the lemma index (offset to correctly index into the concatenated W matrix).
+// This strange pointer-based interface is for ease of interaction with our production environment.
+/*virtual*/ void FactoredVocab::transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const {
+ for (; num-- > 0; ptr++) {
+ auto word = Word::fromWordIndex(*ptr);
+ auto wordString = word2string(word);
+ auto lemmaIndex = getFactor(word, 0) + groupRanges_[0].first;
+ *ptr = (WordIndex)lemmaIndex;
+ }
+}
+
+// generate a valid random factored word (used by collectStats())
+/*virtual*/ Word FactoredVocab::randWord() const /*override final*/ {
+ auto numGroups = getNumGroups();
+ std::vector<size_t> factorIndices; factorIndices.reserve(numGroups);
+ for (size_t g = 0; g < numGroups; g++) {
+ size_t factorIndex;
+ if (g == 0 || lemmaHasFactorGroup(factorIndices[0], g))
+ factorIndex = rand() % (factorShape_[g] - 1);
+ else
+ factorIndex = FACTOR_NOT_APPLICABLE;
+ factorIndices.push_back(factorIndex);
+ }
+ return factors2word(factorIndices);
+}
+
+// encode a string representation of an entire token sequence, as found in the corpus file, into a 'Word' array
+/*virtual*/ Words FactoredVocab::encode(const std::string& line, bool addEOS /*= true*/, bool /*inference*/ /*= false*/) const /*override final*/ {
+ std::vector<std::string> lineTokens;
+ utils::split(line, lineTokens, " ");
+ Words res; res.reserve(lineTokens.size() + addEOS);
+ for (const auto& tok : lineTokens)
+ res.push_back((*this)[tok]);
+ if (addEOS)
+ res.push_back(getEosId());
+ return res;
+}
+
+// decode a 'Word' array into the external string representation of that token sequence, as written to output files
+/*virtual*/ std::string FactoredVocab::decode(const Words& sentence, bool ignoreEOS /*= true*/) const /*override final*/ {
+ std::vector<std::string> decoded; decoded.reserve(sentence.size());
+ for(auto w : sentence)
+ if((w != getEosId() || !ignoreEOS))
+ decoded.push_back((*this)[w]);
+ return utils::join(decoded, " ");
+}
+
+// diagnostics version of decode() that will not fail on partial words, will print EOS, and is a little slower
+std::string FactoredVocab::decodeForDiagnostics(const Words& sentence) const {
+ std::vector<std::string> decoded; decoded.reserve(sentence.size());
+ for (auto w : sentence)
+ decoded.push_back(word2string(w));
+ return utils::join(decoded, " ");
+}
+
+// helper to unescape \x.. and \u....
+static void unescapeHexEscapes(std::string& utf8Lemma) {
+ if (utf8Lemma.find('\\') == std::string::npos)
+ return; // nothing to do
+ auto lemma = utils::utf8ToUtf16String(utf8Lemma); // \u.... implies we must operate on UTF-16 level (not UCS-4)
+ auto pos = lemma.find('\\');
+ while (pos != std::string::npos) {
+ ABORT_IF(pos + 1 >= lemma.size() || (lemma[pos+1] != 'x' && lemma[pos + 1] != 'u'), "Malformed escape in factored encoding: {}", utf8Lemma);
+ int numDigits = 2 + 2 * (lemma[pos + 1] == 'u'); // 2 for \x, 4 for \u
+ ABORT_IF(pos + 2 + numDigits > lemma.size(), "Malformed escape in factored encoding: {}", utf8Lemma);
+ auto digits = utils::utf8FromUtf16String(lemma.substr(pos + 2, numDigits));
+ auto c = std::strtoul(digits.c_str(), nullptr, 16);
+ lemma[pos] = (char16_t)c;
+ lemma.erase(pos + 1, 1 + numDigits);
+ pos = lemma.find('\\', pos+1);
+ }
+ utf8Lemma = utils::utf8FromUtf16String(lemma);
+}
+
+// convert a 'Word' sequence to its final human-readable surface form
+// This interprets the capitalization and glue factors.
+// This assumes a specific notation of factors, emulating our C# code for generating these factors:
+// - | as separator symbol
+// - capitalization factors are cn, ci, and ca
+// - glue factors are gl+, gr+, wbn, wen, cbn, cen
+std::string FactoredVocab::surfaceForm(const Words& sentence) const /*override final*/ {
+ std::string res;
+ res.reserve(sentence.size() * 10);
+ bool prevHadGlueRight = true; // no space at sentence start
+ for(auto w : sentence) {
+ if (w == getEosId())
+ break;
+ auto token = (*this)[w];
+ auto tokens = utils::split(token, "|");
+ //std::cerr << token << " ";
+ auto lemma = tokens[0];
+ std::set<std::string> tokenSet(tokens.begin() + 1, tokens.end());
+ auto has = [&](const char* factor) { return tokenSet.find(factor) != tokenSet.end(); };
+ // spacing
+ bool hasGlueRight = has("gr+") || has("wen") || has("cen");
+ bool hasGlueLeft = has("gl+") || has("wbn") || has("cbn") || has("wi");
+ bool insertSpaceBefore = !prevHadGlueRight && !hasGlueLeft;
+ if (insertSpaceBefore)
+ res.push_back(' ');
+ prevHadGlueRight = hasGlueRight;
+ // capitalization
+ unescapeHexEscapes(lemma); // unescape \x.. and \u....
+ if (utils::beginsWith(lemma, "\xE2\x96\x81")) // remove leading _ (\u2581, for DistinguishInitialAndInternalPieces mode)
+ lemma = lemma.substr(3);
+ if (has("ci")) lemma = utils::utf8Capitalized(lemma);
+ else if (has("ca")) lemma = utils::utf8ToUpper (lemma);
+ else if (has("cn")) lemma = utils::utf8ToLower (lemma);
+ else if (has("scu")) lemma = utils::utf8ToUpper (lemma);
+ else if (has("scl")) lemma = utils::utf8ToLower (lemma);
+ res.append(lemma);
+ }
+ //std::cerr << "\n" << res << "\n";
+ return res;
+}
+
+// create a CSR matrix M[V,U] from words[] with M[v,u] = 1 if factor u is a factor of word v
+// This is used to form the embedding of a multi-factor token.
+// That embedding is a sum of the embeddings of the individual factors.
+// Those individual embeddings are assumed to be concatenated into one joint large embedding matrix.
+// The factor embeddings are summed up by multiplying the joint embedding matrix with a sparse matrix
+// that contains a 1 for all positions in the joint matrix that should be summed up.
+// This function creates that sparse matrix in CSR form.
+FactoredVocab::CSRData FactoredVocab::csr_rows(const Words& words) const {
+ auto numGroups = getNumGroups();
+ std::vector<float> weights;
+ std::vector<IndexType> indices;
+ std::vector<IndexType> offsets;
+ offsets.reserve(words.size() + 1);
+ indices.reserve(words.size()); // (at least this many)
+ // loop over all input words, and select the corresponding set of unit indices into CSR format
+ offsets.push_back((IndexType)indices.size());
+ std::vector<size_t> factorIndices;
+ for (auto word : words) {
+ if (vocab_.contains(word.toWordIndex())) { // skip invalid combinations in the space (can only happen during initialization) --@TODO: add a check?
+ word2factors(word, factorIndices);
+ for (size_t g = 0; g < numGroups; g++) { // @TODO: make this faster by having a list of all factors to consider for a lemma?
+ auto factorIndex = factorIndices[g];
+ ABORT_IF(factorIndex == FACTOR_NOT_SPECIFIED, "Attempted to embed a word with a factor not specified");
+ if (factorIndex == FACTOR_NOT_APPLICABLE)
+ continue;
+ indices.push_back((IndexType)(factorIndex + groupRanges_[g].first)); // map to unit index
+ weights.push_back(1.0f);
+ }
+ }
+ offsets.push_back((IndexType)indices.size()); // next matrix row begins at this offset
+ }
+ return { Shape({(int)words.size(), (int)factorVocabSize()}), weights, indices, offsets };
+}
+
+// Helper to construct and load a FactordVocab from a path is given (non-empty) and if it specifies a factored vocab.
+// This is used by the Embedding and Output layers.
+/*static*/ Ptr<FactoredVocab> FactoredVocab::tryCreateAndLoad(const std::string& path) {
+ Ptr<FactoredVocab> res;
+ if (!path.empty()) {
+ res = std::static_pointer_cast<FactoredVocab>(createFactoredVocab(path)); // this checks the file extension
+ if (res)
+ res->load(path); // or throw
+ }
+ return res;
+}
+
+// WordLUT
+WordIndex FactoredVocab::WordLUT::add(const std::string& word, WordIndex index) {
+ ABORT_IF(word.empty(), "Attempted to add the empty word to a dictionary");
+ auto wasInserted = str2index_.insert(std::make_pair(word, index)).second;
+ ABORT_IF(!wasInserted, "Duplicate vocab entry for '{}', new index {} vs. existing index {}", word, index, str2index_[word]);
+ wasInserted = index2str_.insert(std::make_pair(index, word)).second;
+ ABORT_IF(!wasInserted, "Duplicate vocab entry for index {} (new: '{}'; existing: '{}')", index, word, index2str_[index]);
+ return index;
+}
+
+static const std::string g_emptyString;
+const std::string& FactoredVocab::WordLUT::operator[](WordIndex index) const {
+ auto iter = index2str_.find(index);
+ if (iter == index2str_.end())
+ // returns an empty string for unknown index values
+ // @TODO: is that ever used ? If so, document.If not, remove this feature and let it fail.static const std::string g_emptyString;
+ return g_emptyString; // (using a global since we return a reference)
+ else
+ return iter->second;
+}
+WordIndex FactoredVocab::WordLUT::operator[](const std::string& word) const {
+ auto iter = str2index_.find(word);
+ ABORT_IF(iter == str2index_.end(), "Token '{}' not found in vocabulary", word);
+ return iter->second;
+}
+bool FactoredVocab::WordLUT::tryFind(const std::string& word, WordIndex& index) const {
+ auto iter = str2index_.find(word);
+ if (iter == str2index_.end())
+ return false;
+ index = iter->second;
+ return true;
+}
+size_t FactoredVocab::WordLUT::load(const std::string& path) {
+ std::string line;
+ io::InputFileStream in(path);
+ for (WordIndex v = 0; io::getline(in, line); v++)
+ add(line, v);
+ return size();
+}
+
+void FactoredVocab::WordLUT::dumpToFile(const std::string& path) {
+ io::OutputFileStream out(path);
+ for (auto kvp : index2str_)
+ out << kvp.second << "\t" << utils::withCommas(kvp.first) << "\n";
+}
+
+const static std::vector<std::string> exts{ ".fsv", ".fm"/*legacy*/ }; // @TODO: delete the legacy one
+
+// Note: This does not actually load it, only checks the path for the type.
+// Since loading takes a while, we cache instances.
+Ptr<IVocab> createFactoredVocab(const std::string& vocabPath) {
+ // this can be multi-threaded, so must run under lock
+ static std::mutex s_mtx;
+ std::lock_guard<std::mutex> criticalSection(s_mtx);
+
+ bool isFactoredVocab = std::any_of(exts.begin(), exts.end(), [&](const std::string& ext) { return utils::endsWith(vocabPath, ext); });
+ if (isFactoredVocab) {
+ static std::map<std::string, Ptr<IVocab>> s_cache;
+ auto iter = s_cache.find(vocabPath);
+ if (iter != s_cache.end()) {
+ LOG_ONCE(info, "[vocab] Reusing existing vocabulary object in memory (vocab size {})", iter->second->size());
+ return iter->second;
+ }
+ auto vocab = New<FactoredVocab>();
+ s_cache.insert(std::make_pair(vocabPath, vocab));
+ return vocab;
+ }
+ else
+ return nullptr;
+}
+/*virtual*/ const std::vector<std::string>& FactoredVocab::suffixes() const /*override final*/ {
+ return exts;
+}
+
+} // namespace marian
diff --git a/src/data/factored_vocab.h b/src/data/factored_vocab.h
new file mode 100755
index 00000000..215e92f0
--- /dev/null
+++ b/src/data/factored_vocab.h
@@ -0,0 +1,137 @@
+// Implementation of an IVocab that represents a factored representation.
+// This is accessed via the IVocab interface for the base vocab functionality,
+// and via dynamic_cast to FactoredVocab for factored-specific things used by
+// the Embedding and Output layers.
+
+#pragma once
+
+#include "common/definitions.h"
+#include "data/types.h"
+#include "data/vocab_base.h"
+
+#undef FACTOR_FULL_EXPANSION // define this to get full expansion. @TODO: infeasible for many factors; just delete this
+
+namespace marian {
+
+class FactoredVocab : public IVocab {
+public:
+ struct CSRData {
+ Shape shape;
+ std::vector<float> weights;
+ std::vector<IndexType> indices;
+ std::vector<IndexType> offsets;
+ };
+
+ // from IVocab:
+ virtual size_t load(const std::string& factoredVocabPath, size_t maxSizeUnused = 0) override final;
+ virtual void create(const std::string& vocabPath, const std::vector<std::string>& trainPaths, size_t maxSize) override final { vocabPath, trainPaths, maxSize; ABORT("Factored vocab cannot be created on the fly"); }
+ virtual const std::string& canonicalExtension() const override final { return suffixes()[0]; }
+ virtual const std::vector<std::string>& suffixes() const override final;
+ virtual Word operator[](const std::string& word) const override final;
+ virtual Words encode(const std::string& line, bool addEOS = true, bool inference = false) const override final;
+ virtual std::string decode(const Words& sentence, bool ignoreEos = true) const override final;
+ virtual std::string surfaceForm(const Words& sentence) const override final;
+ virtual const std::string& operator[](Word id) const override final;
+ virtual size_t size() const override final { return vocab_.size(); } // active factored vocabulary size (counting all valid combinations but not gaps)
+ virtual std::string type() const override final { return "FactoredVocab"; }
+ virtual Word getEosId() const override final { return eosId_; }
+ virtual Word getUnkId() const override final { return unkId_; }
+ virtual std::string toUpper(const std::string& line) const override final;
+ virtual std::string toEnglishTitleCase(const std::string& line) const override final;
+ virtual void transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const override final;
+ WordIndex getUnkIndex() const { return (WordIndex)getFactor(getUnkId(), 0); } // used in decoding
+ virtual void createFake() override final { ABORT("[data] Fake FactoredVocab vocabulary not supported"); }
+ virtual Word randWord() const override final;
+
+ // factor-specific. These methods are consumed by Output and Embedding.
+ size_t factorVocabSize() const { return factorVocab_.size(); } // total number of factors across all types
+ size_t virtualVocabSize() const { return factorShape_.elements<size_t>(); } // valid WordIndex range (representing all factor combinations including gaps); virtual and huge
+
+ CSRData csr_rows(const Words& words) const; // sparse matrix for summing up factors from the concatenated embedding matrix for each word
+
+#ifdef FACTOR_FULL_EXPANSION
+ const CSRData& getGlobalFactorMatrix() const { return globalFactorMatrix_; } // [v,u] (sparse) -> =1 if u is factor of v --only used in getLogits()
+#endif
+ size_t getNumGroups() const { return groupRanges_.size(); }
+ std::pair<size_t, size_t> getGroupRange(size_t g) const { return groupRanges_[g]; } // [g] -> (u_begin, u_end)
+#ifdef FACTOR_FULL_EXPANSION
+ const std::vector<float>& getGapLogMask() const { return gapLogMask_; } // [v] -inf if v is a gap entry, else 0
+#endif
+
+ // convert representations
+ Word factors2word(const std::vector<size_t>& factors) const;
+ void word2factors(Word word, std::vector<size_t>& factors) const;
+ Word lemma2Word(size_t factor0Index) const;
+ Word expandFactoredWord(Word word, size_t groupIndex, size_t factorIndex) const;
+ bool canExpandFactoredWord(Word word, size_t groupIndex) const { return lemmaHasFactorGroup(getFactor(word, 0), groupIndex); }
+ size_t getFactor(Word word, size_t groupIndex) const;
+ bool lemmaHasFactorGroup(size_t factor0Index, size_t g) const { return lemmaHasFactorGroup_[factor0Index][g]; }
+ const std::string& getFactorGroupPrefix(size_t groupIndex) const { return groupPrefixes_[groupIndex]; } // for diagnostics only
+ const std::string& getFactorName(size_t groupIndex, size_t factorIndex) const { return factorVocab_[(WordIndex)(factorIndex + groupRanges_[groupIndex].first)]; }
+ std::string decodeForDiagnostics(const Words& sentence) const;
+
+ static constexpr size_t FACTOR_NOT_APPLICABLE = (SIZE_MAX - 1);
+ static constexpr size_t FACTOR_NOT_SPECIFIED = (SIZE_MAX - 2);
+ static bool isFactorValid(size_t factorIndex) { return factorIndex < FACTOR_NOT_SPECIFIED; }
+
+ static Ptr<FactoredVocab> tryCreateAndLoad(const std::string& path); // load from "vocab" option if it specifies a factored vocab
+ std::string word2string(Word word) const;
+ Word string2word(const std::string& w) const;
+ bool tryGetFactor(const std::string& factorGroupName, size_t& groupIndex, size_t& factorIndex) const; // note: factorGroupName given without separator
+
+ // some hard-coded constants from FactoredSegmenter
+ // The naming mimics the names in FactoredSegmenter.cs, and therefore intentionally does not follow Marian conventions.
+ // @TODO: We have more hard-coded constants throughout the code. Move them all here.
+ // @TODO: figure out how to do this with static const*/constexpr
+#define FactoredVocab_INLINE_FIX_WHAT_serialized "is"
+#define FactoredVocab_FIX_SRC_ID_TAG "<IOPEN>"
+#define FactoredVocab_FIX_TGT_ID_TAG "<IDELIM>"
+#define FactoredVocab_FIX_END_ID_TAG "<ICLOSE>"
+
+private:
+ void constructGroupInfoFromFactorVocab();
+ void constructFactorIndexConversion();
+ void rCompleteVocab(std::vector<size_t>& factorIndices, size_t g);
+#ifdef FACTOR_FULL_EXPANSION
+ void constructNormalizationInfoForVocab();
+#endif
+ size_t factorUnit2FactorIndex(WordIndex u) const;
+private:
+ // @TODO: Should we move WordLUT to utils?
+ class WordLUT { // map between strings and WordIndex
+ std::map<std::string, WordIndex> str2index_;
+ std::map<WordIndex, std::string> index2str_;
+ public:
+ WordIndex add(const std::string& word, WordIndex index);
+ const std::string& operator[](WordIndex index) const;
+ WordIndex operator[](const std::string& word) const;
+ bool contains(WordIndex index) const { return index2str_.find(index) != index2str_.end(); }
+ bool tryFind(const std::string& word, WordIndex& index) const;
+ size_t size() const { return str2index_.size(); }
+ size_t load(const std::string& path);
+ void dumpToFile(const std::string& path);
+ };
+
+ // main vocab
+ Word eosId_{};
+ Word unkId_{};
+ WordLUT vocab_;
+
+ // factors
+ char factorSeparator_ = '|'; // separator symbol for parsing factored words
+ WordLUT factorVocab_; // [factor name] -> factor index = row of E_
+ std::vector<std::string> groupPrefixes_; // [group id g] shared prefix of factors (used for grouping)
+#ifdef FACTOR_FULL_EXPANSION
+ CSRData globalFactorMatrix_; // [v,u] (sparse) -> =1 if u is factor of v
+#endif
+ std::vector<size_t> factorGroups_; // [u] -> group id of factor u
+ std::vector<std::pair<size_t, size_t>> groupRanges_; // [group id g] -> (u_begin,u_end) index range of factors u for this group. These don't overlap.
+ std::vector<std::vector<bool>> lemmaHasFactorGroup_; // [factor 0 index][g] -> true if lemma has factor group
+ Shape factorShape_; // [g] number of factors in each factor group
+ std::vector<size_t> factorStrides_; // [g] stride for factor dimension
+#ifdef FACTOR_FULL_EXPANSION
+ std::vector<float> gapLogMask_; // [v] -1e8 if this is a gap, else 0
+#endif
+};
+
+} // namespace marian
diff --git a/src/data/rng_engine.h b/src/data/rng_engine.h
index 2efca136..2efca136 100755..100644
--- a/src/data/rng_engine.h
+++ b/src/data/rng_engine.h
diff --git a/src/data/sentencepiece_vocab.cpp b/src/data/sentencepiece_vocab.cpp
index 4457dc06..c168f6e3 100755..100644
--- a/src/data/sentencepiece_vocab.cpp
+++ b/src/data/sentencepiece_vocab.cpp
@@ -19,7 +19,7 @@ namespace marian {
#ifdef USE_SENTENCEPIECE
// Wrapper around https://github.com/google/sentencepiece
-class SentencePieceVocab : public VocabBase {
+class SentencePieceVocab : public IVocab {
private:
// Actual SentencePiece processor object
UPtr<sentencepiece::SentencePieceProcessor> spm_;
@@ -36,17 +36,18 @@ private:
std::mt19937 generator_;
std::uniform_int_distribution<int> randInt_; // from 0 to INT_MAX
+ // Keeps sentences segmented into subword units
+ bool keepEncoded_{false};
+
// Sample from one file, based on first algorithm from:
// https://en.wikipedia.org/wiki/Reservoir_sampling
void reservoirSampling(std::vector<std::string>& sample, size_t& seenLines,
const std::string& trainPath, size_t maxLines, size_t maxBytes) {
-
ABORT_IF(maxLines == 0, "Sample needs to be larger 0");
- std::unique_ptr<io::InputFileStream> trainStrm(
- trainPath == "stdin" ? new io::InputFileStream(std::cin)
- : new io::InputFileStream(trainPath)
- );
+ std::unique_ptr<std::istream> trainStrm(trainPath == "stdin"
+ ? new std::istream(std::cin.rdbuf())
+ : new io::InputFileStream(trainPath));
std::string line;
while(getline(*trainStrm, line)) {
@@ -78,9 +79,8 @@ private:
reservoirSampling(sample, seenLines, trainPath, maxLines, maxBytes);
std::shuffle(sample.begin(), sample.end(), generator_);
- io::OutputFileStream out(temp);
for(const auto& line : sample)
- out << line << std::endl;
+ temp << line << std::endl;
LOG(info, "[SentencePiece] Selected {} lines", sample.size());
return sample.size();
@@ -94,12 +94,11 @@ private:
size_t seenLines = 0;
std::string line;
- io::OutputFileStream out(temp);
for(const auto& trainPath : trainPaths) {
io::InputFileStream in(trainPath);
while(getline(in, line)) {
if(line.size() > 0 && line.size() < maxBytes) {
- out << line << std::endl;
+ temp << line << std::endl;
seenLines++;
}
}
@@ -111,8 +110,10 @@ private:
public:
SentencePieceVocab(Ptr<Options> options, size_t batchIndex)
- : options_(options), batchIndex_(batchIndex), generator_((uint32_t)Config::seed) {
-
+ : options_(options),
+ batchIndex_(batchIndex),
+ generator_((uint32_t)Config::seed),
+ keepEncoded_(options->get<bool>("no-spm-decode", false)) {
if(options_->has("sentencepiece-alphas")) {
auto alphas = options_->get<std::vector<float>>("sentencepiece-alphas");
if(alphas.size() <= batchIndex)
@@ -126,7 +127,6 @@ public:
alpha_,
batchIndex_);
}
-
}
virtual const std::string& canonicalExtension() const override { return suffixes_[0]; }
@@ -136,8 +136,8 @@ public:
virtual std::string type() const override { return "SentencePieceVocab"; }
- virtual Word getEosId() const override { return (Word)spm_->eos_id(); }
- virtual Word getUnkId() const override { return (Word)spm_->unk_id(); }
+ virtual Word getEosId() const override { return Word::fromWordIndex(spm_->eos_id()); }
+ virtual Word getUnkId() const override { return Word::fromWordIndex(spm_->unk_id()); }
void create(const std::string& vocabPath,
const std::vector<std::string>& trainPaths,
@@ -198,12 +198,12 @@ public:
}
Word operator[](const std::string& token) const override {
- return (Word)spm_->PieceToId(token);
+ return Word::fromWordIndex(spm_->PieceToId(token));
}
const std::string& operator[](Word id) const override {
- ABORT_IF(id >= size(), "Unknown word id: ", id);
- return spm_->IdToPiece(id);
+ ABORT_IF(id.toWordIndex() >= size(), "Unknown word id: ", id.toWordIndex());
+ return spm_->IdToPiece(id.toWordIndex());
}
Words encode(const std::string& line, bool addEOS, bool inference) const override {
@@ -213,7 +213,9 @@ public:
else
spm_->SampleEncode(line, -1, alpha_, &spmIds);
- Words words(spmIds.begin(), spmIds.end());
+ Words words; words.reserve(spmIds.size() + addEOS);
+ for (auto&& spmId : spmIds)
+ words.push_back(Word::fromWordIndex(spmId));
if(addEOS)
words.push_back(getEosId());
@@ -222,12 +224,26 @@ public:
std::string decode(const Words& sentence, bool /*ignoreEOS*/) const override {
std::string line;
- // convert vector of Word to vector of int
- std::vector<int> spmSentence(sentence.begin(), sentence.end());
- spm_->Decode(spmSentence, &line);
+ if(keepEncoded_) { // i.e. keep the sentence segmented into subword units
+ for(const Word& id : sentence)
+ line += (*this)[id] + " ";
+ line.pop_back(); // trim the trailing whitespace
+ } else {
+ // convert vector of Word to vector of int
+ std::vector<int> spmSentence;
+ spmSentence.reserve(sentence.size());
+ for(auto&& word : sentence)
+ spmSentence.push_back(word.toWordIndex());
+ spm_->Decode(spmSentence, &line);
+ }
return line;
}
+ std::string surfaceForm(const Words& sentence) const override {
+ // with SentencePiece, decoded form and surface form are identical
+ return decode(sentence, /*ignoreEOS=*/true);
+ }
+
size_t size() const override {
return spm_->GetPieceSize();
}
@@ -236,7 +252,7 @@ public:
LOG(info, "[data] Loading SentencePiece vocabulary from file {}", vocabPath);
ABORT_IF(!filesystem::exists(vocabPath),
- "SentencePiece vocabulary file {} does not exits",
+ "SentencePiece vocabulary file {} does not exist",
vocabPath);
spm_.reset(new sentencepiece::SentencePieceProcessor());
@@ -249,10 +265,12 @@ public:
return spm_->GetPieceSize();
}
+ std::string toUpper(const std::string& line) const override { return utils::utf8ToUpper(line); }
+ std::string toEnglishTitleCase(const std::string& line) const override { return utils::toEnglishTitleCase(line); }
};
#endif // USE_SENTENCEPIECE
-Ptr<VocabBase> createSentencePieceVocab(const std::string& vocabPath, Ptr<Options> options, size_t batchIndex) {
+Ptr<IVocab> createSentencePieceVocab(const std::string& vocabPath, Ptr<Options> options, size_t batchIndex) {
bool isSentencePiece = regex::regex_search(vocabPath, regex::regex("\\.(spm)$"));
if(isSentencePiece) {
#ifdef USE_SENTENCEPIECE
diff --git a/src/data/shortlist.h b/src/data/shortlist.h
index 55fe4f5b..177272df 100755..100644
--- a/src/data/shortlist.h
+++ b/src/data/shortlist.h
@@ -8,40 +8,48 @@
#include <unordered_map>
#include <vector>
#include <iostream>
+#include <algorithm>
namespace marian {
namespace data {
class Shortlist {
private:
- std::vector<Word> indices_;
- std::vector<Word> mappedIndices_;
- std::vector<Word> reverseMap_;
+ std::vector<WordIndex> indices_; // // [packed shortlist index] -> word index, used to select columns from output embeddings
public:
- Shortlist(const std::vector<Word>& indices,
- const std::vector<Word>& mappedIndices,
- const std::vector<Word>& reverseMap)
- : indices_(indices),
- mappedIndices_(mappedIndices),
- reverseMap_(reverseMap) {}
-
- std::vector<Word>& indices() { return indices_; }
- std::vector<Word>& mappedIndices() { return mappedIndices_; }
- Word reverseMap(Word idx) { return reverseMap_[idx]; }
+ Shortlist(const std::vector<WordIndex>& indices)
+ : indices_(indices) {}
+
+ const std::vector<WordIndex>& indices() const { return indices_; }
+ WordIndex reverseMap(int idx) { return indices_[idx]; }
+
+ int tryForwardMap(WordIndex wIdx) {
+ auto first = std::lower_bound(indices_.begin(), indices_.end(), wIdx);
+ if(first != indices_.end() && *first == wIdx) // check if element not less than wIdx has been found and if equal to wIdx
+ return (int)std::distance(indices_.begin(), first); // return coordinate if found
+ else
+ return -1; // return -1 if not found
+ }
+
};
class ShortlistGenerator {
public:
- virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) = 0;
+ virtual ~ShortlistGenerator() {}
+
+ virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const = 0;
// Writes text version of (possibly) pruned short list to file
// with given prefix and implementation-specific suffixes.
- virtual void dump(const std::string& /*prefix*/) {
+ virtual void dump(const std::string& /*prefix*/) const {
ABORT("Not implemented");
}
};
+
+// Intended for use during training in the future, currently disabled
+#if 0
class SampledShortlistGenerator : public ShortlistGenerator {
private:
Ptr<Options> options_;
@@ -54,8 +62,8 @@ private:
size_t trgIdx_;
bool shared_{false};
- std::random_device rd_;
- std::mt19937 gen_;
+ // static thread_local std::random_device rd_;
+ static thread_local std::unique_ptr<std::mt19937> gen_;
public:
SampledShortlistGenerator(Ptr<Options> options,
@@ -65,68 +73,70 @@ public:
: options_(options),
srcIdx_(srcIdx),
trgIdx_(trgIdx),
- shared_(shared),
- gen_(rd_()) {}
+ shared_(shared)
+ { }
- virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) override {
+ virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override {
auto srcBatch = (*batch)[srcIdx_];
auto trgBatch = (*batch)[trgIdx_];
// add firstNum most frequent words
- std::unordered_set<Word> idxSet;
- for(Word i = 0; i < firstNum_ && i < maxVocab_; ++i)
- idxSet.insert(i);
+ std::unordered_set<WordIndex> indexSet;
+ for(WordIndex i = 0; i < firstNum_ && i < maxVocab_; ++i)
+ indexSet.insert(i);
// add all words from ground truth
for(auto i : trgBatch->data())
- idxSet.insert(i);
+ indexSet.insert(i.toWordIndex());
// add all words from source
if(shared_)
for(auto i : srcBatch->data())
- idxSet.insert(i);
+ indexSet.insert(i.toWordIndex());
std::uniform_int_distribution<> dis((int)firstNum_, (int)maxVocab_);
- while(idxSet.size() < total_ && idxSet.size() < maxVocab_)
- idxSet.insert(dis(gen_));
+ if (gen_ == NULL)
+ gen_.reset(new std::mt19937(std::random_device{}()));
+ while(indexSet.size() < total_ && indexSet.size() < maxVocab_)
+ indexSet.insert(dis(*gen_));
// turn into vector and sort (selected indices)
- std::vector<Word> idx(idxSet.begin(), idxSet.end());
+ std::vector<WordIndex> idx(indexSet.begin(), indexSet.end());
std::sort(idx.begin(), idx.end());
// assign new shifted position
- std::unordered_map<Word, Word> pos;
- std::vector<Word> reverseMap;
+ std::unordered_map<WordIndex, WordIndex> pos;
+ std::vector<WordIndex> reverseMap;
- for(Word i = 0; i < idx.size(); ++i) {
+ for(WordIndex i = 0; i < idx.size(); ++i) {
pos[idx[i]] = i;
reverseMap.push_back(idx[i]);
}
- std::vector<Word> mapped;
+ Words mapped;
for(auto i : trgBatch->data()) {
// mapped postions for cross-entropy
- mapped.push_back(pos[i]);
+ mapped.push_back(Word::fromWordIndex(pos[i.toWordIndex()]));
}
return New<Shortlist>(idx, mapped, reverseMap);
}
};
+#endif
class LexicalShortlistGenerator : public ShortlistGenerator {
private:
Ptr<Options> options_;
- Ptr<Vocab> srcVocab_;
- Ptr<Vocab> trgVocab_;
+ Ptr<const Vocab> srcVocab_;
+ Ptr<const Vocab> trgVocab_;
size_t srcIdx_;
- size_t trgIdx_;
bool shared_{false};
size_t firstNum_{100};
size_t bestNum_{100};
- std::vector<std::unordered_map<Word, float>> data_;
+ std::vector<std::unordered_map<WordIndex, float>> data_; // [WordIndex src] -> [WordIndex tgt] -> P_trans(tgt|src) --@TODO: rename data_ accordingly
void load(const std::string& fname) {
io::InputFileStream in(fname);
@@ -138,8 +148,8 @@ private:
if(src == "NULL" || trg == "NULL")
continue;
- Word sId = (*srcVocab_)[src];
- Word tId = (*trgVocab_)[trg];
+ auto sId = (*srcVocab_)[src].toWordIndex();
+ auto tId = (*trgVocab_)[trg].toWordIndex();
if(data_.size() <= sId)
data_.resize(sId + 1);
@@ -150,12 +160,12 @@ private:
void prune(float threshold = 0.f) {
size_t i = 0;
for(auto& probs : data_) {
- std::vector<std::pair<float, Word>> sorter;
+ std::vector<std::pair<float, WordIndex>> sorter;
for(auto& it : probs)
- sorter.emplace_back(it.second, (Word)it.first);
+ sorter.emplace_back(it.second, it.first);
std::sort(
- sorter.begin(), sorter.end(), std::greater<std::pair<float, Word>>());
+ sorter.begin(), sorter.end(), std::greater<std::pair<float, WordIndex>>()); // sort by prob
probs.clear();
for(auto& it : sorter) {
@@ -171,19 +181,17 @@ private:
public:
LexicalShortlistGenerator(Ptr<Options> options,
- Ptr<Vocab> srcVocab,
- Ptr<Vocab> trgVocab,
+ Ptr<const Vocab> srcVocab,
+ Ptr<const Vocab> trgVocab,
size_t srcIdx = 0,
- size_t trgIdx = 1,
+ size_t /*trgIdx*/ = 1,
bool shared = false)
: options_(options),
srcVocab_(srcVocab),
trgVocab_(trgVocab),
srcIdx_(srcIdx),
- trgIdx_(trgIdx),
shared_(shared) {
- std::vector<std::string> vals
- = options_->get<std::vector<std::string>>("shortlist");
+ std::vector<std::string> vals = options_->get<std::vector<std::string>>("shortlist");
ABORT_IF(vals.empty(), "No path to filter path given");
std::string fname = vals[0];
@@ -200,6 +208,7 @@ public:
bestNum_,
threshold);
+ // @TODO: Load and prune in one go.
load(fname);
prune(threshold);
@@ -207,90 +216,68 @@ public:
dump(dumpPath);
}
- virtual void dump(const std::string& prefix) override {
+ virtual void dump(const std::string& prefix) const override {
// Dump top most frequent words from target vocabulary
LOG(info, "[data] Saving shortlist dump to {}", prefix + ".{top,dic}");
io::OutputFileStream outTop(prefix + ".top");
- for(Word i = 0; i < firstNum_ && i < trgVocab_->size(); ++i)
- outTop << (*trgVocab_)[i] << std::endl;
+ for(WordIndex i = 0; i < firstNum_ && i < trgVocab_->size(); ++i)
+ outTop << (*trgVocab_)[Word::fromWordIndex(i)] << std::endl;
// Dump translation pairs from dictionary
io::OutputFileStream outDic(prefix + ".dic");
- for(Word srcId = 0; srcId < data_.size(); srcId++) {
- for(auto& it : data_[srcId]) { // @TODO: change data_.first from size_t to Word
- Word trgId = (Word)it.first;
- outDic << (*srcVocab_)[srcId] << "\t" << (*trgVocab_)[trgId] << std::endl;
+ for(WordIndex srcId = 0; srcId < data_.size(); srcId++) {
+ for(auto& it : data_[srcId]) {
+ auto trgId = it.first;
+ outDic << (*srcVocab_)[Word::fromWordIndex(srcId)] << "\t" << (*trgVocab_)[Word::fromWordIndex(trgId)] << std::endl;
}
}
}
- virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) override {
+ virtual Ptr<Shortlist> generate(Ptr<data::CorpusBatch> batch) const override {
auto srcBatch = (*batch)[srcIdx_];
- // auto trgBatch = (*batch)[trgIdx_];
// add firstNum most frequent words
- std::unordered_set<Word> idxSet;
- for(Word i = 0; i < firstNum_ && i < trgVocab_->size(); ++i)
- idxSet.insert(i);
+ std::unordered_set<WordIndex> indexSet;
+ for(WordIndex i = 0; i < firstNum_ && i < trgVocab_->size(); ++i)
+ indexSet.insert(i);
// add all words from ground truth
// for(auto i : trgBatch->data())
- // idxSet.insert(i);
+ // indexSet.insert(i.toWordIndex());
// collect unique words form source
- std::unordered_set<Word> srcSet;
+ std::unordered_set<WordIndex> srcSet;
for(auto i : srcBatch->data())
- srcSet.insert(i);
+ srcSet.insert(i.toWordIndex());
// add aligned target words
for(auto i : srcSet) {
if(shared_)
- idxSet.insert(i);
+ indexSet.insert(i);
for(auto& it : data_[i])
- idxSet.insert((Word)it.first); // @TODO: change it.first to Word
+ indexSet.insert(it.first);
}
// turn into vector and sort (selected indices)
- std::vector<Word> idx(idxSet.begin(), idxSet.end());
- std::sort(idx.begin(), idx.end());
+ std::vector<WordIndex> indices(indexSet.begin(), indexSet.end());
+ std::sort(indices.begin(), indices.end());
- // assign new shifted position
- // std::unordered_map<Word, Word> pos;
- std::vector<Word> reverseMap;
-
- for(Word i = 0; i < idx.size(); ++i) {
- // pos[idx[i]] = i;
- reverseMap.push_back(idx[i]);
- }
-
- std::vector<Word> mapped;
- // for(auto i : trgBatch->data()) {
- // mapped postions for cross-entropy
- // mapped.push_back(pos[i]);
- //}
-
- return New<Shortlist>(idx, mapped, reverseMap);
+ return New<Shortlist>(indices);
}
};
class FakeShortlistGenerator : public ShortlistGenerator {
private:
- std::vector<Word> idx_;
- std::vector<Word> reverseIdx_;
+ std::vector<WordIndex> indices_;
public:
- FakeShortlistGenerator(const std::unordered_set<Word>& idxSet)
- : idx_(idxSet.begin(), idxSet.end()) {
- std::sort(idx_.begin(), idx_.end());
- // assign new shifted position
- for(Word i = 0; i < idx_.size(); ++i) {
- reverseIdx_.push_back(idx_[i]);
- }
+ FakeShortlistGenerator(const std::unordered_set<WordIndex>& indexSet)
+ : indices_(indexSet.begin(), indexSet.end()) {
+ std::sort(indices_.begin(), indices_.end());
}
- Ptr<Shortlist> generate(Ptr<data::CorpusBatch> /*batch*/) override {
- std::vector<Word> tmp;
- return New<Shortlist>(idx_, tmp, reverseIdx_);
+ Ptr<Shortlist> generate(Ptr<data::CorpusBatch> /*batch*/) const override {
+ return New<Shortlist>(indices_);
}
};
diff --git a/src/data/text_input.cpp b/src/data/text_input.cpp
index 78fbc7af..5d68d9fb 100755..100644
--- a/src/data/text_input.cpp
+++ b/src/data/text_input.cpp
@@ -5,9 +5,7 @@ namespace marian {
namespace data {
TextIterator::TextIterator() : pos_(-1), tup_(0) {}
-
-TextIterator::TextIterator(TextInput& corpus)
- : corpus_(&corpus), pos_(0), tup_(corpus_->next()) {}
+TextIterator::TextIterator(TextInput& corpus) : corpus_(&corpus), pos_(0), tup_(corpus_->next()) {}
void TextIterator::increment() {
tup_ = corpus_->next();
@@ -25,40 +23,36 @@ const SentenceTuple& TextIterator::dereference() const {
TextInput::TextInput(std::vector<std::string> inputs,
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> options)
- // TODO: fix this: input text is stored in an inherited variable named
- // paths_ that is very confusing
- : DatasetBase(inputs, options),
- vocabs_(vocabs) {
+ : DatasetBase(inputs, options), vocabs_(vocabs) {
+ // note: inputs are automatically stored in the inherited variable named paths_, but these are
+ // texts not paths!
for(const auto& text : paths_)
files_.emplace_back(new std::istringstream(text));
}
+// TextInput is mainly used for inference in the server mode, not for training, so skipping too long
+// or ill-formed inputs is not necessary here
SentenceTuple TextInput::next() {
- // @TODO: This code mixes two patterns (while and early exit). Fix that.
- bool cont = true;
- while(cont) {
- // get index of the current sentence
- size_t curId = pos_++;
-
- // fill up the sentence tuple with sentences from all input files
- SentenceTuple tup(curId);
- for(size_t i = 0; i < files_.size(); ++i) {
- std::string line;
- io::InputFileStream dummyStream(*files_[i]);
- if(io::getline(dummyStream, line)) {
- Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_);
- if(words.empty())
- words.push_back(0);
- tup.push_back(words);
- }
+ // get index of the current sentence
+ size_t curId = pos_++;
+
+ // fill up the sentence tuple with source and/or target sentences
+ SentenceTuple tup(curId);
+ for(size_t i = 0; i < files_.size(); ++i) {
+ std::string line;
+ if(io::getline(*files_[i], line)) {
+ Words words = vocabs_[i]->encode(line, /*addEOS =*/ true, /*inference =*/ inference_);
+ if(words.empty())
+ words.push_back(Word::ZERO); // @TODO: What is this for? @BUGBUG: addEOS=true, so this can never happen, right?
+ tup.push_back(words);
}
-
- // continue only if each input file has provided an example
- cont = tup.size() == files_.size();
- if(cont)
- return tup;
}
+
+ // check if each input file provided an example
+ if(tup.size() == files_.size())
+ return tup;
return SentenceTuple(0);
}
+
} // namespace data
} // namespace marian
diff --git a/src/data/text_input.h b/src/data/text_input.h
index bf1bc1eb..db99ef6a 100755..100644
--- a/src/data/text_input.h
+++ b/src/data/text_input.h
@@ -36,9 +36,8 @@ private:
public:
typedef SentenceTuple Sample;
- TextInput(std::vector<std::string> inputs,
- std::vector<Ptr<Vocab>> vocabs,
- Ptr<Options> options);
+ TextInput(std::vector<std::string> inputs, std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options);
+ virtual ~TextInput() {}
Sample next() override;
diff --git a/src/data/types.h b/src/data/types.h
index 2bda6ece..05cdf96c 100644
--- a/src/data/types.h
+++ b/src/data/types.h
@@ -7,18 +7,45 @@
#include <string>
#include <unordered_map>
#include <vector>
+#include <iterator>
namespace marian {
// Type for all vocabulary items, based on IndexType
-typedef IndexType Word;
+typedef IndexType WordIndex; // WordIndex is used for words or tokens arranged in consecutive order
+class Word { // Word is an abstraction of a unique id, not necessarily consecutive
+ WordIndex wordId_;
+ explicit Word(std::size_t wordId) : wordId_((WordIndex)wordId) {}
+public:
+ static Word fromWordIndex(std::size_t wordId) { return Word(wordId); }
+ const WordIndex& toWordIndex() const { return wordId_; }
+ std::string toString() const { return std::to_string(wordId_); }
+
+ // needed for STL containers
+ Word() : wordId_((WordIndex)-1) {}
+ bool operator==(const Word& other) const { return wordId_ == other.wordId_; }
+ bool operator!=(const Word& other) const { return !(*this == other); }
+ bool operator<(const Word& other) const { return wordId_ < other.wordId_; }
+ std::size_t hash() const { return std::hash<WordIndex>{}(wordId_); }
+
+ // constants
+ static Word NONE; // @TODO: decide whether we need this, in additional Word()
+ static Word ZERO; // an invalid word that nevertheless can safely be looked up (and then masked out)
+ // EOS and UNK are placed in these positions in Marian-generated vocabs
+ static Word DEFAULT_EOS_ID;
+ static Word DEFAULT_UNK_ID;
+};
// Sequence of vocabulary items
typedef std::vector<Word> Words;
-// EOS and UNK are placed in these positions in Marian-generated vocabs
-const Word DEFAULT_EOS_ID = 0;
-const Word DEFAULT_UNK_ID = 1;
+// Helper to map a Word vector to a WordIndex vector
+static inline std::vector<WordIndex> toWordIndexVector(const Words& words) {
+ std::vector<WordIndex> res;
+ std::transform(words.begin(), words.end(), std::back_inserter(res),
+ [](const Word& word) -> WordIndex { return word.toWordIndex(); });
+ return res;
+}
// names of EOS and UNK symbols
const std::string DEFAULT_EOS_STR = "</s>";
@@ -29,3 +56,9 @@ const std::string NEMATUS_EOS_STR = "eos";
const std::string NEMATUS_UNK_STR = "UNK";
} // namespace marian
+
+namespace std {
+ template<> struct hash<marian::Word> {
+ std::size_t operator()(const marian::Word& s) const noexcept { return s.hash(); }
+ };
+}
diff --git a/src/data/vocab.cpp b/src/data/vocab.cpp
index 62b021a3..b48d5eca 100755
--- a/src/data/vocab.cpp
+++ b/src/data/vocab.cpp
@@ -3,18 +3,31 @@
namespace marian {
-Ptr<VocabBase> createDefaultVocab();
-Ptr<VocabBase> createSentencePieceVocab(const std::string& /*vocabPath*/, Ptr<Options>, size_t /*batchIndex*/);
+Word Word::NONE = Word();
+Word Word::ZERO = Word(0);
+Word Word::DEFAULT_EOS_ID = Word(0);
+Word Word::DEFAULT_UNK_ID = Word(1);
// @TODO: make each vocab peek on type
-Ptr<VocabBase> createVocab(const std::string& vocabPath, Ptr<Options> options, size_t batchIndex) {
+Ptr<IVocab> createVocab(const std::string& vocabPath, Ptr<Options> options, size_t batchIndex) {
+ // try SentencePiece
auto vocab = createSentencePieceVocab(vocabPath, options, batchIndex);
- return vocab ? vocab : createDefaultVocab();
+ if(vocab)
+ return vocab; // this is defined which means that a sentencepiece vocabulary could be created, so return it
+ // try factored
+ vocab = createFactoredVocab(vocabPath);
+ if (vocab)
+ return vocab;
+ // regular vocab
+ // check type of input, if not given, assume "sequence"
+ auto inputTypes = options->get<std::vector<std::string>>("input-types", {});
+ std::string inputType = inputTypes.size() > batchIndex ? inputTypes[batchIndex] : "sequence";
+ return inputType == "class" ? createClassVocab() : createDefaultVocab();
}
size_t Vocab::loadOrCreate(const std::string& vocabPath,
- const std::vector<std::string>& trainPaths,
- size_t maxSize) {
+ const std::vector<std::string>& trainPaths,
+ size_t maxSize) {
size_t size = 0;
if(vocabPath.empty()) {
// No vocabulary path was given, attempt to first find a vocabulary
@@ -78,6 +91,10 @@ void Vocab::createFake() {
vImpl_->createFake();
}
+Word Vocab::randWord() {
+ return vImpl_->randWord();
+}
+
// string token to token id
Word Vocab::operator[](const std::string& word) const {
return vImpl_->operator[](word);
@@ -95,12 +112,19 @@ Words Vocab::encode(const std::string& line,
return vImpl_->encode(line, addEOS, inference);
}
-// list of token ids to single line, can perform detokenization
+// convert sequence of token ids to single line, can perform detokenization
std::string Vocab::decode(const Words& sentence,
bool ignoreEOS) const {
return vImpl_->decode(sentence, ignoreEOS);
}
+// convert sequence of token its to surface form (incl. removng spaces, applying factors)
+// for in-process BLEU validation
+std::string Vocab::surfaceForm(const Words& sentence) const {
+ return vImpl_->surfaceForm(sentence);
+}
+
+
// number of vocabulary items
size_t Vocab::size() const { return vImpl_->size(); }
@@ -113,4 +137,13 @@ Word Vocab::getEosId() const { return vImpl_->getEosId(); }
// return UNK symbol id
Word Vocab::getUnkId() const { return vImpl_->getUnkId(); }
+// for corpus augmentation: convert string to all-caps
+std::string Vocab::toUpper(const std::string& line) const { return vImpl_->toUpper(line); }
+
+// for corpus augmentation: convert string to title case
+std::string Vocab::toEnglishTitleCase(const std::string& line) const { return vImpl_->toEnglishTitleCase(line); }
+
+// for short-list generation
+void Vocab::transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const { vImpl_->transcodeToShortlistInPlace(ptr, num); }
+
} // namespace marian
diff --git a/src/data/vocab.h b/src/data/vocab.h
index af4ea71f..9a40ba16 100755
--- a/src/data/vocab.h
+++ b/src/data/vocab.h
@@ -7,7 +7,7 @@
namespace marian {
-class VocabBase;
+class IVocab;
// Wrapper around vocabulary types. Can choose underlying
// vocabulary implementation (vImpl_) based on speficied path
@@ -17,7 +17,7 @@ class VocabBase;
// * SentencePiece with suffix *.spm (works, but has to be created outside Marian)
class Vocab {
private:
- Ptr<VocabBase> vImpl_;
+ Ptr<IVocab> vImpl_;
Ptr<Options> options_;
size_t batchIndex_;
@@ -42,18 +42,22 @@ public:
// string token to token id
Word operator[](const std::string& word) const;
- // token id to string token
- const std::string& operator[](Word id) const;
+ // token index to string token
+ const std::string& operator[](Word word) const;
// line of text to list of token ids, can perform tokenization
Words encode(const std::string& line,
bool addEOS = true,
bool inference = false) const;
- // list of token ids to single line, can perform detokenization
+ // convert sequence of token ids to single line, can perform detokenization
std::string decode(const Words& sentence,
bool ignoreEOS = true) const;
+ // convert sequence of token its to surface form (incl. removng spaces, applying factors)
+ // for in-process BLEU validation
+ std::string surfaceForm(const Words& sentence) const;
+
// number of vocabulary items
size_t size() const;
@@ -66,8 +70,26 @@ public:
// return UNK symbol id
Word getUnkId() const;
+ // for corpus augmentation: convert string to all-caps
+ // @TODO: Consider a different implementation where this does not show on the vocab interface,
+ // but instead as additional options passed to vocab instantiation.
+ std::string toUpper(const std::string& line) const;
+
+ // for corpus augmentation: convert string to title case
+ std::string toEnglishTitleCase(const std::string& line) const;
+
+ // for short-list generation
+ void transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const;
+
// create fake vocabulary for collecting batch statistics
void createFake();
+
+ // generate a fake word (using rand())
+ Word randWord();
+
+ // give access to base implementation. Returns null if not the requested type.
+ template<class VocabType> // e.g. FactoredVocab
+ Ptr<VocabType> tryAs() const { return std::dynamic_pointer_cast<VocabType>(vImpl_); }
};
} // namespace marian
diff --git a/src/data/vocab_base.h b/src/data/vocab_base.h
index 2050792f..8c214c97 100644..100755
--- a/src/data/vocab_base.h
+++ b/src/data/vocab_base.h
@@ -7,7 +7,7 @@
namespace marian {
-class VocabBase {
+class IVocab {
public:
virtual size_t load(const std::string& vocabPath, size_t maxSize = 0) = 0;
@@ -19,7 +19,7 @@ public:
virtual const std::string& canonicalExtension() const = 0;
virtual const std::vector<std::string>& suffixes() const = 0;
- size_t findAndLoad(const std::string& path, size_t maxSize) {
+ size_t findAndLoad(const std::string& path, size_t maxSize) { // @TODO: Only used in one place; just inline it there -> true interface
for(auto suffix : suffixes())
if(filesystem::exists(path + suffix))
return load(path + suffix, maxSize);
@@ -34,6 +34,7 @@ public:
virtual std::string decode(const Words& sentence,
bool ignoreEos = true) const = 0;
+ virtual std::string surfaceForm(const Words& sentence) const = 0;
virtual const std::string& operator[](Word id) const = 0;
@@ -43,7 +44,26 @@ public:
virtual Word getEosId() const = 0;
virtual Word getUnkId() const = 0;
+ // without specific knowledge of tokenization, these two functions can do nothing
+ // Both SentencePieceVocab and FactoredSegmenterVocab
+ virtual std::string toUpper(const std::string& line) const { return line; }
+ virtual std::string toEnglishTitleCase(const std::string& line) const { return line; }
+
+ // this function is an identity mapping for default vocabularies, hence do nothing
+ virtual void transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const { ptr; num; }
+
virtual void createFake() = 0;
+
+ virtual Word randWord() const {
+ return Word::fromWordIndex(rand() % size());
+ }
+ virtual ~IVocab() {};
};
-} \ No newline at end of file
+class Options;
+Ptr<IVocab> createDefaultVocab();
+Ptr<IVocab> createClassVocab();
+Ptr<IVocab> createSentencePieceVocab(const std::string& vocabPath, Ptr<Options>, size_t batchIndex);
+Ptr<IVocab> createFactoredVocab(const std::string& vocabPath);
+
+}
diff --git a/src/examples/CMakeLists.txt b/src/examples/CMakeLists.txt
index 8740fa6b..c904a3f3 100644
--- a/src/examples/CMakeLists.txt
+++ b/src/examples/CMakeLists.txt
@@ -4,7 +4,7 @@ add_executable(mnist_example mnist/mnist_ffnn.cpp)
foreach(exec iris_example mnist_example)
target_link_libraries(${exec} marian ${EXT_LIBS})
if(CUDA_FOUND)
- target_link_libraries(${exec} marian marian_cuda ${EXT_LIBS})
+ target_link_libraries(${exec} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS})
endif(CUDA_FOUND)
set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
endforeach(exec)
diff --git a/src/examples/iris/iris.cpp b/src/examples/iris/iris.cpp
index 2d66282d..9878a1ff 100644
--- a/src/examples/iris/iris.cpp
+++ b/src/examples/iris/iris.cpp
@@ -22,16 +22,16 @@ Expr buildIrisClassifier(Ptr<ExpressionGraph> graph,
graph->clear();
// Define the input layer
- auto x = graph->constant({N, NUM_FEATURES}, inits::from_vector(inputData));
+ auto x = graph->constant({N, NUM_FEATURES}, inits::fromVector(inputData));
// Define the hidden layer
auto W1 = graph->param("W1", {NUM_FEATURES, 5}, inits::uniform(-0.1f, 0.1f));
- auto b1 = graph->param("b1", {1, 5}, inits::zeros);
+ auto b1 = graph->param("b1", {1, 5}, inits::zeros());
auto h = tanh(affine(x, W1, b1));
// Define the output layer
auto W2 = graph->param("W2", {5, NUM_LABELS}, inits::uniform(-0.1f, 0.1f));
- auto b2 = graph->param("b2", {1, NUM_LABELS}, inits::zeros);
+ auto b2 = graph->param("b2", {1, NUM_LABELS}, inits::zeros());
auto o = affine(h, W2, b2);
if(train) {
diff --git a/src/examples/mnist/dataset.h b/src/examples/mnist/dataset.h
index 2152f850..b0492b85 100755..100644
--- a/src/examples/mnist/dataset.h
+++ b/src/examples/mnist/dataset.h
@@ -19,7 +19,12 @@ namespace data {
typedef std::vector<float> Data;
typedef std::vector<IndexType> Labels;
-typedef std::vector<Data> Example;
+struct Example : public std::vector<Data> { // a std::vector<Data> with a getId() method
+ size_t id;
+ size_t getId() const { return id; }
+ Example(std::vector<Data>&& data, size_t id) : std::vector<Data>(std::move(data)), id(id) {}
+ Example() : id(SIZE_MAX) {}
+};
typedef std::vector<Example> Examples;
typedef Examples::const_iterator ExampleIterator;
@@ -57,13 +62,14 @@ private:
std::vector<Input> inputs_;
public:
+
std::vector<Input>& inputs() { return inputs_; }
const std::vector<Input>& inputs() const { return inputs_; }
void push_back(Input input) { inputs_.push_back(input); }
- virtual std::vector<Ptr<Batch>> split(size_t /*n*/) override { ABORT("Not implemented"); }
+ virtual std::vector<Ptr<Batch>> split(size_t /*n*/, size_t /*sizeLimit*/) override { ABORT("Not implemented"); }
Data& features() { return inputs_[0].data(); }
@@ -139,6 +145,8 @@ public:
loadData();
}
+ virtual ~MNISTData(){}
+
void loadData() override {
ABORT_IF(paths_.size() != 2, "Paths to MNIST data files are not specified");
@@ -147,12 +155,11 @@ public:
ABORT_IF(features.size() != labels.size(), "Features do not match labels");
for(size_t i = 0; i < features.size(); ++i) {
- Example ex = {features[i], labels[i]};
- examples_.emplace_back(ex);
+ examples_.emplace_back(std::vector<Data>{ features[i], labels[i] }, i);
}
}
- Example next() override { return{ }; } //@TODO: this return was added to fix a warning. Is it correct?
+ Example next() override { return Example(); } //@TODO: this return was added to fix a warning. Is it correct?
private:
typedef unsigned char uchar;
diff --git a/src/examples/mnist/download.sh b/src/examples/mnist/download.sh
index 54436c75..54436c75 100755..100644
--- a/src/examples/mnist/download.sh
+++ b/src/examples/mnist/download.sh
diff --git a/src/examples/mnist/model.h b/src/examples/mnist/model.h
index 71de15ec..f7af1681 100755
--- a/src/examples/mnist/model.h
+++ b/src/examples/mnist/model.h
@@ -8,21 +8,24 @@
#include "graph/expression_graph.h"
#include "models/costs.h"
#include "models/model_base.h"
+#include "layers/loss.h"
#include "examples/mnist/dataset.h"
namespace marian {
namespace models {
-class MNISTCrossEntropyCost : public CostBase {
+// @TODO: looking at this file, simplify the new RationalLoss idea. Here it gets too complicated
+
+class MNISTCrossEntropyCost : public ICost {
public:
MNISTCrossEntropyCost() {}
- Expr apply(Ptr<ModelBase> model,
- Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
- auto top = model->build(graph, batch, clearGraph);
+ Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
+ auto top = model->build(graph, batch, clearGraph).getLogits();
auto vfLabels = std::static_pointer_cast<data::DataBatch>(batch)->labels();
@@ -31,35 +34,45 @@ public:
auto labels = graph->indices(vLabels);
// Define a top-level node for training
- return mean(cross_entropy(top, labels), /*axis =*/ 0);
+ // use CE loss
+
+ auto loss = sum(cross_entropy(top, labels), /*axis =*/ 0);
+ auto multiLoss = New<SumMultiRationalLoss>();
+ multiLoss->push_back({loss, (float)vLabels.size()});
+ return multiLoss;
}
};
-class MNISTLogsoftmax : public CostBase {
+class MNISTLogsoftmax : public ILogProb {
public:
MNISTLogsoftmax() {}
- Expr apply(Ptr<ModelBase> model,
+ virtual ~MNISTLogsoftmax(){}
+
+ Logits apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
auto top = model->build(graph, batch, clearGraph);
- return logsoftmax(top);
+ return top.applyUnaryFunction(logsoftmax);
}
};
-class MnistFeedForwardNet : public ModelBase {
+class MnistFeedForwardNet : public IModel {
public:
typedef data::MNISTData dataset_type;
template <class... Args>
- MnistFeedForwardNet(Ptr<Options> options, Args... args)
+ MnistFeedForwardNet(Ptr<Options> options, Args... /*args*/)
: options_(options), inference_(options->get<bool>("inference", false)) {}
- virtual Expr build(Ptr<ExpressionGraph> graph,
+ virtual ~MnistFeedForwardNet(){}
+
+ virtual Logits build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool /*clean*/ = false) override {
- return construct(graph, batch, inference_);
+
+ return Logits(apply(graph, batch, inference_));
}
void load(Ptr<ExpressionGraph> /*graph*/, const std::string& /*name*/, bool) override {
@@ -84,11 +97,10 @@ public:
protected:
Ptr<Options> options_;
- bool inference_{false};
+ const bool inference_{false};
/**
- * @brief Constructs an expression graph representing a feed-forward
- * classifier.
+ * @brief Builds an expression graph representing a feed-forward classifier.
*
* @param dims number of nodes in each layer of the feed-forward classifier
* @param batch a batch of training or testing examples
@@ -96,9 +108,9 @@ protected:
*
* @return a shared pointer to the newly constructed expression graph
*/
- virtual Expr construct(Ptr<ExpressionGraph> g,
- Ptr<data::Batch> batch,
- bool /*inference*/ = false) {
+ virtual Expr apply(Ptr<ExpressionGraph> g,
+ Ptr<data::Batch> batch,
+ bool /*inference*/ = false) {
const std::vector<int> dims = {784, 2048, 2048, 10};
// Start with an empty expression graph
@@ -109,7 +121,7 @@ protected:
auto features
= std::static_pointer_cast<data::DataBatch>(batch)->features();
auto x = g->constant({(int)batch->size(), dims[0]},
- inits::from_vector(features));
+ inits::fromVector(features));
// Construct hidden layers
std::vector<Expr> layers, weights, biases;
@@ -133,11 +145,11 @@ protected:
// Construct a weight node for the outgoing connections from layer i
weights.emplace_back(
- g->param("W" + std::to_string(i), {in, out}, inits::glorot_uniform));
+ g->param("W" + std::to_string(i), {in, out}, inits::glorotUniform()));
// Construct a bias node. These weights are initialized to zero
biases.emplace_back(
- g->param("b" + std::to_string(i), {1, out}, inits::zeros));
+ g->param("b" + std::to_string(i), {1, out}, inits::zeros()));
}
// Perform matrix multiplication and addition for the last layer
diff --git a/src/examples/mnist/model_lenet.h b/src/examples/mnist/model_lenet.h
index c2a39977..3abe7aa4 100644
--- a/src/examples/mnist/model_lenet.h
+++ b/src/examples/mnist/model_lenet.h
@@ -15,9 +15,9 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { graph->clear(); };
protected:
- virtual Expr construct(Ptr<ExpressionGraph> g,
- Ptr<data::Batch> batch,
- bool inference = false) override {
+ virtual Expr apply(Ptr<ExpressionGraph> g,
+ Ptr<data::Batch> batch,
+ bool inference = false) override {
const std::vector<int> dims = {784, 128, 10};
// Start with an empty expression graph
@@ -28,7 +28,7 @@ protected:
auto features
= std::static_pointer_cast<data::DataBatch>(batch)->features();
auto x = g->constant({(int)batch->size(), 1, 28, 28},
- inits::from_vector(features));
+ inits::fromVector(features));
// Construct hidden layers
@@ -80,7 +80,7 @@ protected:
// Construct a bias node. These weights are initialized to zero
biases.emplace_back(
- g->param("b" + std::to_string(i), {1, out}, inits::zeros));
+ g->param("b" + std::to_string(i), {1, out}, inits::zeros()));
}
// Perform matrix multiplication and addition for the last layer
@@ -91,7 +91,7 @@ protected:
// Create an output layer of shape batchSize x 1 and populate it with
// labels
auto labels = std::static_pointer_cast<data::DataBatch>(batch)->labels();
- auto y = g->constant({(int)batch->size(), 1}, inits::from_vector(labels));
+ auto y = g->constant({(int)batch->size(), 1}, inits::fromVector(labels));
// Define a top-level node for training
return mean(cross_entropy(last, y), /*axis =*/ 0);
diff --git a/src/examples/mnist/training.h b/src/examples/mnist/training.h
index 8cdacb5e..0b25f771 100644
--- a/src/examples/mnist/training.h
+++ b/src/examples/mnist/training.h
@@ -28,16 +28,19 @@ public:
// Prepare scheduler with validators
auto trainState = New<TrainingState>(options_->get<float>("learn-rate"));
auto scheduler = New<Scheduler>(options_, trainState);
- scheduler->addValidator(New<AccuracyValidator>(options_));
+ scheduler->addValidator(New<MNISTAccuracyValidator>(options_));
+
+ // Multi-node training
+ auto mpi = initMPI(/*multiThreaded=*/false);
// Prepare model
- auto model = New<ModelWrapper>(options_);
+ auto model = New<ModelWrapper>(options_, mpi);
model->setScheduler(scheduler);
model->load();
// Run training
while(scheduler->keepGoing()) {
- batchGenerator->prepare(!options_->get<bool>("no-shuffle"));
+ batchGenerator->prepare();
for(auto batch : *batchGenerator) {
if(!scheduler->keepGoing())
break;
@@ -47,6 +50,8 @@ public:
scheduler->increaseEpoch();
}
scheduler->finished();
+ model = nullptr;
+ finalizeMPI(std::move(mpi));
}
};
} // namespace marian
diff --git a/src/examples/mnist/validator.h b/src/examples/mnist/validator.h
index 907994ee..232bb124 100644
--- a/src/examples/mnist/validator.h
+++ b/src/examples/mnist/validator.h
@@ -12,14 +12,16 @@ using namespace marian;
namespace marian {
-class AccuracyValidator : public Validator<data::MNISTData> {
+class MNISTAccuracyValidator : public Validator<data::MNISTData, models::IModel> {
public:
- AccuracyValidator(Ptr<Options> options) : Validator(std::vector<Ptr<Vocab>>(), options, false) {
+ MNISTAccuracyValidator(Ptr<Options> options) : Validator(std::vector<Ptr<Vocab>>(), options, false) {
createBatchGenerator(/*isTranslating=*/false);
- builder_ = models::from_options(options, models::usage::scoring);
+ builder_ = models::createModelFromOptions(options, models::usage::translation);
}
- virtual void keepBest(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
+ virtual ~MNISTAccuracyValidator(){}
+
+ virtual void keepBest(const std::vector<Ptr<ExpressionGraph>>& /*graphs*/) override {
LOG(warn, "Keeping best model for MNIST examples is not supported");
}
@@ -31,7 +33,7 @@ protected:
size_t samples = 0;
for(auto batch : *batchGenerator_) {
- auto probs = builder_->build(graphs[0], batch, true);
+ auto probs = builder_->build(graphs[0], batch, true).getLogits();
graphs[0]->forward();
std::vector<float> scores;
diff --git a/src/functional/approx.h b/src/functional/approx.h
index 2989965d..233fd89d 100755..100644
--- a/src/functional/approx.h
+++ b/src/functional/approx.h
@@ -6,7 +6,7 @@ namespace marian {
namespace functional {
// approximate any unary float function within range with
-// piecewise linear functions in equal steps.
+// piecewise linear functions in equal steps.
//
// Example:
// static Approx<10, 0, 100> approxSigmoid(stableSigmoid);
@@ -14,12 +14,12 @@ namespace functional {
//
// Creates a functor for range [-10,10] with piecewise linear
// approximations of a sigmoid, 100 pieces, step 0.2.
-// This is quite fast on the CPU.
+// This is quite fast on the CPU.
//
-// approxSigmoid.grad(x) computes the corresponding gradient.
-//
-// When used as a local variable, use static keyword to create
-// only once.
+// approxSigmoid.grad(x) computes the corresponding gradient.
+//
+// When used as a local variable, use static keyword to create
+// only once.
template <int radius = 5, int offset = 0, int pieces = 10>
struct Approx {
@@ -55,7 +55,7 @@ struct Approx {
}
- __HDI__ int index(float x) const {
+ HOST_DEVICE_INLINE int index(float x) const {
if(x <= -radius)
return 0;
if(x < radius) // +1 because 0 holds value for x < -radius
@@ -63,16 +63,16 @@ struct Approx {
return pieces + 1;
}
- __HDI__ float domain(int i) const {
+ HOST_DEVICE_INLINE float domain(int i) const {
return i * ((2.f * radius) / pieces) + offset - radius;
}
- __HDI__ float operator()(float x) const {
+ HOST_DEVICE_INLINE float operator()(float x) const {
int i = index(x);
return a[i] * x + b[i];
}
- __HDI__ float grad(float x) const {
+ HOST_DEVICE_INLINE float grad(float x) const {
int i = index(x);
return a[i];
}
diff --git a/src/functional/array.h b/src/functional/array.h
index 2b73fc28..3c94ff03 100644
--- a/src/functional/array.h
+++ b/src/functional/array.h
@@ -9,28 +9,37 @@ namespace functional {
template <typename T, size_t N>
struct Array {
typedef T value_type;
-
T data_[N];
- __HDI__ const T* data() const { return data_; }
+ HOST_DEVICE_INLINE const T* data() const { return data_; }
- __HDI__ T* data() { return data_; }
+ HOST_DEVICE_INLINE T* data() { return data_; }
- __HDI__ constexpr static size_t size() { return N; }
+ HOST_DEVICE_INLINE constexpr static size_t size() { return N; }
- __HDI__ T& operator[](size_t i) { return data_[i]; }
- __HDI__ const T& operator[](size_t i) const { return data_[i]; }
+ HOST_DEVICE_INLINE T& operator[](size_t i) { return data_[i]; }
+ HOST_DEVICE_INLINE const T& operator[](size_t i) const { return data_[i]; }
- __HDI__ T* begin() { return data_; }
- __HDI__ const T* begin() const { return data_; }
+ HOST_DEVICE_INLINE T* begin() { return data_; }
+ HOST_DEVICE_INLINE const T* begin() const { return data_; }
- __HDI__ T* end() { return data_ + N; }
- __HDI__ const T* end() const { return data_ + N; }
+ HOST_DEVICE_INLINE T* end() { return data_ + N; }
+ HOST_DEVICE_INLINE const T* end() const { return data_ + N; }
- __HDI__ void fill(T val) {
+ HOST_DEVICE_INLINE void fill(T val) {
for(int i = 0; i < N; ++i)
data_[i] = val;
}
+
+ HOST_DEVICE_INLINE T& back() { return data_[N - 1]; }
+ HOST_DEVICE_INLINE const T& back() const { return data_[N - 1]; }
+
+ HOST_DEVICE_INLINE bool operator==(const Array<T, N>& other) {
+ for(int i = 0; i < N; ++i)
+ if(data_[i] != other[i])
+ return false;
+ return true;
+ }
};
} // namespace functional
} // namespace marian
diff --git a/src/functional/defs.h b/src/functional/defs.h
index 7ad627d9..6873a453 100755..100644
--- a/src/functional/defs.h
+++ b/src/functional/defs.h
@@ -1,20 +1,22 @@
#pragma once
-#ifdef __CUDA_ARCH__
+#ifdef __CUDACC__ // Compiling with NVCC, host or device code
#include <cuda.h>
-#define __H__ __host__
-#define __D__ __device__
-#define __HI__ __host__ inline
-#define __HD__ __host__ __device__
-#define __HDI__ __host__ __device__ inline
+#define HOST __host__
+#define DEVICE __device__
+#define DEVICE_INLINE __device__ inline
+#define HOST_INLINE __host__ inline
+#define HOST_DEVICE __host__ __device__
+#define HOST_DEVICE_INLINE __host__ __device__ inline
-#else
+#else // Compiling with GCC or other host compiler
-#define __H__
-#define __D__
-#define __HI__ inline
-#define __HD__
-#define __HDI__ inline
+#define HOST
+#define DEVICE
+#define DEVICE_INLINE inline
+#define HOST_INLINE inline
+#define HOST_DEVICE
+#define HOST_DEVICE_INLINE inline
#endif
diff --git a/src/functional/floats.h b/src/functional/floats.h
index 03e3f540..2519d552 100644
--- a/src/functional/floats.h
+++ b/src/functional/floats.h
@@ -83,7 +83,7 @@ struct F {
static constexpr auto binary = V;
template <typename... Args>
- __HDI__ constexpr float operator()(Args&&... args) const {
+ HOST_DEVICE_INLINE constexpr float operator()(Args&&... args) const {
return value;
}
diff --git a/src/functional/functional.h b/src/functional/functional.h
index 435acbd1..a3e57e4c 100755..100644
--- a/src/functional/functional.h
+++ b/src/functional/functional.h
@@ -1,7 +1,10 @@
#pragma once
+// this header is meant to be included for all operations from the "functional" namespace.
+
#include "functional/operands.h"
#include "functional/predicates.h"
+#include "functional/operators.h"
namespace marian {
namespace functional {
diff --git a/src/functional/operands.h b/src/functional/operands.h
index 2d10a6a0..2bfa1cd0 100755..100644
--- a/src/functional/operands.h
+++ b/src/functional/operands.h
@@ -13,7 +13,7 @@ using IsClass = typename std::enable_if<std::is_class<C>::value, C>::type;
template <int N>
struct Select {
template <typename T, typename... Args>
- __HDI__ static auto apply(T&& /*arg*/, Args&&... args)
+ HOST_DEVICE_INLINE static auto apply(T&& /*arg*/, Args&&... args)
-> decltype(Select<N - 1>::apply(args...)) {
return Select<N - 1>::apply(args...);
}
@@ -22,7 +22,7 @@ struct Select {
template <>
struct Select<0> {
template <typename T, typename... Args>
- __HDI__ static T apply(T&& arg, Args&&... /*args*/) {
+ HOST_DEVICE_INLINE static T apply(T&& arg, Args&&... /*args*/) {
return arg;
}
};
@@ -33,12 +33,12 @@ template <int V>
struct C {
static constexpr auto value = V;
- template <typename... Args>
- __HDI__ float operator()(Args&&... args) {
- return V;
+ template <typename T, typename... Args>
+ HOST_DEVICE_INLINE T operator()(T&& /*arg*/, Args&&... /*args*/) {
+ return (T)V;
}
- std::string to_string() { return "C<" + std::to_string(V) + ">"; }
+ std::string to_string() const { return "C<" + std::to_string(V) + ">"; }
};
/******************************************************************************/
@@ -48,12 +48,12 @@ struct Capture {
Capture(float val) : value(val){};
- template <typename... Args>
- __HDI__ float operator()(Args&&... /*args*/) {
- return value;
+ template <typename T, typename... Args>
+ HOST_DEVICE_INLINE T operator()(const T& /*arg*/, const Args&... /*args*/) {
+ return T(value);
}
- std::string to_string() { return "Cap(" + std::to_string(value) + ")"; }
+ std::string to_string() const { return "Cap(" + std::to_string(value) + ")"; }
};
/******************************************************************************/
@@ -62,12 +62,12 @@ template <int N>
struct Var {
static constexpr auto index = N;
- template <typename... Args>
- __HDI__ float& operator()(Args&&... args) {
- return Select<N - 1>::apply(args...);
+ template <typename T, typename... Args>
+ HOST_DEVICE_INLINE T& operator()(T&& arg, Args&&... args) {
+ return Select<N - 1>::apply(arg, args...);
}
- std::string to_string() { return "Var<" + std::to_string(N) + ">"; }
+ std::string to_string() const { return "Var<" + std::to_string(N) + ">"; }
};
} // namespace functional
} // namespace marian
diff --git a/src/functional/operators.h b/src/functional/operators.h
new file mode 100755
index 00000000..be6bb2e5
--- /dev/null
+++ b/src/functional/operators.h
@@ -0,0 +1,606 @@
+#pragma once
+
+#include "common/types.h"
+#include <cmath>
+
+namespace marian {
+namespace functional {
+
+// General template, will be used for any type without specializations
+// and will fail at runtime with an abort message. Note that the
+// general template functions don't have named parameters on purpose,
+// because clang will warn about unused parameters during compilation.
+
+template <typename T>
+struct Ops {
+ static HOST_DEVICE_INLINE T tanh(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T sin(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T cos(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T tan(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T log(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T exp(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T abs(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T sqrt(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T neg(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T sgn(const T&) { ABORT("Unknown type"); }
+
+ static HOST_DEVICE_INLINE T add(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T sub(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T mul(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T div(const T&, const T&) { ABORT("Unknown type"); }
+
+ static HOST_DEVICE_INLINE T max(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T min(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T pow(const T&, const T&) { ABORT("Unknown type"); }
+
+ static HOST_DEVICE_INLINE T negate(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T eq(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T neq(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T gt(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T lt(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T geq(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T leq(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T _and(const T&, const T&) { ABORT("Unknown type"); } // 'and' is used by gcc
+ static HOST_DEVICE_INLINE T _or(const T&, const T&) { ABORT("Unknown type"); } // 'or' is used by gcc
+
+ // Neural Networks specific functions
+ static HOST_DEVICE_INLINE T sigmoid(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T logaddexp(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T clip(const T&, const T&) { ABORT("Unknown type"); }
+ // derivative of Clip, cut-off function
+ static HOST_DEVICE_INLINE T bump(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T relu(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T reluBack(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T prelu(const T&, const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T preluBack(const T&, const T&) { ABORT("Unknown type"); }
+
+ static HOST_DEVICE_INLINE T if_then_else(const T&, const T&, const T&) { ABORT("Unknown type"); }
+
+ static HOST_DEVICE_INLINE T sumReduce(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T maxReduce(const T&) { ABORT("Unknown type"); }
+ static HOST_DEVICE_INLINE T minReduce(const T&) { ABORT("Unknown type"); }
+};
+
+// Specialization for float
+template <>
+struct Ops<float> {
+ typedef float Single;
+
+ static HOST_DEVICE_INLINE float tanh(const float& x) { return tanhf(x); }
+ static HOST_DEVICE_INLINE float sin(const float& x) { return sinf(x); }
+ static HOST_DEVICE_INLINE float cos(const float& x) { return cosf(x); }
+ static HOST_DEVICE_INLINE float tan(const float& x) { return tanf(x); }
+ static HOST_DEVICE_INLINE float log(const float& x) { return logf(x); }
+ static HOST_DEVICE_INLINE float exp(const float& x) { return expf(x); }
+ static HOST_DEVICE_INLINE float abs(const float& x) { return fabs(x); }
+ static HOST_DEVICE_INLINE float sqrt(const float& x) { return sqrtf(x); }
+ static HOST_DEVICE_INLINE float neg(const float& x) { return -x; }
+ static HOST_DEVICE_INLINE float sgn(const float& x) { return (float)((0 < x) - (x < 0)); }
+
+ static HOST_DEVICE_INLINE float add(const float& x, const float& y) { return x + y; }
+ static HOST_DEVICE_INLINE float sub(const float& x, const float& y) { return x - y; }
+ static HOST_DEVICE_INLINE float mul(const float& x, const float& y) { return x * y; }
+ static HOST_DEVICE_INLINE float div(const float& x, const float& y) { return x / y; }
+
+ static HOST_DEVICE_INLINE float max(const float& x, const float& y) { return x < y ? y : x; }
+ static HOST_DEVICE_INLINE float min(const float& x, const float& y) { return x < y ? x : y; }
+ static HOST_DEVICE_INLINE float pow(const float& x, const float& y) { return powf(x, y); }
+
+
+ static HOST_DEVICE_INLINE float negate(const float& x) { return !(bool)x; }
+ static HOST_DEVICE_INLINE float eq(const float& x, const float& y) { return x == y; }
+ static HOST_DEVICE_INLINE float neq(const float& x, const float& y) { return x != y; }
+ static HOST_DEVICE_INLINE float gt(const float& x, const float& y) { return x > y; }
+ static HOST_DEVICE_INLINE float lt(const float& x, const float& y) { return x < y; }
+ static HOST_DEVICE_INLINE float geq(const float& x, const float& y) { return x >= y; }
+ static HOST_DEVICE_INLINE float leq(const float& x, const float& y) { return x <= y; }
+ static HOST_DEVICE_INLINE float and_(const float& x, const float& y) { return x && y; } // 'and' is used by gcc
+ static HOST_DEVICE_INLINE float or_(const float& x, const float& y) { return x || y; } // 'or' is used by gcc
+
+ // Neural Networks specific functions
+ static HOST_DEVICE_INLINE float sigmoid(const float& x) {
+ return x > 0 ? (1.f / (1.f + exp(-x))) : (exp(x) / (1.f + exp(x)));
+ }
+
+ static HOST_DEVICE_INLINE float logaddexp(const float& x, const float& y) {
+ // Note: This may not be ideal for CUDA; cf. CNTK implementation
+ return x < y ? (y + log1pf(exp(x - y))) : (x + log1pf(exp(y - x)));
+ }
+
+ static HOST_DEVICE_INLINE float clip(const float& x, const float& y) { return abs(x) >= y ? sgn(x) * y : x; }
+ // derivative of Clip, cut-off function
+ static HOST_DEVICE_INLINE float bump(const float& x, const float& y) { return abs(x) >= y ? 0.f : 1.f; }
+
+ static HOST_DEVICE_INLINE float relu(const float& x) { return x > 0.f ? x : 0.f; }
+ static HOST_DEVICE_INLINE float reluBack(const float& x) { return x > 0.f ? 1.f : 0.f; }
+
+ static HOST_DEVICE_INLINE float prelu(const float& x, const float& y) { return x > 0.f ? x : x * y; }
+ static HOST_DEVICE_INLINE float preluBack(const float& x, const float& y) { return x > 0.f ? 1.f : y; }
+
+ static HOST_DEVICE_INLINE float if_then_else(const float& x, const float& y, const float& z) { return x ? y : z; }
+
+ static HOST_DEVICE_INLINE float sumReduce(const float& x) { return x; }
+ static HOST_DEVICE_INLINE float maxReduce(const float& x) { return x; }
+ static HOST_DEVICE_INLINE float minReduce(const float& x) { return x; }
+
+};
+
+// Specialization for double
+template <>
+struct Ops<double> {
+ typedef double Single;
+
+ static HOST_DEVICE_INLINE double tanh(const double& x) { return std::tanh(x); }
+ static HOST_DEVICE_INLINE double sin(const double& x) { return std::sin(x); }
+ static HOST_DEVICE_INLINE double cos(const double& x) { return std::cos(x); }
+ static HOST_DEVICE_INLINE double tan(const double& x) { return std::tan(x); }
+ static HOST_DEVICE_INLINE double log(const double& x) { return std::log(x); }
+ static HOST_DEVICE_INLINE double exp(const double& x) { return std::exp(x); }
+ static HOST_DEVICE_INLINE double abs(const double& x) { return std::abs(x); }
+ static HOST_DEVICE_INLINE double sqrt(const double& x) { return std::sqrt(x); }
+ static HOST_DEVICE_INLINE double neg(const double& x) { return -x; }
+ static HOST_DEVICE_INLINE double sgn(const double& x) { return (0 < x) - (x < 0); }
+
+ static HOST_DEVICE_INLINE double add(const double& x, const double& y) { return x + y; }
+ static HOST_DEVICE_INLINE double sub(const double& x, const double& y) { return x - y; }
+ static HOST_DEVICE_INLINE double mul(const double& x, const double& y) { return x * y; }
+ static HOST_DEVICE_INLINE double div(const double& x, const double& y) { return x / y; }
+
+ static HOST_DEVICE_INLINE double max(const double& x, const double& y) { return x < y ? y : x; }
+ static HOST_DEVICE_INLINE double min(const double& x, const double& y) { return x < y ? x : y; }
+ static HOST_DEVICE_INLINE double pow(const double& x, const double& y) { return std::pow(x, y); }
+
+
+ static HOST_DEVICE_INLINE double negate(const double& x) { return !(bool)x; }
+ static HOST_DEVICE_INLINE double eq(const double& x, const double& y) { return x == y; }
+ static HOST_DEVICE_INLINE double neq(const double& x, const double& y) { return x != y; }
+ static HOST_DEVICE_INLINE double gt(const double& x, const double& y) { return x > y; }
+ static HOST_DEVICE_INLINE double lt(const double& x, const double& y) { return x < y; }
+ static HOST_DEVICE_INLINE double geq(const double& x, const double& y) { return x >= y; }
+ static HOST_DEVICE_INLINE double leq(const double& x, const double& y) { return x <= y; }
+ static HOST_DEVICE_INLINE double and_(const double& x, const double& y) { return x && y; } // 'and' is used by gcc
+ static HOST_DEVICE_INLINE double or_(const double& x, const double& y) { return x || y; } // 'or' is used by gcc
+
+ // Neural Networks specific functions
+ static HOST_DEVICE_INLINE double sigmoid(const double& x) {
+ return x > 0 ? (1.f / (1.f + exp(-x))) : (exp(x) / (1.f + exp(x)));
+ }
+
+ static HOST_DEVICE_INLINE double logaddexp(const double& x, const double& y) {
+ // Note: This may not be ideal for CUDA; cf. CNTK implementation
+ return x < y ? (y + log1p(exp(x - y))) : (x + log1p(exp(y - x)));
+ }
+
+ static HOST_DEVICE_INLINE double clip(const double& x, const double& y) { return abs(x) >= y ? sgn(x) * y : x; }
+ // derivative of Clip, cut-off function
+ static HOST_DEVICE_INLINE double bump(const double& x, const double& y) { return abs(x) >= y ? 0.f : 1.f; }
+
+ static HOST_DEVICE_INLINE double relu(const double& x) { return x > 0.f ? x : 0.f; }
+ static HOST_DEVICE_INLINE double reluBack(const double& x) { return x > 0.f ? 1.f : 0.f; }
+
+ static HOST_DEVICE_INLINE double prelu(const double& x, const double& y) { return x > 0.f ? x : x * y; }
+ static HOST_DEVICE_INLINE double preluBack(const double& x, const double& y) { return x > 0.f ? 1.f : y; }
+
+ static HOST_DEVICE_INLINE double if_then_else(const double& x, const double& y, const double& z) { return x ? y : z; }
+
+ static HOST_DEVICE_INLINE double sumReduce(const double& x) { return x; }
+ static HOST_DEVICE_INLINE double maxReduce(const double& x) { return x; }
+ static HOST_DEVICE_INLINE double minReduce(const double& x) { return x; }
+
+};
+
+} // end namespace functional
+} // end namespace marian
+
+// stay invisible to NVCC as it seems to have problems with intrinsics;
+// will still be compiled into the binary by cpu-side gcc/g++
+// __CUDACC__ is defined when compiling with NVCC regardless of device type
+// __CUDA_ARCH__ is defined when compiling device (GPU) code
+#ifndef __CUDACC__
+
+#include "3rd_party/sse_mathfun.h"
+
+namespace marian {
+namespace functional {
+
+// Specialization for float32x8 (=__m128, CPU SSE intrisics)
+template <>
+struct Ops<float32x4> {
+ typedef float Single;
+
+ static inline float32x4 loop4(const std::function<float(const float&)>& f, const float32x4& x) {
+ float32x4 out;
+ for(int i = 0; i < 4; i++)
+ ((float*)&out)[i] = f(((const float*)&x)[i]);
+ return out;
+ }
+
+ static inline float32x4 loop4(const std::function<float(const float&, const float&)>& f, const float32x4& x, const float32x4& y) {
+ float32x4 out;
+ for(int i = 0; i < 4; i++)
+ ((float*)&out)[i] = f(((const float*)&x)[i], ((const float*)&y)[i]);
+ return out;
+ }
+
+ static inline float32x4 loop4(const std::function<float(const float&, const float&, const float&)>& f, const float32x4& x, const float32x4& y, const float32x4& z) {
+ float32x4 out;
+ for(int i = 0; i < 4; i++)
+ ((float*)&out)[i] = f(((const float*)&x)[i], ((const float*)&y)[i], ((const float*)&z)[i]);
+ return out;
+ }
+
+ // @TODO: why is this slow?
+ static inline float32x4 tanh(const float32x4& x) {
+ // ( e^x - e^-x )/( e^x + e^-x ) = (e^2x - 1) / (e^2x + 1)
+ float32x4 e2x = exp(mul(2.f, x));
+ return div(sub(e2x, 1.f), add(e2x, 1.f));
+ }
+
+ static inline float32x4 sin(const float32x4& x) { return sin_ps(x); }
+ static inline float32x4 cos(const float32x4& x) { return cos_ps(x); }
+ static inline float32x4 tan(const float32x4& x) { return div(sin(x), cos(x)); }
+ static inline float32x4 log(const float32x4& x) { return log_ps(x); }
+ static inline float32x4 exp(const float32x4& x) { return exp_ps(x); }
+
+ // @TODO: get rid of loop4 with proper intrisics
+ static inline float32x4 abs(const float32x4& x) { return loop4(Ops<float>::abs, x); }
+ static inline float32x4 sqrt(const float32x4& x) { return _mm_sqrt_ps(x); }
+ static inline float32x4 neg(const float32x4& x) { return sub(0.f, x); }
+
+ // @TODO: get rid of loop4 with proper intrisics
+ static inline float32x4 sgn(const float32x4& x) { return loop4(Ops<float>::sgn, x); }
+
+ static inline float32x4 add(const float32x4& x, const float32x4& y) { return _mm_add_ps(x, y); }
+ static inline float32x4 sub(const float32x4& x, const float32x4& y) { return _mm_sub_ps(x, y); }
+ static inline float32x4 mul(const float32x4& x, const float32x4& y) { return _mm_mul_ps(x, y); }
+ static inline float32x4 div(const float32x4& x, const float32x4& y) { return _mm_div_ps(x, y); }
+
+ static inline float32x4 max(const float32x4& x, const float32x4& y) { return _mm_max_ps(x, y); }
+ static inline float32x4 min(const float32x4& x, const float32x4& y) { return _mm_min_ps(x, y); }
+ static inline float32x4 pow(const float32x4& x, const float32x4& y) { return exp(mul(y, log(x))); }
+
+ // @TODO: get rid of loop4 with proper intrisics
+ static inline float32x4 negate(float32x4& x) { return loop4(Ops<float>::negate, x); }
+
+ static inline float32x4 eq(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::eq, x, y); }
+ static inline float32x4 neq(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::neq, x, y); }
+ static inline float32x4 gt(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::gt, x, y); }
+ static inline float32x4 lt(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::lt, x, y); }
+ static inline float32x4 geq(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::geq, x, y); }
+ static inline float32x4 leq(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::leq, x, y); }
+ static inline float32x4 and_(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::and_, x, y); } // 'and' is used by gcc
+ static inline float32x4 or_(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::or_, x, y); } // 'or' is used by gcc
+
+ // Neural Networks specific functions
+ // @TODO: this is unsafe
+ static inline float32x4 sigmoid(const float32x4& x) {
+ float32x4 e = exp(x);
+ return div(e, add(1.f, e));
+ }
+
+ // // Neural Networks specific functions
+ // static HOST_DEVICE_INLINE float sigmoid(const float& x) {
+ // return x > 0 ? (1.f / (1.f + exp(-x))) : (exp(x) / (1.f + exp(x)));
+ // }
+
+ static inline float32x4 logaddexp(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::logaddexp, x, y); }
+
+ static inline float32x4 clip(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::clip, x, y); }
+ static inline float32x4 bump(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::bump, x, y); }
+
+ static inline float32x4 relu(const float32x4& x) { return max(0.f, x); }
+
+ static inline float32x4 reluBack(const float32x4& x) { return loop4(Ops<float>::reluBack, x); }
+ static inline float32x4 prelu(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::prelu, x, y); }
+ static inline float32x4 preluBack(const float32x4& x, const float32x4& y) { return loop4(Ops<float>::preluBack, x, y); }
+
+ static inline float32x4 if_then_else(const float32x4& x, const float32x4& y, const float32x4& z) { return loop4(Ops<float>::if_then_else, x, y, z); }
+
+ static inline Single sumReduce(const float32x4& x) {
+ Single sum = 0;
+ for(int i = 0; i < 4; ++i)
+ sum = Ops<Single>::add(sum, x[i]);
+ return sum;
+ }
+
+ static inline Single maxReduce(const float32x4& x) {
+ Single maxs = x[0];
+ for(int i = 1; i < 4; ++i)
+ maxs = Ops<Single>::max(maxs, x[i]);
+ return maxs;
+ }
+
+ static inline Single minReduce(const float32x4& x) {
+ Single mins = x[0];
+ for(int i = 1; i < 4; ++i)
+ mins = Ops<Single>::min(mins, x[i]);
+ return mins;
+ }
+
+
+};
+
+} // end namespace functional
+} // end namespace marian
+#ifdef __AVX__
+#include "3rd_party/avx_mathfun.h"
+
+namespace marian {
+namespace functional {
+
+//*******************************************************************************************
+// Specialization for float32x8 (=__m256, CPU AVX intrisics)
+template <>
+struct Ops<float32x8> {
+ typedef float Single;
+
+
+ static inline float32x8 loop8(const std::function<float(const float&)>& f, const float32x8& x) {
+ float32x8 out;
+ for(int i = 0; i < 8; i++)
+ ((float*)&out)[i] = f(((const float*)&x)[i]);
+ return out;
+ }
+
+ static inline float32x8 loop8(const std::function<float(const float&, const float&)>& f, const float32x8& x, const float32x8& y) {
+ float32x8 out;
+ for(int i = 0; i < 8; i++)
+ ((float*)&out)[i] = f(((const float*)&x)[i], ((const float*)&y)[i]);
+ return out;
+ }
+
+ static inline float32x8 loop8(const std::function<float(const float&, const float&, const float&)>& f, const float32x8& x, const float32x8& y, const float32x8& z) {
+ float32x8 out;
+ for(int i = 0; i < 8; i++)
+ ((float*)&out)[i] = f(((const float*)&x)[i], ((const float*)&y)[i], ((const float*)&z)[i]);
+ return out;
+ }
+
+ static inline float32x8 tanh(const float32x8& x) { // ( e^x - e^-x )/( e^x + e^-x )
+ float32x8 e2x = exp(mul(2.f, x));
+ return div(sub(e2x, 1.f), add(e2x, 1.f));
+ }
+
+ static inline float32x8 sin(const float32x8& x) { return sin256_ps(x); }
+ static inline float32x8 cos(const float32x8& x) { return cos256_ps(x); }
+ static inline float32x8 tan(const float32x8& x) { return div(sin(x), cos(x)); } // @TODO: use sincos256_ps
+ static inline float32x8 log(const float32x8& x) { return log256_ps(x); }
+ static inline float32x8 exp(const float32x8& x) { return exp256_ps(x); }
+
+ // @TODO: get rid of loop8 with proper intrisics
+ static inline float32x8 abs(const float32x8& x) { return loop8(Ops<float>::abs, x); }
+ static inline float32x8 sqrt(const float32x8& x) { return _mm256_sqrt_ps(x); }
+ static inline float32x8 neg(const float32x8& x) { return sub(0.f, x); }
+
+ // @TODO: get rid of loop8 with proper intrisics
+ static inline float32x8 sgn(const float32x8& x) { return loop8(Ops<float>::sgn, x); }
+
+ static inline float32x8 add(const float32x8& x, const float32x8& y) { return _mm256_add_ps(x, y); }
+ static inline float32x8 sub(const float32x8& x, const float32x8& y) { return _mm256_sub_ps(x, y); }
+ static inline float32x8 mul(const float32x8& x, const float32x8& y) { return _mm256_mul_ps(x, y); }
+ static inline float32x8 div(const float32x8& x, const float32x8& y) { return _mm256_div_ps(x, y); }
+
+ static inline float32x8 max(const float32x8& x, const float32x8& y) { return _mm256_max_ps(x, y); }
+ static inline float32x8 min(const float32x8& x, const float32x8& y) { return _mm256_min_ps(x, y); }
+ static inline float32x8 pow(const float32x8& x, const float32x8& y) { return exp(mul(y, log(x))); }
+
+ // @TODO: get rid of loop8 with proper intrisics
+ static inline float32x8 negate(float32x8& x) { return loop8(Ops<float>::negate, x); }
+
+ static inline float32x8 eq(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::eq, x, y); }
+ static inline float32x8 neq(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::neq, x, y); }
+ static inline float32x8 gt(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::gt, x, y); }
+ static inline float32x8 lt(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::lt, x, y); }
+ static inline float32x8 geq(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::geq, x, y); }
+ static inline float32x8 leq(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::leq, x, y); }
+ static inline float32x8 and_(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::and_, x, y); } // 'and' is used by gcc
+ static inline float32x8 or_(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::or_, x, y); } // 'or' is used by gcc
+
+
+ // Neural Networks specific functions
+ // @TODO: this is unsafe
+ static inline float32x8 sigmoid(const float32x8& x) {
+ float32x8 e = exp(x);
+ return div(e, add(1.f, e));
+ }
+
+ static inline float32x8 logaddexp(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::logaddexp, x, y); }
+
+ static inline float32x8 clip(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::clip, x, y); }
+ static inline float32x8 bump(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::bump, x, y); }
+
+ static inline float32x8 relu(const float32x8& x) { return max(0.f, x); }
+
+ static inline float32x8 reluBack(const float32x8& x) { return loop8(Ops<float>::reluBack, x); }
+ static inline float32x8 prelu(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::prelu, x, y); }
+ static inline float32x8 preluBack(const float32x8& x, const float32x8& y) { return loop8(Ops<float>::preluBack, x, y); }
+
+ static inline float32x8 if_then_else(const float32x8& x, const float32x8& y, const float32x8& z) { return loop8(Ops<float>::if_then_else, x, y, z); }
+
+ static inline Single sumReduce(const float32x8& x) {
+ Single sum = 0;
+ for(int i = 0; i < 8; ++i)
+ sum = Ops<Single>::add(sum, x[i]);
+ return sum;
+ }
+
+ static inline Single maxReduce(const float32x8& x) {
+ Single maxs = x[0];
+ for(int i = 1; i < 8; ++i)
+ maxs = Ops<Single>::max(maxs, x[i]);
+ return maxs;
+ }
+
+ static inline Single minReduce(const float32x8& x) {
+ Single mins = x[0];
+ for(int i = 1; i < 8; ++i)
+ mins = Ops<Single>::min(mins, x[i]);
+ return mins;
+ }
+};
+
+} // end namespace functional
+} // end namespace marian
+#endif
+#endif // of "#ifndef __CUDACC__"
+
+#ifdef __CUDACC__
+#if COMPILE_FP16
+// only compile with fp16 support for compute_70, i.e. VOLTA 100 and above.
+#include <cuda_fp16.h>
+
+namespace marian {
+namespace functional {
+
+// Specialization for half
+template <>
+struct Ops<half> {
+
+ static DEVICE_INLINE half sin(const half& x) { return hsin(x); }
+ static DEVICE_INLINE half cos(const half& x) { return hcos(x); }
+ static DEVICE_INLINE half tan(const half& x) { return hsin(x) / hcos(x); }
+ static DEVICE_INLINE half log(const half& x) { return hlog(x); }
+ static DEVICE_INLINE half exp(const half& x) { return hexp(x); }
+ static DEVICE_INLINE half sqrt(const half& x) { return hsqrt(x); }
+ static DEVICE_INLINE half neg(const half& x) { return -x; }
+
+ static DEVICE_INLINE half abs(const half& x) { return fabs((float)x); }// @TODO half has this information somewhere in the struct, right?
+ static DEVICE_INLINE half sgn(const half& x) { half zero = 0.f; return (zero < x) - (x < zero); } // @TODO half has this information somewhere in the struct, right?
+
+ static DEVICE_INLINE half add(const half& x, const half& y) { return x + y; }
+ static DEVICE_INLINE half sub(const half& x, const half& y) { return x - y; }
+ static DEVICE_INLINE half mul(const half& x, const half& y) { return x * y; }
+ static DEVICE_INLINE half div(const half& x, const half& y) { return x / y; }
+
+ static DEVICE_INLINE half max(const half& x, const half& y) { return x < y ? y : x; }
+ static DEVICE_INLINE half min(const half& x, const half& y) { return x < y ? x : y; }
+ static DEVICE_INLINE half pow(const half& x, const half& y) { return exp(y * log(x)); }
+
+ static DEVICE_INLINE half negate(const half& x) { return !(bool)x; }
+ static DEVICE_INLINE half eq(const half& x, const half& y) { return x == y; }
+ static DEVICE_INLINE half neq(const half& x, const half& y) { return x != y; }
+ static DEVICE_INLINE half gt(const half& x, const half& y) { return x > y; }
+ static DEVICE_INLINE half lt(const half& x, const half& y) { return x < y; }
+ static DEVICE_INLINE half geq(const half& x, const half& y) { return x >= y; }
+ static DEVICE_INLINE half leq(const half& x, const half& y) { return x <= y; }
+ static DEVICE_INLINE half and_(const half& x, const half& y) { return x && y; } // 'and' is used by gcc
+ static DEVICE_INLINE half or_(const half& x, const half& y) { return x || y; } // 'or' is used by gcc
+
+ // Neural Networks specific functions
+ static DEVICE_INLINE half sigmoid(const half& x) {
+ const half zero = 0.f;
+ const half one = 1.f;
+ return x > zero ? (one / (one + exp(-x))) : (exp(x) / (one + exp(x))); // safe sigmoid
+ }
+
+ static DEVICE_INLINE half tanh(const half& x) {
+ // tanh(x) = 2 * sigmoid(2 * x) - 1
+ const half one = 1.f;
+ const half two = 2.f;
+ return two * sigmoid(two * x) - one; // safe sigmoid => safe tanh
+ }
+
+ static DEVICE_INLINE half log1p(const half& x) {
+ return log(x + (half)1.f); // probably acceptable loss of precision, it's half anyway
+ }
+
+ static DEVICE_INLINE half logaddexp(const half& x, const half& y) {
+ // Note: This may not be ideal for CUDA; cf. CNTK implementation
+ return x < y ? (y + log1p(exp(x - y))) : (x + log1p(exp(y - x)));
+ }
+
+ static DEVICE_INLINE half clip(const half& x, const half& y) { return abs(x) >= y ? sgn(x) * y : x; }
+
+ // derivative of Clip, cut-off function
+ static DEVICE_INLINE half bump(const half& x, const half& y) {
+ const half zero = 0.f;
+ const half one = 1.f;
+ return abs(x) >= y ? zero : one;
+ }
+ static DEVICE_INLINE half relu(const half& x) {
+ const half zero = 0.f;
+ return x > zero ? x : zero;
+ }
+ static DEVICE_INLINE half reluBack(const half& x) {
+ const half zero = 0.f;
+ const half one = 1.f;
+ return x > zero ? one : zero;
+ }
+
+ static DEVICE_INLINE half prelu(const half& x, const half& y) {
+ const half zero = 0.f;
+ return x > zero ? x : x * y;
+ }
+
+ static DEVICE_INLINE half preluBack(const half& x, const half& y) {
+ const half zero = 0.f;
+ const half one = 1.f;
+ return x > zero ? one : y;
+ }
+
+ static DEVICE_INLINE half if_then_else(const half& x, const half& y, const half& z) { return x ? y : z; }
+
+ static DEVICE_INLINE half sumReduce(const half& x) { return x; }
+ static DEVICE_INLINE half maxReduce(const half& x) { return x; }
+ static DEVICE_INLINE half minReduce(const half& x) { return x; }
+
+};
+
+} // end namespace functional
+} // end namespace marian
+
+#endif
+#endif
+
+//*******************************************************************************************
+
+#include "functional/defs.h"
+#include "functional/predicates.h"
+
+namespace marian {
+namespace functional {
+
+UNARY(Tanh, tanh, Ops<ElementType>::tanh(x));
+UNARY(Sin, sin, Ops<ElementType>::sin(x));
+UNARY(Cos, cos, Ops<ElementType>::cos(x));
+UNARY(Tan, tan, Ops<ElementType>::tan(x));
+UNARY(Log, log, Ops<ElementType>::log(x));
+UNARY(Exp, exp, Ops<ElementType>::exp(x));
+UNARY(Abs, abs, Ops<ElementType>::abs(x));
+UNARY(Sqrt, sqrt, Ops<ElementType>::sqrt(x));
+UNARY(Neg, operator-, Ops<ElementType>::neg(x));
+UNARY(Sgn, sgn, Ops<ElementType>::sgn(x));
+
+BINARY(Plus, operator+, Ops<ElementType>::add(x, y));
+BINARY(Minus, operator-, Ops<ElementType>::sub(x, y));
+BINARY(Mult, operator*, Ops<ElementType>::mul(x, y));
+BINARY(Div, operator/, Ops<ElementType>::div(x, y));
+BINARY(Max, max, Ops<ElementType>::max(x, y));
+BINARY(Min, min, Ops<ElementType>::min(x, y));
+UNARY(Negate, operator!, Ops<ElementType>::negate(x));
+BINARY(Eq, operator==, Ops<ElementType>::eq(x, y));
+BINARY(NEq, operator!=, Ops<ElementType>::neq(x, y));
+BINARY(Gt, operator>, Ops<ElementType>::gt(x, y));
+BINARY(Lt, operator<, Ops<ElementType>::lt(x, y));
+BINARY(Geq, operator>=, Ops<ElementType>::geq(x, y));
+BINARY(Leq, operator<=, Ops<ElementType>::leq(x, y));
+BINARY(And, operator&&, Ops<ElementType>::and_(x, y));
+BINARY(Or, operator||, Ops<ElementType>::or_(x, y));
+BINARY(Pow, pow, Ops<ElementType>::pow(x, y));
+
+TERNARY(IfThenElse, if_then_else, Ops<ElementType>::if_then_else(x, y, z));
+
+// Neural Networks specific functions
+BINARY(Clip, clip, Ops<ElementType>::clip(x, y));
+// derivative of Clip, cut-off function
+BINARY(Bump, bump, Ops<ElementType>::bump(x, y));
+
+UNARY(Sigmoid, sigmoid, Ops<ElementType>::sigmoid(x));
+BINARY(LogAddExp, logaddexp, Ops<ElementType>::logaddexp(x, y));
+UNARY(sReLU, ReLU, Ops<ElementType>::relu(x));
+UNARY(sReLUBack, ReLUback, Ops<ElementType>::reluBack(x));
+BINARY(sPReLU, PReLU, Ops<ElementType>::prelu(x, y));
+BINARY(sPReLUBack, PReLUback, Ops<ElementType>::preluBack(x, y));
+
+} // end namespace functional
+} // end namespace marian
diff --git a/src/functional/predicates.h b/src/functional/predicates.h
index fce5da47..420a88a3 100755..100644
--- a/src/functional/predicates.h
+++ b/src/functional/predicates.h
@@ -1,7 +1,5 @@
#pragma once
-#include <cmath>
-
#include "functional/defs.h"
#include "functional/operands.h"
@@ -15,12 +13,12 @@ struct UnaryFunctor {
template <class Arg>
UnaryFunctor(Arg a) : x(a) {}
- template <typename... Args>
- __HDI__ float operator()(Args&&... args) {
- return Function::apply(x(args...));
+ template <typename T, typename... Args>
+ HOST_DEVICE_INLINE T operator()(T arg, Args&&... args) {
+ return Function::apply(x(arg, args...));
}
- std::string to_string() { return Function::n() + "<" + x.to_string() + ">"; }
+ std::string to_string() const { return Function::n() + "<" + x.to_string() + ">"; }
};
template <class Function, class X, class Y>
@@ -31,112 +29,57 @@ struct BinaryFunctor {
template <class Arg1, class Arg2>
BinaryFunctor(Arg1 arg1, Arg2 arg2) : x(arg1), y(arg2) {}
- template <typename... Args>
- __HDI__ float operator()(Args&&... args) {
- return Function::apply(x(args...), y(args...));
+ template <typename T, typename... Args>
+ HOST_DEVICE_INLINE T operator()(T arg, Args&&... args) {
+ return Function::apply(x(arg, args...), y(arg, args...));
}
- std::string to_string() {
+ std::string to_string() const {
return Function::n() + "<" + x.to_string() + "," + y.to_string() + ">";
}
};
-#define UNARY(name, name2, func) \
- namespace elem { \
- struct name { \
- __HDI__ static float apply(float x) { return func; } \
- static std::string n() { return #name; } \
- }; \
- } \
- template <class X> \
- using name = UnaryFunctor<elem::name, X>; \
- template <typename X> \
- static inline name<IsClass<X>> name2(X x) { \
- return name<X>(x); \
- } \
+#define UNARY(name, name2, func) \
+ namespace elem { \
+ struct name { \
+ template <typename ElementType> \
+ HOST_DEVICE_INLINE static ElementType apply(const ElementType& x) { return func; } \
+ static std::string n() { return #name; } \
+ }; \
+ } \
+ template <class X> \
+ using name = UnaryFunctor<elem::name, X>; \
+ template <typename X> \
+ static inline name<IsClass<X>> name2(X x) { \
+ return name<X>(x); \
+ } \
static inline name<Capture> name2(Capture x) { return name<Capture>(x); }
#define BINARY(name, name2, func) \
namespace elem { \
struct name { \
- __HDI__ static float apply(float x, float y) { return func; } \
+ template <typename ElementType> \
+ HOST_DEVICE_INLINE static ElementType apply(const ElementType& x, \
+ const ElementType& y) \
+ { return func; } \
static std::string n() { return #name; } \
}; \
} \
template <class X, class Y> \
using name = BinaryFunctor<elem::name, X, Y>; \
template <class X, class Y> \
- name<IsClass<X>, IsClass<Y>> name2(X x, Y y) { \
+ name<IsClass<X>, IsClass<Y>> name2(const X& x, const Y& y) { \
return name<X, Y>(x, y); \
} \
template <class Y> \
- name<Capture, IsClass<Y>> name2(Capture x, Y y) { \
+ name<Capture, IsClass<Y>> name2(const Capture& x, const Y& y) { \
return name<Capture, Y>(x, y); \
} \
template <class X> \
- name<IsClass<X>, Capture> name2(X x, Capture y) { \
+ name<IsClass<X>, Capture> name2(const X& x, const Capture& y) { \
return name<X, Capture>(x, y); \
}
-UNARY(Tanh, tanh, tanhf(x));
-UNARY(Sin, sin, sinf(x));
-UNARY(Cos, cos, cosf(x));
-UNARY(Tan, tan, tanf(x));
-UNARY(Log, log, logf(x));
-UNARY(Exp, exp, expf(x));
-UNARY(Abs, abs, fabs(x));
-UNARY(Sqrt, sqrt, sqrtf(x));
-UNARY(Neg, operator-, - x);
-UNARY(Sigmoid,
- sigmoid,
- x > 0 ? (1.f / (1.f + expf(-x))) : (expf(x) / (1.f + expf(x))));
-
-BINARY(Plus, operator+, x + y);
-BINARY(Minus, operator-, x - y);
-BINARY(Mult, operator*, x* y);
-BINARY(Div, operator/, x / y);
-
-BINARY(LogAddExp,
- logaddexp,
- (/*if*/ (x < y)
- ? // Note: This may not be ideal for CUDA; cf. CNTK implementation
- (y + log1pf(expf(x - y)))
- /*else*/
- : (x + log1pf(expf(y - x)))));
-BINARY(Maximum,
- max,
- (x > y) ? y : x); // note: std::max not available on CUDA it seems
-BINARY(Minimum, min, (x < y) ? y : x);
-
-UNARY(Negate, operator!, !x);
-BINARY(Eq, operator==, x == y);
-BINARY(NEq, operator!=, x != y);
-BINARY(Gt, operator>, x> y);
-BINARY(Lt, operator<, x<y);
-BINARY(Geq, operator>=, x >= y);
-BINARY(Leq, operator<=, x <= y);
-BINARY(And, operator&&, x&& y);
-BINARY(Or, operator||, x || y);
-
-template <typename T>
-__HDI__ T sgn(T val) {
- return T((0 < val) - (val < 0));
-}
-
-UNARY(Sgn, sgn, sgn(x));
-
-BINARY(Pow, pow, pow(x, y));
-
-BINARY(Clip, clip, fabs(x) >= y ? sgn(x) * y : x);
-
-// derivative of Clip, cut-off function
-BINARY(Bump, bump, fabs(x) >= y ? 0.f : 1.f);
-
-UNARY(sReLU, ReLU, x > 0.f ? x : 0.f);
-UNARY(sReLUBack, ReLUback, x > 0.f ? 1.f : 0.f);
-BINARY(sPReLU, PReLU, x > 0.f ? x : x * y);
-BINARY(sPReLUBack, PReLUback, x > 0.f ? 1.f : y);
-
template <class Function, class X, class Y, class Z>
struct TernaryFunctor {
X x;
@@ -147,7 +90,7 @@ struct TernaryFunctor {
TernaryFunctor(Arg1 arg1, Arg2 arg2, Arg3 arg3) : x(arg1), y(arg2), z(arg3) {}
template <typename... Args>
- __HDI__ float operator()(Args&&... args) {
+ HOST_DEVICE_INLINE float operator()(Args&&... args) {
return Function::apply(x(args...), y(args...), z(args...));
}
};
@@ -155,7 +98,11 @@ struct TernaryFunctor {
#define TERNARY(name, name2, func) \
namespace elem { \
struct name { \
- __HDI__ static float apply(float x, float y, float z) { return func; } \
+ template <typename ElementType> \
+ HOST_DEVICE_INLINE static ElementType apply(ElementType x, \
+ ElementType y, \
+ ElementType z) \
+ { return func; } \
}; \
} \
template <class X, class Y, class Z> \
@@ -185,8 +132,6 @@ struct TernaryFunctor {
return name<Capture, Capture, Z>(x, y, z); \
}
-TERNARY(IfThenElse, if_then_else, x ? y : z);
-
template <class X, class Y>
struct Assign {
X x;
@@ -195,9 +140,13 @@ struct Assign {
template <class Arg1, class Arg2>
Assign(Arg1 arg1, Arg2 arg2) : x(arg1), y(arg2) {}
- template <typename... Args>
- __HDI__ float operator()(Args&&... args) {
- return x(args...) = y(args...);
+ template <typename T, typename... Args>
+ HOST_DEVICE_INLINE T operator()(T&& arg, Args&&... args) {
+ return x(arg, args...) = y(arg, args...);
+ }
+
+ std::string to_string() const {
+ return "Assign<" + x.to_string() + "," + y.to_string() + ">";
}
};
@@ -208,9 +157,9 @@ struct Assignee {
Assignee() {}
Assignee(Var<N> v) : var(v) {}
- template <typename... Args>
- __HDI__ float& operator()(Args&&... args) {
- return var(args...);
+ template <typename T, typename... Args>
+ HOST_DEVICE_INLINE T& operator()(T&& arg, Args&&... args) {
+ return var(arg, args...);
}
template <class X>
@@ -242,7 +191,7 @@ struct Assignee {
return *this = *this / x;
}
- std::string to_string() { return var.to_string(); }
+ std::string to_string() const { return var.to_string(); }
};
/******************************************************************************/
diff --git a/src/functional/shape.h b/src/functional/shape.h
index 144c1458..fd354e1e 100755
--- a/src/functional/shape.h
+++ b/src/functional/shape.h
@@ -13,6 +13,41 @@ namespace functional {
#define CONST_SHAPE_DIMS 4
+// attempts at low-level slicing and proper views, not integrated yet
+#if 0
+const int MAX_INT = std::numeric_limits<int>::max();
+struct Slice {
+ static const int END{MAX_INT}; // fix
+
+ int begin{0};
+ int end{END};
+ int stride{1};
+
+ Slice(int b, int e, int s = 1)
+ : begin(b), end(e), stride(s) {}
+
+ Slice()
+ : begin(0), end(END), stride(1) {}
+
+ Slice(int i)
+ : begin(i), end(i + 1), stride(1) {}
+
+ Slice(const std::initializer_list<int>& l) {
+ std::vector<int> v(l);
+ switch(v.size()) {
+ case 0: begin = 0; end = END; stride = 1; break;
+ case 1: begin = v[0]; end = v[0] + 1; stride = 1; break;
+ case 2: begin = v[0]; end = v[1]; stride = 1; break;
+ case 3: begin = v[0]; end = v[1]; stride = v[2]; break;
+ default:
+ ABORT("Too many elements in slice: {}", v.size());
+ }
+ }
+};
+
+const Slice All;
+#endif
+
/**
* @brief Represents the size of each dimension in a tensor.
*/
@@ -22,19 +57,42 @@ struct ConstantShape {
Array<int, N> shape_;
Array<int, N> stride_;
Array<int, N> bstride_;
+
size_t elements_{1};
+ size_t offset_{0};
- __HD__ ConstantShape() {
+ // @TODO: review all these constructors
+ HOST_DEVICE ConstantShape() {
shape_.fill(1);
stride_.fill(1);
bstride_.fill(0);
}
- __HD__ ConstantShape(const ConstantShape& shape)
+ HOST_DEVICE ConstantShape(const ConstantShape& shape)
: shape_(shape.shape_),
stride_(shape.stride_),
bstride_(shape.bstride_),
- elements_(shape.elements_) {}
+ elements_(shape.elements_),
+ offset_(shape.offset_) {}
+
+ template <size_t M>
+ HOST_DEVICE ConstantShape(const Array<int, M>& shape) {
+ ABORT_IF(M > N, "Recompile with CONST_SHAPE_DIMS >= {}", M);
+
+ std::copy(shape.begin(), shape.end(), shape_.begin() + N - M);
+ if(N - M)
+ std::fill_n(shape_.begin(), N - M, 1);
+
+ updateStrides();
+ updateElements();
+ }
+
+ HOST_DEVICE ConstantShape(const Array<int, N>& shape,
+ const Array<int, N>& stride,
+ size_t offset)
+ : shape_(shape), stride_(stride), offset_(offset) {
+ updateElements();
+ }
ConstantShape(const marian::Shape& shape) {
size_t filled = shape.size();
@@ -45,11 +103,13 @@ struct ConstantShape {
std::copy(shape.begin(), shape.end(), shape_.begin() + N - filled);
if(N - filled)
std::fill_n(shape_.begin(), N - filled, 1);
+
updateStrides();
updateElements();
}
- __HDI__ void updateStrides() {
+ // @TODO: do we need bstrides at all?
+ HOST_DEVICE_INLINE void updateStrides() {
stride_[N - 1] = 1;
bstride_[N - 1] = shape_[N - 1] == 1 ? 0 : stride_[N - 1];
@@ -59,67 +119,180 @@ struct ConstantShape {
}
}
- __HDI__ void updateElements() {
+ HOST_DEVICE_INLINE void updateElements() {
elements_ = 1;
for(int i = 0; i < N; ++i)
elements_ *= shape_[i];
}
- __HDI__ void set(int i, int dim) {
+ HOST_DEVICE_INLINE void set(int i, int dim) {
shape_[i] = dim;
updateStrides();
updateElements();
}
- __HDI__ int dim(int i) { return shape_[i]; }
+ HOST_DEVICE_INLINE const int& dim(int i) const { return shape_[i]; }
- __HDI__ int dim(int i) const {
- return const_cast<ConstantShape&>(*this).dim(i);
- }
+ HOST_DEVICE_INLINE const int& back() const { return dim(N - 1); }
+
+ HOST_DEVICE_INLINE const int& operator[](int i) const { return dim(i); }
- __HDI__ int back() const { return dim(N - 1); }
+ HOST_DEVICE_INLINE const int& stride(int i) const { return stride_[i]; }
- __HDI__ int operator[](int i) { return dim(i); }
+ HOST_DEVICE_INLINE const int& bstride(int i) const { return bstride_[i]; }
- __HDI__ int operator[](int i) const { return dim(i); }
+ HOST_DEVICE_INLINE static constexpr size_t size() { return N; }
- __HDI__ int stride(int i) const { return stride_[i]; }
+ HOST_DEVICE_INLINE int elements() const { return (int)elements_; }
- __HDI__ int bstride(int i) const { return bstride_[i]; }
+ // The following functions iterate over shape dimensions and use recursive
+ // templates. They unroll over a compile-time defined number of dimensions.
- __HDI__ static constexpr size_t size() { return N; }
+ // Struct for recurrent template calls over shape dimensions,
+ // version for K > 0
+ template <const int K, const int D> struct I {
+ HOST_DEVICE_INLINE static int index(const Array<int, D>& dims,
+ const Array<int, D>& stride) {
+ return dims[K] * stride[K] + I<K-1, D>::index(dims, stride);
+ }
- __HDI__ int elements() const { return (int)elements_; }
+ HOST_DEVICE_INLINE static int index(int si,
+ const Array<int, D>& shape,
+ const Array<int, D>& stride) {
+ return (si % shape[K]) * stride[K] + I<K-1, D>::index(si / shape[K], shape, stride);
+ }
- __HDI__ int index(const Array<int, N>& d) const {
- int i = 0;
- for(int j = 0; j < N; ++j)
- i += d[j] * stride_[j];
- return i;
+ HOST_DEVICE_INLINE static void dims(int si,
+ Array<int, D>& dims,
+ const Array<int, D>& shape) {
+ dims[K] = si % shape[K];
+ I<K-1, D>::dims(si / shape[K], dims, shape);
+ }
+
+ };
+
+ // Struct for recurrent template calls over shape dimensions,
+ // specialization for K == 0
+ template <const int D> struct I<0, D> {
+ HOST_DEVICE_INLINE static int index(const Array<int, D>& dims,
+ const Array<int, D>& stride) {
+ return dims[0] * stride[0];
+ }
+
+ HOST_DEVICE_INLINE static int index(int si,
+ const Array<int, D>& shape,
+ const Array<int, D>& stride) {
+ return (si % shape[0]) * stride[0];
+ }
+
+ HOST_DEVICE_INLINE static void dims(int si,
+ Array<int, D>& dims,
+ const Array<int, D>& shape) {
+ dims[0] = si % shape[0];
+ }
+ };
+
+ HOST_DEVICE_INLINE int index(const Array<int, N>& dims) const {
+ return (int)offset_ + I<N-1, N>::index(dims, stride_);
}
- __HDI__ int bindex(const Array<int, N>& d) const {
- int i = 0;
- for(int j = 0; j < N; ++j)
- i += d[j] * bstride_[j];
- return i;
+ HOST_DEVICE_INLINE int index(int si) const {
+ return (int)offset_ + I<N-1, N>::index(si, shape_, stride_);
+ }
+
+ HOST_DEVICE_INLINE void dims(int si, Array<int, N>& dims) const {
+ I<N-1, N>::dims(si, dims, shape_);
}
- __HDI__ void dims(int i, Array<int, N>& d) const {
+ HOST_DEVICE_INLINE int bindex(const Array<int, N>& dims) const {
+ int i = 0;
+ // ?? : return offset_ + I<N-1, N>::index(d, bstride_);
for(int j = 0; j < N; ++j)
- d[j] = (i / stride_[j]) % shape_[j];
+ i += dims[j] * bstride_[j];
+ return i;
}
- __HDI__ bool operator==(const ConstantShape& other) const {
+ // @TODO: should this check all the members?
+ HOST_DEVICE_INLINE bool operator==(const ConstantShape& other) const {
for(int i = 0; i < N; ++i)
if(shape_[i] != other[i])
return false;
return true;
}
- __HDI__ bool operator!=(const ConstantShape& other) const {
+ HOST_DEVICE_INLINE bool operator!=(const ConstantShape& other) const {
return !(*this == other);
}
+
+ std::string toString() const {
+ std::stringstream strm;
+ // @TODO: add more information
+ strm << "shape=" << (*this)[0];
+ for(int i = 1; i < size(); ++i)
+ strm << "x" << (*this)[i];
+ strm << " size=" << elements();
+ return strm.str();
+ }
+
+// @TODO: attempts at proper slicing. Works but not integrated anywhere. To be revisited.
+#if 0
+ // Performs numpy-like slicing on a given shape object. The number
+ // of slices corresponds to the number of dimensions.
+ HOST_DEVICE_INLINE ConstantShape<N> slice(const Array<Slice, N>& slices) {
+ // @TODO: add various checks
+ Array<int, N> offsets;
+ Array<int, N> shape;
+ Array<int, N> stride;
+ for(int i = 0; i < N; ++i) {
+ int beg = slices[i].begin;
+ // restrict maximum value to actual shape size if larger than shape size
+ int end = slices[i].end < shape_[i] ? slices[i].end : shape_[i];
+ int str = slices[i].stride;
+
+ // collect starting points for all coordinates
+ offsets[i] = beg;
+
+ // when calculating the new shape, take into account stride
+ // TODO: std::ceil does not work on the GPU
+ shape[i] = std::ceil((end - beg) / (float) str);
+
+ // new stride is just old stride multiplied by slice stride
+ stride[i] = str * stride_[i];
+ }
+
+ // map offset coordinates into single offset index
+ int offset = index(offsets);
+
+ return ConstantShape<N>(shape, stride, offset);
+ }
+
+// non-continguous slices cannot be reshaped! need to be copied
+// template <const int D>
+// HOST_DEVICE_INLINE ConstantShape<D> reshape(const ConstantShape<D>& other) const {
+// // @TODO: add various checks
+// #ifndef __CUDA__ARCH__
+// ABORT_IF(elements() != other.elements(),
+// "Reshaping operation requires matching number of elements");
+// #endif
+
+// Array<int, D> stride;
+// for(int i = 0; i < D; ++i) {
+// stride[i] = /*other.stride_[i] **/ stride_[i];
+// }
+
+// stride[D - 1] = stride_[N - 1];
+// for(int i = 2; i < D + 1; ++i) {
+// stride[D - i] = stride[D - i + 1] * stride_[N - i + 1] * shape_[D - i + 1];
+// }
+
+// return ConstantShape<D>(other.shape_, stride, offset_);
+// }
+#endif
+
+ friend std::ostream& operator<<(std::ostream& strm, const ConstantShape<N>& shape) {
+ strm << shape.toString();
+ return strm;
+ }
};
typedef ConstantShape<CONST_SHAPE_DIMS> Shape;
diff --git a/src/functional/tensor.h b/src/functional/tensor.h
index 45e2056c..c928ff08 100644..100755
--- a/src/functional/tensor.h
+++ b/src/functional/tensor.h
@@ -7,36 +7,208 @@
namespace marian {
namespace functional {
+// By default for single valued types like float do nothing. Usually the number of elements in a tensor
+// is correctly mirrored in the shape object. Only special multi-element types like float32x4 (4 floats),
+// float32x8 (8 floats) and half2 (2 half) require special handling done by specializations below.
+// Similar for multi-element integer types to be added later.
template <typename T>
-struct Tensor {
+inline marian::Shape adapt(const marian::Shape& shape) {
+ return shape;
+}
+
+// modify last shape dimension to automatically map to a larger stride. We are moving now by 4 floats
+// at once and need to stop earlier. This is a shallow typecast to bascially an array of 4 floats.
+
+#ifndef __CUDACC__ // vectorized types not available from .cu files
+
+template <>
+inline marian::Shape adapt<float32x4>(const marian::Shape& shape) {
+ ABORT_IF(shape[-1] % 4 != 0,
+ "Last dim ({}) is not a multiple of 4 while converting to Tensor<float32x4>",
+ shape[-1]);
+
+ marian::Shape x4Shape = shape;
+ x4Shape.set(-1, shape[-1] / 4);
+ return x4Shape;
+}
+#ifdef __AVX__
+template <>
+inline marian::Shape adapt<float32x8>(const marian::Shape& shape) {
+ ABORT_IF(shape[-1] % 8 != 0,
+ "Last dim ({}) is not a multiple of 8 while converting to Tensor<float32x8>",
+ shape[-1]);
+
+ marian::Shape x8Shape = shape;
+ x8Shape.set(-1, shape[-1] / 8);
+ return x8Shape;
+}
+#endif
+#endif
+
+template <typename T, const int D>
+struct View {
T* data_;
- functional::Shape shape_;
+ ConstantShape<D> shape_;
- __HD__ Tensor() {}
+ HOST_DEVICE View() {}
- __HD__ Tensor(T* ptr, const functional::Shape& shape)
+ HOST_DEVICE View(T* ptr, const ConstantShape<D>& shape)
: data_(ptr), shape_(shape) {}
- __H__ Tensor(marian::Tensor t) : data_(t->data()), shape_(t->shape()) {}
+ HOST View(marian::Tensor t) : data_(t->data<T>()), shape_(adapt<T>(t->shape())) {}
- __HDI__ float& operator[](size_t i) { return data_[i]; }
- __HDI__ const float& operator[](size_t i) const { return data_[i]; }
+ HOST_DEVICE_INLINE T& operator[](size_t i) {
+ return data_[shape_.index((int)i)];
+ }
- __HDI__ float& operator[](
- const functional::Array<int, functional::Shape::size()>& indices) {
- return data_[shape_.index(indices)];
+ HOST_DEVICE_INLINE const T& operator[](size_t i) const {
+ return data_[shape_.index(i)];
}
- __HDI__ const float& operator[](
- const functional::Array<int, functional::Shape::size()>& indices) const {
+ HOST_DEVICE_INLINE T& operator[](const Array<int, D>& indices) {
return data_[shape_.index(indices)];
}
- __HDI__ T* data() { return data_; }
- __HDI__ const T* data() const { return data_; }
+ HOST_DEVICE_INLINE const T& operator[](const Array<int, D>& indices) const {
+ return data_[shape_.index(indices)];
+ }
+
+ HOST_DEVICE_INLINE T* data() { return data_; }
+ HOST_DEVICE_INLINE const T* data() const { return data_; }
+
+ HOST_DEVICE_INLINE ConstantShape<D>& shape() { return shape_; }
+ HOST_DEVICE_INLINE const ConstantShape<D>& shape() const { return shape_; }
+
+ HOST_DEVICE_INLINE size_t size() const { return shape_.elements(); }
+
+ // @TODO: This is code duplication from marian::Tensor
+ std::string debug(int precision = 8, int dispCols = 5) {
+ std::stringstream strm;
+ assert(shape_.size());
+
+ strm << shape_;
+ strm << " type=" << request<T>();
+ strm << " ptr=" << (size_t)data_;
+ strm << std::endl;
+
+ size_t totSize = shape_.elements();
+ std::vector<T> values(totSize);
+ for(int i = 0; i < size(); ++i)
+ values[i] = operator[](i);
+
+ int colWidth = precision + 4;
+ strm << std::fixed << std::setprecision(precision) << std::setfill(' ');
+
+ for(int i = 0; i < values.size(); ++i) {
+ Array<int, D> dims;
+ shape().dims(i, dims);
+
+ bool disp = true;
+ for(int j = 0; j < dims.size(); ++j)
+ disp = disp && (dims[j] < dispCols || dims[j] >= shape()[j] - dispCols);
+
+ if(disp) {
+ if(dims.back() == 0) {
+ bool par = true;
+ std::vector<std::string> p;
+ for(int j = (int)dims.size() - 1; j >= 0; --j) {
+ if(dims[j] != 0)
+ par = false;
- __HDI__ Shape& shape() { return shape_; }
- __HDI__ const Shape& shape() const { return shape_; }
+ p.push_back(par ? "[" : " ");
+ }
+ for(auto it = p.rbegin(); it != p.rend(); ++it)
+ strm << *it;
+ strm << " ";
+ }
+
+ strm << std::setw(colWidth);
+ strm << values[i];
+ strm << " ";
+
+ if(dims.back() + 1 == shape().back()) {
+ for(int j = (int)dims.size() - 1; j >= 0; --j) {
+ if(dims[j] + 1 != shape()[j])
+ break;
+ strm << "]";
+ }
+ strm << std::endl;
+ }
+
+ bool prev = true;
+ for(int j = (int)dims.size() - 1; j >= 0; --j) {
+ if(j < (int)dims.size() - 1)
+ prev = prev && dims[j + 1] + 1 == shape()[j + 1];
+ if(prev && dims[j] + 1 == dispCols && shape()[j] > 2 * dispCols) {
+ if(j < (int)dims.size() - 1)
+ for(int k = 0; k <= j; ++k)
+ strm << " ";
+ strm << "... ";
+ if(j < (int)dims.size() - 1)
+ strm << std::endl;
+ break;
+ }
+ }
+ }
+ }
+ strm << std::endl;
+ return strm.str();
+ }
};
+
+// @TODO: Attempts at correct slicing, not supported anywhere yet.
+#if 0
+template <typename T, const int D>
+HOST_DEVICE_INLINE View<T, D> slice(View<T, D> view, const Array<Slice, D>& slices) {
+ const auto& slicedShape = view.shape().slice(slices);
+ return View<T, D>(view.data(), slicedShape);
+}
+
+// template <typename T, const int D, class ...Slices>
+// View<T, D> slice(View<T, D> view,
+// const Slices&... slices) {
+// return slice(view, {slices...});
+// }
+
+template <typename T>
+HOST_DEVICE_INLINE View<T, 1> slice(View<T, 1>& view,
+ const Slice& slice0) {
+ return slice(view, {slice0});
+}
+
+template <typename T>
+HOST_DEVICE_INLINE View<T, 2> slice(View<T, 2>& view,
+ const Slice& slice0,
+ const Slice& slice1) {
+ return slice(view, {slice0, slice1});
+}
+
+template <typename T>
+HOST_DEVICE_INLINE View<T, 3> slice(View<T, 3>& view,
+ const Slice& slice0,
+ const Slice& slice1,
+ const Slice& slice2) {
+ return slice(view, {slice0, slice1, slice2});
+}
+
+template <typename T>
+HOST_DEVICE_INLINE View<T, 4> slice(View<T, 4>& view,
+ const Slice& slice0,
+ const Slice& slice1,
+ const Slice& slice2,
+ const Slice& slice3) {
+ return slice(view, {slice0, slice1, slice2, slice3});
+}
+
+// template <typename T, const int D1, const int D2>
+// View<T, D2> reshape(View<T, D1>& view, const ConstantShape<D2>& shape) {
+// auto reshaped = view.shape().reshape(shape);
+// return View<T, D2>(view.data(), reshaped);
+// }
+#endif
+
+template <typename T>
+using Tensor = View<T, CONST_SHAPE_DIMS>;
+
} // namespace functional
-} // namespace marian \ No newline at end of file
+} // namespace marian
diff --git a/src/functional/tmp.h b/src/functional/tmp.h
index 720901bc..a83c0ff4 100644..100755
--- a/src/functional/tmp.h
+++ b/src/functional/tmp.h
@@ -1,3 +1,5 @@
+// TMP here stands for Template Meta-Programming
+
#pragma once
#include "functional/array.h"
@@ -7,145 +9,210 @@
namespace marian {
namespace functional {
-template <size_t K, class Functor>
+// This struct and its specializations are never used directly, only through apply and applyWithCast below.
+template <size_t K, class Functor, typename AccType> // K-ary application of Functor, elements are cast to AccType before application of Functor
struct FApply {};
-template <class Functor>
-struct FApply<1, Functor> {
- __HDI__ static float apply(
+template <class Functor, typename AccType>
+struct FApply<1, Functor, AccType> {
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
Functor functor,
- functional::Array<functional::Tensor<float>, 1>& in,
+ functional::Array<functional::Tensor<ElementType>, 1>& in,
const functional::Array<int, 1>& indices) {
- return functor(in[0][indices[0]]);
+ return functor((AccType)in[0].data()[indices[0]]); // indices is an array of offsets into multiple tensors, index[i] corresponds in[i] based on up to arity K
}
- __HDI__ static float apply(
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
Functor functor,
- functional::Array<functional::Tensor<float>, 1>& in,
+ functional::Array<functional::Tensor<ElementType>, 1>& in,
int index) {
- return functor(in[0][index]);
+ return functor((AccType)in[0].data()[index]);
}
};
-template <class Functor>
-struct FApply<2, Functor> {
- __HDI__ static float apply(
+template <class Functor, typename AccType>
+struct FApply<2, Functor, AccType> {
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
Functor functor,
- functional::Array<functional::Tensor<float>, 2>& in,
+ functional::Array<functional::Tensor<ElementType>, 2>& in,
const functional::Array<int, 2>& indices) {
- return functor(in[0][indices[0]], in[1][indices[1]]);
+ return functor((AccType)in[0].data()[indices[0]],
+ (AccType)in[1].data()[indices[1]]);
}
- __HDI__ static float apply(
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
Functor functor,
- functional::Array<functional::Tensor<float>, 2>& in,
+ functional::Array<functional::Tensor<ElementType>, 2>& in,
int index) {
- return functor(in[0][index], in[1][index]);
+ return functor((AccType)in[0].data()[index],
+ (AccType)in[1].data()[index]);
}
};
-template <class Functor>
-struct FApply<3, Functor> {
- __HDI__ static float apply(
+template <class Functor, typename AccType>
+struct FApply<3, Functor, AccType> {
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
Functor functor,
- functional::Array<functional::Tensor<float>, 3>& in,
+ functional::Array<functional::Tensor<ElementType>, 3>& in,
const functional::Array<int, 3>& indices) {
- return functor(in[0][indices[0]], in[1][indices[1]], in[2][indices[2]]);
+ return functor((AccType)in[0].data()[indices[0]],
+ (AccType)in[1].data()[indices[1]],
+ (AccType)in[2].data()[indices[2]]);
}
- __HDI__ static float apply(
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
Functor functor,
- functional::Array<functional::Tensor<float>, 3>& in,
+ functional::Array<functional::Tensor<ElementType>, 3>& in,
int index) {
- return functor(in[0][index], in[1][index], in[2][index]);
+ return functor((AccType)in[0].data()[index],
+ (AccType)in[1].data()[index],
+ (AccType)in[2].data()[index]);
}
};
-template <class Functor>
-struct FApply<4, Functor> {
- __HDI__ static float apply(
+template <class Functor, typename AccType>
+struct FApply<4, Functor, AccType> {
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
Functor functor,
- functional::Array<functional::Tensor<float>, 4>& in,
+ functional::Array<functional::Tensor<ElementType>, 4>& in,
const functional::Array<int, 4>& indices) {
- return functor(in[0][indices[0]],
- in[1][indices[1]],
- in[2][indices[2]],
- in[3][indices[3]]);
+ return functor((AccType)in[0].data()[indices[0]],
+ (AccType)in[1].data()[indices[1]],
+ (AccType)in[2].data()[indices[2]],
+ (AccType)in[3].data()[indices[3]]);
+ }
+
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
+ Functor functor,
+ functional::Array<functional::Tensor<ElementType>, 4>& in,
+ int index) {
+ return functor((AccType)in[0].data()[index],
+ (AccType)in[1].data()[index],
+ (AccType)in[2].data()[index],
+ (AccType)in[3].data()[index]);
}
+};
- __HDI__ static float apply(
+template <class Functor, typename AccType>
+struct FApply<5, Functor, AccType> {
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
Functor functor,
- functional::Array<functional::Tensor<float>, 4>& in,
+ functional::Array<functional::Tensor<ElementType>, 5>& in,
+ const functional::Array<int, 5>& indices) {
+ return functor((AccType)in[0].data()[indices[0]],
+ (AccType)in[1].data()[indices[1]],
+ (AccType)in[2].data()[indices[2]],
+ (AccType)in[3].data()[indices[3]],
+ (AccType)in[4].data()[indices[4]]);
+ }
+
+ template <typename ElementType>
+ HOST_DEVICE_INLINE static AccType apply(
+ Functor functor,
+ functional::Array<functional::Tensor<ElementType>, 5>& in,
int index) {
- return functor(in[0][index], in[1][index], in[2][index], in[3][index]);
+ return functor((AccType)in[0].data()[index],
+ (AccType)in[1].data()[index],
+ (AccType)in[2].data()[index],
+ (AccType)in[3].data()[index],
+ (AccType)in[4].data()[index]);
}
};
-template <size_t K, class Functor>
-__HDI__ float apply(Functor functor,
- functional::Array<functional::Tensor<float>, K>& in,
+/******************************************************************************/
+// Applying functor to sets of K tensors
+template <typename ElementType, size_t K, class Functor>
+HOST_DEVICE_INLINE ElementType apply(Functor functor,
+ functional::Array<functional::Tensor<ElementType>, K>& in,
const functional::Array<int, K>& indices) {
- return FApply<K, Functor>::apply(functor, in, indices);
+ return FApply<K, Functor, ElementType>::apply(functor, in, indices); // functor is applied to same type as input ElementType, no casting required
}
-template <size_t K, class Functor>
-__HDI__ float apply(Functor functor,
- functional::Array<functional::Tensor<float>, K>& in,
+template <typename ElementType, size_t K, class Functor>
+HOST_DEVICE_INLINE ElementType apply(Functor functor,
+ functional::Array<functional::Tensor<ElementType>, K>& in,
int index) {
- return FApply<K, Functor>::apply(functor, in, index);
+ return FApply<K, Functor, ElementType>::apply(functor, in, index); // functor is applied to same type as input ElementType, no casting required
+}
+
+template <typename AccType, typename ElementType, size_t K, class Functor>
+HOST_DEVICE_INLINE AccType applyWithCast(Functor functor,
+ functional::Array<functional::Tensor<ElementType>, K>& in,
+ const functional::Array<int, K>& indices) {
+ return FApply<K, Functor, AccType>::apply(functor, in, indices); // ElementType and AccType are potentially different, cast to AccType before applying functor.
+ // This is useful when accumulating e.g. 16-bit into 32-bit and we want to case to 32-bit before
+ // the functor is applied. L2-Norm is a good use-case since the square can be large.
+}
+
+template <typename AccType, typename ElementType, size_t K, class Functor>
+HOST_DEVICE_INLINE AccType applyWithCast(Functor functor,
+ functional::Array<functional::Tensor<ElementType>, K>& in,
+ int index) {
+ return FApply<K, Functor, AccType>::apply(functor, in, index); // ElementType and AccType are potentially different, cast to AccType before applying functor
}
/******************************************************************************/
+// @TODO: Rename this. It is a reduction loop.
template <size_t n, size_t N, size_t K>
struct Loop {
- template <class Functor>
- __HDI__ static float result(
- Functor functor,
- functional::Array<functional::Tensor<float>, K>& in,
+ template <class Functor, class AggFunctor, typename ElementType, typename AccType>
+ HOST_DEVICE_INLINE static AccType result(
+ Functor functor, AccType aggInit, AggFunctor aggFunctor,
+ functional::Array<functional::Tensor<ElementType>, K>& in,
const functional::Array<int, K>& pAcc,
const functional::Array<int, N>& length,
const functional::Array<int, N>& dim) {
- float sum = 0;
+ AccType agg = aggInit;
functional::Array<int, K> acc;
for(int i = 0; i < length[N - n]; ++i) {
for(size_t j = 0; j < K; ++j) {
acc[j] = pAcc[j] + (dim[N - n] + i) * in[j].shape().bstride(N - n);
}
- sum += Loop<n - 1, N, K>::result(functor, in, acc, length, dim);
+ agg = aggFunctor(agg, Loop<n - 1, N, K>::result(functor, aggInit, aggFunctor, in, acc, length, dim));
}
- return sum;
+ return agg;
}
};
template <size_t N, size_t K>
struct Loop<1, N, K> {
- template <class Functor>
- __HDI__ static float result(
- Functor functor,
- functional::Array<functional::Tensor<float>, K>& in,
+ template <class Functor, class AggFunctor, typename ElementType, typename AccType>
+ HOST_DEVICE_INLINE static AccType result(
+ Functor functor, AccType aggInit, AggFunctor aggFunctor,
+ functional::Array<functional::Tensor<ElementType>, K>& in,
const functional::Array<int, K>& pAcc,
const functional::Array<int, N>& length,
const functional::Array<int, N>& dim) {
- float sum = 0;
+ AccType agg = aggInit;
functional::Array<int, K> acc;
for(int i = 0; i < length[N - 1]; ++i) {
for(size_t j = 0; j < K; ++j) {
acc[j] = pAcc[j] + (dim[N - 1] + i) * in[j].shape().bstride(N - 1);
}
- sum += apply<K>(functor, in, acc);
+ agg = aggFunctor(agg, applyWithCast<AccType>(functor, in, acc));
}
- return sum;
+ return agg;
}
};
-template <size_t N, size_t K, class Functor>
-__HDI__ float loops(Functor functor,
- functional::Array<functional::Tensor<float>, K>& in,
+
+template <size_t N, size_t K, class Functor, class AggFunctor, typename ElementType, typename AccType>
+HOST_DEVICE_INLINE AccType loops(Functor functor, AccType aggInit, AggFunctor aggFunctor,
+ functional::Array<functional::Tensor<ElementType>, K>& in,
const functional::Array<int, N>& length,
const functional::Array<int, N>& dim) {
functional::Array<int, K> acc = {0};
- return Loop<N, N, K>::result(functor, in, acc, length, dim);
+ return Loop<N, N, K>::result(functor, aggInit, aggFunctor, in, acc, length, dim);
}
} // namespace functional
} // namespace marian
diff --git a/src/graph/auto_tuner.h b/src/graph/auto_tuner.h
index 7bf80c79..01f33085 100644..100755
--- a/src/graph/auto_tuner.h
+++ b/src/graph/auto_tuner.h
@@ -20,15 +20,25 @@ class AutoTuner : public AutoTunerRecorder {
private:
typedef std::function<Return(Args...)> Algorithm;
- const size_t max = 100;
-
+ // When the autotuner decides the fastest algorithm for a specific tensor operation (e.g. GEMM),
+ // the autotuner runs each algorithm at least this 'collectStatMax' number of times and
+ // collects the statistics.
+ const size_t collectStatMax = 50;
UPtr<timer::CPUTimer> timer_;
+ // This structure holds a hash key an algorithm function (e.g. int16, packed gemm, mkl gemm)
+ // for a specific operation size
+ // hash: a unique hash key for each operation size
+ // (e.g. m, n, k, transpose A, transpose B, bias size for GEMM)
+ // algorithm: a function that holds an algorithm
struct HashedAlgorithm {
size_t hash;
Algorithm algorithm;
};
+ // This structure represents the collected statistics.
+ // time: total accumulated time of this operator execution with the given algorithm
+ // runs: total time this algorithm was executed
struct Stat {
double time;
size_t runs;
@@ -53,7 +63,7 @@ private:
auto& stat = it->second;
// collect more stats
- if(stat.runs < max)
+ if(stat.runs < collectStatMax)
return i;
if(stat.time < bestTime) {
@@ -88,17 +98,16 @@ public:
if(stop && done_.count(hash) == 0) {
timer_->stop();
- typedef std::chrono::duration<double> sec;
- sec seconds = std::chrono::nanoseconds(timer_->elapsed().user);
+ auto seconds = timer_->elapsed();
auto it = stats_.find(hash);
if(it != stats_.end()) {
- if(it->second.runs < max) {
- it->second.time += seconds.count();
+ if(it->second.runs < collectStatMax) {
+ it->second.time += seconds;
it->second.runs += 1;
}
} else {
- stats_.emplace(hash, Stat({seconds.count(), 1}));
+ stats_.emplace(hash, Stat({seconds, 1}));
}
timer_.reset(nullptr);
diff --git a/src/graph/chainable.h b/src/graph/chainable.h
index 2679843e..b78eb485 100755..100644
--- a/src/graph/chainable.h
+++ b/src/graph/chainable.h
@@ -4,6 +4,7 @@
#include <memory>
#include <vector>
+#include <list>
namespace marian {
@@ -18,8 +19,8 @@ class Chainable;
* A convenience type to represent a shared pointer to a Chainable<Tensor>
* object.
*/
-typedef Ptr<Chainable<Tensor>> Expr;
-typedef Weak<Chainable<Tensor>> WExpr;
+typedef IPtr<Chainable<Tensor>> Expr;
+typedef IWeak<Chainable<Tensor>> WExpr;
class ExpressionGraph;
@@ -50,6 +51,9 @@ class ExpressionGraph;
*/
template <class DataType>
class Chainable {
+private:
+ ENABLE_INTRUSIVE_PTR(Chainable<DataType>)
+
public:
Chainable() {}
virtual ~Chainable(){};
@@ -103,5 +107,10 @@ public:
virtual bool equal(Expr) = 0;
virtual void record(Ptr<AutoTunerRecorder>, size_t, bool) = 0;
+
+ virtual void markCheckpoint() = 0;
+ virtual bool isCheckpoint() const = 0;
+ virtual void setSubtape(Ptr<std::list<Expr>>) = 0;
+ virtual Ptr<std::list<Expr>> getSubtape() = 0;
};
} // namespace marian
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp
index b3c237d1..ed8fef6d 100755
--- a/src/graph/expression_graph.cpp
+++ b/src/graph/expression_graph.cpp
@@ -1,18 +1,21 @@
#include "graph/expression_graph.h"
-#include <sstream>
-
#include "tensors/tensor_operators.h"
+#include <sstream>
+
namespace marian {
-ExpressionGraph::ExpressionGraph(bool inference, bool optimized)
- : inferenceOnly_(inference), optimized_(optimized), backend_(nullptr) {}
+ExpressionGraph::ExpressionGraph(bool inference)
+ : inferenceOnly_(inference),
+ backend_(nullptr) {}
void ExpressionGraph::setDevice(DeviceId deviceId, Ptr<Device> device) {
if(!backend_) {
backend_ = BackendByDeviceId(deviceId, Config::seed);
- params_ = New<Parameters>();
- params_->init(backend_);
+ auto params = New<Parameters>(defaultElementType_);
+ params->init(backend_);
+ paramsByElementType_[defaultElementType_] = params;
+
if(device)
tensors_ = New<Tensors>(backend_, device);
else
@@ -20,45 +23,244 @@ void ExpressionGraph::setDevice(DeviceId deviceId, Ptr<Device> device) {
}
}
-Expr ExpressionGraph::dropout(float prob, const Shape& shape) {
- return constant(shape, inits::dropout(prob));
+Expr ExpressionGraph::add(Expr node) {
+ auto found = tensors_->findOrRemember(node);
+ if(found) {
+ return found;
+ } else {
+ node->setId(count_++);
+
+ // record in foward graph
+ nodesForward_.push_back(node);
+
+ // record in backward graph if training, and keep track of roots
+ if(!inferenceOnly_ && node->trainable()) {
+ nodesBackward_.push_back(node);
+ topNodes_.insert(node); // opportunistically record all new nodes as roots (gets removed once consumed)
+ }
+
+ if(topNodes_.count(node)) // only erase children of nodes with are themselves in the topNodes list
+ for(auto child : node->children())
+ topNodes_.erase(child); // this child is consumed and therefore not a root
+
+ return node;
+ }
}
-void ExpressionGraph::checkNan(Tensor t) {
- ABORT_IF(throwNaN_, "Not implemented"); t;
- // ABORT_IF(throwNaN_ && IsNan(t), "Tensor has NaN");
+// Call on every checkpoint in backwards order
+void createSubtape(Expr node) {
+ auto subtape = New<std::list<Expr>>();
+
+ for(auto child : node->children()) {
+ if(child->isCheckpoint()) {
+ /* do not descend */
+ } else {
+ if(child->getSubtape()) {
+ /* already visited */
+ } else {
+ createSubtape(child);
+ subtape->splice(subtape->end(), *(child->getSubtape()));
+ }
+ }
+ }
+
+ if(!node->isCheckpoint())
+ subtape->push_back(node);
+
+ node->setSubtape(subtape);
+}
+
+void ExpressionGraph::forwardNext() {
+ // @TODO: check if allocation works properly
+ tensors_->clearShorttermMemory();
+
+ if(checkpointing_) {
+ for(auto top : topNodes_)
+ top->markCheckpoint();
+
+ auto it = nodesBackward_.rbegin();
+ while(it != nodesBackward_.rend()) {
+ auto v = *it;
+ if(v->isCheckpoint())
+ createSubtape(v);
+ it++;
+ }
+
+ // To avoid recomputation of range from last checkpoint to the top,
+ // turn all nodes on last subtape into checkpoints and clear subtape.
+ // @TODO: put this into special backprob function? Needs to know that we are done with adding nodes
+ for(auto top : topNodes_) {
+ if(top->getSubtape()) {
+ for(auto& node : *top->getSubtape())
+ node->markCheckpoint();
+ top->getSubtape()->clear();
+ }
+ }
+ }
+
+ forward(nodesForward_, /*finalPass=*/!checkpointing_); // if checkPointing, this is not final
+}
+
+void ExpressionGraph::forward(std::list<Expr>& forwardTape, bool finalPass) {
+ while(!forwardTape.empty()) {
+ auto v = forwardTape.front();
+
+ v->allocate();
+ v->init();
+
+ for(auto& child : v->children())
+ ABORT_IF(!child->val(), "De-allocated child {} {} of {} {}", child->getId(), child->type(), v->getId(), v->type());
+
+ v->forward();
+
+ if(v->trainable() && throwNaN_) {
+ bool isNaN = false, isInf = false;
+ checkNaN(v->val(), isNaN, isInf);
+ if(isNaN || isInf) {
+ LOG(critical, "Detected NaN ({}) or Inf ({}) in value (forward pass)", isNaN, isInf);
+ LOG(critical, "\tType: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
+ v->type(), v->shape(), v->name(), v->getId(), v->hash());
+ LOG(critical, "Children: {}", v->children().size());
+ for(auto&& child : v->children()) {
+ LOG(critical, "\tType: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
+ child->type(), child->shape(), child->name(), child->getId(), child->hash());
+ }
+ }
+ }
+
+ if(v->marked_for_debug()) {
+ Logger log = spdlog::get("general");
+ if(log) {
+ LOG(info, "Debug: {} op={}", v->debug_message(), v->type());
+ LOG(info, v->val()->debug());
+ }
+ else {
+ std::cerr << "Debug: " << v->debug_message() << " op=" << v->type() << std::endl;
+ std::cerr << v->val()->debug() << std::endl;
+ }
+ }
+
+ if(inferenceOnly_)
+ v->children().clear();
+
+ if(checkpointing_ && !finalPass) {
+ auto subtape = v->getSubtape();
+ if(subtape) {
+ for(auto& node : *subtape) {
+ node->free();
+ }
+ }
+ }
+
+ forwardTape.pop_front();
+ }
}
-void ExpressionGraph::save(std::vector<io::Item>& ioItems) {
- for(auto p : params()->getMap()) {
- std::string pName = p.first;
+void ExpressionGraph::backward(bool reset, float clipValue) {
+ if(topNodes_.size() > 1) {
+ LOG(critical, "There are more ({}) than one top most nodes for backward pass:", topNodes_.size());
+ for(auto node : topNodes_) {
+ LOG(critical,
+ "\tType: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
+ node->type(),
+ node->shape(),
+ node->name(),
+ node->getId(),
+ node->hash());
+ }
+ ABORT("Aborting");
+ }
+
+ for(auto kvParams : paramsByElementType_) {
+ kvParams.second->allocateBackward();
+ if(reset)
+ kvParams.second->set_zero_adjoint();
+ }
+
+ for(auto&& v : topNodes_)
+ v->init_dependent();
+
+ topNodes_.clear();
+
+ tensors_->clearShorttermMemory();
+
+ bool firstNaN = true;
+ while(!nodesBackward_.empty()) {
+ auto v = nodesBackward_.back();
+ nodesBackward_.pop_back();
+
+ for(auto&& child : v->children())
+ if(child->trainable() && child->type() != "param")
+ child->set_zero_adjoint();
- if(!namespace_.empty()) {
- if(pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
- pName = pName.substr(namespace_.size() + 2);
+ if(checkpointing_ && v->getSubtape()) {
+ forward(*v->getSubtape(), /*finalPass=*/true);
}
- ABORT_IF(p.second->val()->type() != Type::float32,
- "Only float32 supported at the moment");
+ if(v->trainable() && v->marked_for_debug()) {
+ LOG(info, "Debug Grad: {} op={}", v->debug_message(), v->type());
+ LOG(info, v->grad()->debug());
+ }
+
+ if(v->trainable() && clipValue != 0) {
+ using namespace functional;
+ Element(_1 = clip(_1, clipValue), v->grad());
+ }
+
+ if(v->trainable())
+ v->backward();
- Tensor val = p.second->val();
+ if(throwNaN_ && firstNaN) {
+ for(auto&& child : v->children()) {
+ if(child->trainable()) {
+ bool isNaN = false, isInf = false;
+ checkNaN(child->grad(), isNaN, isInf);
+ if(isNaN) {
+ LOG(critical, "Detected NaN ({}) or Inf ({}) in gradient (backward pass) of child node", isNaN, isInf);
+ LOG(critical, "Child - Type: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
+ child->type(), child->shape(), child->name(), child->getId(), child->hash());
+ LOG(critical, "Parent - Type: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
+ v->type(), v->shape(), v->name(), v->getId(), v->hash());
+ firstNaN = false;
+ }
+ }
+ }
+ }
- io::Item item;
- item.name = pName;
- item.shape = val->shape();
- item.type = val->type();
+ v->children().clear();
+ }
+}
- // Use the actual memory as this will be aligned and padded.
- // When memory mapping this is required. Shape keeps track of
- // tensor size. Saving to *.npz will cut to size.
- auto mem = val->memory();
- item.bytes.resize(mem->size());
- copy(backend_,
- mem->data<char>(),
- mem->data<char>() + mem->size(),
- item.bytes.data());
+Expr ExpressionGraph::dropoutMask(float prob, const Shape& shape, Type valueType) {
+ return constant(shape, inits::dropout(prob), valueType);
+}
- ioItems.emplace_back(std::move(item));
+Expr ExpressionGraph::dropoutMask(float prob, const Shape& shape) {
+ return constant(shape, inits::dropout(prob), defaultElementType_);
+}
+
+void ExpressionGraph::checkNaN(Tensor t, bool& isNaN, bool& isInf) {
+ IsNaN(t, allocator(), isNaN, isInf);
+}
+
+void ExpressionGraph::save(std::vector<io::Item>& ioItems, Type saveElementType) {
+ // sorted by type in std::map
+ for(auto kvParams : paramsByElementType_) {
+ // sorted by name in std::map
+ for(auto p : kvParams.second->getMap()) {
+ std::string pName = p.first;
+
+ if(!namespace_.empty()) {
+ if(pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
+ pName = pName.substr(namespace_.size() + 2);
+ }
+
+ Tensor val = p.second->val();
+ io::Item item;
+ val->get(item, pName);
+ item.convert(saveElementType);
+ ioItems.emplace_back(std::move(item));
+ }
}
}
diff --git a/src/graph/expression_graph.h b/src/graph/expression_graph.h
index 815d9080..f7a78126 100755
--- a/src/graph/expression_graph.h
+++ b/src/graph/expression_graph.h
@@ -63,7 +63,7 @@ public:
tensors_->allocate(node->grad(), node->shape(), node->value_type());
}
- void free(Tensor& tensor) { tensors_->free(tensor); }
+ void free(const Tensor& tensor) { tensors_->free(tensor); }
// @TODO: get rid of this, not really used or can be done better
Ptr<Allocator> allocator() { return tensors_->allocator(); }
@@ -90,14 +90,13 @@ public:
auto it = shortterm_->find(hash);
if(it != shortterm_->end()) {
- for(auto foundWeak : it->second) {
- auto found = foundWeak.lock();
+ for(auto found : it->second) {
if(node->equal(found)) {
return found;
}
}
}
- (*shortterm_)[hash].push_back(node);
+ (*shortterm_)[hash].push_back(node.get()); // weakPtr
return nullptr;
}
@@ -111,8 +110,9 @@ public:
void clearLongtermMemory() { longterm_->clear(); }
};
+typedef std::map<Type, Ptr<Parameters>> ElementTypeParamsMap; // keep it sorted, hence map not unordered map
+
class ExpressionGraph : public std::enable_shared_from_this<ExpressionGraph> {
-private:
size_t count_{0};
std::list<Expr> nodesForward_;
@@ -120,21 +120,19 @@ private:
std::unordered_set<Expr> topNodes_; // current set of roots. In the end, all but one must have been consumed.
- // Holds memory and expressions that correspond to graph parameters
- Ptr<Parameters> params_;
-
// Holds memory and expressions that correspond to temporary expressions.
// This gets cleared before a new graph is built.
Ptr<Tensors> tensors_;
std::unordered_map<size_t, std::vector<Expr>> memoized_;
+ Type defaultElementType_{Type::float32}; // Type used for storing parameters, currently all parameters have to have the same type
+
bool inferenceOnly_{false};
- bool optimized_{false};
- Ptr<Backend> backend_;
+
+ bool checkpointing_{false}; // use gradient checkpointing if true
bool reloaded_{false};
- std::string namespace_;
bool throwNaN_{false};
@@ -143,40 +141,51 @@ protected:
ExpressionGraph(const ExpressionGraph&) = delete;
ExpressionGraph(ExpressionGraph&&) = delete;
+ // Holds memory and expressions that correspond to graph parameters
+ // Now we can have multiple types of parameters in a separate parameters object per value type.
+ // This is currently only accessible through private functions during loading, will abort during training
+ // when params() is called (e.g. optimizer) and there is more or other types than the default parameter type.
+ // Currently the only usecase is inference. Trying to access params() for non-default parameter type is going
+ // to abort. Inference does not need to access a whole set of parameters.
+ ElementTypeParamsMap paramsByElementType_;
+ Ptr<Backend> backend_;
+
+ std::string namespace_;
+
public:
/** @brief Constructs a new expression graph
*
* Constructor should be used as New<ExpressionGraph>()
*/
- ExpressionGraph(bool inference = false, bool optimized = false);
-
- void setInference(bool inference) { inferenceOnly_ = inference; }
- bool isInference() { return inferenceOnly_; }
+ ExpressionGraph(bool inference = false);
- ~ExpressionGraph() {
+ virtual ~ExpressionGraph() {
clear();
- params_->clear();
+ for(auto kvParams : paramsByElementType_)
+ kvParams.second->clear();
}
- void setDevice(DeviceId deviceId = {0, DeviceType::gpu},
- Ptr<Device> device = nullptr);
+ virtual void setDevice(DeviceId deviceId = {0, DeviceType::gpu},
+ Ptr<Device> device = nullptr);
DeviceId getDeviceId() { return backend_->getDeviceId(); }
Ptr<Backend> getBackend() { return backend_; }
- void setOptimized(bool optimized) { optimized_ = optimized; }
- bool isOptimized() { return (optimized_ && inferenceOnly_); }
+ void setInference(bool inference) { inferenceOnly_ = inference; }
+ bool isInference() { return inferenceOnly_; }
+
+ void setCheckpointing(bool checkpointing) { checkpointing_ = checkpointing; }
+ bool isCheckpointing() { return checkpointing_; }
void switchParams(const std::string& newNamespace) {
namespace_ = newNamespace;
}
- void copyParams(Ptr<ExpressionGraph> graph) {
+ virtual void copyParams(Ptr<ExpressionGraph> graph) {
for(auto p : *graph->params())
- param(p->name(), p->shape(), inits::dummy);
- params()->allocateForward();
- params()->vals()->copyFrom(graph->params()->vals());
+ param(p->name(), p->shape(), inits::fromTensor(p->val()), p->value_type());
+ forward(); // this will allocate parameters, execute the intializers and therefore copy parameter values
}
void reserveWorkspaceMB(size_t num) {
@@ -212,86 +221,18 @@ public:
return true;
}
+ void checkNaN(Tensor t, bool& isNaN, bool& isInf);
+
void forward() {
- params_->allocateForward();
+ for(auto kvParams : paramsByElementType_)
+ kvParams.second->allocateForward();
forwardNext();
}
- void checkNan(Tensor t);
-
- void forwardNext() {
- // @TODO: check if allocation works properly
- tensors_->clearShorttermMemory();
-
- while(!nodesForward_.empty()) {
- auto v = nodesForward_.front();
- v->allocate();
- v->init();
- v->forward();
-
- checkNan(v->val());
-
- if(v->marked_for_debug()) {
- std::cerr << "Debug: " << v->debug_message() << " op=" << v->type()
- << std::endl;
- std::cerr << v->val()->debug() << std::endl;
- }
-
- if(inferenceOnly_)
- v->children().clear();
- nodesForward_.pop_front();
- }
- }
-
- void backward(bool zero = true) {
- if(topNodes_.size() > 1) {
- LOG(critical, "There are more ({}) than one top most node for backward step:", topNodes_.size());
- for(auto node : topNodes_) {
- LOG(critical,
- "\tType: {}, Shape: {}, Name: {}, Id: {}, Hash: {}",
- node->type(),
- node->shape(),
- node->name(),
- node->getId(),
- node->hash());
- }
- ABORT("Aborting");
- }
-
- params_->allocateBackward();
- if(zero)
- params_->set_zero_adjoint();
-
- for(auto&& v : topNodes_)
- v->init_dependent();
-
- // named_.clear();
- topNodes_.clear();
-
- tensors_->clearShorttermMemory();
+ void forwardNext();
+ void forward(std::list<Expr>& forwardTape, bool finalPass);
- while(!nodesBackward_.empty()) {
- auto v = nodesBackward_.back();
- nodesBackward_.pop_back();
-
- for(auto&& child : v->children()) {
- if(child->trainable() && child->type() != "param")
- child->set_zero_adjoint();
- }
-
- if(v->trainable())
- v->backward();
-
- checkNan(v->grad());
-
- if(v->trainable() && v->marked_for_debug()) {
- std::cerr << "Debug Grad: " << v->debug_message() << std::endl;
- std::cerr << v->grad()->debug() << std::endl;
- }
-
- v->children().clear();
- }
- }
+ void backward(bool reset = true, float clipValue = 0.f);
std::string graphviz() {
std::stringstream ss;
@@ -316,50 +257,118 @@ public:
dot.close();
}
+private:
+
+ // Find the named parameter and its typed parent parameter object (params) and return both.
+ // If the parameter is not found return the parent parameter object that the parameter should be added to.
+ // Return [nullptr, nullptr] if no matching parent parameter object exists.
+ std::tuple<Expr, Ptr<Parameters>> findParams(const std::string& name,
+ Type elementType,
+ bool typeSpecified) const {
+ Expr p; Ptr<Parameters> params;
+ if(typeSpecified) { // type has been specified, so we are only allowed to look for a parameter with that type
+ auto it = paramsByElementType_.find(elementType);
+ if(it != paramsByElementType_.end()) {
+ params = it->second;
+ p = params->get(name);
+ }
+ } else { // type has not been specified, so we take any type as long as the name matches
+ for(auto kvParams : paramsByElementType_) {
+ p = kvParams.second->get(name);
+
+ if(p) { // p has been found, return with matching params object
+ params = kvParams.second;
+ break;
+ }
+
+ if(kvParams.first == elementType) // even if p has not been found, set the params object to be returned
+ params = kvParams.second;
+ }
+ }
+
+ return std::make_tuple(p, params);
+ }
+
Expr param(const std::string& pname,
const Shape& shape,
- const NodeInitializer& init,
- bool fixed = false) {
+ const Ptr<inits::NodeInitializer>& init,
+ const Type elementType,
+ bool fixed,
+ bool typeSpecified) {
std::string name = pname;
if(!namespace_.empty())
name = namespace_ + "::" + name;
- // check first if parameter already exists
- auto p = params_->get(name);
- if(p) {
- // if yes add to tape and return
- ABORT_IF(shape != p->shape(),
- "Requested shape {} for existing parameter '{}' does not match "
- "original shape {}",
- shape,
- name,
- p->shape());
-
- p->setTrainable(!fixed);
- add(p);
- return p;
+ Expr p; Ptr<Parameters> params; std::tie
+ (p, params) = findParams(name, elementType, typeSpecified);
+
+ if(!params) {
+ params = New<Parameters>(elementType);
+ params->init(backend_);
+ paramsByElementType_.insert({elementType, params});
+ } else {
+ if(p) {
+ // if yes add to tape and return
+ ABORT_IF(shape != p->shape(),
+ "Requested shape {} for existing parameter '{}' does not match "
+ "original shape {}",
+ shape,
+ name,
+ p->shape());
+
+ p->setTrainable(!fixed);
+ add(p);
+ return p;
+ }
}
// if graph was reloaded do not allow creation of new parameters
ABORT_IF(reloaded_,
- "Graph was reloaded and parameter '{}' is newly created",
- name);
+ "Graph was reloaded and parameter '{}' with type {} (specified: {}) is newly created",
+ name, elementType, typeSpecified);
// if not check if name is not taken by other node
- ABORT_IF(get(name), "Non-parameter with name '{}' already exists", name);
+ auto other = get(name);
+ ABORT_IF(other, "Parameter with name '{}' already exists and has type {}", name, other->value_type());
// create parameter node (adds to tape)
- p = Expression<ParamNode>(shared_from_this(), shape, init, fixed);
+ p = Expression<ParamNode>(shared_from_this(), shape, init, elementType, fixed);
+ LOG(debug, "Created parameter {} with shape {} and type {}", name, shape, elementType);
// set name and id and add to list of parameters
p->set_name(name);
- params_->add(p, name);
+ params->add(p, name);
return p;
}
- Expr constant(const Shape& shape, const NodeInitializer& init, Type value_type = Type::float32) {
- return Expression<ConstantNode>(shared_from_this(), shape, init, value_type);
+public:
+ Expr param(const std::string& pname,
+ const Shape& shape,
+ const Ptr<inits::NodeInitializer>& init,
+ const Type elementType,
+ bool fixed = false) {
+ // this param is called with a specified type
+ return param(pname, shape, init, elementType, fixed, /*typeSpecified=*/true);
+ }
+
+ Expr param(const std::string& pname,
+ const Shape& shape,
+ const Ptr<inits::NodeInitializer>& init,
+ bool fixed = false) {
+ // since this param is called without a specified type, we assume defaultElementType but allow to check for a different type
+ return param(pname, shape, init, defaultElementType_, fixed, /*typeSpecified=*/false);
+ }
+
+ Expr constant(const Shape& shape,
+ const Ptr<inits::NodeInitializer>& init,
+ Type elementType) {
+ return Expression<ConstantNode>(shared_from_this(), shape, init, elementType);
+ }
+
+ Expr constant(const Shape& shape,
+ const Ptr<inits::NodeInitializer>& init) {
+ return Expression<ConstantNode>(shared_from_this(), shape, init, defaultElementType_);
}
// @TODO: add version with iterators
@@ -367,66 +376,78 @@ public:
// like rows or select
Expr indices(const std::vector<IndexType>& indicesVector) {
return constant({(int)indicesVector.size()},
- inits::from_vector(indicesVector),
+ inits::fromVector(indicesVector),
Type::uint32);
}
// this version sets up the shape such that the indices are in a given axis
- // Use this if you want to pass these indices to select().
+ // Use this if you want to pass these indices to gather().
// indexee shape = (3, 2, 5, 2); axis = 1 -> resulting shape = (1, size of indicesVector, 1, 1)
Expr indices(const std::vector<IndexType>& indicesVector, Expr indexee, int axis = -1) {
Shape shape;
shape.resize(indexee->shape().size());
shape.set(axis, indicesVector.size());
return constant(Shape(shape),
- inits::from_vector(indicesVector),
+ inits::fromVector(indicesVector),
Type::uint32);
}
+ Expr ones(const Shape& shape, Type elementType) {
+ return constant(shape, inits::ones(), elementType);
+ }
Expr ones(const Shape& shape) {
- return constant(shape, inits::ones);
+ return constant(shape, inits::ones(), defaultElementType_);
}
+ Expr zeros(const Shape& shape, Type elementType) {
+ return constant(shape, inits::zeros(), elementType);
+ }
Expr zeros(const Shape& shape) {
- return constant(shape, inits::zeros);
+ return constant(shape, inits::zeros(), defaultElementType_);
}
// prob = dropProb, e.g. 0.1 means 90% of values are kept
- Expr dropout(float dropProb, const Shape& shape);
+ Expr dropoutMask(float dropProb, const Shape& shape, Type elementType);
+ Expr dropoutMask(float dropProb, const Shape& shape);
Expr get(std::string name) {
if(!namespace_.empty())
name = namespace_ + "::" + name;
-
- auto e = params_->get(name);
- if(e)
- return e;
- return Expr();
+ Expr p; Ptr<Parameters> params; std::tie
+ (p, params) = findParams(name, defaultElementType_, /*specifiedType=*/false);
+ return p;
}
- Ptr<Parameters>& params() { return params_; }
-
- Expr add(Expr node) {
- auto found = tensors_->findOrRemember(node);
- if(found) {
- return found;
- } else {
- node->setId(count_++);
+ Expr get(std::string name, Type specifiedElementType) {
+ if(!namespace_.empty())
+ name = namespace_ + "::" + name;
+ Expr p; Ptr<Parameters> params; std::tie
+ (p, params) = findParams(name, specifiedElementType, /*specifiedType=*/true);
+ return p;
+ }
- // record in foward graph
- nodesForward_.push_back(node);
+ Ptr<Parameters>& params() {
+ // There are no parameter objects, that's weird.
+ ABORT_IF(paramsByElementType_.empty(), "No parameter object has been created");
+
+ // Safeguard against accessing parameters from the outside with multiple parameter types, not yet supported
+ ABORT_IF(paramsByElementType_.size() > 1, "Calling of params() is currently not supported with multiple ({}) parameters", paramsByElementType_.size());
+
+ // Safeguard against accessing parameters from the outside with other than default parameter type, not yet supported
+ auto it = paramsByElementType_.find(defaultElementType_);
+ ABORT_IF(it == paramsByElementType_.end(), "Parameter object for type {} does not exist", defaultElementType_);
- // record in backward graph if training, and keep track of roots
- if(!inferenceOnly_ && node->trainable()) {
- nodesBackward_.push_back(node);
- topNodes_.insert(node); // opportunistically record all new nodes as roots (gets removed once consumed)
- }
- for(auto child : node->children())
- topNodes_.erase(child); // this child is consumed and therefore not a root
+ return it->second;
+ }
- return node;
- }
+ void setDefaultElementType(Type defaultElementType) {
+ ABORT_IF(!paramsByElementType_.empty() && defaultElementType != defaultElementType_,
+ "Parameter objects already exist, cannot change default type from {} to {}",
+ defaultElementType_, defaultElementType);
+ defaultElementType_ = defaultElementType;
}
+ Expr add(Expr node);
+
void allocateForward(Expr node) {
if(tensors_)
tensors_->allocateForward(node);
@@ -437,7 +458,7 @@ public:
tensors_->allocateBackward(node);
}
- void free(Tensor& tensor) {
+ void free(const Tensor& tensor) {
if(tensors_)
tensors_->free(tensor);
}
@@ -456,22 +477,26 @@ public:
tensors_->clear();
}
- void clearParameters() { params_->clear(); }
-
void setReloaded(bool reloaded) { reloaded_ = reloaded; }
void setThrowNaN(bool throwNaN) { throwNaN_ = throwNaN; }
+ bool getThrowNaN() { return throwNaN_; }
public:
- // convert all parameters into an array of IoItem elements, for loading
- void load(const std::vector<io::Item>& ioItems, bool markReloaded = true) {
+ // loading from array of io::Items
+ void load(std::vector<io::Item>& ioItems, bool markReloaded = true) {
setReloaded(false);
for(auto& item : ioItems) {
std::string pName = item.name;
// skip over special parameters starting with "special:"
if(pName.substr(0, 8) == "special:")
continue;
- param(pName, item.shape, inits::from_item(item));
+
+ // if during loading the loaded type is of the same type class as the default element type, allow conversion;
+ // otherwise keep the loaded type. This is used when e.g. loading a float32 model as a float16 model as both
+ // have type class TypeClass::float_type.
+ auto loadElementType = isSameTypeClass(item.type, defaultElementType_) ? defaultElementType_ : item.type;
+ param(pName, item.shape, inits::fromItem(item), loadElementType, /*fixed=*/false);
}
if(markReloaded)
setReloaded(true);
@@ -479,39 +504,62 @@ public:
void load(const std::string& name, bool markReloaded = true) {
LOG(info, "Loading model from {}", name);
- load(io::loadItems(name), markReloaded);
+ auto items = io::loadItems(name);
+ load(items, markReloaded);
}
void load(const void* ptr, bool markReloaded = true) {
LOG(info, "Loading model from buffer at {}", ptr);
- load(io::loadItems(ptr), markReloaded);
+ auto items = io::loadItems(ptr);
+ load(items, markReloaded);
}
void mmap(const void* ptr, bool markReloaded = true) {
ABORT_IF(backend_->getDeviceId().type != DeviceType::cpu || !inferenceOnly_,
"Memory mapping only supported for CPU inference mode");
- params_ = New<MappedParameters>();
- params_->init(backend_);
-
LOG(info, "Memory mapping model at {}", ptr);
- load(io::mmapItems(ptr), markReloaded);
+ auto items = io::mmapItems(ptr);
+
+ // Deal with default parameter set object that might not be a mapped object.
+ // This gets assigned during ExpressionGraph::setDevice(...) and by default
+ // would contain allocated tensors. Here we replace it with a mmapped version.
+ auto it = paramsByElementType_.find(defaultElementType_);
+ if(it != paramsByElementType_.end()) {
+ // there is parameter object for that type
+ auto defaultParams = std::dynamic_pointer_cast<MappedParameters>(it->second);
+ if(!defaultParams) {
+ // but it's not mapped, so delete it and replace it with a mapped version
+ defaultParams = New<MappedParameters>(defaultElementType_);
+ defaultParams->init(backend_);
+ paramsByElementType_[defaultElementType_] = defaultParams;
+ }
+ }
+
+
+ // pre-populate parameters by type
+ for(auto& item : items) {
+ auto it1 = paramsByElementType_.find(item.type);
+ if(it1 == paramsByElementType_.end()) {
+ auto params = New<MappedParameters>(item.type);
+ params->init(backend_);
+ paramsByElementType_.insert({item.type, params});
+ }
+ }
+
+ load(items, markReloaded);
}
public:
// convert all parameters into an array of io::Item elements, for saving
- void save(std::vector<io::Item>& ioItems);
-
- void save(const std::string& name, const std::string& meta = "") {
- // LOG(info, "Saving model to {}", name);
+ void save(std::vector<io::Item>& ioItems, Type saveElementType = Type::float32);
+ void save(const std::string& name, const std::string& meta = "", Type saveElementType = Type::float32) {
std::vector<io::Item> ioItems;
- save(ioItems);
+ save(ioItems, saveElementType);
if(!meta.empty())
io::addMetaToItems(meta, "special:model.yml", ioItems);
io::saveItems(name, ioItems);
-
- // LOG(info, "Saved {} items.", ioItems.size());
}
};
diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp
index 2f4f5ecf..f858d730 100755
--- a/src/graph/expression_operators.cpp
+++ b/src/graph/expression_operators.cpp
@@ -7,6 +7,11 @@
#include "graph/auto_tuner.h"
#include "tensors/cpu/int16.h"
+#include "tensors/cpu/fbgemm/expanded_gemm.h"
+
+#if USE_FBGEMM
+#include "fbgemm/Utils.h"
+#endif
namespace marian {
@@ -15,6 +20,11 @@ Expr debug(Expr a, const std::string& message) {
return a;
}
+Expr checkpoint(Expr a) {
+ a->markCheckpoint();
+ return a;
+}
+
// logistic function. Note: scipy name is expit()
Expr sigmoid(Expr a) {
return Expression<SigmoidNodeOp>(a);
@@ -51,6 +61,10 @@ Expr swish(Expr a) {
return Expression<SwishNodeOp>(a);
}
+Expr gelu(Expr a) {
+ return Expression<SwishNodeOp>(a, 1.702f);
+}
+
Expr operator-(Expr a) {
return Expression<NegNodeOp>(a);
};
@@ -69,10 +83,15 @@ Expr softmax(Expr a, int axis /*=-1*/)
}
Expr softmax(Expr a, Expr zeroOneMask, int axis /*=-1*/) {
- auto logMask = (1 - zeroOneMask) * -99999999.f;
+ // This will return the smallest value / 2 for the input type converted to float
+ // So for Type::Float16 that will be the smallest fp16 value expressed as float
+ // We divide by 2 to allow for some tolerance and overflow protection.
+ float smallestFloat = NumericLimits<float>(a->value_type()).lowest / 2.f;
+ auto logMask = (1.f - zeroOneMask) * smallestFloat;
return softmax(a + logMask, axis);
}
+// @TODO: add mask
Expr logsoftmax(Expr a) {
return Expression<LogSoftmaxNodeOp>(a);
}
@@ -107,39 +126,78 @@ Expr minimum(Expr a, Expr b) {
return Expression<MinimumNodeOp>(a, b);
}
+Expr lt(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, -1, false); }
+Expr eq(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, 0, false); }
+Expr gt(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, 1, false); }
+Expr ge(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, -1, true); }
+Expr ne(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, 0, true); }
+Expr le(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, 1, true); }
+
+Expr lt(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b, -1, false); }
+Expr eq(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b, 0, false); }
+Expr gt(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b, 1, false); }
+Expr ge(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b, -1, true); }
+Expr ne(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b, 0, true); }
+Expr le(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::fromValue(a), b->value_type()), b, 1, true); }
+
+Expr lt(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()), -1, false); }
+Expr eq(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()), 0, false); }
+Expr gt(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()), 1, false); }
+Expr ge(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()), -1, true); }
+Expr ne(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()), 0, true); }
+Expr le(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::fromValue(b), a->value_type()), 1, true); }
+
/*********************************************************/
Expr operator+(Expr a, float b) {
- return Expression<ScalarAddNodeOp>(a, b);
+ if (b == 0)
+ return a;
+ else
+ return Expression<ScalarAddNodeOp>(a, b);
}
Expr operator+(float a, Expr b) {
- return Expression<ScalarAddNodeOp>(b, a);
+ if (a == 0)
+ return b;
+ else
+ return Expression<ScalarAddNodeOp>(b, a);
}
Expr operator-(Expr a, float b) {
- return Expression<ScalarAddNodeOp>(a, -b);
+ if (b == 0)
+ return a;
+ else
+ return Expression<ScalarAddNodeOp>(a, -b);
}
Expr operator-(float a, Expr b) {
- return Expression<ScalarAddNodeOp>(-b, a);
+ if (a == 0)
+ return -b;
+ else
+ return Expression<ScalarAddNodeOp>(-b, a);
}
Expr operator*(float a, Expr b) {
- return Expression<ScalarMultNodeOp>(b, a);
+ if (a == 1.0f)
+ return b;
+ else
+ return Expression<ScalarMultNodeOp>(b, a);
}
Expr operator*(Expr a, float b) {
- return Expression<ScalarMultNodeOp>(a, b);
+ if (b == 1.0f)
+ return a;
+ else
+ return Expression<ScalarMultNodeOp>(a, b);
}
Expr operator/(Expr a, float b) {
- return Expression<ScalarMultNodeOp>(a, 1.f / b);
+ return a * (1.f / b);
}
// TODO: efficient version of this without constant()
Expr operator/(float a, Expr b) {
- auto aExpr = b->graph()->constant({}, inits::from_value(a));
+ auto aExpr = b->graph()->constant({}, inits::fromValue(a));
return aExpr / b;
}
@@ -170,9 +228,17 @@ Expr repeat(Expr a, size_t repeats, int ax) {
}
Expr reshape(Expr a, Shape shape) {
+ if (a->shape() == shape)
+ return a;
return Expression<ReshapeNodeOp>(a, shape);
}
+// @TODO: remove this if it turns out that we can train FP16 without that
+Expr clipGradient(Expr a, float clipValue) {
+ // don't create node if no clipping
+ return clipValue != 0.f ? Expression<ClipGradientNodeOp>(a, clipValue) : a;
+}
+
Expr atleast_1d(Expr a) {
return atleast_nd(a, 1);
}
@@ -211,50 +277,126 @@ Expr flatten_2d(Expr a) {
return Expression<ReshapeNodeOp>(a, shape);
}
-Expr constant_like(Expr a, const NodeInitializer& init) {
- const auto& shape = a->shape();
- auto graph = a->graph();
- return graph->constant(shape, init);
+Expr stopGradient(Expr a) {
+ // implemented as a dummy reshape that is not trainable
+ auto res = Expression<ReshapeNodeOp>(a, a->shape());
+ res->setTrainable(false);
+ return res;
+}
+
+// gather() -- gather arbitrary elements along an axis; batched or non-batched
+Expr gather(Expr a, int axis, Expr indices) {
+ return Expression<GatherNodeOp>(a, axis, indices);
+}
+
+// index_select() -- gather arbitrary elements along an axis from an unbatched
+// input 'a'. Indices are specified as a 1D vector.
+// This is used e.g. for embedding lookup.
+// Note: To use a batch of index vectors, reshape them into a single vector,
+// call index_select(), then reshape the result back. Reshapes are cheap.
+// This function has the same semantics as PyTorch operation of the same name.
+Expr index_select(Expr a, int axis, Expr indices) {
+ ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor");
+ // We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor.
+ auto rank = a->shape().size();
+ if (rank == 2) {
+ if (axis == 0 || axis == -2)
+ return Expression<RowsNodeOp>(a, indices);
+ else if (axis == -1 || axis == 1)
+ return Expression<ColsNodeOp>(a, indices);
+ }
+ // Delegate to gather() for any other axis or non-matrix input.
+ Shape shape;
+ shape.resize(a->shape().size());
+ shape.set(axis, indices->shape()[0]);
+ indices = reshape(indices, shape); // move index to axis
+ return gather(a, axis, indices);
}
-Expr rows(Expr a, Expr indices) {
- // @TODO:: replace with `select(a, indices, -2)`
- // as soon as select is efficient enough
- return Expression<RowsNodeOp>(a, indices);
+Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices) {
+ auto indexExpr = a->graph()->indices(indices);
+ return index_select(a, axis, indexExpr);
}
-Expr rows(Expr a, const std::vector<IndexType>& indices) {
- auto indexExpr = a->graph()->indices(indices);
- return rows(a, indexExpr);
+static Expr sliceCopy(Expr a, int axis, const Slice& slice) { // copy a Slice via gather()
+ ABORT_IF(slice.stride < 0, "Negative strides are not supported yet");
+ ABORT_IF(slice.begin == slice.end, "Empty slices are not allowed"); // @TODO: Or are they?
+ std::vector<IndexType> indices;
+ indices.reserve((slice.end - slice.begin - 1) / slice.stride + 1);
+ for (int i = slice.begin; i < slice.end; i += slice.stride)
+ indices.push_back((IndexType)i);
+ return gather(a, axis, a->graph()->indices(indices, a, axis));
}
+static Expr sliceView(Expr a, int axis, const Slice& slice) { // view a slice (must be memory-consecutive)
+ return Expression<SliceViewNodeOp>(a, axis, slice);
+}
-Expr cols(Expr a, Expr indices) {
- // @TODO:: replace with `select(a, indices, -1)`
- // as soon as select is efficient enough
- return Expression<ColsNodeOp>(a, indices);
+// slice() -- gather a slice along an axis (step size > 1 allowed)
+Expr slice(Expr a, int axis, Slice slice) { // numpy __getslice__ semantics, but with axis parameter
+ const auto& shape = a->shape();
+ axis = shape.axis(axis); // normalize negative axis
+ slice = shape.slice(slice, axis); // normalize negative slice values
+ if (slice.begin == 0 && slice.end == shape[axis] && slice.stride == 1)
+ return a; // it's a no-op
+#if 1 // until strided views are supported, non-consecutive slices are implemented via gather()
+ if (slice.stride != 1)
+ return sliceCopy(a, axis, slice);
+ for (int i = 0; i < axis; ++i) {
+ if (shape[i] != 1) // this makes it non-consecutive
+ return sliceCopy(a, axis, slice);
+ }
+#endif
+ return sliceView(a, axis, slice);
}
-Expr cols(Expr a, const std::vector<IndexType>& indices) {
- auto indexExpr = a->graph()->indices(indices);
- return cols(a, indexExpr);
+Expr sum(Expr a, int ax) {
+ if(a->shape()[ax] == 1) // nothing to reduce, sum of itself is a
+ return a;
+ return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::sum);
}
-Expr select(Expr a, Expr indices, int axis) {
- return Expression<SelectNodeOp>(a, indices, axis);
+Expr mean(Expr a, int ax) {
+ if(a->shape()[ax] == 1) // nothing to reduce, mean of itself is a
+ return a;
+ return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::mean);
}
-Expr select(Expr a, const std::vector<IndexType>& indices, int axis) {
- auto indexExpr = a->graph()->indices(indices, a, axis);
- return select(a, indexExpr, axis);
+Expr std(Expr a, int ax) {
+ if(a->shape()[ax] == 1) // nothing to reduce, std(a) = 0
+ return a - a;
+ return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::rms);
}
-Expr sum(Expr a, int ax) {
- return Expression<SumNodeOp>(a, ax);
+Expr var(Expr a, int ax) {
+ if(a->shape()[ax] == 1) // nothing to reduce, var(a) = 0
+ return a - a;
+ return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
}
-Expr mean(Expr a, int ax) {
- return Expression<MeanNodeOp>(a, ax);
+Expr max(Expr a, int ax) {
+ if(a->shape()[ax] == 1) // nothing to reduce, max of itself is a
+ return a;
+ return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::max);
+}
+
+Expr min(Expr a, int ax) {
+ if(a->shape()[ax] == 1) // nothing to reduce, min of itself is a
+ return a;
+ return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::min);
+}
+
+Expr prod(Expr a, int ax) {
+ if(a->shape()[ax] == 1) // nothing to reduce, prod of itself is a
+ return a;
+ return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::prod);
+}
+
+// log(sum(exp(a)))
+Expr logsumexp(Expr a, int ax) {
+ if(a->shape()[ax] == 1) // nothing to reduce, log(sum(exp(a))) = log(exp(a)) = a
+ return a;
+ return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::logSumExp);
}
Expr scalar_product(Expr a, Expr b, int ax) {
@@ -270,17 +412,50 @@ Expr weighted_average(Expr in, Expr weights, int ax) {
Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) {
auto device = a->graph()->getDeviceId().type;
float clipValue = a->graph()->getBackend()->getClip();
+ // added support for packed GEMM API (fp16, int8)
+ Type aElementType = a->value_type();
+ Type bElementType = b->value_type();
// Currently only true when command line options
// --optimize --cpu-thread=N with N > 0 are set.
- if(a->graph()->isOptimized() && device == DeviceType::cpu) {
- // dotInt16 computes A * B.T, hence the transpose for B to get A * B
- // if transA = false and transB = false.
-
- return cpu::int16::dot(
- cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
- cpu::int16::quantize(transB ? b : transpose(b), clipValue),
- scale);
+ if(device == DeviceType::cpu) {
+ if(isFloat(aElementType) && isFloat(bElementType)) {
+ if(a->graph()->getBackend()->isOptimized()) {
+ // dotInt16 computes A * B.T, hence the transpose for B to get A * B
+ // if transA = false and transB = false.
+
+ return cpu::int16::dot(
+ cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
+ cpu::int16::quantize(transB ? b : transpose(b), clipValue),
+ scale);
+ } else {
+ return Expression<DotNodeOp>(
+ clip(a, clipValue), clip(b, clipValue), transA, transB, scale);
+ }
+ } else if(isFloat(aElementType) && isPacked(bElementType)) {
+#if USE_FBGEMM
+ // 07/10/2019 - Use packed GEMM only if the cpu architecture supports AVX2
+ // one of the fbgemm's sub modules, cpuinfo (https://github.com/pytorch/cpuinfo).
+ // It looks at the cpu register
+ // (https://github.com/pytorch/cpuinfo/blob/master/src/x86/isa.c#L391),
+ // and this cpu lookup is executed only once and the state is kept in FBGEMM.
+ if(fbgemm::fbgemmHasAvx2Support()) {
+ // This variant of dot product can handle matrix multiplications with packed8 and packed16 weight matrix (B).
+ return cpu::variant::dot(clip(a, clipValue),
+ b,
+ b->shape(),
+ transA,
+ transB,
+ scale);
+ } else {
+ ABORT("AVX2 is not available. At least, AVX2 is needed to use fbgemm-based packed GEMM");
+ }
+#else
+ ABORT("Packed GEMM is not available in this build");
+#endif // USE_FBGEMM
+ } else {
+ ABORT("Combination of types A: {} B: {} not supported", aElementType, bElementType);
+ }
} else {
return Expression<DotNodeOp>(
clip(a, clipValue), clip(b, clipValue), transA, transB, scale);
@@ -291,104 +466,93 @@ Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
return Expression<DotBatchedNodeOp>(a, b, transA, transB, scale);
}
-Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
- auto device = a->graph()->getDeviceId().type;
+static Expr affineDefault(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
+ // general version, MKL, CBlas or CUDA
+ // if clipValue > 0, the inputs will be clipped to range [-clipValue,
+ // clipValue] This is meant to keep values at the same range as used during
+ // training when optimizing for 8-bit integer products. Likely to be removed
+ // in the future when we explore better ways to handle this.
float clipValue = a->graph()->getBackend()->getClip();
- if(a->graph()->isOptimized() && device == DeviceType::cpu) {
- bool autotune = true;
- if(autotune) {
- thread_local Ptr<AutoTuner<Expr>> tuner = New<AutoTuner<Expr>>();
-
- // start with new set of algorithms
- tuner->clear();
-
- // lower precicion for shapes, reduces data sparsity
- auto sh = [](Shape sh) {
- for(size_t i = 0; i < sh.size(); ++i)
- sh.set(i, sh[i] / 4);
- return sh;
- };
-
- // create context for current call as hash
- std::size_t hash = sh(a->shape()).hash();
- util::hash_combine(hash, sh(b->shape()).hash());
- util::hash_combine(hash, sh(bias->shape()).hash());
- util::hash_combine(hash, transA);
- util::hash_combine(hash, transB);
-
- // add first algorithm variant (Int16)
- size_t hash1 = hash;
- util::hash_combine(hash1, 1);
- auto rec1 = [=](Expr e, bool stop = false) {
- e->record(tuner, hash1, stop);
- return e;
- };
- auto alg1 = [=]() {
- return rec1(
- cpu::int16::affine(
- rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a,
- clipValue)),
- cpu::int16::quantize(transB ? b : transpose(b), clipValue),
- bias,
- scale),
- true);
- };
- tuner->insert({hash1, alg1});
-
- // add second algorithm variant (CBlas)
- size_t hash2 = hash;
- util::hash_combine(hash2, 2);
- auto rec2 = [=](Expr e, bool stop = false) {
- e->record(tuner, hash2, stop);
- return e;
- };
-
- auto alg2 = [=]() {
- auto ac = clip(a, clipValue);
- if(ac != a)
- ac = rec2(ac);
-
- auto bc = clip(b, clipValue);
- if(bc != b)
- bc = rec2(bc);
-
- int rows = ac->shape().elements() / ac->shape()[-1];
- Expr ones = ac->graph()->ones({rows, 1});
- std::vector<Expr> nodes = {ac, bc, bias, ones};
- return rec2(Expression<AffineNodeOp>(nodes, transA, transB, scale),
- true);
- };
- tuner->insert({hash2, alg2});
-
- // execute algorithm with autotuning
- return tuner->run();
+ int rows = a->shape().elements() / a->shape()[-1];
+ Expr ones = a->graph()->ones({ rows, 1 });
+ std::vector<Expr> nodes
+ = { clip(a, clipValue), clip(b, clipValue), bias, ones };
+ return Expression<AffineNodeOp>(nodes, transA, transB, scale);
+}
- } else {
- // cpu int16 version
- return cpu::int16::affine(
+// This operation used to implement auto-tuning. We have removed it for now due to complexity, but plan to revisit it in the future.
+// The last branch with auto-tuner is:
+// youki/packed-model-pr-backup1031
+// https://machinetranslation.visualstudio.com/Marian/_git/marian-dev?version=GByouki%2Fpacked-model-pr-backup1031
+// SHA: 3456a7ed1d1608cfad74cd2c414e7e8fe141aa52
+Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
+ auto device = a->graph()->getDeviceId().type;
+
+ float clipValue = a->graph()->getBackend()->getClip();
+ Type aElementType = a->value_type();
+ Type bElementType = b->value_type();
+
+ if(device == DeviceType::cpu) {
+ if(isFloat(aElementType) && isFloat(bElementType)) {
+ if(a->graph()->getBackend()->isOptimized()) {
+ // cpu int16 version
+ return cpu::int16::affine(
cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
cpu::int16::quantize(transB ? b : transpose(b), clipValue),
bias,
scale);
+ } else {
+ return affineDefault(a, b, bias, transA, transB, scale);
+ }
+ } else if(isFloat(aElementType) && isPacked(bElementType)) {
+#if USE_FBGEMM
+ // 07/10/2019 - Use packed GEMM only if the cpu architecture supports AVX2
+ // one of the fbgemm's sub modules, cpuinfo (https://github.com/pytorch/cpuinfo).
+ // It looks at the cpu register
+ // (https://github.com/pytorch/cpuinfo/blob/master/src/x86/isa.c#L391),
+ // and this cpu lookup is executed only once and the state is kept in FBGEMM.
+ if(fbgemm::fbgemmHasAvx2Support()) {
+ // This variant of affine product can handle matrix multiplications with packed8 and packed16 weight matrix (B).
+ return cpu::variant::affine(clip(a, clipValue),
+ b,
+ b->shape(),
+ bias,
+ transA,
+ transB,
+ scale);
+ } else {
+ ABORT("AVX2 is not available. At least, AVX2 is needed to use fbgemm-based packed GEMM");
+ }
+#else
+ ABORT("Packed GEMM is not available in this build");
+#endif // USE_FBGEMM
+ } else {
+ ABORT("Combination of types A: {} B: {} not supported", aElementType, bElementType);
}
} else {
- // general version, MKL, CBlas or CUDA
-
- // if clipValue > 0, the inputs will be clipped to range [-clipValue,
- // clipValue] This is meant to keep values at the same range as used during
- // training when optimizing for 8-bit integer products. Likely to be removed
- // in the future when we explore better ways to handle this.
-
- int rows = a->shape().elements() / a->shape()[-1];
- Expr ones = a->graph()->ones({rows, 1});
- std::vector<Expr> nodes
- = {clip(a, clipValue), clip(b, clipValue), bias, ones};
- return Expression<AffineNodeOp>(nodes, transA, transB, scale);
+ // Default GEMM
+ ABORT_IF(!isFloat(aElementType) || !isFloat(bElementType),
+ "GPU-based GEMM only supports float types, you have A: {} and B: {}",
+ aElementType, bElementType);
+ return affineDefault(a, b, bias, transA, transB, scale);
}
}
+// multiply a CSR matrix A with a matrix B
+// A[i,j] is at A_values[A_offsets[i]+k], where k is position of j in A_indices[A_offsets[i]:A_offsets[i+1]]
+// @TODO: Define a proper sparse tensor type.
+Expr csr_dot(const Shape& A_shape, Expr A_values, Expr A_indices, Expr A_offsets, Expr B, bool transA /*= false*/) {
+ return Expression<CSRDotNodeOp>(A_shape, A_values, A_indices, A_offsets, B, transA, /*swapOperands=*/false);
+}
+
+// multiply a matrix A with a CSR matrix B
+// @TODO: Define a proper sparse tensor type.
+Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB /*= false*/) {
+ return Expression<CSRDotNodeOp>(B_shape, B_values, B_indices, B_offsets, A, transB, /*swapOperands=*/true);
+}
+
// swap the last two axes
// @TODO: change to swapAxes(a, -1, -2)
Expr transpose(Expr a) {
@@ -409,32 +573,70 @@ Expr transpose(Expr a, const std::vector<int>& axes) {
Expr swapAxes(Expr x, int axis1, int axis2)
{
- axis1 = x->shape().axis(axis1);
- axis2 = x->shape().axis(axis2);
+ const auto& shape = x->shape();
+ axis1 = shape.axis(axis1);
+ axis2 = shape.axis(axis2);
if (axis1 == axis2)
return x;
+ if (shape[axis1] == 1 || shape[axis2] == 1) { // can we use a reshape instead?
+ if (axis1 > axis2)
+ std::swap(axis1, axis2);
+ bool canReshape = true;
+ for (int ax = axis1 + 1; ax < axis2 && canReshape; ax++)
+ canReshape &= (shape[ax] == 1);
+ if (canReshape) {
+ auto newShape = shape;
+ newShape.set(axis1, shape[axis2]);
+ newShape.set(axis2, shape[axis1]);
+ //LOG(info, "SwapAxes() did a reshape from {} to {}", shape.toString(), newShape.toString());
+ return reshape(x, newShape);
+ }
+ }
// TODO: This is code dup from transpose(x). Implement transpose(x) as swapAxes(x, 0, 1)
- std::vector<int> axes(x->shape().size());
- for (int i = 0; i < axes.size(); ++i)
+ std::vector<int> axes(shape.size());
+ for (int i = 0; i < axes.size(); ++i) // @TODO: use std::iota()
axes[i] = i;
std::swap(axes[axis1], axes[axis2]);
return transpose(x, axes);
}
-Expr step(Expr a, int step, int axis) {
- return Expression<StepNodeOp>(a, step, axis);
+Expr cast(Expr a, Type type) {
+ if(a->value_type() == type) {
+ return a;
+ } else {
+ return Expression<CastNodeOp>(a, type);
+ }
}
-Expr cross_entropy(Expr a, Expr indices) {
- return Expression<CrossEntropyNodeOp>(a, indices);
+Expr cross_entropy(Expr logits, Expr indices) {
+ return Expression<CrossEntropyNodeOp>(logits, indices);
}
-Expr plus(const std::vector<Expr>&) {
- ABORT("Not implemented");
+// Unlikelihood loss based on https://arxiv.org/abs/1908.04319
+Expr unlikelihood(Expr logits, Expr indices) {
+ int dimBatch = logits->shape()[-2];
+ int dimTime = logits->shape()[-3];
+
+ // @TODO: fix this outside of this function in decoder.h etc.
+ auto indicesWithLayout = reshape(indices, {1, dimTime, dimBatch, 1});
+
+ // This is currently implemented with multiple ops, might be worth doing a special operation like for cross_entropy
+ return -log(gather(1.f - softmax(logits), /*axis=*/-1, indicesWithLayout));
}
-Expr swish(const std::vector<Expr>&) {
- ABORT("Not implemented");
+Expr plus(const std::vector<Expr>& nodes) {
+ ABORT_IF(nodes.size() > 1, "Not implemented");
+ return nodes[0];
+}
+
+Expr swish(const std::vector<Expr>& nodes) {
+ ABORT_IF(nodes.size() > 1, "Not implemented");
+ return swish(nodes[0]);
+}
+
+Expr gelu(const std::vector<Expr>& nodes) {
+ ABORT_IF(nodes.size() > 1, "Not implemented");
+ return gelu(nodes[0]);
}
Expr tanh(const std::vector<Expr>& nodes) {
@@ -445,8 +647,9 @@ Expr sigmoid(const std::vector<Expr>&) {
ABORT("Not implemented");
}
-Expr relu(const std::vector<Expr>&) {
- ABORT("Not implemented");
+Expr relu(const std::vector<Expr>& nodes) {
+ ABORT_IF(nodes.size() > 1, "Not implemented");
+ return relu(nodes[0]);
}
Expr leakyrelu(const std::vector<Expr>&) {
@@ -469,6 +672,8 @@ Expr layerNorm(Expr x,
Expr gamma,
Expr beta /*= nullptr*/,
float eps /*= 1e-9*/) {
+
+ // layerNorm accumulates in float, so small eps is fine
std::vector<Expr> nodes = {x, gamma};
if(beta)
nodes.push_back(beta);
@@ -483,39 +688,25 @@ Expr highway(Expr y, Expr x, Expr t) {
Expr highway(const std::string prefix, Expr x) {
// clang-format off
size_t outDim = x->shape()[-1];
- auto g = mlp::dense(x->graph())
+ auto graph = x->graph();
+ auto g = mlp::dense()
("prefix", prefix + "_highway_d1")
("dim", outDim)
- ("activation", mlp::act::sigmoid)
- .construct()->apply(x);
- auto relued = mlp::dense(x->graph())
+ ("activation", (int)mlp::act::sigmoid)
+ .construct(graph)->apply(x);
+ auto relued = mlp::dense()
("prefix", prefix + "_highway_d2")
("dim", outDim)
- ("activation", mlp::act::ReLU)
- .construct()->apply(x);
+ ("activation", (int)mlp::act::ReLU)
+ .construct(graph)->apply(x);
return (g * relued) + ((1 - g) * x);
// clang-format on
}
-// Expr batch_norm(Expr x, Expr gamma, Expr beta) {
-// auto mju = mean(x, keywords::axis=0);
-// auto xmmju = x - mju;
-// auto std = sqrt(mean(square(xmmju), keywords::axis=0), 1e-9);
-//
-// if(beta)
-// return gamma * (xmmju / std) + beta;
-// else
-// return gamma * (xmmju / std);
-//}
-
Expr shift(Expr a, Shape shift, float padValue) {
return Expression<ShiftNodeOp>(a, shift, padValue);
}
-// Expr lexical_bias(Expr logits, Expr att, float eps, Ptr<sparse::CSR> lf) {
-// return Expression<LexicalProbNodeOp>(logits, att, eps, lf);
-//}
-
#ifdef CUDA_FOUND
#ifdef CUDNN
diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h
index 6c0a11b3..7600d2be 100644..100755
--- a/src/graph/expression_operators.h
+++ b/src/graph/expression_operators.h
@@ -6,6 +6,8 @@ namespace marian {
Expr debug(Expr a, const std::string& message = "");
+Expr checkpoint(Expr a);
+
typedef Expr(ActivationFunction)(Expr);
Expr plus(const std::vector<Expr>&);
@@ -17,6 +19,9 @@ Expr sigmoid(const std::vector<Expr>&);
Expr swish(Expr a);
Expr swish(const std::vector<Expr>&);
+Expr gelu(Expr a);
+Expr gelu(const std::vector<Expr>&);
+
Expr tanh(const std::vector<Expr>&);
template <typename... Args>
@@ -66,9 +71,32 @@ Expr operator/(Expr a, float b);
Expr logaddexp(Expr a, Expr b);
-Expr max(Expr a, Expr b); // TODO: haggle over the name (max vs. elementMax)
-
-Expr min(Expr a, Expr b); // TODO: haggle over the name
+// Note: Following numpy, minimum() is element-wise, while min() is along an axis in both Numpy and PyTorch.
+Expr maximum(Expr a, Expr b);
+Expr minimum(Expr a, Expr b);
+
+// Note: We cannot overload the relational operators, as they also mean something for Expr itself.
+// Note: These names follow PyTorch convention.
+Expr lt(Expr a, Expr b);
+Expr eq(Expr a, Expr b);
+Expr gt(Expr a, Expr b);
+Expr ge(Expr a, Expr b);
+Expr ne(Expr a, Expr b);
+Expr le(Expr a, Expr b);
+
+Expr lt(float a, Expr b);
+Expr eq(float a, Expr b);
+Expr gt(float a, Expr b);
+Expr ge(float a, Expr b);
+Expr ne(float a, Expr b);
+Expr le(float a, Expr b);
+
+Expr lt(Expr a, float b);
+Expr eq(Expr a, float b);
+Expr gt(Expr a, float b);
+Expr ge(Expr a, float b);
+Expr ne(Expr a, float b);
+Expr le(Expr a, float b);
Expr dot(Expr a,
Expr b,
@@ -89,16 +117,23 @@ Expr affine(Expr a,
bool transB = false,
float scalar = 1.f);
+Expr csr_dot(const Shape& A_shape, Expr Avalues, Expr Aindices, Expr Aoffsets, Expr B, bool transA = false);
+Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB = false);
+
Expr transpose(Expr a);
Expr transpose(Expr a, const std::vector<int>& axes);
Expr swapAxes(Expr x, int axis1, int axis2);
+Expr cast(Expr a, Type type = Type::float32);
+
Expr concatenate(const std::vector<Expr>& concats, int ax = 0);
Expr repeat(Expr a, size_t repeats, int ax = 0);
Expr reshape(Expr a, Shape shape);
+Expr clipGradient(Expr a, float clipValue);
+
Expr atleast_1d(Expr a);
Expr atleast_2d(Expr a);
Expr atleast_3d(Expr a);
@@ -106,23 +141,64 @@ Expr atleast_4d(Expr a);
Expr atleast_nd(Expr a, size_t dims);
// create a constant of shape a->shape() and initialize with init
-Expr constant_like(Expr a, const NodeInitializer& init);
+// @TODO: add a && version, to avoid a ref count. NodeInitializers are typically temps.
+// @TODO: and/or make this a template on init
+static inline Expr constant_like(Expr a, const Ptr<inits::NodeInitializer>& init) {
+ return a->graph()->constant(a->shape(), init, a->value_type());
+}
+
+// short-cut to init from std::vector, since we do this so often
+template<typename ElementType>
+Expr constant_like(Expr a, const std::vector<ElementType>& v) { return constant_like(a, inits::fromVector(std::move(v))); }
+template<typename ElementType>
+Expr constant_like(Expr a, std::vector<ElementType>&& v) { return constant_like(a, inits::fromVector(v)); }
Expr flatten(Expr a);
Expr flatten_2d(Expr a);
-Expr rows(Expr a, Expr indices);
-Expr rows(Expr a, const std::vector<IndexType>& indices);
+Expr stopGradient(Expr a);
+
+Expr gather(Expr a, int axis, Expr indices);
+
+// Warning: Don't try to pass a scalar literal 0 as indices; it will compile but pass nullptr...
+Expr index_select(Expr a, int axis, Expr indices);
+
+// convenience wrappers for index_select()
+Expr index_select(Expr a, int axis, const std::vector<IndexType>& indices);
+static inline Expr rows(Expr a, Expr indices) {
+ return index_select(a, 0, indices);
+}
+static inline Expr rows(Expr a, const std::vector<IndexType>& indexVector) {
+ return index_select(a, 0, indexVector);
+}
+static inline Expr cols(Expr a, Expr indices) {
+ return index_select(a, -1, indices);
+}
+static inline Expr cols(Expr a, const std::vector<IndexType>& indexVector) {
+ return index_select(a, -1, indexVector);
+}
+
+Expr slice(Expr a, int axis, Slice slice);
-Expr cols(Expr a, Expr indices);
-Expr cols(Expr a, const std::vector<IndexType>& indices);
+// convenience wrappers for slice()
+static inline Expr slice(Expr a, int axis, int index) { // single index @NOTE: This was formerlly called step()
+ return slice(a, axis, Slice(index));
+}
-Expr select(Expr a, Expr indices, int axis);
-Expr select(Expr a, const std::vector<IndexType>& indices, int axis);
+static inline Expr narrow(Expr a, int axis, size_t start, size_t length) { // PyTorch name
+ return slice(a, axis, Slice((int)start, (int)(start + length)));
+}
/*********************************************************/
Expr sum(Expr a, int ax = 0);
+Expr mean(Expr a, int ax = 0);
+Expr std(Expr a, int ax);
+Expr var(Expr a, int ax);
+Expr max(Expr a, int ax);
+Expr min(Expr a, int ax);
+Expr prod(Expr a, int ax);
+Expr logsumexp(Expr a, int ax);
Expr softmax(Expr x, int axis = -1);
@@ -132,16 +208,14 @@ Expr softmax(Expr a, Expr zeroOneMask, int axis = -1);
Expr logsoftmax(Expr a);
-Expr mean(Expr a, int ax = 0);
-
Expr cross_entropy(Expr a, Expr b);
+Expr unlikelihood(Expr a, Expr b);
+
Expr scalar_product(Expr a, Expr b, int ax = 0);
Expr weighted_average(Expr in, Expr weights, int ax = 0);
-Expr step(Expr a, int step, int axis);
-
Expr sqrt(Expr a, float eps = 0.f);
Expr square(Expr a);
@@ -151,14 +225,17 @@ Expr highway(Expr y, Expr x, Expr t);
Expr highway(const std::string prefix, Expr x);
static inline Expr dropout(Expr x, Expr mask) {
- return x * mask;
+ if (mask)
+ return x * mask;
+ else
+ return x;
}
static inline Expr dropout(Expr x, float dropProb, Shape shape) {
if(dropProb == 0)
return x;
auto graph = x->graph();
- auto mask = graph->dropout(dropProb, shape);
+ auto mask = graph->dropoutMask(dropProb, shape);
return dropout(x, mask);
}
diff --git a/src/graph/node.cpp b/src/graph/node.cpp
index c11531da..c15c4eb6 100755
--- a/src/graph/node.cpp
+++ b/src/graph/node.cpp
@@ -8,18 +8,25 @@ namespace marian {
size_t Node::allocate() {
size_t elements = 0;
if(!val_) {
- graph()->allocateForward(shared_from_this());
+ graph()->allocateForward(this);
elements = val_->shape().elements();
}
return elements;
}
void Node::free() {
- if(graph()) {
- if(val_)
- graph()->free(val_);
- if(adj_)
- graph()->free(adj_);
+ if(destroy_) { // don't free views, @TODO: better naming
+ //std::cerr << "Freeing" << std::endl;
+ if(graph()) {
+ if(val_) {
+ graph()->free(val_);
+ val_ = nullptr;
+ }
+ if(adj_) {
+ graph()->free(adj_);
+ adj_ = nullptr;
+ }
+ }
}
}
@@ -30,7 +37,7 @@ void Node::free() {
*/
void Node::init_dependent() {
if(!adj_) {
- graph()->allocateBackward(shared_from_this());
+ graph()->allocateBackward(this);
adj_->set(1.f);
}
}
@@ -43,7 +50,7 @@ void Node::init_dependent() {
*/
void Node::set_zero_adjoint() {
if(!adj_) {
- graph()->allocateBackward(shared_from_this());
+ graph()->allocateBackward(this);
adj_->set(0.f);
}
}
diff --git a/src/graph/node.h b/src/graph/node.h
index 1397e74b..c017eeb2 100755..100644
--- a/src/graph/node.h
+++ b/src/graph/node.h
@@ -17,8 +17,7 @@ namespace marian {
* implements most common functions demanded by Chainable.
* Each operation in a computation graph is a node.
*/
-class Node : public Chainable<Tensor>,
- public std::enable_shared_from_this<Node> {
+class Node : public Chainable<Tensor> {
protected:
size_t id_{0};
size_t edges_{0};
@@ -30,7 +29,7 @@ protected:
Weak<ExpressionGraph> graph_;
Shape shape_{1, 1, 1, 1};
- Type value_type_{Type::float32};
+ Type valueType_{Type::float32};
std::string name_{"none"};
@@ -40,18 +39,19 @@ protected:
bool markedForDebug_{false};
std::string debugMessage_;
+ Ptr<std::list<Expr>> subtape_; // a subtape is used to keep track of nodes that need to be freed and recomputed with gradient-checkpointing.
+ bool isCheckpoint_{false}; // true if this node has been selected to be a checkpoint, currently only done manually.
+
Ptr<AutoTunerRecorder> recorder_;
size_t recorderHash_;
bool recorderStop_;
public:
- Node(Ptr<ExpressionGraph> graph, Shape shape, Type value_type = Type::float32)
- : graph_(graph), shape_(shape), value_type_(value_type) {}
+ Node(Ptr<ExpressionGraph> graph, const Shape& shape, const Type& valueType = Type::float32)
+ : graph_(graph), shape_(shape), valueType_(valueType) {}
virtual ~Node() {
- if(destroy_) {
- free();
- }
+ free();
}
virtual float scalar() override;
@@ -85,7 +85,7 @@ public:
virtual void setId(size_t id) override { id_ = id; }
virtual size_t getId() override { return id_; }
-
+
virtual void increaseEdges(size_t edges = 1) { edges_ += edges; };
virtual void decreaseEdges(size_t edges = 1) { edges_ -= edges; };
virtual size_t edges() { return edges_; };
@@ -104,7 +104,7 @@ public:
virtual void free() override;
- virtual void init() override{};
+ virtual void init() override {};
virtual void init_dependent() override;
@@ -115,7 +115,7 @@ public:
virtual Tensor& grad() override { return adj_; };
virtual const Shape& shape() override { return shape_; }
- virtual const Type& value_type() override { return value_type_; }
+ virtual const Type& value_type() override { return valueType_; }
void set_name(const std::string& name) override { name_ = name; }
@@ -138,10 +138,21 @@ public:
virtual std::string graphviz() override {
std::stringstream ss;
- ss << "\"" << this << "\" [shape=\"" << form() << "\", label=" << label()
- << ", style=\"filled\", fillcolor=\"" << color() << "\"]" << std::endl;
+ ss << "\"" << this << "\" ["
+ << "shape=\"" << form() << "\", "
+ << "label=" << label() << ", "
+ << "style=\"filled\", "
+ << (isCheckpoint_ ? "penwidth=3, " : "penwidth=1, ")
+ << "fillcolor=\"" << color() << "\"];" << std::endl;
+
for(auto&& child : children())
- ss << "\"" << child << "\" -> \"" << this << "\"" << std::endl;
+ ss << "\"" << child << "\" -> \"" << this << "\";" << std::endl;
+
+ if(subtape_) {
+ for(auto&& dep : *subtape_)
+ ss << "\"" << dep << "\" -> \"" << this << "\" [style=dotted];" << std::endl;
+ }
+
ss << std::endl;
return ss.str();
}
@@ -153,15 +164,55 @@ public:
Ptr<Backend> getBackend();
void record(Ptr<AutoTunerRecorder>, size_t, bool) override;
+
+ // this is currently only called manually by checkpoint(Expr). In the future we will figure out a general algorithm
+ virtual void markCheckpoint() override {
+ isCheckpoint_ = true;
+ }
+
+ virtual bool isCheckpoint() const override {
+ return (children_.empty() || isCheckpoint_); // this node is a checkPoint if it's a leaf or if it has been marked.
+ }
+
+ virtual void setSubtape(Ptr<std::list<Expr>> subtape) override {
+ subtape_ = subtape;
+ }
+
+ virtual Ptr<std::list<Expr>> getSubtape() override {
+ return subtape_;
+ };
};
struct NaryNodeOp : public Node {
size_t hash_{0};
+ // Deduce type automatically, but then all types must be the same
+ // this is called automatically when no output type is specified.
+ // If the input types are mixed, the output type needs to be specified
+ // in the constructor.
+ static Type commonType(const std::vector<Expr>& nodes) {
+ ABORT_IF(nodes.size() == 0, "NaryNodeOp has no children");
+ Type type = nodes[0]->value_type();
+ for(int i = 1; i < nodes.size(); ++i)
+ ABORT_IF(nodes[i]->value_type() != type,
+ "Child {} has different type (first: {} != child: {})",
+ i, type, nodes[i]->value_type());
+ return type;
+ }
+
+ NaryNodeOp(const std::vector<Expr>& nodes)
+ : NaryNodeOp(nodes, nodes[0]->shape()) {}
+
+ // this contructor will try to deduce the node type automatically
+ NaryNodeOp(const std::vector<Expr>& nodes, Shape shape)
+ : NaryNodeOp(nodes, shape, commonType(nodes)) {}
+
+ // this contructor will takes a node type
NaryNodeOp(const std::vector<Expr>& nodes,
Shape shape,
- Type value_type = Type::float32)
+ Type value_type)
: Node(nodes.front()->graph(), shape, value_type) {
+
children_.resize(nodes.size());
for(size_t i = 0; i < nodes.size(); ++i)
children_[i] = nodes[i];
@@ -174,9 +225,6 @@ struct NaryNodeOp : public Node {
nodes.begin(), nodes.end(), [](Expr a) { return a->memoize(); }));
}
- NaryNodeOp(const std::vector<Expr>& nodes)
- : NaryNodeOp(nodes, nodes[0]->shape()) {}
-
virtual ~NaryNodeOp() {}
std::vector<Expr>& children() override { return children_; }
@@ -185,6 +233,7 @@ struct NaryNodeOp : public Node {
if(!hash_) {
std::size_t seed = util::hash<std::string>()(name());
util::hash_combine(seed, type());
+ util::hash_combine(seed, (size_t)value_type());
for(size_t i = 0; i < children_.size(); ++i)
util::hash_combine(seed, child(i)->hash());
hash_ = seed;
@@ -195,14 +244,18 @@ struct NaryNodeOp : public Node {
virtual bool equal(Expr node) override {
if(type() != node->type())
return false;
- if(name() != node->name())
+ else if(name() != node->name())
+ return false;
+ else if(value_type() != node->value_type())
return false;
- if(children().size() != node->children().size())
+ else if(children().size() != node->children().size())
return false;
- for(size_t i = 0; i < children().size(); ++i)
- if(children()[i]->getId() != node->children()[i]->getId())
- return false;
- return true;
+ else {
+ for(size_t i = 0; i < children().size(); ++i)
+ if(children()[i]->getId() != node->children()[i]->getId())
+ return false;
+ return true;
+ }
}
};
} // namespace marian
diff --git a/src/graph/node_initializers.cpp b/src/graph/node_initializers.cpp
index dc71562d..27796d33 100755
--- a/src/graph/node_initializers.cpp
+++ b/src/graph/node_initializers.cpp
@@ -11,117 +11,165 @@ namespace marian {
namespace inits {
-void zeros(Tensor t) {
- t->set(0.f);
+class LambdaInit : public NodeInitializer {
+ private:
+ std::function<void(Tensor)> lambda_;
+
+ public:
+ LambdaInit(std::function<void(Tensor)>&& lambda) : lambda_(std::move(lambda)) {}
+
+ void apply(Tensor tensor) override {
+ lambda_(tensor);
+ }
+};
+
+class LambdaInitConvert : public NodeInitializer {
+ private:
+ std::function<void(Tensor)> lambda_;
+ Type intermediateType_; // is used for the creation of a temporary intermedia tensor on which the lambda actually operates.
+ // This tensor is then automatically cast and copied to the type of the actual tensor.
+
+ public:
+ LambdaInitConvert(std::function<void(Tensor)>&& lambda,
+ Type intermediateType)
+ : lambda_(std::move(lambda)), intermediateType_(intermediateType) {}
+
+ void apply(Tensor tensor) override {
+ if(tensor->type() != intermediateType_) {
+ auto sharedAllocator = allocator_.lock();
+ ABORT_IF(!sharedAllocator, "Allocator in LambdaInitConvert has not been set or expired");
+
+ auto memory = sharedAllocator->alloc(requiredBytes(tensor->shape(), intermediateType_));
+ auto temp = TensorBase::New(memory,
+ tensor->shape(),
+ intermediateType_,
+ tensor->getBackend());
+ lambda_(temp);
+ CopyCast(tensor, temp); // Cast and copy from temp to tensor
+ sharedAllocator->free(memory);
+ }
+ else {
+ lambda_(tensor);
+ }
+ }
+};
+
+Ptr<NodeInitializer> fromLambda(std::function<void(Tensor)>&& func) {
+ return New<LambdaInit>(std::move(func));
}
-void ones(Tensor t) {
- t->set(1.0f);
+Ptr<NodeInitializer> fromLambda(std::function<void(Tensor)>&& func, Type intermediateType) {
+ return New<LambdaInitConvert>(std::move(func), intermediateType);
}
-NodeInitializer from_value(float v) {
- return [v](Tensor t) { t->set(v); };
+Ptr<NodeInitializer> fromValue(float v) {
+ return fromLambda([v](Tensor t){ t->set(v); });
}
// diagonal matrix with value val along diagonal
-NodeInitializer eye(float val) {
- return [val](Tensor t) {
+Ptr<NodeInitializer> eye(float val) {
+ auto eyeLambda = [val](Tensor t) {
ABORT_IF(t->shape().size() != 2 || t->shape()[-1] != t->shape()[-2],
- "eye(val) is defined only for quadratic tensors, shape is {}",
- t->shape());
+ "eye(val) is defined only for quadratic tensors, shape is {}",
+ t->shape());
// @TODO: implement efficient version on the GPU
std::vector<float> vec(t->size(), 0);
for(int i = 0; i < t->shape()[-1]; ++i)
vec[i * t->shape()[0] + i] = val;
+
t->set(vec);
};
+
+ return fromLambda(eyeLambda, Type::float32);
}
-NodeInitializer uniform(float a, float b) {
- return [a, b](Tensor tensor) {
- tensor->getBackend()->getRandomGenerator()->uniform(tensor, a, b);
- };
+Ptr<NodeInitializer> uniform(float a, float b) {
+ // only works for float, hence the conversion through intermedia type Type::float32
+ return fromLambda([a, b](Tensor t) { t->getBackend()->getRandomGenerator()->uniform(t, a, b); }, Type::float32);
}
-NodeInitializer normal(float mean, float stddev) {
- return [mean, stddev](Tensor tensor) {
- tensor->getBackend()->getRandomGenerator()->normal(tensor, mean, stddev);
- };
+Ptr<NodeInitializer> normal(float mean, float stddev) {
+ // only works for float, hence the conversion through intermedia type Type::float32
+ return fromLambda([mean, stddev](Tensor t) { t->getBackend()->getRandomGenerator()->normal(t, mean, stddev); }, Type::float32);
}
-void glorot_uniform(Tensor tensor) {
- float scale = sqrtf(6.0f / (tensor->shape()[-2] + tensor->shape()[-1]));
- uniform(-scale, scale)(tensor);
+Ptr<NodeInitializer> glorotUniform(bool fanIn, bool fanOut, float scalingFactor) {
+ return fromLambda([fanIn, fanOut, scalingFactor](Tensor t) {
+ float scale = sqrtf(6.0f / (t->shape()[-2] + t->shape()[-1]));
+ if(fanIn && !fanOut)
+ scale = sqrtf(3.0f / t->shape()[-2]); // results in columns of matrix to be ~unit length
+ if(!fanIn && fanOut)
+ scale = sqrtf(3.0f / t->shape()[-1]);
+
+ scale *= scalingFactor;
+
+ t->getBackend()->getRandomGenerator()->uniform(t, -scale, scale);
+ }, Type::float32);
}
-void glorot_normal(Tensor tensor) {
- float scale = sqrtf(2.0f / (tensor->shape()[-2] + tensor->shape()[-1]));
- normal(0.f, scale)(tensor);
+Ptr<NodeInitializer> glorotNormal(bool fanIn, bool fanOut, float scalingFactor) {
+ return fromLambda([fanIn, fanOut, scalingFactor](Tensor t) {
+ float scale = sqrtf(2.0f / (t->shape()[-2] + t->shape()[-1]));
+ if(fanIn && !fanOut)
+ scale = sqrtf(1.0f / t->shape()[-2]);
+ if(!fanIn && fanOut)
+ scale = sqrtf(1.0f / t->shape()[-1]);
+
+ scale *= scalingFactor;
+
+ t->getBackend()->getRandomGenerator()->normal(t, 0.f, scale);
+ }, Type::float32);
}
-NodeInitializer bernoulli(float prob, float scale) {
- return [prob, scale](Tensor tensor) {
- Bernoulli(tensor, prob, scale);
- };
+Ptr<NodeInitializer> bernoulli(float prob, float scale, float shift) {
+ return fromLambda([prob, scale, shift](Tensor t) { Bernoulli(t, prob, scale, shift); }, Type::float32);
}
-NodeInitializer dropout(float dropProb) {
- return [dropProb](Tensor t) {
- Dropout(t, dropProb);
- };
+Ptr<NodeInitializer> dropout(float dropProb) {
+ return fromLambda([dropProb](Tensor t) { Dropout(t, dropProb); }, Type::float32);
}
// gumbel noise:
// -log(-log(uniform(0.f + eps, 1.f - eps)));
-void gumbel(Tensor tensor) {
- using namespace functional;
- // @TODO: make eps a parameter? Seems to influence amplitude quite heavily
- float eps = 1e-05f;
- uniform(0.f + eps, 1.f - eps)(tensor);
- Element(_1 = -log(-log(_1)), tensor);
-}
-
-NodeInitializer from_vector(const std::vector<float>& v) {
- auto vPtr = New<std::vector<float>>(v.begin(), v.end());
- return
- [vPtr](Tensor t) { t->set(vPtr->data(), vPtr->data() + vPtr->size()); };
-}
-
-// @TODO: handle this better with proper type support, the NodeInitializer
-// should be able to inform the calling function about the tensor type it
-// is expecting. Probably needs to turn into struct with type information.
-NodeInitializer from_vector(const std::vector<IndexType>& v) {
- auto vPtr = New<std::vector<IndexType>>(v.begin(), v.end());
- return
- [vPtr](Tensor t) { t->set(vPtr->data(), vPtr->data() + vPtr->size()); };
-}
-
-NodeInitializer from_sparse_vector(
- std::pair<std::vector<size_t>, std::vector<float>>& v) {
- return [v](Tensor t) {
- t->set(1e-6);
- t->setSparse(v.first, v.second);
- };
+Ptr<NodeInitializer> gumbel(float eps) {
+ return fromLambda([eps](Tensor tensor) {
+ tensor->getBackend()->getRandomGenerator()->uniform(tensor, 0.f + eps, 1.f - eps);
+ using namespace functional;
+ Element(_1 = -log(-log(_1)), tensor);
+ }, Type::float32);
+}
+
+template <typename T>
+Ptr<NodeInitializer> fromVector(const std::vector<T>& v) {
+ return fromLambda([v](Tensor t) { t->set(v.data(), v.data() + v.size()); }, typeId<T>());
+}
+
+template <typename T>
+Ptr<NodeInitializer> fromVector(std::vector<T>&& v) {
+ return fromLambda([v](Tensor t) { t->set(v.data(), v.data() + v.size()); }, typeId<T>());
}
-// NodeInitializer from_numpy(const cnpy::NpyArrayPtr& np) {
-// return [np](Tensor t) {
-// size_t size = 1;
-// for(size_t dim : np->shape)
-// size *= dim;
-// t->set((float*)np->data(), (float*)np->data() + size);
-// };
-//}
+template Ptr<NodeInitializer> fromVector<float16>(const std::vector<float16>& v);
+template Ptr<NodeInitializer> fromVector<float>(const std::vector<float>& v);
+template Ptr<NodeInitializer> fromVector<IndexType>(const std::vector<IndexType>& v);
+
+// @TODO: can we remove the const& ones above? They always make a copy anyways, and often from a temp
+template Ptr<NodeInitializer> fromVector<float16> (std::vector<float16> && v);
+template Ptr<NodeInitializer> fromVector<float> (std::vector<float> && v);
+template Ptr<NodeInitializer> fromVector<IndexType>(std::vector<IndexType>&& v);
+
+Ptr<NodeInitializer> fromSparseVector(std::pair<std::vector<size_t>, std::vector<float>>& v) {
+ return fromLambda([v](Tensor t) { t->set(1e-6); t->setSparse(v.first, v.second); });
+}
// move this somewhere else
-NodeInitializer from_word2vec(const std::string& file,
+Ptr<NodeInitializer> fromWord2vec(const std::string& file,
int dimVoc,
int dimEmb,
bool normalize /*= false*/) {
- return [file, dimVoc, dimEmb, normalize](Tensor t) {
+ return fromLambda([file, dimVoc, dimEmb, normalize](Tensor t) {
auto embs = Word2VecReader().read(file, dimVoc, dimEmb);
-
if(normalize) {
float norm = 0;
for(auto e : embs)
@@ -132,31 +180,59 @@ NodeInitializer from_word2vec(const std::string& file,
e = e / norm;
}
t->set(embs);
- };
+ });
}
-NodeInitializer from_item(const io::Item& item) {
+Ptr<NodeInitializer> fromItem(const io::Item& item) {
if(item.mapped) {
- return [item](Tensor t) {
+ return fromLambda([item](Tensor tensor) {
// @TODO: implement other types, for now croak loudly.
- ABORT_IF(t->getBackend()->getDeviceId().type != DeviceType::cpu,
+ ABORT_IF(tensor->getBackend()->getDeviceId().type != DeviceType::cpu,
"Memory mapping only works for CPU tensors");
- ABORT_IF(!matchType<float>(t->type()),
- "Tensor type and type for mapping do not match");
- auto mp = New<MemoryPiece>((uint8_t*)item.ptr, t->size() * sizeof(float));
- t->reset(mp);
- };
+ ABORT_IF(tensor->type() != item.type,
+ "Tensor type ({}) and type for mapping ({}) do not match",
+ tensor->type(),
+ item.type);
+ ABORT_IF(tensor->shape() != item.shape,
+ "Tensor shape ({}) and shape of mapped item ({}) do not match",
+ tensor->shape(),
+ item.shape);
+ auto mp = MemoryPiece::New((uint8_t*)item.ptr, item.size()); // @TODO: this is not properly aligned now
+ tensor->reset(mp);
+ });
} else {
- return [item](Tensor t) {
- // @TODO: implement other types, for now croak loudly.
- ABORT_IF(!matchType<float>(t->type()),
- "Tensor type and type for mapping do not match");
- t->set((const float*)item.bytes.data(),
- (const float*)item.bytes.data() + t->size());
- };
+ return fromLambda(
+ [item](Tensor tensor) { tensor->set(item); },
+ item.type);
}
}
+Ptr<NodeInitializer> fromTensor(Tensor externalTensor) {
+ return fromLambda([externalTensor](Tensor t) { t->copyFrom(externalTensor); }, externalTensor->type());
+}
+
+// Computes Google's sinusoidal position embeddings
+Ptr<NodeInitializer> sinusoidalPositionEmbeddings(int start) {
+ return fromLambda([start](Tensor t) {
+ int dimEmb = t->shape()[-1];
+ int dimWords = (int)t->size() / dimEmb;
+
+ float numTimescales = (float)dimEmb / 2;
+ float logTimescaleIncrement = std::log(10000.f) / (numTimescales - 1.f);
+
+ std::vector<float> vPos(dimEmb * dimWords, 0);
+ for(int p = start; p < dimWords + start; ++p) {
+ for(int i = 0; i < numTimescales; ++i) {
+ float v = p * std::exp(i * -logTimescaleIncrement);
+ vPos[(p - start) * dimEmb + i ] = std::sin(v);
+ vPos[(p - start) * dimEmb + (int)numTimescales + i] = std::cos(v); // @TODO: is int vs. float correct for num_timescales?
+ }
+ }
+
+ t->set(vPos);
+ }, Type::float32);
+}
+
} // namespace inits
} // namespace marian
diff --git a/src/graph/node_initializers.h b/src/graph/node_initializers.h
index 58690e4f..21dcc95d 100755
--- a/src/graph/node_initializers.h
+++ b/src/graph/node_initializers.h
@@ -3,54 +3,179 @@
#include "common/config.h"
#include "tensors/tensor.h"
+#include "tensors/tensor_operators.h"
#include <functional>
#include <random>
namespace marian {
-typedef std::function<void(Tensor)> NodeInitializer;
+class ExpressionGraph; // Forward declaration
namespace inits {
-void zeros(Tensor t);
+/**
+ * Base class for specialized NodeInitializers.
+ *
+ * A NodeInitializer is a functor that is associated with parameters
+ * and constants, and is invoked on a tensor during node intialization.
+ * You need to override NodeIntializer::apply(Tensor) with your own
+ * functionality or use a fromLambda intializer.
+ *
+ * See node_initializers.cpp for examples.
+ */
+class NodeInitializer {
+protected:
+ Weak<Allocator> allocator_;
+
+public:
+ virtual void apply(Tensor t) = 0;
+ void setAllocator(Ptr<Allocator> allocator) { allocator_ = allocator; }
+ virtual ~NodeInitializer() {}
+};
+
+/**
+ * Use a lambda function of form [](Tensor t) { do something with t } to initalize tensor
+ */
+Ptr<NodeInitializer> fromLambda(std::function<void(Tensor)>&& func);
+
+/**
+ * Use a lambda function of form [](Tensor t) { do something with t } to initalize tensor
+ * Create temporary tensor of Type intermediateType first, initialize and then copy/convert to actual Tensor
+ * Useful for functions that can only operate on a specific type of tensor
+ */
+Ptr<NodeInitializer> fromLambda(std::function<void(Tensor)>&& func, Type intermediateType);
+
+/**
+ * Initialize tensor with given value
+ *
+ * Creates a NodeInitializer that will intialize the given tensor
+ * with `value`. Works with any underlying numeric tensor type.
+ *
+ * @return A NodeInitializer
+ */
+Ptr<NodeInitializer> fromValue(float value);
+
+/**
+ * Fill tensor with `0`
+ *
+ * Creates a NodeInitializer that will intialize the given tensor
+ * with `0`. Works with any underlying numeric tensor type.
+ *
+ * @return A NodeInitializer
+ */
+static Ptr<NodeInitializer> zeros() { return fromValue(0.0f); }
+
+/**
+ * Fill tensor with `1`
+ *
+ * Creates a NodeInitializer that will intialize the given tensor
+ * with `1`. Works with any underlying numeric tensor type.
+ *
+ * @return A NodeInitializer
+ */
+static Ptr<NodeInitializer> ones() { return fromValue(1.0f); }
+
+/**
+ * Set diagonal of two dimensional quadratic matrix to `value`.
+ *
+ * Sets all values of the tensor to 0 and intializes the diagonal with
+ * the given `value`. If no value is specified `1` is used by default.
+ *
+ * @return A NodeInitializer
+ */
+Ptr<NodeInitializer> eye(float value = 1.f);
+
+/**
+ * Intialize tensor with normally distributed random numbers
+ *
+ * Be default this generates floating point numbers from the
+ * normal distribution Normal(0, 1) unless specified differently.
+ *
+ * If compiled with `CUDA`, `marian` will use the `cuRand` library
+ * for both, GPU and CPU computation. The random sequences generated
+ * are the same on both devices.
+ *
+ * If `marian` is compiled without `CUDA`, a random generator
+ * from the C++ standard library is used. These random generators
+ * do not have the same random sequences.
+ *
+ * @return A NodeInitializer
+ */
+Ptr<NodeInitializer> normal(float mean = 0.f, float stddev = 1.f);
+
+/**
+ * Intialize tensor with uniformly distributed random numbers
+ *
+ * Be default this generates floating point numbers from the
+ * uniform distribution Uniform(0, 1) unless specified differently.
+ *
+ * If compiled with `CUDA`, `marian` will use the `cuRand` library
+ * for both, GPU and CPU computation. The random sequences generated
+ * are the same on both devices.
+ *
+ * If `marian` is compiled without `CUDA`, a random generator
+ * from the C++ standard library is used. These random generators
+ * do not have the same random sequences.
+ *
+ * @return A NodeInitializer
+ */
+Ptr<NodeInitializer> uniform(float a = 0.f, float b = 1.f);
+
+// @TODO: add documentation
+Ptr<NodeInitializer> bernoulli(float p, float scale = 1.f, float shift = 0.f);
+
+// @TODO: add documentation
+Ptr<NodeInitializer> glorotUniform(bool fanIn = false, bool fanOut = false, float scale = 1.f);
+
+// @TODO: add documentation
+Ptr<NodeInitializer> glorotNormal(bool fanIn = false, bool fanOut = false, float scale = 1.f);
+
+// @TODO: add documentation
+Ptr<NodeInitializer> dropout(float dropoutProbabilty);
+
+/**
+ * Intialize with gumbel noise, i.e. -log(-log(u)) where u ~ Uniform(0 + eps, 1 - eps)
+ *
+ * @return A NodeInitializer
+ */
+Ptr<NodeInitializer> gumbel(float eps = 1e-5f);
+
+// @TODO: add documentation
+template <typename T>
+Ptr<NodeInitializer> fromVector(const std::vector<T>& v);
+template <typename T>
+Ptr<NodeInitializer> fromVector(std::vector<T>&& v);
+
+// @TODO: add documentation
+Ptr<NodeInitializer> fromSparseVector(std::pair<std::vector<size_t>, std::vector<float>>& v);
+
+// @TODO: add documentation
+Ptr<NodeInitializer> fromItem(const io::Item& item);
+
+// @TODO: add documentation
+Ptr<NodeInitializer> fromTensor(Tensor tensor);
+
+// @TODO: add documentation
+Ptr<NodeInitializer> fromWord2vec(const std::string& file,
+ int dimVoc,
+ int dimEmb,
+ bool normalize = false);
+
+/**
+ * Computes Google's Transformer-style sinusoidal position embeddings
+ * starting from position 'start' taking into account batch and time
+ * dimensions of the tensor.
+ *
+ * Expected tensor layout {-2: time, -1: model}
+ *
+ * Usually gets later reshaped to {time, 1, model} and
+ * added with a broadcast to learned embeddings. Positional
+ * embeddings are the same for each batch entry and change
+ * over time steps.
+ */
+Ptr<NodeInitializer> sinusoidalPositionEmbeddings(int start);
-void ones(Tensor t);
-
-NodeInitializer from_value(float v);
-
-NodeInitializer eye(float val = 1.f);
-
-NodeInitializer normal(float mean = 0.f, float stddev = 1.f);
-
-NodeInitializer uniform(float a = 0.f, float b = 1.f);
-
-void glorot_uniform(Tensor t);
-
-void glorot_normal(Tensor t);
-
-NodeInitializer bernoulli(float p, float scale = 1.f);
-
-NodeInitializer dropout(float dropProb);
-
-void gumbel(Tensor t);
-
-static inline void dummy(Tensor) {}
-
-NodeInitializer from_vector(const std::vector<float>& v);
-NodeInitializer from_vector(const std::vector<IndexType>& v);
-
-NodeInitializer from_item(const io::Item& item);
-
-NodeInitializer from_sparse_vector(
- std::pair<std::vector<size_t>, std::vector<float>>& v);
-
-// NodeInitializer from_numpy(const cnpy::NpyArrayPtr& np);
-
-NodeInitializer from_word2vec(const std::string& file,
- int dimVoc,
- int dimEmb,
- bool normalize = false);
} // namespace inits
} // namespace marian
diff --git a/src/graph/node_operators.cpp b/src/graph/node_operators.cpp
index 00e0a319..932fee88 100644
--- a/src/graph/node_operators.cpp
+++ b/src/graph/node_operators.cpp
@@ -5,10 +5,21 @@
namespace marian {
+ConstantNode::ConstantNode(Ptr<ExpressionGraph> graph,
+ const Shape& shape,
+ const Ptr<inits::NodeInitializer>& init,
+ Type valueType)
+ : Node(graph, shape, valueType),
+ init_(init),
+ initialized_(false) {
+ init_->setAllocator(graph->allocator());
+ setTrainable(false);
+}
+
size_t ConstantNode::allocate() {
size_t elements = 0;
if(!val_) {
- graph()->allocateForward(shared_from_this());
+ graph()->allocateForward(this);
elements = val_->shape().elements();
}
return elements;
@@ -16,7 +27,7 @@ size_t ConstantNode::allocate() {
void ConstantNode::init() {
if(!initialized_) {
- (*init_)(val_);
+ init_->apply(val_);
initialized_ = true;
}
init_.reset();
@@ -24,18 +35,26 @@ void ConstantNode::init() {
ParamNode::ParamNode(Ptr<ExpressionGraph> graph,
const Shape& shape,
- const NodeInitializer& init,
+ const Ptr<inits::NodeInitializer>& init,
+ bool fixed)
+ : ParamNode(graph, shape, init, Type::float32, fixed) {}
+
+ParamNode::ParamNode(Ptr<ExpressionGraph> graph,
+ const Shape& shape,
+ const Ptr<inits::NodeInitializer>& init,
+ Type valueType,
bool fixed)
- : Node(graph, shape), // TODO: add value_type
- init_(new NodeInitializer(init)),
+ : Node(graph, shape, valueType),
+ init_(init),
initialized_(false) {
+ init_->setAllocator(graph->allocator());
setTrainable(!fixed);
setMemoize(graph->isInference());
}
void ParamNode::init() {
if(!initialized_) {
- (*init_)(val_);
+ init_->apply(val_);
initialized_ = true;
}
init_.reset();
diff --git a/src/graph/node_operators.h b/src/graph/node_operators.h
index 018d0533..7479bd69 100644
--- a/src/graph/node_operators.h
+++ b/src/graph/node_operators.h
@@ -9,13 +9,8 @@ namespace marian {
struct ConstantNode : public Node {
ConstantNode(Ptr<ExpressionGraph> graph,
const Shape& shape,
- const NodeInitializer& init,
- Type value_type = Type::float32)
- : Node(graph, shape, value_type),
- init_(new NodeInitializer(init)),
- initialized_(false) {
- setTrainable(false);
- }
+ const Ptr<inits::NodeInitializer>& init,
+ Type valueType = Type::float32);
~ConstantNode() {}
@@ -37,20 +32,26 @@ struct ConstantNode : public Node {
virtual void record(Ptr<AutoTunerRecorder>, size_t, bool) override{};
private:
- UPtr<NodeInitializer> init_;
+ Ptr<inits::NodeInitializer> init_;
bool initialized_;
};
struct ParamNode : public Node {
ParamNode(Ptr<ExpressionGraph> graph,
const Shape& shape,
- const NodeInitializer& init,
+ const Ptr<inits::NodeInitializer>& init,
+ bool fixed = false);
+
+ ParamNode(Ptr<ExpressionGraph> graph,
+ const Shape& shape,
+ const Ptr<inits::NodeInitializer>& init,
+ Type valueType,
bool fixed = false);
~ParamNode() {}
virtual size_t allocate() override {
- ABORT_IF(!val_, "Parameters should be allocated by their graph");
+ ABORT_IF(!val_, "Parameters should be allocated by their graph. Parameter {} was not", name_);
return 0;
}
@@ -72,7 +73,7 @@ struct ParamNode : public Node {
virtual void record(Ptr<AutoTunerRecorder>, size_t, bool) override{};
private:
- UPtr<NodeInitializer> init_;
+ Ptr<inits::NodeInitializer> init_;
bool initialized_;
};
} // namespace marian
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index 7da85443..63158ffa 100755
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -42,7 +42,7 @@ public:
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
- "matrix product requires dimensions to match");
+ "Matrix product requires inner dimensions to match");
return outShape;
}
@@ -63,7 +63,6 @@ public:
// df/dB += alpha * dot(op(A).T, D)
// beta set to 1.0 in gemm, C = alpha * dot(op(A), op(B)) + beta * C
// to sum gradients from different graph parts
-
if(!transA_ && transB_)
return {NodeOp(Prod(child(0)->grad(),
adj_,
@@ -128,7 +127,30 @@ public:
scalar_))};
}
- const std::string type() override { return "•"; }
+ const std::string type() override { return "dot"; }
+
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, transA_);
+ util::hash_combine(seed, transB_);
+ util::hash_combine(seed, scalar_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<DotNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(transA_ != cnode->transA_)
+ return false;
+ if(transB_ != cnode->transB_)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
const std::string color() override { return "orange"; }
};
@@ -165,7 +187,7 @@ public:
Shape outShape = shapeA;
outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
- "matrix product requires dimensions to match");
+ "Matrix product requires inner dimensions to match");
return outShape;
}
@@ -211,7 +233,6 @@ public:
scalar_)),
NodeOp(Prod(
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f))
- // NodeOp(Add(_1, child(2)->grad(), adj_))
};
if(transA_ && !transB_)
@@ -232,7 +253,6 @@ public:
scalar_)),
NodeOp(Prod(
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f))
- // NodeOp(Add(_1, child(2)->grad(), adj_))
};
if(transA_ && transB_)
@@ -253,7 +273,6 @@ public:
scalar_)),
NodeOp(Prod(
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f))
- // NodeOp(Add(_1, child(2)->grad(), adj_))
};
return {
@@ -273,11 +292,34 @@ public:
scalar_)),
NodeOp(Prod(
child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f))
- // NodeOp(Add(_1, child(2)->grad(), adj_))
};
}
const std::string type() override { return "affine"; }
+
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, transA_);
+ util::hash_combine(seed, transB_);
+ util::hash_combine(seed, scalar_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<AffineNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(transA_ != cnode->transA_)
+ return false;
+ if(transB_ != cnode->transB_)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
+
};
class DotBatchedNodeOp : public NaryNodeOp {
@@ -309,7 +351,7 @@ public:
Shape outShape = shapeA;
outShape.set(-1, shapeB[-1]);
ABORT_IF(shapeA[-1] != shapeB[-2],
- "matrix product requires dimensions to match");
+ "Batched matrix product requires inner dimensions to match");
return outShape;
}
@@ -404,7 +446,110 @@ public:
scalar_))};
}
- const std::string type() override { return "•"; }
+ const std::string type() override { return "bdot"; }
+
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, transA_);
+ util::hash_combine(seed, transB_);
+ util::hash_combine(seed, scalar_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<DotBatchedNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(transA_ != cnode->transA_)
+ return false;
+ if(transB_ != cnode->transB_)
+ return false;
+ if(scalar_ != cnode->scalar_)
+ return false;
+ return true;
+ }
+
+ const std::string color() override { return "orange"; }
+};
+
+// Note: To reduce code duplication, we use the same NodeOp for C = op(S) x D and C = D x op(S).
+// Set swapOperands to select the latter.
+class CSRDotNodeOp : public NaryNodeOp {
+ bool transS_;
+ bool swapOperands_;
+public:
+ CSRDotNodeOp(const Shape& S_shape, Expr S_values, Expr S_indices,
+ Expr S_offsets, Expr D, bool transS, bool swapOperands)
+ : NaryNodeOp({ S_values, S_indices, S_offsets, D },
+ newShape(S_shape, S_values, S_indices, S_offsets, D, transS, swapOperands),
+ NaryNodeOp::commonType({S_values, D})),
+ transS_(transS), swapOperands_(swapOperands) {
+ matchOrAbort<IndexType>(S_indices->value_type());
+ matchOrAbort<IndexType>(S_offsets->value_type());
+ }
+
+ Shape newShape(const Shape& S_shape, Expr S_values, Expr S_indices, Expr S_offsets, Expr D, bool transS, bool swapOperands) {
+ ABORT_IF(S_values->shape().size() != 1 || S_indices->shape().size() != 1 || S_offsets->shape().size() != 1,
+ "Sparse matrix components must all be vectors");
+ ABORT_IF(S_values->shape() != S_indices->shape(),
+ "Sparse matrix values and indices must have the same shape");
+ ABORT_IF(S_shape.size() != 2,
+ "Sparse matrix must have rank 2");
+ ABORT_IF(S_offsets->shape()[0] - 1 != S_shape[0],
+ "Sparse matrix offset vector has incorrect size");
+ auto outShape = D->shape();
+ ABORT_IF(S_shape[transS == swapOperands ? 1 : 0] != outShape[-(int)swapOperands],
+ "Matrix product requires inner dimensions to match");
+ outShape.set(-(int)swapOperands, S_shape[transS != swapOperands]);
+ return outShape;
+ }
+
+ NodeOps forwardOps() override {
+ return {NodeOp(CSRProd(val_,
+ graph()->allocator(),
+ child(0)->val(), child(1)->val(), child(2)->val(),
+ child(3)->val(),
+ /*transS=*/transS_, /*swapOperands=*/swapOperands_, /*beta=*/0))};
+ }
+
+ NodeOps backwardOps() override {
+ return { nullptr, // can't backprop into the sparse matrix (the gradient is dense)
+ nullptr,
+ nullptr,
+ NodeOp(CSRProd(child(3)->grad(), // child(3) = D
+ graph()->allocator(),
+ child(0)->val(), child(1)->val(), child(2)->val(), // children(0..2) = A
+ adj_,
+ /*transS=*/!transS_, /*swapOperands=*/swapOperands_, /*beta=*/1))};
+ }
+
+ const std::string type() override { return "csr_dot"; }
+
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ for(auto s : shape())
+ util::hash_combine(seed, s);
+ util::hash_combine(seed, transS_);
+ util::hash_combine(seed, swapOperands_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<CSRDotNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(transS_ != cnode->transS_)
+ return false;
+ if(shape() != cnode->shape())
+ return false;
+ if(swapOperands_ != cnode->swapOperands_)
+ return false;
+ return true;
+ }
const std::string color() override { return "orange"; }
};
@@ -460,7 +605,7 @@ struct ScalarProductNodeOp : public NaryNodeOp {
struct RowsNodeOp : public NaryNodeOp {
RowsNodeOp(Expr a, Expr indices)
- : NaryNodeOp({a, indices}, newShape(a, indices->shape().elements())) {
+ : NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) {
matchOrAbort<IndexType>(indices->value_type());
}
@@ -473,12 +618,11 @@ struct RowsNodeOp : public NaryNodeOp {
return {NodeOp(PasteRows(child(0)->grad(), adj_, child(1)->val()))};
}
- template <class... Args>
- Shape newShape(Expr a, size_t num) {
+ Shape newShape(Expr a, Expr indices) {
Shape shape = a->shape();
ABORT_IF(shape.size() != 2,
"rows operator can only be used with 2-dimensional tensors");
- shape.set(0, num);
+ shape.set(0, (int)indices->shape().elements());
return shape;
}
@@ -487,10 +631,10 @@ struct RowsNodeOp : public NaryNodeOp {
const std::string color() override { return "orange"; }
};
-// This operation indexes a tensor along an axis.
-// This is similar to the common gather() operation in other toolkits.
+// This operation gathers elements of a tensor along an axis.
+// This is like PyTorch gather().
// For example, this can be used for:
-// - Same index applied to all batch items (today's select()):
+// - Same index applied to all batch items:
// 'index' has 1 in the axes that match batch axes in the input, and axis set to the one axis that gets selected over.
// Example: Selecting Transformer head 0, i.e. return a[:,1,:,:]
// axis = -3
@@ -505,7 +649,7 @@ struct RowsNodeOp : public NaryNodeOp {
// idx: (#(B*S)#, 1) B=batch size, S=source length, idx values are in range 0..V-1
// out: ( (B*S) , E) out[b, s, e] == e[/*0,*/ idx[b, s, 0], e]
// - Batched selection (x-ent scenario): Both 'index' and 'data' have matching batch axes.
-// Example: Cross-entropy loss as -select(logSoftmax(logits), groundTruth, axis=-1):
+// Example: Cross-entropy loss as -gather(logSoftmax(logits), groundTruth, axis=-1):
// axis = -1
// lp : (B, T, V ) B=batch size, T=trg length, V=vocab size
// idx: (B, T, #1#) idx values are in range 0..V-1
@@ -520,14 +664,10 @@ struct RowsNodeOp : public NaryNodeOp {
// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
-// If 'a' and 'indices' do not have the same rank, then negative 'axis' is
-// interpreted relative to 'a', and 'indices' must have the resulting axis.
-// Broadcasting is supported as usual.
-// @TODO: The current implementation does not support batched indices (third scenario above).
-// I.e. all axes of 'indices' except 'axis' must have dimension 1.
-struct SelectNodeOp : public NaryNodeOp {
- SelectNodeOp(Expr a, Expr indices, int axis)
- : NaryNodeOp({a, indices}, newShape(a, indices, axis), a->value_type()),
+// 'a' and 'indices' must have the same rank.
+struct GatherNodeOp : public NaryNodeOp {
+ GatherNodeOp(Expr a, int axis, Expr indices)
+ : NaryNodeOp({a, indices}, newShape(a, axis, indices), a->value_type()),
axis_(a->shape().axis(axis)) {
matchOrAbort<IndexType>(indices->value_type());
}
@@ -542,24 +682,23 @@ struct SelectNodeOp : public NaryNodeOp {
Insert(child(0)->grad(), adj_, child(1)->val(), axis_))};
}
- Shape newShape(Expr a, Expr indices, int axis) {
- axis = a->shape().axis(axis);
- auto indicesRank = indices->shape().size();
- ABORT_IF(axis >= indicesRank, "Axis {} is invalid for indices shape {}", axis, std::string(indices->shape()));
+ Shape newShape(Expr a, int axis, Expr indices) {
Shape shape = a->shape();
- if (shape.size() < indicesRank) // pad
- shape.resize(indicesRank);
+ axis = shape.axis(axis);
+ auto rank = shape.size();
+ ABORT_IF(rank != indices->shape().size(), "Mismatching ranks for input ({}) and indices ({})", std::string(shape), std::string(indices->shape()));
+ axis = a->shape().axis(axis);
shape.set(axis, indices->shape()[axis]);
-#if 1 // presently, this implementation does not support batched indices
- for (size_t i = 0; i < indicesRank; ++i) {
- ABORT_IF(indices->shape()[i] != 1 && i + shape.size() - indicesRank != axis,
- "Presently, select() does not implement batched indices");
+ for (size_t i = 0; i < rank; ++i) {
+ if (i != axis) {
+ ABORT_IF(indices->shape()[i] != shape[i] && indices->shape()[i] != 1,
+ "Dimensions must match or broadcast for input ({}) and indices ({})", std::string(shape), std::string(indices->shape()));
+ }
}
-#endif
return shape;
}
- const std::string type() override { return "select"; }
+ const std::string type() override { return "gather"; }
const std::string color() override { return "orange"; }
@@ -575,7 +714,7 @@ struct SelectNodeOp : public NaryNodeOp {
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<SelectNodeOp> cnode = std::dynamic_pointer_cast<SelectNodeOp>(node);
+ auto cnode = std::dynamic_pointer_cast<GatherNodeOp>(node);
if(!cnode)
return false;
if(axis_ != cnode->axis_)
@@ -588,7 +727,7 @@ struct SelectNodeOp : public NaryNodeOp {
struct ColsNodeOp : public NaryNodeOp {
ColsNodeOp(Expr a, Expr indices)
- : NaryNodeOp({a, indices}, newShape(a, indices->shape().elements())) {
+ : NaryNodeOp({a, indices}, newShape(a, indices), a->value_type()) {
matchOrAbort<IndexType>(indices->value_type());
}
@@ -600,10 +739,9 @@ struct ColsNodeOp : public NaryNodeOp {
return {NodeOp(PasteCols(child(0)->grad(), adj_, child(1)->val()))};
}
- template <class... Args>
- Shape newShape(Expr a, size_t num) {
+ Shape newShape(Expr a, Expr indices) {
Shape shape = a->shape();
- shape.set(1, num);
+ shape.set(1, (int)indices->shape().elements());
return shape;
}
@@ -614,7 +752,8 @@ struct ColsNodeOp : public NaryNodeOp {
struct ElementBinaryNodeOp : public NaryNodeOp {
- ElementBinaryNodeOp(Expr a, Expr b) : NaryNodeOp({a, b}, newShape(a, b)) {}
+ ElementBinaryNodeOp(Expr a, Expr b)
+ : NaryNodeOp({a, b}, newShape(a, b)) {}
Shape newShape(Expr a, Expr b) { return Shape::broadcast({a, b}); }
@@ -814,13 +953,71 @@ struct MinimumNodeOp : public ElementBinaryNodeOp {
const std::string type() override { return "min"; }
};
+struct CmpNodeOp : public ElementBinaryNodeOp {
+ CmpNodeOp(Expr a, Expr b, int cmp_, bool not_) : ElementBinaryNodeOp(a, b), cmp_(cmp_), not_(not_) {
+ //setTrainable(false); // has no gradient
+ // Note: ^^ Disabled because it currently causing Marian to choke, for unknown reasons.
+ // Not setting this will not change the result since the vector of gradient functions is empty.
+ }
+
+ NodeOps forwardOps() override {
+ using namespace functional;
+
+ return {
+ NodeOp(Element(_1 = ((((_2 > _3) - (_2 < _3)) == (float)cmp_) != not_),
+ val_, child(0)->val(), child(1)->val()))};
+ }
+
+ NodeOps backwardOps() override { return {}; }
+
+ const std::string type() override {
+ switch (cmp_) {
+ case -1: return not_ ? "ge" : "lt";
+ case 0: return not_ ? "ne" : "eq";
+ case 1: return not_ ? "le" : "gt";
+ }
+ ABORT("Should not get here??");
+ }
+
+ virtual size_t hash() override {
+ if(!hash_) {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, cmp_);
+ util::hash_combine(seed, not_);
+ hash_ = seed;
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<CmpNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(cmp_ != cnode->cmp_)
+ return false;
+ if(not_ != cnode->not_)
+ return false;
+ return true;
+ }
+
+private:
+ int cmp_; // -1: less; 0: equal; 1: greater
+ bool not_; // invert result if true
+};
+
// In each j-th row, take the corresponding j-th label index i from indices and compute:
// For each vocabulary item v, the only non-zero element in a row in the sum is the item
// that matches the label indexed by i (the picked element).
// C = sum_{v in V}(-logsoftmax(A) * delta(v, i) = -logsoftmax(A)[i]
struct CrossEntropyNodeOp : public NaryNodeOp {
- CrossEntropyNodeOp(Expr a, Expr indices) : NaryNodeOp({a, indices}, newShape(a)) {
+ CrossEntropyNodeOp(Expr a, Expr indices)
+ : NaryNodeOp({a, indices}, newShape(a), a->value_type()) {
matchOrAbort<IndexType>(indices->value_type());
+ int rows = a->shape().elements() / a->shape()[-1];
+ int labels = indices->shape().elements();
+ ABORT_IF(rows != labels, "Number of examples and labels does not match: {} != {}", rows, labels);
}
Shape newShape(Expr a) {
@@ -847,12 +1044,21 @@ struct ConcatenateNodeOp : public NaryNodeOp {
}
Shape newShape(const std::vector<Expr>& nodes, int ax) {
- Shape shape = nodes.back()->shape();
+ ABORT_IF(nodes.empty(), "No child nodes given");
+
+ Shape shape = nodes[0]->shape();
ax_ = shape.axis(ax);
int sum = 0;
- for(auto child : nodes)
+ auto checkShape = shape;
+ for(auto child : nodes) {
+ checkShape.set(ax_, child->shape()[ax_]); // don't abort on different sizes on axis dim.
+ ABORT_IF(checkShape != child->shape(),
+ "Child shapes {} and {} cannot be concatenated along axis {}",
+ shape, child->shape(), ax);
+
sum += child->shape()[ax_];
+ }
shape.set(ax_, sum);
return shape;
@@ -869,8 +1075,7 @@ struct ConcatenateNodeOp : public NaryNodeOp {
std::vector<Tensor> deconcatenees;
for(size_t i = 0; i < children_.size(); ++i) {
auto childPtr = child(i);
- childPtr
- ->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly
+ childPtr->set_zero_adjoint(); // @TODO: this is a hotfix, do this properly
deconcatenees.push_back(childPtr->grad());
}
Deconcatenate(deconcatenees, adj_, ax_);
@@ -913,7 +1118,9 @@ public:
}
NodeOps backwardOps() override {
- return {NodeOp(LayerNormalizationGrad(
+ return {NodeOp(
+ LayerNormalizationGrad(
+ graph()->allocator(),
child(0)->grad(),
child(1)->grad(),
(children_.size() == 3) ? child(2)->grad() : nullptr,
@@ -927,6 +1134,23 @@ public:
const std::string type() override { return "layer_normalization"; }
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, eps_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<LayerNormalizationOp>(node);
+ if(!cnode)
+ return false;
+ if(eps_ != cnode->eps_)
+ return false;
+ return true;
+ }
+
private:
float eps_;
};
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index b8b19208..a52ecf0e 100644..100755
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -14,12 +14,18 @@
namespace marian {
struct UnaryNodeOp : public NaryNodeOp {
- UnaryNodeOp(Expr a, Shape shape, Type value_type = Type::float32)
+ UnaryNodeOp(Expr a, Shape shape, Type value_type)
: NaryNodeOp({a}, shape, value_type) {}
- UnaryNodeOp(Expr a, Type value_type = Type::float32)
+ UnaryNodeOp(Expr a, Type value_type)
: NaryNodeOp({a}, a->shape(), value_type) {}
+ UnaryNodeOp(Expr a, Shape shape)
+ : NaryNodeOp({a}, shape, a->value_type()) {}
+
+ UnaryNodeOp(Expr a)
+ : NaryNodeOp({a}, a->shape(), a->value_type()) {}
+
const std::string color() override { return "yellow"; }
};
@@ -62,6 +68,24 @@ public:
}
};
+// Cast a tensor to a different type
+struct CastNodeOp : public UnaryNodeOp {
+public:
+ CastNodeOp(Expr a, Type type) : UnaryNodeOp(a, type) {}
+
+ NodeOps forwardOps() override {
+ using namespace functional;
+ return { NodeOp(CopyCast(val_, child(0)->val())) };
+ }
+
+ NodeOps backwardOps() override {
+ using namespace functional;
+ return { NodeOp(CopyCast(child(0)->grad(), adj_)) };
+ }
+
+ const std::string type() override { return "cast"; }
+};
+
struct ScalarMultNodeOp : public UnaryNodeOp {
private:
float scalar_{0};
@@ -343,31 +367,52 @@ private:
* in an expression graph.
*
* This node implements the activation function
- * \f$ f(x) = x \cdot \sigma(x) \f$
+ * \f$ f(x) = x \cdot \sigma(bx) \f$
* and its derivative
- * \f$ f^\prime(x) = f(x) + \sigma(x)(1 - f(x)) \f$ .
+ * \f$ f^\prime(x) = bf(x) + \sigma(bx)(1 - bf(x)) \f$ .
*
*/
struct SwishNodeOp : public UnaryNodeOp {
- SwishNodeOp(Expr a) : UnaryNodeOp(a) {}
+ SwishNodeOp(Expr a, float b = 1.f) : UnaryNodeOp(a), b_{b} {}
NodeOps forwardOps() override {
using namespace functional;
- return {NodeOp(Element(_1 = _2 * sigmoid(_2), val_, child(0)->val()))};
+ return {NodeOp(Element(_1 = _2 * sigmoid(b_ * _2), val_, child(0)->val()))};
}
NodeOps backwardOps() override {
using namespace functional;
- // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
- return {NodeOp(Add(_1 * (_3 + sigmoid(_2) * (1.f - _3)),
+ // dJ/dx += dJ/df * (b*f(x) + sigmoid(b*x) * (1 - b*f(x)))
+ return {NodeOp(Add(_1 * (b_ * _3 + sigmoid(b_ * _2) * (1.f - (b_ * _3))),
child(0)->grad(), // dJ/dx
adj_, // _1 := dJ/df
child(0)->val(), // _2 := x
- val_ // _3 := f(x) = x*sigma(x)
+ val_ // _3 := f(x) = x*sigmoid(b*x)
))};
}
const std::string type() override { return "swish"; }
+
+ virtual size_t hash() override {
+ if(!hash_) {
+ hash_ = NaryNodeOp::hash();
+ util::hash_combine(hash_, b_);
+ }
+ return hash_;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<SwishNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(b_ != cnode->b_)
+ return false;
+ return true;
+ }
+
+ float b_;
};
struct SoftmaxNodeOp : public UnaryNodeOp {
@@ -412,20 +457,79 @@ struct LogSoftmaxNodeOp : public UnaryNodeOp {
const std::string type() override { return "logsoftmax"; }
};
-struct SumNodeOp : public UnaryNodeOp {
+enum class ReduceNodeOpCode {
+ sum, mean, rms, meanSqr, min, max, prod, logSumExp
+};
+
+struct ReduceNodeOp : public UnaryNodeOp {
int axis_;
+ ReduceNodeOpCode opCode_;
+ int reducedDim_; // dimension of axis being reduced, e.g. used in mean()
- SumNodeOp(Expr a, int axis) : UnaryNodeOp(a, newShape(a, axis)) {}
+ ReduceNodeOp(Expr a, int axis, ReduceNodeOpCode opCode)
+ : UnaryNodeOp(a, newShape(a, axis)), opCode_(opCode)
+ {
+ reducedDim_ = a->shape()[axis]; // e.g. used in mean()
+ ABORT_IF(reducedDim_ != a->shape().elements() / shape().elements(), "bug in determining reducedDim");
+ }
NodeOps forwardOps() override {
using namespace functional;
- return {NodeOp(Reduce(_1, val_, child(0)->val()))};
+ switch (opCode_) {
+ case ReduceNodeOpCode::sum:
+ return {NodeOp(Reduce(_1, val_, child(0)->val()))};
+ case ReduceNodeOpCode::mean:
+ return {NodeOp(Reduce(_1, 1.0f / (float)reducedDim_, val_, child(0)->val()))};
+ case ReduceNodeOpCode::rms:
+ return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val());
+ Element(_1 = sqrt(_1), val_))};
+ case ReduceNodeOpCode::meanSqr:
+ return {NodeOp(Reduce(_1 * _1, 1.0f / (float)reducedDim_, val_, child(0)->val()))};
+ case ReduceNodeOpCode::min:
+ return {NodeOp(Reduce(_1, min(_1,_2), std::numeric_limits<float>::max(), val_, child(0)->val()))};
+ case ReduceNodeOpCode::max:
+ return {NodeOp(Reduce(_1, max(_1,_2), std::numeric_limits<float>::lowest(), val_, child(0)->val()))};
+ case ReduceNodeOpCode::prod:
+ return {NodeOp(Reduce(_1, _1 * _2, 1.0f, val_, child(0)->val()))};
+ case ReduceNodeOpCode::logSumExp:
+ return {NodeOp(Reduce(_1, logaddexp(_1,_2), std::numeric_limits<float>::lowest(), val_, child(0)->val()))};
+ default:
+ ABORT("Unexpected reduction op-code {}", (int)opCode_);
+ }
}
NodeOps backwardOps() override {
using namespace functional;
- return {NodeOp(Add(_1, child(0)->grad(), adj_))};
+#if 1 // @BUGBUG: This is a workaround for not correctly propagating non-trainable information. @TODO: Do this the right and general way.
+ if (adj_ == nullptr)
+ return {};
+#endif
+ switch (opCode_) {
+ case ReduceNodeOpCode::sum:
+ return {NodeOp(Add(_1, child(0)->grad(), adj_))};
+ case ReduceNodeOpCode::mean:
+ return {NodeOp(Add(_1, 1.0f / (float)reducedDim_, child(0)->grad(), adj_))};
+ case ReduceNodeOpCode::rms: // WARNING: UNTESTED!!
+ // y = (sum_j x_j^2)^0.5
+ // dJ/dx_i = dJ/dy * 0.5 (sum_j x_j^2)^-0.5 * 2 x_i = dJ/dy * x_i / y --@REVIEW: is this correct?
+ // @TODO: do we need protection against div by 0? L'hospital rule?
+ return {NodeOp(Add(_1 * _2 / _3, child(0)->grad(), adj_, child(0)->val(), val_))};
+ case ReduceNodeOpCode::meanSqr: // WARNING: UNTESTED!!
+ // y = sum_j x_j^2
+ // dJ/dx_i = dJ/dy * sum_j dx_j^2/dx_i = dJ/dy * 2 dx_i --@REVIEW: is this correct?
+ return {NodeOp(Add(_1 * 2.0f * _2, child(0)->grad(), adj_, child(0)->val()))};
+ case ReduceNodeOpCode::min: // WARNING: UNTESTED!!
+ case ReduceNodeOpCode::max: // WARNING: UNTESTED!!
+ // adj_ gets routed into the min/max value --@REVIEW: is this correct?
+ return {NodeOp(Add((_1 == _2) * _3, child(0)->grad(), child(0)->val(), val_, adj_))};
+ case ReduceNodeOpCode::logSumExp:
+ // y = log(sum_j exp(x_j))
+ // dJ/dx_i = dJ/dy * 1/(sum_j exp(x_j)) exp(x_i) = dJ/dy * exp(x_i - y)) --@REVIEW: is this correct?
+ return {NodeOp(Add(_1 * exp(_2 - _3), child(0)->grad(), adj_, child(0)->val(), val_))};
+ default:
+ ABORT("Unexpected reduction op-code {}", (int)opCode_);
+ }
}
Shape newShape(Expr a, int axis) {
@@ -436,66 +540,27 @@ struct SumNodeOp : public UnaryNodeOp {
return shape;
}
- const std::string type() override { return "sum"; }
-
- const std::string color() override { return "orange"; }
-
- virtual size_t hash() override {
- if(!hash_) {
- hash_ = NaryNodeOp::hash();
- util::hash_combine(hash_, axis_);
+ const std::string type() override {
+ switch (opCode_) {
+ case ReduceNodeOpCode::sum: return "sum";
+ case ReduceNodeOpCode::mean: return "mean";
+ case ReduceNodeOpCode::rms: return "rms";
+ case ReduceNodeOpCode::meanSqr: return "meanSqr";
+ case ReduceNodeOpCode::min: return "min";
+ case ReduceNodeOpCode::max: return "max";
+ case ReduceNodeOpCode::prod: return "prod";
+ case ReduceNodeOpCode::logSumExp: return "logSumExp";
+ default: ABORT("Unexpected reduction op-code {}", (int)opCode_);
}
- return hash_;
- }
-
- virtual bool equal(Expr node) override {
- if(!NaryNodeOp::equal(node))
- return false;
- Ptr<SumNodeOp> cnode = std::dynamic_pointer_cast<SumNodeOp>(node);
- if(!cnode)
- return false;
- if(axis_ != cnode->axis_)
- return false;
- return true;
- }
-};
-
-struct MeanNodeOp : public UnaryNodeOp {
- int axis_;
-
- MeanNodeOp(Expr a, int axis) : UnaryNodeOp(a, newShape(a, axis)) {}
-
- NodeOps forwardOps() override {
- using namespace functional;
- int left = child(0)->shape().elements() / val_->shape().elements();
- float scale = 1.f / left;
-
- return {NodeOp(Reduce(_1, scale, val_, child(0)->val()))};
- }
-
- NodeOps backwardOps() override {
- using namespace functional;
- int left = child(0)->shape().elements() / val_->shape().elements();
- float scale = 1.f / left;
-
- return {NodeOp(Add(_1, scale, child(0)->grad(), adj_))};
}
- Shape newShape(Expr a, int axis) {
- Shape shape = a->shape();
- axis_ = shape.axis(axis);
- shape.set(axis_, 1);
- return shape;
- }
-
- const std::string type() override { return "mean"; }
-
const std::string color() override { return "orange"; }
virtual size_t hash() override {
if(!hash_) {
hash_ = NaryNodeOp::hash();
util::hash_combine(hash_, axis_);
+ util::hash_combine(hash_, (int)opCode_);
}
return hash_;
}
@@ -503,10 +568,10 @@ struct MeanNodeOp : public UnaryNodeOp {
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<MeanNodeOp> cnode = std::dynamic_pointer_cast<MeanNodeOp>(node);
+ auto cnode = std::dynamic_pointer_cast<ReduceNodeOp>(node);
if(!cnode)
return false;
- if(axis_ != cnode->axis_)
+ if(axis_ != cnode->axis_ || opCode_ != cnode->opCode_)
return false;
return true;
}
@@ -575,7 +640,7 @@ struct SqrtNodeOp : public UnaryNodeOp {
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<SqrtNodeOp> cnode = std::dynamic_pointer_cast<SqrtNodeOp>(node);
+ auto cnode = std::dynamic_pointer_cast<SqrtNodeOp>(node);
if(!cnode)
return false;
if(epsilon_ != cnode->epsilon_)
@@ -635,7 +700,6 @@ struct TransposeNodeOp : public UnaryNodeOp {
return {NodeOp(TransposeNDGrad(child(0)->grad(), adj_, axesBw_))};
}
- template <class... Args>
Shape newShape(Expr a, const std::vector<int>& axes) {
Shape shape = a->shape();
@@ -661,8 +725,7 @@ struct TransposeNodeOp : public UnaryNodeOp {
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<TransposeNodeOp> cnode
- = std::dynamic_pointer_cast<TransposeNodeOp>(node);
+ auto cnode = std::dynamic_pointer_cast<TransposeNodeOp>(node);
if(!cnode)
return false;
if(axes_ != cnode->axes_)
@@ -680,8 +743,9 @@ private:
Expr reshapee_;
public:
- template <typename... Args>
ReshapeNodeOp(Expr a, Shape shape) : UnaryNodeOp(a, shape), reshapee_(a) {
+ ABORT_IF(a->shape().elements() != shape.elements(),
+ "Reshape must not change the number of elements (from {} to {})", a->shape().toString(), shape.toString());
Node::destroy_ = false;
}
@@ -699,15 +763,15 @@ public:
Tensor& val() override {
auto childVal = reshapee_->val();
- val_.reset(
- new TensorBase(childVal->memory(), shape(), childVal->getBackend()));
+ auto temp = TensorBase::New(childVal->memory(), shape(), childVal->type(), childVal->getBackend());
+ val_.swap(temp);
return val_;
};
Tensor& grad() override {
auto childGrad = reshapee_->grad();
- adj_.reset(
- new TensorBase(childGrad->memory(), shape(), childGrad->getBackend()));
+ auto temp = TensorBase::New(childGrad->memory(), shape(), childGrad->type(), childGrad->getBackend());
+ adj_.swap(temp);
return adj_;
};
@@ -728,7 +792,7 @@ public:
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<ReshapeNodeOp> cnode = std::dynamic_pointer_cast<ReshapeNodeOp>(node);
+ auto cnode = std::dynamic_pointer_cast<ReshapeNodeOp>(node);
if(!cnode)
return false;
if(shape() != cnode->shape())
@@ -737,29 +801,105 @@ public:
}
};
-class StepNodeOp : public UnaryNodeOp {
+// @TODO: review if still required as this is an ugly hack anyway.
+// Memory less operator that clips gradients during backward step
+// Executes this as an additional operation on the gradient.
+class ClipGradientNodeOp : public UnaryNodeOp {
private:
- Expr stepNode_;
- int step_;
- int axis_;
+ Expr clipee_;
+ float clipValue_{0};
public:
- StepNodeOp(Expr a, int step, int axis)
- : UnaryNodeOp(a, newShape(a, axis)), stepNode_(a), step_(step) {
+ ClipGradientNodeOp(Expr a, float clipValue)
+ : UnaryNodeOp(a), clipee_(a), clipValue_(clipValue) {
Node::destroy_ = false;
}
- Shape newShape(Expr a, int axis) {
- Shape outShape = a->shape();
+ ~ClipGradientNodeOp() {}
- axis_ = outShape.axis(axis);
-#if 0 // this check currently fails in translation; I think should not fail for
- // step==0
- for(int i = 0; i < axis_; ++i)
- ABORT_IF(outShape[i] != 1, "non-consecutive slices are presently not supported by step()");
-#endif
- outShape.set(axis_, 1);
+ size_t allocate() override { return 0; }
+ void free() override {}
+
+ void forward() override {}
+
+ void backward() override {
+ using namespace marian::functional;
+ Element(_1 = clip(_1, clipValue_), adj_);
+ }
+
+ void init_dependent() override { clipee_->init_dependent(); }
+
+ void set_zero_adjoint() override { clipee_->set_zero_adjoint(); }
+
+ Tensor& val() override {
+ auto childVal = clipee_->val();
+ auto temp = TensorBase::New(childVal->memory(), shape(), childVal->type(), childVal->getBackend());
+ val_.swap(temp);
+ return val_;
+ };
+
+ Tensor& grad() override {
+ auto childGrad = clipee_->grad();
+ auto temp = TensorBase::New(childGrad->memory(), shape(), childGrad->type(), childGrad->getBackend());
+ adj_.swap(temp);
+ return adj_;
+ };
+
+ const std::string type() override { return "clipGradient"; }
+
+ const std::string color() override { return "grey"; }
+
+ virtual size_t hash() override {
+ if(!hash_) {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, clipValue_);
+ hash_ = seed;
+ }
+ return hash_;
+ }
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<ClipGradientNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(clipValue_ != cnode->clipValue_)
+ return false;
+ return true;
+ }
+};
+
+// narrow an axis to [begin, end)
+// The resulting object must be consecutive in memory.
+class SliceViewNodeOp : public UnaryNodeOp {
+private:
+ Expr viewedNode_; // viewed underlying node
+ Slice slice_; // index range
+ int axis_; // and axis along which it is viewed
+ size_t byteOffset_, byteSize_; // viewed segment in bytes (memory-consecutive)
+
+public:
+ SliceViewNodeOp(Expr a, int axis, Slice slice)
+ : UnaryNodeOp(a, newShape(a, axis, slice), a->value_type()), viewedNode_(a), slice_(slice), axis_(axis) {
+ Node::destroy_ = false;
+ auto byteStride = a->shape().stride(axis) * sizeOf(value_type());
+ byteOffset_ = slice.begin * byteStride;
+ byteSize_ = shape()[axis] * byteStride;
+ }
+
+ static Shape newShape(Expr a, int& axis, Slice& slice) { // note: normalizes slice and axis in-place
+ const auto& shape = a->shape();
+ axis = shape.axis(axis); // normalize negative axis
+ slice = shape.slice(slice, axis); // normalize negative slice values
+ // enforce consecutive memory
+ if (slice.begin != 0 || slice.end != shape[axis] || slice.stride != 1) { // unless it's a no-op
+ ABORT_IF(slice.stride != 1, "Strides other than 1 are presently not supported by sliceView()");
+ for(int i = 0; i < axis; ++i)
+ ABORT_IF(shape[i] != 1, "Non-consecutive slices are presently not supported by sliceView()");
+ }
+ Shape outShape = shape;
+ outShape.set(axis, slice.end - slice.begin);
return outShape;
}
@@ -769,36 +909,36 @@ public:
void forward() override {}
void backward() override {}
- void init_dependent() override { stepNode_->init_dependent(); }
+ void init_dependent() override { viewedNode_->init_dependent(); }
- void set_zero_adjoint() override { stepNode_->set_zero_adjoint(); }
+ void set_zero_adjoint() override { viewedNode_->set_zero_adjoint(); } // lazily allocate and zero out gradient (only runs once)
Tensor& val() override {
- auto childVal = stepNode_->val();
- size_t offset = step_ * shape().elements() * sizeof(float);
- auto mem = New<MemoryPiece>(childVal->memory()->data() + offset,
- childVal->memory()->size());
- val_.reset(new TensorBase(mem, shape(), childVal->getBackend()));
+ auto childVal = viewedNode_->val();
+ auto mem = MemoryPiece::New(childVal->memory()->data() + byteOffset_, byteSize_);
+ auto temp = TensorBase::New(mem, shape(), childVal->type(), childVal->getBackend());
+ val_.swap(temp);
return val_;
};
Tensor& grad() override {
- auto childGrad = stepNode_->grad();
- size_t offset = step_ * shape().elements() * sizeof(float);
- auto mem = New<MemoryPiece>(childGrad->memory()->data() + offset,
- childGrad->memory()->size());
- adj_.reset(new TensorBase(mem, shape(), childGrad->getBackend()));
+ auto childGrad = viewedNode_->grad();
+ auto mem = MemoryPiece::New(childGrad->memory()->data() + byteOffset_, byteSize_);
+ auto temp = TensorBase::New(mem, shape(), childGrad->type(), childGrad->getBackend());
+ adj_.swap(temp);
return adj_;
};
- const std::string type() override { return "step"; }
+ const std::string type() override { return "sliceView"; }
const std::string color() override { return "grey"; }
virtual size_t hash() override {
if(!hash_) {
hash_ = NaryNodeOp::hash();
- util::hash_combine(hash_, step_);
+ util::hash_combine(hash_, slice_.begin);
+ util::hash_combine(hash_, slice_.end);
+ util::hash_combine(hash_, slice_.stride);
util::hash_combine(hash_, axis_);
}
return hash_;
@@ -807,10 +947,10 @@ public:
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<StepNodeOp> cnode = std::dynamic_pointer_cast<StepNodeOp>(node);
+ auto cnode = std::dynamic_pointer_cast<SliceViewNodeOp>(node);
if(!cnode)
return false;
- if(step_ != cnode->step_)
+ if(slice_ != cnode->slice_)
return false;
if(axis_ != cnode->axis_)
return false;
@@ -849,10 +989,12 @@ struct ShiftNodeOp : public UnaryNodeOp {
virtual bool equal(Expr node) override {
if(!NaryNodeOp::equal(node))
return false;
- Ptr<ShiftNodeOp> cnode = std::dynamic_pointer_cast<ShiftNodeOp>(node);
+ auto cnode = std::dynamic_pointer_cast<ShiftNodeOp>(node);
if(!cnode)
return false;
if(shift_ != cnode->shift_)
+ return false;
+ if(padValue_ != cnode->padValue_)
return false;
return true;
}
diff --git a/src/graph/parameters.h b/src/graph/parameters.h
index 32f88a1e..8b4af9dd 100755..100644
--- a/src/graph/parameters.h
+++ b/src/graph/parameters.h
@@ -10,8 +10,16 @@
namespace marian {
+// @TODO: Currently an ExpressionGraph only supports one Parameters object and
+// the type of parameters has to be the inside on Parameters object. This limits
+// parameter types to a single chosen type, e.g. only fp32 or only fp16. This should
+// be extended to allow multiple sets of parameters.
+// The reason here is to be able to efficiently compute updates of whole parameter
+// sets of one type.
class Parameters {
protected:
+ Type acceptedElementType_; // this parameter object only takes paramters of this type
+
/** @brief List of all parameter nodes of this expression graph. */
std::vector<Expr> params_;
std::map<std::string, Expr> named_;
@@ -22,12 +30,20 @@ protected:
size_t totalCapacity(Ptr<TensorAllocator> alloc) {
size_t sum = 0;
for(auto p : params_) {
- sum += alloc->capacity(p->shape(), Type::float32);
+ sum += alloc->capacity(p->shape(), p->value_type());
}
return sum;
}
public:
+ Parameters(Type acceptedType) : acceptedElementType_(acceptedType) {
+ LOG(debug, "Created parameter object of type {}", acceptedElementType_);
+ }
+
+ virtual ~Parameters() {
+ LOG(debug, "Destroyed parameter object of type {}", acceptedElementType_);
+ }
+
auto begin() -> decltype(params_.begin()) { return params_.begin(); }
auto end() -> decltype(params_.begin()) { return params_.end(); }
@@ -46,8 +62,13 @@ public:
size_t size() { return params_.size(); }
void add(Expr p, const std::string& name) {
- params_.push_back(p);
+ LOG(debug, "Adding parameter {} to parameter object of type {}", name, acceptedElementType_);
+
ABORT_IF(named_.count(name), "Parameter '{}' already exists", name);
+ ABORT_IF(p->value_type() != acceptedElementType_,
+ "Requested parameter type ({}) is different from chosen parameter type ({})",
+ p->value_type(), acceptedElementType_);
+ params_.push_back(p);
named_[name] = p;
}
@@ -56,12 +77,21 @@ public:
grads_ = New<TensorAllocator>(backend);
}
+ virtual void init(Ptr<Backend> backend, Ptr<Device> device) {
+ vals_ = New<TensorAllocator>(backend, device);
+ grads_ = New<TensorAllocator>(backend, device);
+ }
+
virtual void allocateForward() {
if(!params_.empty() && vals_->size() == 0) {
vals_->reserveExact(totalCapacity(vals_));
+
+ // sort parameters by name before allocation to make sure the memory layout after allocation is always the same
+ std::sort(params_.begin(), params_.end(), [](Expr n1, Expr n2){ return n1->name() < n2->name(); });
+
for(auto p : params_) {
if(!p->val()) {
- vals_->allocate(p->val(), p->shape());
+ vals_->allocate(p->val(), p->shape(), p->value_type());
}
}
}
@@ -69,18 +99,22 @@ public:
virtual void allocateBackward() {
if(!params_.empty() && grads_->size() == 0) {
+
+ // sort parameters by name before allocation to make sure the memory layout after allocation is always the same
+ std::sort(params_.begin(), params_.end(), [](Expr n1, Expr n2){ return n1->name() < n2->name(); });
+
grads_->reserveExact(totalCapacity(grads_));
for(auto p : params_)
if(!p->grad())
- grads_->allocate(p->grad(), p->shape());
+ grads_->allocate(p->grad(), p->shape(), p->value_type());
}
}
virtual void set_zero_adjoint() { grads()->set(0.f); }
- virtual Tensor vals() { return vals_->asTensor(); }
+ virtual Tensor vals() { return vals_->asTensor(acceptedElementType_); }
- virtual Tensor grads() { return grads_->asTensor(); }
+ virtual Tensor grads() { return grads_->asTensor(acceptedElementType_); }
virtual void clear() {
params_.clear();
@@ -96,14 +130,18 @@ private:
Ptr<Backend> backend_;
public:
+ MappedParameters(Type acceptedElementType) : Parameters(acceptedElementType) {
+ LOG(debug, "Created mapped parameter object of type {}", acceptedElementType);
+ }
+
virtual void init(Ptr<Backend> backend) override { backend_ = backend; }
+ virtual void init(Ptr<Backend> backend, Ptr<Device>) override { init(backend); }
virtual void allocateForward() override {
if(!params_.empty()) {
for(auto p : params_) {
if(!p->val()) {
- p->val() = Tensor(
- new TensorBase(nullptr, p->shape(), Type::float32, backend_));
+ p->val() = TensorBase::New(nullptr, p->shape(), p->value_type(), backend_);
}
}
}
diff --git a/src/layers/constructors.h b/src/layers/constructors.h
index c063f44c..a2c38197 100644..100755
--- a/src/layers/constructors.h
+++ b/src/layers/constructors.h
@@ -10,23 +10,7 @@ namespace mlp {
* Base class for layer factories, can be used in a multi-layer network factory.
*/
struct LayerFactory : public Factory {
- LayerFactory(Ptr<ExpressionGraph> graph) : Factory(graph) {}
- LayerFactory(const LayerFactory&) = default;
- LayerFactory(LayerFactory&&) = default;
-
- virtual ~LayerFactory() {}
-
- template <typename Cast>
- inline Ptr<Cast> as() {
- return std::dynamic_pointer_cast<Cast>(shared_from_this());
- }
-
- template <typename Cast>
- inline bool is() {
- return as<Cast>() != nullptr;
- }
-
- virtual Ptr<Layer> construct() = 0;
+ virtual Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) = 0;
};
/**
@@ -34,15 +18,12 @@ struct LayerFactory : public Factory {
*/
class DenseFactory : public LayerFactory {
public:
- DenseFactory(Ptr<ExpressionGraph> graph) : LayerFactory(graph) {}
-
- Ptr<Layer> construct() override {
- auto dense = New<Dense>(graph_, options_);
- return dense;
+ Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override {
+ return New<Dense>(graph, options_);
}
DenseFactory clone() {
- DenseFactory aClone(graph_);
+ DenseFactory aClone;
aClone.options_->merge(options_);
return aClone;
}
@@ -54,37 +35,39 @@ typedef Accumulator<DenseFactory> dense;
/**
* Factory for output layers, can be used in a multi-layer network factory.
*/
-class OutputFactory : public LayerFactory {
+struct LogitLayerFactory : public Factory {
+ using Factory::Factory;
+ virtual Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) = 0;
+};
+
+// @TODO: In the long run, I hope we can get rid of the abstract factories altogether.
+class OutputFactory : public LogitLayerFactory {
+ using LogitLayerFactory::LogitLayerFactory;
protected:
- std::vector<std::pair<std::string, std::string>> tiedParamsTransposed_;
+ std::string tiedTransposedName_;
Ptr<data::Shortlist> shortlist_;
public:
- OutputFactory(Ptr<ExpressionGraph> graph) : LayerFactory(graph) {}
-
- Accumulator<OutputFactory> tie_transposed(const std::string& param,
- const std::string& tied) {
- tiedParamsTransposed_.push_back({param, tied});
+ Accumulator<OutputFactory> tieTransposed(const std::string& tied) {
+ tiedTransposedName_ = tied;
return Accumulator<OutputFactory>(*this);
}
- Accumulator<OutputFactory> set_shortlist(Ptr<data::Shortlist> shortlist) {
+ void setShortlist(Ptr<data::Shortlist> shortlist) {
shortlist_ = shortlist;
- return Accumulator<OutputFactory>(*this);
}
- Ptr<Layer> construct() override {
- auto output = New<Output>(graph_, options_);
- for(auto& p : tiedParamsTransposed_)
- output->tie_transposed(p.first, p.second);
- output->set_shortlist(shortlist_);
+ Ptr<IUnaryLogitLayer> construct(Ptr<ExpressionGraph> graph) override {
+ auto output = New<Output>(graph, options_);
+ output->tieTransposed(graph->get(tiedTransposedName_));
+ output->setShortlist(shortlist_);
return output;
}
OutputFactory clone() {
- OutputFactory aClone(graph_);
+ OutputFactory aClone;
aClone.options_->merge(options_);
- aClone.tiedParamsTransposed_ = tiedParamsTransposed_;
+ aClone.tiedTransposedName_ = tiedTransposedName_;
aClone.shortlist_ = shortlist_;
return aClone;
}
@@ -96,21 +79,18 @@ typedef Accumulator<OutputFactory> output;
/**
* Multi-layer network, holds and applies layers.
*/
-class MLP {
+class MLP : public IUnaryLogitLayer, public IHasShortList {
protected:
Ptr<ExpressionGraph> graph_;
Ptr<Options> options_;
- std::vector<Ptr<Layer>> layers_;
+ std::vector<Ptr<IUnaryLayer>> layers_;
public:
MLP(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: graph_(graph), options_(options) {}
- template <typename... Args>
- Expr apply(Args... args) {
- std::vector<Expr> av = {args...};
-
+ Expr apply(const std::vector<Expr>& av) override {
Expr output;
if(av.size() == 1)
output = layers_[0]->apply(av[0]);
@@ -123,7 +103,47 @@ public:
return output;
}
- void push_back(Ptr<Layer> layer) { layers_.push_back(layer); }
+ Logits applyAsLogits(const std::vector<Expr>& av) override {
+ // same as apply() except for the last layer, we invoke applyAsLogits(), which has a different return type
+ auto lastLayer = std::dynamic_pointer_cast<IUnaryLogitLayer>(layers_.back());
+ ABORT_IF(!lastLayer, "MLP::applyAsLogits() was called on an MLP whose last layer is not an IUnaryLogitLayer");
+ if (layers_.size() == 1) {
+ if (av.size() == 1)
+ return lastLayer->applyAsLogits(av[0]);
+ else
+ return lastLayer->applyAsLogits(av);
+ }
+ else {
+ Expr output;
+ if (av.size() == 1)
+ output = layers_[0]->apply(av[0]);
+ else
+ output = layers_[0]->apply(av);
+ for (size_t i = 1; i < layers_.size() - 1; ++i)
+ output = layers_[i]->apply(output);
+ return lastLayer->applyAsLogits(output);
+ }
+ }
+
+ Expr apply(Expr e) override { return apply(std::vector<Expr>{ e }); }
+ Logits applyAsLogits(Expr e) override { return applyAsLogits(std::vector<Expr>{ e }); }
+
+ void push_back(Ptr<IUnaryLayer> layer) { layers_.push_back(layer); }
+ void push_back(Ptr<IUnaryLogitLayer> layer) { layers_.push_back(layer); }
+
+ void setShortlist(Ptr<data::Shortlist> shortlist) override final {
+ auto p = tryAsHasShortlist();
+ ABORT_IF(!p, "setShortlist() called on an MLP with an output layer that does not support short lists");
+ p->setShortlist(shortlist);
+ }
+
+ void clear() override final {
+ auto p = tryAsHasShortlist();
+ if (p)
+ p->clear();
+ }
+private:
+ Ptr<IHasShortList> tryAsHasShortlist() const { return std::dynamic_pointer_cast<IHasShortList>(layers_.back()); }
};
/**
@@ -131,31 +151,58 @@ public:
* to accumulate options for later lazy construction.
*/
class MLPFactory : public Factory {
+ using Factory::Factory;
private:
std::vector<Ptr<LayerFactory>> layers_;
public:
- MLPFactory(Ptr<ExpressionGraph> graph) : Factory(graph) {}
-
- Ptr<MLP> construct() {
- auto mlp = New<MLP>(graph_, options_);
+ Ptr<MLP> construct(Ptr<ExpressionGraph> graph) {
+ auto mlp = New<MLP>(graph, options_);
for(auto layer : layers_) {
- layer->getOptions()->merge(options_);
- mlp->push_back(layer->construct());
+ layer->mergeOpts(options_);
+ mlp->push_back(layer->construct(graph));
}
return mlp;
}
- Ptr<MLP> operator->() { return construct(); }
-
template <class LF>
Accumulator<MLPFactory> push_back(const LF& lf) {
layers_.push_back(New<LF>(lf));
return Accumulator<MLPFactory>(*this);
}
+
+ // Special case for last layer, which may be a IUnaryLogitLayer. Requires some hackery,
+ // which will go away if we get rid of the abstract factories, and instead just construct
+ // all layers immediately, which is my long-term goal for Marian.
+private:
+ template<class WrappedFactory>
+ class AsLayerFactory : public LayerFactory {
+ WrappedFactory us;
+ public:
+ AsLayerFactory(const WrappedFactory& wrapped) : us(wrapped) {}
+ Ptr<IUnaryLayer> construct(Ptr<ExpressionGraph> graph) override final {
+ auto p = std::static_pointer_cast<IUnaryLayer>(us.construct(graph));
+ ABORT_IF(!p, "Attempted to cast a Factory to LayerFactory that isn't one");
+ return p;
+ }
+ };
+ template<class WrappedFactory>
+ static inline AsLayerFactory<WrappedFactory> asLayerFactory(const WrappedFactory& wrapped) { return wrapped; }
+public:
+ Accumulator<MLPFactory> push_back(const Accumulator<OutputFactory>& lf) {
+ push_back(AsLayerFactory<OutputFactory>(lf));
+ //layers_.push_back(New<AsLayerFactory<OutputFactory>>(asLayerFactory((OutputFactory&)lf)));
+ return Accumulator<MLPFactory>(*this);
+ }
};
// @TODO: change naming convention.
typedef Accumulator<MLPFactory> mlp;
} // namespace mlp
+
+typedef ConstructingFactory<Embedding> EmbeddingFactory;
+typedef ConstructingFactory<ULREmbedding> ULREmbeddingFactory;
+
+typedef Accumulator<EmbeddingFactory> embedding;
+typedef Accumulator<ULREmbeddingFactory> ulr_embedding;
} // namespace marian
diff --git a/src/layers/convolution.cpp b/src/layers/convolution.cpp
index 225058e7..d46ede36 100644
--- a/src/layers/convolution.cpp
+++ b/src/layers/convolution.cpp
@@ -4,9 +4,11 @@
namespace marian {
#ifdef CUDNN
-Convolution::Convolution(Ptr<ExpressionGraph> graph) : Factory(graph) {}
+Convolution::Convolution(Ptr<ExpressionGraph> graph) {}
Expr Convolution::apply(Expr x) {
+ auto graph = x->graph();
+
auto prefix = opt<std::string>("prefix");
auto kernelDims = opt<std::pair<int, int>>("kernel-dims");
auto kernelNum = opt<int>("kernel-num");
@@ -15,12 +17,12 @@ Expr Convolution::apply(Expr x) {
int layerIn = x->shape()[1];
auto kernel
- = graph_->param(prefix + "_conv_kernels",
+ = graph->param(prefix + "_conv_kernels",
{layerIn, kernelNum, kernelDims.first, kernelDims.second},
- inits::glorot_uniform);
+ inits::glorotUniform());
- auto bias = graph_->param(
- prefix + "_conv_bias", {1, kernelNum, 1, 1}, inits::zeros);
+ auto bias = graph->param(
+ prefix + "_conv_bias", {1, kernelNum, 1, 1}, inits::zeros());
std::vector<Expr> nodes = {x, kernel, bias};
return Expression<ConvolutionOp>(
diff --git a/src/layers/factory.h b/src/layers/factory.h
index 0e84fd16..f9e4ddf9 100644..100755
--- a/src/layers/factory.h
+++ b/src/layers/factory.h
@@ -7,58 +7,102 @@ namespace marian {
class Factory : public std::enable_shared_from_this<Factory> {
protected:
Ptr<Options> options_;
- Ptr<ExpressionGraph> graph_;
public:
- Factory(Ptr<ExpressionGraph> graph)
- : options_(New<Options>()), graph_(graph) {}
+ // construct with empty options
+ Factory() : options_(New<Options>()) {}
+ // construct with options
+ Factory(Ptr<Options> options) : Factory() {
+ options_->merge(options);
+ }
+ // construct with one or more individual option parameters
+ // Factory("var1", val1, "var2", val2, ...)
+ template <typename T, typename... Args>
+ Factory(const std::string& key, T value, Args&&... moreArgs) : Factory() {
+ setOpts(key, value, std::forward<Args>(moreArgs)...);
+ }
+ // construct with options and one or more individual option parameters
+ // Factory(options, "var1", val1, "var2", val2, ...)
+ template <typename... Args>
+ Factory(Ptr<Options> options, Args&&... args) : Factory(options) {
+ setOpts(std::forward<Args>(args)...);
+ }
+ Factory(const Factory& factory) = default;
virtual ~Factory() {}
- Ptr<Options> getOptions() { return options_; }
+ std::string asYamlString() { return options_->asYamlString(); }
- std::string str() { return options_->str(); }
+ // retrieve an option
+ // auto val = opt<T>("var");
+ template <typename T>
+ T opt(const char* const key) { return options_->get<T>(key); }
template <typename T>
- T opt(const std::string& key) {
- return options_->get<T>(key);
- }
+ T opt(const char* const key, T defaultValue) { return options_->get<T>(key, defaultValue); }
template <typename T>
- T opt(const std::string& key, T defaultValue) {
- return options_->get<T>(key, defaultValue);
+ T opt(const std::string& key) { return options_->get<T>(key.c_str()); }
+
+ template <typename T>
+ T opt(const std::string& key, T defaultValue) { return options_->get<T>(key.c_str(), defaultValue); }
+
+ // set a single option
+ // setOpt("var", val);
+ template <typename T>
+ void setOpt(const std::string& key, T value) { options_->set(key, value); }
+
+ // set one or more options at once
+ // setOpts("var1", val1, "var2", val2, ...);
+ template <typename T, typename... Args>
+ void setOpts(const std::string& key, T value, Args&&... moreArgs) { options_->set(key, value, std::forward<Args>(moreArgs)...); }
+
+ void mergeOpts(Ptr<Options> options) { options_->merge(options); }
+
+ template <class Cast>
+ inline Ptr<Cast> as() { return std::dynamic_pointer_cast<Cast>(shared_from_this()); }
+
+ // @TODO: this fails with 'target type must be a pointer or reference to a defined class'
+ //template <class Cast>
+ //inline bool is() { return dynamic_cast<Cast>(this) != nullptr; }
+ template <class Cast>
+ inline bool is() { return std::dynamic_pointer_cast<Cast>(shared_from_this()) != nullptr; }
+};
+
+// simplest form of Factory that just passes on options to the constructor of a layer type
+template<class Class>
+struct ConstructingFactory : public Factory {
+ using Factory::Factory;
+
+ Ptr<Class> construct(Ptr<ExpressionGraph> graph) {
+ return New<Class>(graph, options_);
}
};
-template <class BaseFactory>
+template <class BaseFactory> // where BaseFactory : Factory
class Accumulator : public BaseFactory {
typedef BaseFactory Factory;
public:
- Accumulator() : Factory(nullptr) {}
- Accumulator(Ptr<ExpressionGraph> graph) : Factory(graph) {}
+ Accumulator() : Factory() {}
+ Accumulator(Ptr<Options> options) : Factory(options) {}
+ template <typename... Args>
+ Accumulator(Ptr<Options> options, Args&&... moreArgs) : Factory(options, std::forward<Args>(moreArgs)...) {}
+ template <typename T, typename... Args>
+ Accumulator(const std::string& key, T value, Args&&... moreArgs) : Factory(key, value, std::forward<Args>(moreArgs)...) {}
Accumulator(const Factory& factory) : Factory(factory) {}
Accumulator(const Accumulator&) = default;
Accumulator(Accumulator&&) = default;
+ // deprecated chaining syntax
template <typename T>
Accumulator& operator()(const std::string& key, T value) {
- Factory::getOptions()->set(key, value);
- return *this;
- }
-
- Accumulator& operator()(const std::string& yaml) {
- Factory::getOptions()->parse(yaml);
- return *this;
- }
-
- Accumulator& operator()(Config::YamlNode yaml) {
- Factory::getOptions()->merge(yaml);
+ Factory::setOpt(key, value);
return *this;
}
Accumulator& operator()(Ptr<Options> options) {
- Factory::getOptions()->merge(options);
+ Factory::mergeOpts(options);
return *this;
}
diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp
new file mode 100755
index 00000000..45d66041
--- /dev/null
+++ b/src/layers/generic.cpp
@@ -0,0 +1,566 @@
+#include "marian.h"
+
+#include "layers/generic.h"
+#include "layers/constructors.h"
+#include "layers/loss.h"
+#include "data/factored_vocab.h"
+#include "rnn/types.h" // for State::select()
+#include "models/states.h" // for EncoderState
+
+//using std::size_t; // not sure why this is needed
+
+namespace marian {
+ Logits::Logits(Expr logits) : Logits(New<RationalLoss>(logits, nullptr)) {} // single-output constructor from Expr only (RationalLoss has no count)
+
+ Ptr<ExpressionGraph> Logits::graph() const {
+ ABORT_IF(logits_.empty(), "Empty logits object??");
+ return logits_.front()->loss()->graph();
+ }
+
+ // This function assumes that the object holds one or more factor logits.
+ // It applies the supplied loss function to each, and then returns the aggregate loss over all factors.
+ Expr Logits::applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/, Expr/*indices*/)>& lossFn) const {
+ LOG_ONCE(info, "[logits] applyLossFunction() for {} factors", logits_.size());
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+
+ auto firstLogits = logits_.front()->loss();
+ ABORT_IF(labels.size() * firstLogits->shape()[-1] != firstLogits->shape().elements(),
+ "Labels not matching logits shape ({} != {}, {})??",
+ labels.size() * firstLogits->shape()[-1],
+ firstLogits->shape().elements(),
+ firstLogits->shape());
+
+ // base case (no factors)
+ if (!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return lossFn(firstLogits, indices(toWordIndexVector(labels)));
+ }
+
+ auto numGroups = factoredVocab_->getNumGroups();
+
+ // split labels into individual factor labels
+ auto allMaskedFactoredLabels = factorizeWords(labels); // [numGroups][labels.size()] = [numGroups][B... flattened]
+
+ //Expr indices = this->indices(toWordIndexVector(labels));
+ // accumulate all CEs for all words that have the factor
+ // Memory-wise, this is cheap, all temp objects below are batches of scalars or lookup vectors.
+ Expr loss;
+ for (size_t g = 0; g < numGroups; g++) {
+ if (!logits_[g])
+ continue; // empty factor --@TODO: use an array of indices of non-empty logits_[]
+ const auto& maskedFactoredLabels = allMaskedFactoredLabels[g]; // array of (word index, mask)
+ auto factorIndices = indices (maskedFactoredLabels.indices); // [B... flattened] factor-label indices, or 0 if factor does not apply
+ auto factorMask = constant(maskedFactoredLabels.masks); // [B... flattened] loss values get multiplied with 0 for labels that don't have this factor
+ auto factorLogits = logits_[g]; // [B... * Ug] label-wise loss values (not aggregated yet)
+ // For each location in [B...] select [indices[B...]]. If not using factor, select [0] and mask it out next.
+ auto factorLoss = lossFn(factorLogits->loss(), factorIndices); // [B... x 1]
+ factorLoss = factorLoss * reshape(factorMask, factorLoss->shape()); // mask out factor for words that do not have that factor
+ loss = loss ? (loss + factorLoss) : factorLoss; // [B... x 1]
+ }
+ return loss;
+ }
+
+ // This function assumes this object holds a single factor that represents a rational loss (with count).
+ //Ptr<RationalLoss> Logits::getRationalLoss() const {
+ // ABORT_IF(logits_.size() != 1 || factoredVocab_, "getRationalLoss() cannot be used on multi-factor outputs");
+ // ABORT_IF(!logits_.front()->count(), "getRationalLoss() used on rational loss without count");
+ // return logits_.front();
+ //}
+
+ // get logits for one factor group
+ // For groupIndex == 0, the function also requires the shortlist if there is one.
+ Expr Logits::getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist /*= nullptr*/, const std::vector<IndexType>& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+ auto sel = logits_[groupIndex]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab]
+
+ // normalize for decoding:
+ // - all secondary factors: subtract their max
+ // - lemma: add all maxes of applicable factors
+ if (groupIndex > 0) {
+ sel = sel - max(sel, -1);
+ }
+ else {
+ auto numGroups = getNumFactorGroups();
+ for (size_t g = 1; g < numGroups; g++) {
+ auto factorMaxima = max(logits_[g]->loss(), -1);
+ auto factorMasks = constant(getFactorMasks(g, shortlist ? shortlist->indices() : std::vector<WordIndex>()));
+ sel = sel + factorMaxima * factorMasks; // those lemmas that don't have a factor get multiplied with 0
+ }
+ }
+
+ // if selIdx are given, then we must reshuffle accordingly
+ if (!hypIndices.empty()) // use the same function that shuffles decoder state
+ sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false);
+ return sel;
+ }
+
+ // used for breakDown() only
+ // Index is flattened
+ Tensor Logits::getFactoredLogitsTensor(size_t groupIndex) const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+ return logits_[groupIndex]->loss()->val();
+ }
+
+ // This function assumes that the object holds one or more factor logits, which are summed up
+ // into output-vocab logits according to the factored model (with correct normalization of factors).
+ // This is infeasible for realistic factor sets, and therefore only implemented for 1 factor.
+ // @TODO: remove altogether
+ Expr Logits::getLogits() const {
+ ABORT_IF(empty(), "Attempted to read out logits on empty Logits object");
+ if (!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return getFactoredLogits(0);
+ }
+
+#ifdef FACTOR_FULL_EXPANSION
+ // compute normalized factor log probs
+ std::vector<Expr> logProbs(logits_.size());
+ for (size_t g = 0; g < logits_.size(); g++)
+ logProbs[g] = logsoftmax(logits_[g]->loss());
+ auto y = concatenate(logProbs, /*axis=*/ -1);
+
+ // sum up the unit logits across factors for each target word
+ auto graph = y->graph();
+ auto factorMatrix = factoredVocab_->getGlobalFactorMatrix(); // [V x U]
+ y = dot_csr(
+ y, // [B x U]
+ factorMatrix.shape,
+ graph->constant({(int)factorMatrix.weights.size()}, inits::fromVector(factorMatrix.weights), Type::float32),
+ graph->constant({(int)factorMatrix.indices.size()}, inits::fromVector(factorMatrix.indices), Type::uint32),
+ graph->constant({(int)factorMatrix.offsets.size()}, inits::fromVector(factorMatrix.offsets), Type::uint32),
+ /*transB=*/ true); // -> [B x V]
+
+ // mask out gaps
+ auto gapLogMask = factoredVocab_->getGapLogMask(); // [V]
+ y = y + graph->constant({ (int)gapLogMask.size() }, inits::fromVector(gapLogMask), Type::float32);
+
+ return y;
+#else
+ ABORT("getLogits() no longer supported for actual factored vocab"); // because it is infeasible
+#endif
+ }
+
+ void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
+ bool isValid = FactoredVocab::isFactorValid(factorIndex);
+ indices.push_back(isValid ? (WordIndex)factorIndex : 0);
+ masks.push_back((float)isValid);
+ }
+
+ std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
+ if (!factoredVocab_) {
+ ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
+ return {MaskedFactorIndices(words)};
+ }
+ auto numGroups = factoredVocab_->getNumGroups();
+ std::vector<MaskedFactorIndices> res(numGroups);
+ for (size_t g = 0; g < numGroups; g++) {
+ auto& resg = res[g];
+ resg.reserve(words.size());
+ for (const auto& word : words)
+ resg.push_back(factoredVocab_->getFactor(word, g));
+ }
+ return res;
+ }
+
+ //// use first factor of each word to determine whether it has a specific factor
+ //std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0
+ // std::vector<float> res;
+ // res.reserve(words.size());
+ // for (const auto& word : words) {
+ // auto lemma = factoredVocab_->getFactor(word, 0);
+ // res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
+ // }
+ // return res;
+ //}
+
+ // return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
+ // If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
+ std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
+ size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size();
+ std::vector<float> res;
+ res.reserve(n);
+ // @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab
+ for (size_t i = 0; i < n; i++) {
+ auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
+ res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
+ }
+ return res;
+ }
+
+ Logits Logits::applyUnaryFunction(const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
+ std::vector<Ptr<RationalLoss>> newLogits;
+ for (const auto& l : logits_)
+ newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
+ return Logits(std::move(newLogits), factoredVocab_);
+ }
+
+ Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const {
+ std::vector<Ptr<RationalLoss>> newLogits;
+ bool first = true;
+ for (const auto& l : logits_) {
+ newLogits.emplace_back(New<RationalLoss>((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others
+ first = false;
+ }
+ return Logits(std::move(newLogits), factoredVocab_);
+ }
+
+ // @TODO: code dup with above; we can merge it into applyToRationalLoss()
+ Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_
+ std::vector<Ptr<RationalLoss>> newLogits;
+ for (const auto& l : logits_)
+ newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
+ return Logits(std::move(newLogits), factoredVocab_);
+ }
+
+ namespace mlp {
+ /*private*/ void Output::lazyConstruct(int inputDim) {
+ // We must construct lazily since we won't know tying nor input dim in constructor.
+ if (Wt_)
+ return;
+
+ auto name = options_->get<std::string>("prefix");
+ auto numOutputClasses = options_->get<int>("dim");
+
+ factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
+ if (factoredVocab_) {
+ numOutputClasses = (int)factoredVocab_->factorVocabSize();
+ LOG_ONCE(info, "[embedding] Factored outputs enabled");
+ }
+
+ if(tiedParam_) {
+ Wt_ = tiedParam_;
+ } else {
+ if (graph_->get(name + "_W")) { // support of legacy models that did not transpose
+ Wt_ = graph_->param(name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
+ isLegacyUntransposedW = true;
+ }
+ else // this is the regular case:
+ Wt_ = graph_->param(name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
+ }
+
+ b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros());
+
+ /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
+ if (lemmaDimEmb > 0) { // > 0 means to embed the (expected) word with a different embedding matrix
+#define HARDMAX_HACK
+#ifdef HARDMAX_HACK
+ lemmaDimEmb = lemmaDimEmb & 0xfffffffe; // hack to select hard-max: use an odd number
+#endif
+ auto range = factoredVocab_->getGroupRange(0);
+ auto lemmaVocabDim = (int)(range.second - range.first);
+ auto initFunc = inits::glorotUniform(/*fanIn=*/true, /*fanOut=*/false); // -> embedding vectors have roughly unit length
+ lemmaEt_ = graph_->param(name + "_lemmaEt", {lemmaDimEmb, lemmaVocabDim}, initFunc); // [L x U] L=lemmaDimEmb; transposed for speed
+ }
+ }
+
+ Logits Output::applyAsLogits(Expr input) /*override final*/ {
+ lazyConstruct(input->shape()[-1]);
+
+ if (shortlist_ && !cachedShortWt_) { // shortlisted versions of parameters are cached within one batch, then clear()ed
+ cachedShortWt_ = index_select(Wt_, isLegacyUntransposedW ? -1 : 0, shortlist_->indices());
+ cachedShortb_ = index_select(b_ , -1, shortlist_->indices());
+ }
+
+ if (factoredVocab_) {
+ auto graph = input->graph();
+
+ // project each factor separately
+ auto numGroups = factoredVocab_->getNumGroups();
+ std::vector<Ptr<RationalLoss>> allLogits(numGroups, nullptr); // (note: null entries for absent factors)
+ Expr input1 = input; // [B... x D]
+ Expr Plemma = nullptr; // used for lemmaDimEmb=-1
+ Expr inputLemma = nullptr; // used for lemmaDimEmb=-2, -3
+ for (size_t g = 0; g < numGroups; g++) {
+ auto range = factoredVocab_->getGroupRange(g);
+ if (g > 0 && range.first == range.second) // empty entry
+ continue;
+ ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g-1).second, "Factor groups must be consecutive (group {} vs predecessor)", g);
+ // slice this group's section out of W_
+ Expr factorWt, factorB;
+ if (g == 0 && shortlist_) {
+ factorWt = cachedShortWt_;
+ factorB = cachedShortb_;
+ }
+ else {
+ factorWt = slice(Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
+ factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
+ }
+ /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
+ if ((lemmaDimEmb == -2 || lemmaDimEmb == -3) && g > 0) { // -2/-3 means a gated transformer-like structure (-3 = hard-max)
+ LOG_ONCE(info, "[embedding] using lemma conditioning with gate");
+ // this mimics one transformer layer
+ // - attention over two inputs:
+ // - e = current lemma. We use the original embedding vector; specifically, expectation over all lemmas.
+ // - input = hidden state FF(h_enc+h_dec)
+ // - dot-prod attention to allow both sides to influence (unlike our recurrent self-attention)
+ // - multi-head to allow for multiple conditions to be modeled
+ // - add & norm, for gradient flow and scaling
+ // - FF layer --this is expensive; it is per-factor
+ // multi-head attention
+ int inputDim = input->shape()[-1];
+ int heads = 8;
+ auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g);
+ auto Wq = graph_->param(name + "_Wq", { inputDim, inputDim }, inits::glorotUniform());
+ auto Wk = graph_->param(name + "_Wk", { inputDim, inputDim }, inits::glorotUniform());
+ auto Wv = graph_->param(name + "_Wv", { inputDim, inputDim }, inits::glorotUniform());
+ auto toMultiHead = [&](Expr x, int heads) {
+ const auto& shape = x->shape();
+ int inputDim = shape[-1];
+ int otherDim = shape.elements() / inputDim;
+ ABORT_IF(inputDim / heads * heads != inputDim, "inputDim ({}) must be multiple of number of heads ({})", inputDim, heads);
+ return reshape(x, { otherDim, heads, 1, inputDim / heads });
+ };
+ input1 = inputLemma;
+ auto qm = toMultiHead(dot(input1, Wq), heads); // [B... x H x D/H] projected query
+ auto kdm = toMultiHead(dot(input1 - input, Wk), heads); // [B... x H x D/H] the two data vectors projected as keys. Use diff and sigmoid, instead of softmax.
+ auto vem = toMultiHead(dot(input1, Wv), heads); // [B... x H x D/H] one of the two data vectors projected as values
+ auto vim = toMultiHead(dot( input, Wv), heads); // [B... x H x D/H] the other
+ auto zm = bdot(qm, kdm, false, true); // [B... x H x 1]
+ auto sm = sigmoid(zm); // [B... x H x 1]
+ auto rm = sm * (vem - vim) + vim; // [B... x H x D/H]
+ auto r = reshape(rm, input->shape()); // [B... x D]
+ // add & norm
+ input1 = r + input1;
+ input1 = layerNorm(input1, name + "_att");
+ // FF layer
+ auto ffnDropProb = 0.1f; // @TODO: get as a parameter
+ auto ffnDim = inputDim * 2; // @TODO: get as a parameter
+ auto f = denseInline(input1, name + "_ffn", /*suffix=*/"1", ffnDim, (ActivationFunction*)relu, ffnDropProb);
+ f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
+ // add & norm
+ input1 = f + input1;
+ input1 = layerNorm(input1, name + "_ffn");
+ }
+ // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a matrix
+ auto factorLogits = affine(input1, factorWt, factorB, false, /*transB=*/isLegacyUntransposedW ? false : true, /*scale=*/1.0f); // [B... x U] factor logits
+ // optionally add lemma-dependent bias
+ if (Plemma) { // [B... x U0]
+ int lemmaVocabDim = Plemma->shape()[-1];
+ int factorVocabDim = factorLogits->shape()[-1];
+ auto name = options_->get<std::string>("prefix");
+ Expr lemmaBt = graph_->param(name + "_lemmaBt_" + std::to_string(g), {factorVocabDim, lemmaVocabDim}, inits::zeros()); // [U x U0] U0=#lemmas one bias per class per lemma
+ auto b = dot(Plemma, lemmaBt, false, true); // [B... x U]
+ factorLogits = factorLogits + b;
+ }
+ allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
+ // optionally add a soft embedding of lemma back to create some lemma dependency
+ // @TODO: if this works, move it into lazyConstruct
+ if (lemmaDimEmb == -2 && g == 0) { // -2 means a gated transformer-like structure
+ LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version");
+ // get expected lemma embedding vector
+ auto factorLogSoftmax = logsoftmax(factorLogits); // [B... x U] note: with shortlist, this is not the full lemma set
+ auto factorSoftmax = exp(factorLogSoftmax);
+ inputLemma = dot(factorSoftmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
+ }
+ else if (lemmaDimEmb == -3 && g == 0) { // same as -2 except with hard max
+ LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version");
+ // get max-lemma embedding vector
+ auto maxVal = max(factorLogits, -1); // [B... x U] note: with shortlist, this is not the full lemma set
+ auto factorHardmax = eq(factorLogits, maxVal);
+ inputLemma = dot(factorHardmax, factorWt, false, /*transB=*/isLegacyUntransposedW ? true : false); // [B... x D]
+ }
+ else if (lemmaDimEmb == -1 && g == 0) { // -1 means learn a lemma-dependent bias
+ ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented");
+ LOG_ONCE(info, "[embedding] using lemma-dependent bias");
+ auto factorLogSoftmax = logsoftmax(factorLogits); // (we do that again later, CSE will kick in)
+ auto z = /*stopGradient*/(factorLogSoftmax);
+ Plemma = exp(z); // [B... x U]
+ }
+ else if (lemmaDimEmb > 0 && g == 0) { // > 0 means learn a re-embedding matrix
+ LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
+ // compute softmax. We compute logsoftmax() separately because this way, computation will be reused later via CSE
+ auto factorLogSoftmax = logsoftmax(factorLogits);
+ auto factorSoftmax = exp(factorLogSoftmax);
+#ifdef HARDMAX_HACK
+ bool hardmax = (lemmaDimEmb & 1) != 0; // odd value triggers hardmax for now (for quick experimentation)
+ if (hardmax) {
+ lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
+ LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
+ auto maxVal = max(factorSoftmax, -1);
+ factorSoftmax = eq(factorSoftmax, maxVal);
+ }
+#endif
+ // re-embedding lookup, soft-indexed by softmax
+ if (shortlist_ && !cachedShortLemmaEt_) // short-listed version of re-embedding matrix
+ cachedShortLemmaEt_ = index_select(lemmaEt_, -1, shortlist_->indices());
+ auto e = dot(factorSoftmax, cachedShortLemmaEt_ ? cachedShortLemmaEt_ : lemmaEt_, false, true); // [B... x L]
+ // project it back to regular hidden dim
+ int inputDim = input1->shape()[-1];
+ auto name = options_->get<std::string>("prefix");
+ // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also length 1
+ Expr lemmaWt = inputDim == lemmaDimEmb ? nullptr : graph_->param(name + "_lemmaWt", { inputDim, lemmaDimEmb }, inits::glorotUniform()); // [D x L] D=hidden-vector dimension
+ auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e; // [B... x D]
+ // augment the original hidden vector with this additional information
+ input1 = input1 + f;
+ }
+ }
+ return Logits(std::move(allLogits), factoredVocab_);
+ }
+ else if (shortlist_)
+ return Logits(affine(input, cachedShortWt_, cachedShortb_, false, /*transB=*/isLegacyUntransposedW ? false : true));
+ else
+ return Logits(affine(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true));
+ }
+ }
+
+ Embedding::Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {
+ std::string name = opt<std::string>("prefix");
+ int dimVoc = opt<int>("dimVocab");
+ int dimEmb = opt<int>("dimEmb");
+
+ bool fixed = opt<bool>("fixed", false);
+
+ factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
+ if (factoredVocab_) {
+ dimVoc = (int)factoredVocab_->factorVocabSize();
+ LOG_ONCE(info, "[embedding] Factored embeddings enabled");
+ }
+
+ // Embedding layer initialization should depend only on embedding size, hence fanIn=false
+ auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true); // -> embedding vectors have roughly unit length
+
+ if (options_->has("embFile")) {
+ std::string file = opt<std::string>("embFile");
+ if (!file.empty()) {
+ bool norm = opt<bool>("normalization", false);
+ initFunc = inits::fromWord2vec(file, dimVoc, dimEmb, norm);
+ }
+ }
+
+ E_ = graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
+ }
+
+ // helper to embed a sequence of words (given as indices) via factored embeddings
+ /*private*/ Expr Embedding::multiRows(const Words& data, float dropProb) const
+ {
+ auto graph = E_->graph();
+ auto factoredData = factoredVocab_->csr_rows(data);
+ // multi-hot factor vectors are represented as a sparse CSR matrix
+ // [row index = word position index] -> set of factor indices for word at this position
+ ABORT_IF(factoredData.shape != Shape({(int)factoredData.offsets.size()-1/*=rows of CSR*/, E_->shape()[0]}), "shape mismatch??");
+ // the CSR matrix is passed in pieces
+ auto weights = graph->constant({ (int)factoredData.weights.size() }, inits::fromVector(factoredData.weights), Type::float32);
+ auto indices = graph->constant({ (int)factoredData.indices.size() }, inits::fromVector(factoredData.indices), Type::uint32);
+ auto offsets = graph->constant({ (int)factoredData.offsets.size() }, inits::fromVector(factoredData.offsets), Type::uint32);
+ // apply dropout
+ // We apply it to the weights, i.e. factors get dropped out separately, but always as entire vectors.
+ weights = dropout(weights, dropProb);
+ // perform the product
+ return csr_dot(factoredData.shape, weights, indices, offsets, E_);
+ }
+
+ std::tuple<Expr/*embeddings*/, Expr/*mask*/> Embedding::apply(Ptr<data::SubBatch> subBatch) const /*override final*/ {
+ auto graph = E_->graph();
+ int dimBatch = (int)subBatch->batchSize();
+ int dimEmb = E_->shape()[-1];
+ int dimWidth = (int)subBatch->batchWidth();
+
+ // factored embeddings:
+ // - regular:
+ // - y = x @ E x:[B x 1ofV] ; E:[V x D] ; y:[B x D]
+ // - factored:
+ // - u = x @ M one-hot to U-dimensional multi-hot (all factors in one concatenated space)
+ // - each row of M contains the set of factors for one word => we want a CSR matrix
+ // - y = (x @ M) @ E (x:[B x 1ofV] ; M:[V x U]) ; E:[U x D] ; y:[B x D]
+ // - first compute x @ M on the CPU
+ // - (Uvalues, Uindices, Uoffsets) = csr_rows(Mvalues, Mindices, Moffsets, subBatch->data()):
+ // - shape (U, specifically) not actually needed here
+ // - foreach input x[i]
+ // - locate row M[i,*]
+ // - copy through its index values (std::vector<push_back>)
+ // - create a matching ones vector (we can keep growing)
+ // - convert to GPU-side CSR matrix. CSR matrix now has #rows equal to len(x)
+ // - CSR matrix product with E
+ // - csr_dot(Uvalues, Uindices, Uoffsets, E_, transposeU)
+ // - double-check if all dimensions are specified. Probably not for transpose (which would be like csc_dot()).
+ // - weighting:
+ // - core factors' gradients are sums over all words that use the factors;
+ // - core factors' embeddings move very fast
+ // - words will need to make up for the move; rare words cannot
+ // - so, we multiply each factor with 1/refCount
+ // - core factors get weighed down a lot
+ // - no impact on gradients, as Adam makes up for it; embeddings still move fast just as before
+ // - but forward pass weighs them down, so that all factors are in a similar numeric range
+ // - if it is required to be in a different range, the embeddings can still learn that, but more slowly
+
+ auto batchEmbeddings = apply(subBatch->data(), {dimWidth, dimBatch, dimEmb});
+#if 0
+ auto batchMask = graph->constant({dimWidth, dimBatch, 1},
+ inits::fromVector(subBatch->mask()));
+#else
+ // experimental: hide inline-fix source tokens from cross attention
+ auto batchMask = graph->constant({dimWidth, dimBatch, 1},
+ inits::fromVector(subBatch->crossMaskWithInlineFixSourceSuppressed()));
+#endif
+
+ return std::make_tuple(batchEmbeddings, batchMask);
+ }
+
+ Expr Embedding::apply(const Words& words, const Shape& shape) const /*override final*/ {
+ if (factoredVocab_) {
+ Expr selectedEmbs = multiRows(words, options_->get<float>("dropout", 0.0f)); // [(B*W) x E]
+ selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
+ //selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 }); // @TODO: replace with factor dropout
+ return selectedEmbs;
+ }
+ else
+ return applyIndices(toWordIndexVector(words), shape);
+ }
+
+ Expr Embedding::applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const /*override final*/ {
+ ABORT_IF(factoredVocab_, "Embedding: applyIndices must not be used with a factored vocabulary");
+ auto selectedEmbs = rows(E_, embIdx); // [(B*W) x E]
+ selectedEmbs = reshape(selectedEmbs, shape); // [W, B, E]
+ // @BUGBUG: We should not broadcast along dimBatch=[-2]. Then we can also dropout before reshape() (test that separately)
+ selectedEmbs = dropout(selectedEmbs, options_->get<float>("dropout", 0.0f), { selectedEmbs->shape()[-3], 1, 1 });
+ return selectedEmbs;
+ }
+
+ // standard encoder word embeddings
+ /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createEmbeddingLayer() const {
+ auto options = New<Options>(
+ "dimVocab", opt<std::vector<int>>("dim-vocabs")[batchIndex_],
+ "dimEmb", opt<int>("dim-emb"),
+ "dropout", dropout_,
+ "prefix", (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all")) ? "Wemb" : prefix_ + "_Wemb",
+ "fixed", embeddingFix_,
+ "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored embeddings
+ if(options_->hasAndNotEmpty("embedding-vectors")) {
+ auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
+ options->set(
+ "embFile", embFiles[batchIndex_],
+ "normalization", opt<bool>("embedding-normalization"));
+ }
+ return New<Embedding>(graph_, options);
+ }
+
+ // ULR word embeddings
+ /*private*/ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::createULREmbeddingLayer() const {
+ return New<ULREmbedding>(graph_, New<Options>(
+ "dimSrcVoc", opt<std::vector<int>>("dim-vocabs")[0], // ULR multi-lingual src
+ "dimTgtVoc", opt<std::vector<int>>("dim-vocabs")[1], // ULR monon tgt
+ "dimUlrEmb", opt<int>("ulr-dim-emb"),
+ "dimEmb", opt<int>("dim-emb"),
+ "ulr-dropout", opt<float>("ulr-dropout"),
+ "dropout", dropout_,
+ "ulrTrainTransform", opt<bool>("ulr-trainable-transformation"),
+ "ulrQueryFile", opt<std::string>("ulr-query-vectors"),
+ "ulrKeysFile", opt<std::string>("ulr-keys-vectors")));
+ }
+
+ // get embedding layer for this encoder or decoder
+ // This is lazy mostly because the constructors of the consuming objects are not
+ // guaranteed presently to have access to their graph.
+ Ptr<IEmbeddingLayer> EncoderDecoderLayerBase::getEmbeddingLayer(bool ulr) const {
+ if (embeddingLayers_.size() <= batchIndex_ || !embeddingLayers_[batchIndex_]) { // lazy
+ if (embeddingLayers_.size() <= batchIndex_)
+ embeddingLayers_.resize(batchIndex_ + 1);
+ if (ulr)
+ embeddingLayers_[batchIndex_] = createULREmbeddingLayer(); // embedding uses ULR
+ else
+ embeddingLayers_[batchIndex_] = createEmbeddingLayer();
+ }
+ return embeddingLayers_[batchIndex_];
+ }
+} // namespace marian
diff --git a/src/layers/generic.h b/src/layers/generic.h
index b5c53b46..a3b9bac4 100755
--- a/src/layers/generic.h
+++ b/src/layers/generic.h
@@ -5,47 +5,167 @@
#include "data/shortlist.h"
#include "layers/factory.h"
-namespace marian {
-namespace mlp {
-/**
- * @brief Activation functions
- */
-enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
-} // namespace mlp
-} // namespace marian
-
-YAML_REGISTER_TYPE(marian::mlp::act, int)
+namespace marian { namespace mlp {
+ /**
+ * @brief Activation functions
+ */
+ enum struct act : int { linear, tanh, sigmoid, ReLU, LeakyReLU, PReLU, swish };
+}}
namespace marian {
-namespace mlp {
-class Layer {
+// Each layer consists of LayerBase and IXXXLayer which defines one or more apply()
+// functions for the respective layer type (different layers may require different signatures).
+// This base class contains configuration info for creating parameters and executing apply().
+class LayerBase {
protected:
Ptr<ExpressionGraph> graph_;
Ptr<Options> options_;
public:
- Layer(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+ LayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: graph_(graph), options_(options) {}
template <typename T>
- T opt(const std::string key) {
+ T opt(const std::string key) const {
return options_->get<T>(key);
}
template <typename T>
- T opt(const std::string key, T defaultValue) {
+ T opt(const std::string key, const T& defaultValue) const {
return options_->get<T>(key, defaultValue);
}
+};
- virtual Expr apply(const std::vector<Expr>&) = 0;
+// Simplest layer interface: Unary function
+struct IUnaryLayer {
+ virtual ~IUnaryLayer() {}
virtual Expr apply(Expr) = 0;
+ virtual Expr apply(const std::vector<Expr>& es) {
+ ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
+ return apply(es.front());
+ }
+};
+
+struct IHasShortList {
+ virtual void setShortlist(Ptr<data::Shortlist> shortlist) = 0;
+ virtual void clear() = 0;
+};
+
+// Embedding from corpus sub-batch to (emb, mask)
+struct IEmbeddingLayer {
+ virtual std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const = 0;
+
+ virtual Expr apply(const Words& embIdx, const Shape& shape) const = 0;
+
+ // alternative from indices directly
+ virtual Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const = 0;
+ virtual ~IEmbeddingLayer() {}
+};
+
+// base class for Encoder and Decoder classes, which have embeddings and a batch index (=stream index)
+class EncoderDecoderLayerBase : public LayerBase {
+protected:
+ const std::string prefix_;
+ const bool embeddingFix_;
+ const float dropout_;
+ const bool inference_;
+ const size_t batchIndex_;
+ mutable std::vector<Ptr<IEmbeddingLayer>> embeddingLayers_; // (lazily created)
+
+ EncoderDecoderLayerBase(Ptr<ExpressionGraph> graph, Ptr<Options> options, const std::string& prefix, size_t batchIndex,
+ float dropout,
+ bool embeddingFix) :
+ LayerBase(graph, options),
+ prefix_(options->get<std::string>("prefix", prefix)),
+ embeddingFix_(embeddingFix),
+ dropout_(dropout),
+ inference_(options->get<bool>("inference", false)),
+ batchIndex_(options->get<size_t>("index", batchIndex)) {}
+
+ virtual ~EncoderDecoderLayerBase() {}
+
+private:
+ Ptr<IEmbeddingLayer> createEmbeddingLayer() const;
+ Ptr<IEmbeddingLayer> createULREmbeddingLayer() const;
+
+public:
+ // get embedding layer; lazily create on first call
+ Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const;
};
-class Dense : public Layer {
+class FactoredVocab;
+
+// To support factors, any output projection (that is followed by a softmax) must
+// retain multiple outputs, one for each factor. Such layer returns not a single Expr,
+// but a Logits object that contains multiple.
+// This allows to compute softmax values in a factored manner, where we never create
+// a fully expanded list of all factor combinations.
+class RationalLoss;
+class Logits {
+public:
+ Logits() {}
+ explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
+ logits_.push_back(logits);
+ }
+ explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
+ Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
+ : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
+ Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
+ Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
+ //Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
+ Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
+ Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
+ Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values
+
+ struct MaskedFactorIndices {
+ std::vector<WordIndex> indices; // factor index, or 0 if masked
+ std::vector<float> masks;
+ void reserve(size_t n) { indices.reserve(n); masks.reserve(n); }
+ void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries
+ MaskedFactorIndices() {}
+ MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case
+ };
+ std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices
+ Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
+ size_t getNumFactorGroups() const { return logits_.size(); }
+ bool empty() const { return logits_.empty(); }
+ Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_
+private:
+ // helper functions
+ Ptr<ExpressionGraph> graph() const;
+ Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data), Type::float32); }
+ Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data), Type::uint32); }
+ template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector
+ Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type
+ std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const;
+private:
+ // members
+ // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr
+ std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
+ Ptr<FactoredVocab> factoredVocab_;
+};
+
+// Unary function that returns a Logits object
+// Also implements IUnaryLayer, since Logits can be cast to Expr.
+// This interface is implemented by all layers that are of the form of a unary function
+// that returns multiple logits, to support factors.
+struct IUnaryLogitLayer : public IUnaryLayer {
+ virtual Logits applyAsLogits(Expr) = 0;
+ virtual Logits applyAsLogits(const std::vector<Expr>& es) {
+ ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
+ return applyAsLogits(es.front());
+ }
+ virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); }
+ virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); }
+};
+
+namespace mlp {
+
+class Dense : public LayerBase, public IUnaryLayer {
public:
Dense(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : Layer(graph, options) {}
+ : LayerBase(graph, options) {}
Expr apply(const std::vector<Expr>& inputs) override {
ABORT_IF(inputs.empty(), "No inputs");
@@ -55,7 +175,7 @@ public:
auto useLayerNorm = opt<bool>("layer-normalization", false);
auto useNematusNorm = opt<bool>("nematus-normalization", false);
- auto activation = opt<act>("activation", act::linear);
+ auto activation = (act)opt<int>("activation", (int)act::linear);
auto g = graph_;
@@ -68,24 +188,23 @@ public:
num = std::to_string(i);
Expr W = g->param(
- name + "_W" + num, {in->shape()[-1], dim}, inits::glorot_uniform);
- Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros);
+ name + "_W" + num, {in->shape()[-1], dim}, inits::glorotUniform());
+ Expr b = g->param(name + "_b" + num, {1, dim}, inits::zeros());
if(useLayerNorm) {
if(useNematusNorm) {
auto ln_s = g->param(
- name + "_ln_s" + num, {1, dim}, inits::from_value(1.f));
- auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros);
+ name + "_ln_s" + num, {1, dim}, inits::fromValue(1.f));
+ auto ln_b = g->param(name + "_ln_b" + num, {1, dim}, inits::zeros());
outputs.push_back(
layerNorm(affine(in, W, b), ln_s, ln_b, NEMATUS_LN_EPS));
} else {
auto gamma = g->param(
- name + "_gamma" + num, {1, dim}, inits::from_value(1.0));
+ name + "_gamma" + num, {1, dim}, inits::fromValue(1.0));
outputs.push_back(layerNorm(dot(in, W), gamma, b));
}
-
} else {
outputs.push_back(affine(in, W, b));
}
@@ -109,124 +228,121 @@ public:
Expr apply(Expr input) override { return apply(std::vector<Expr>({input})); }
};
-class Output : public Layer {
+class Output : public LayerBase, public IUnaryLogitLayer, public IHasShortList {
private:
- std::map<std::string, Expr> tiedParams_;
- Ptr<data::Shortlist> shortlist_;
-
- Expr W_;
+ // parameters held by this layer
+ Expr Wt_; // weight matrix is stored transposed for efficiency
Expr b_;
- bool transposeW_{false};
+ Expr lemmaEt_; // re-embedding matrix for lemmas [lemmaDimEmb x lemmaVocabSize]
+ bool isLegacyUntransposedW{false}; // legacy-model emulation: W is stored in non-transposed form
+ Expr cachedShortWt_; // short-listed version, cached (cleared by clear())
+ Expr cachedShortb_; // these match the current value of shortlist_
+ Expr cachedShortLemmaEt_;
+ Ptr<FactoredVocab> factoredVocab_;
+
+ // optional parameters set/updated after construction
+ Expr tiedParam_;
+ Ptr<data::Shortlist> shortlist_;
+ void lazyConstruct(int inputDim);
public:
Output(Ptr<ExpressionGraph> graph, Ptr<Options> options)
- : Layer(graph, options) {}
-
- void tie_transposed(const std::string& param, const std::string& tied) {
- tiedParams_[param] = graph_->get(tied);
+ : LayerBase(graph, options) {
+ clear();
}
- void set_shortlist(Ptr<data::Shortlist> shortlist) { shortlist_ = shortlist; }
-
- Expr apply(Expr input) override {
- if(!W_) {
- auto name = options_->get<std::string>("prefix");
- auto dim = options_->get<int>("dim");
- std::string nameW = "W";
-
- if(tiedParams_.count(nameW)) {
- transposeW_ = true;
- W_ = tiedParams_[nameW];
- if(shortlist_)
- W_ = rows(W_, shortlist_->indices());
- } else {
- W_ = graph_->param(name + "_" + nameW,
- {input->shape()[-1], dim},
- inits::glorot_uniform);
- if(shortlist_)
- W_ = cols(W_, shortlist_->indices());
- }
+ void tieTransposed(Expr tied) {
+ if (Wt_)
+ ABORT_IF(tiedParam_.get() != tied.get(), "Tied output projection cannot be changed once weights have been created");
+ else
+ tiedParam_ = tied;
+ }
- b_ = graph_->param(name + "_b", {1, dim}, inits::zeros);
- if(shortlist_)
- b_ = cols(b_, shortlist_->indices());
+ void setShortlist(Ptr<data::Shortlist> shortlist) override final {
+ if (shortlist_)
+ ABORT_IF(shortlist.get() != shortlist_.get(), "Output shortlist cannot be changed except after clear()");
+ else {
+ ABORT_IF(cachedShortWt_ || cachedShortb_ || cachedShortLemmaEt_, "No shortlist but cached parameters??");
+ shortlist_ = shortlist;
}
+ // cachedShortWt_ and cachedShortb_ will be created lazily inside apply()
+ }
- return affine(input, W_, b_, false, transposeW_);
+ // this is expected to be called in sync with graph->clear(), which invalidates
+ // cachedShortWt_ etc. in the graph's short-term cache
+ void clear() override final {
+ shortlist_ = nullptr;
+ cachedShortWt_ = nullptr;
+ cachedShortb_ = nullptr;
+ cachedShortLemmaEt_ = nullptr;
}
- virtual Expr apply(const std::vector<Expr>& /*inputs*/) override {
- ABORT("Not implemented");
- };
+ Logits applyAsLogits(Expr input) override final;
};
} // namespace mlp
-struct EmbeddingFactory : public Factory {
- EmbeddingFactory(Ptr<ExpressionGraph> graph) : Factory(graph) {}
+// A regular embedding layer.
+// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
+// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
+// EncoderDecoderLayerBase, which knows to pass on all required parameters from options.
+class Embedding : public LayerBase, public IEmbeddingLayer {
+ Expr E_;
+ Ptr<FactoredVocab> factoredVocab_;
+ Expr multiRows(const Words& data, float dropProb) const;
+public:
+ Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);
- Expr construct() {
- std::string name = opt<std::string>("prefix");
- int dimVoc = opt<int>("dimVocab");
- int dimEmb = opt<int>("dimEmb");
+ std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final;
- bool fixed = opt<bool>("fixed", false);
+ Expr apply(const Words& words, const Shape& shape) const override final;
- NodeInitializer initFunc = inits::glorot_uniform;
- if (options_->has("embFile")) {
- std::string file = opt<std::string>("embFile");
- if (!file.empty()) {
- bool norm = opt<bool>("normalization", false);
- initFunc = inits::from_word2vec(file, dimVoc, dimEmb, norm);
- }
- }
-
- return graph_->param(name, {dimVoc, dimEmb}, initFunc, fixed);
- }
+ Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final;
};
-
-struct ULREmbeddingFactory : public Factory {
-ULREmbeddingFactory(Ptr<ExpressionGraph> graph) : Factory(graph) {}
-
- std::vector<Expr> construct() {
+class ULREmbedding : public LayerBase, public IEmbeddingLayer {
+ std::vector<Expr> ulrEmbeddings_; // @TODO: These could now better be written as 6 named class members
+public:
+ ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options) : LayerBase(graph, options) {
std::string name = "url_embed"; //opt<std::string>("prefix");
int dimKeys = opt<int>("dimTgtVoc");
int dimQueries = opt<int>("dimSrcVoc");
int dimEmb = opt<int>("dimEmb");
int dimUlrEmb = opt<int>("dimUlrEmb"); // ULR mono embed size
bool fixed = opt<bool>("fixed", false);
- std::vector<Expr> ulrEmbeds;
- NodeInitializer initFunc = inits::glorot_uniform;
+
+ // Embedding layer initialization should depend only on embedding size, hence fanIn=false
+ auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true);
+
std::string queryFile = opt<std::string>("ulrQueryFile");
std::string keyFile = opt<std::string>("ulrKeysFile");
bool trainTrans = opt<bool>("ulrTrainTransform", false);
if (!queryFile.empty() && !keyFile.empty()) {
- initFunc = inits::from_word2vec(queryFile, dimQueries, dimUlrEmb, false);
+ initFunc = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
name = "ulr_query";
fixed = true;
auto query_embed = graph_->param(name, { dimQueries, dimUlrEmb }, initFunc, fixed);
- ulrEmbeds.push_back(query_embed);
+ ulrEmbeddings_.push_back(query_embed);
// keys embeds
- initFunc = inits::from_word2vec(keyFile, dimKeys, dimUlrEmb, false);
+ initFunc = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
name = "ulr_keys";
fixed = true;
auto key_embed = graph_->param(name, { dimKeys, dimUlrEmb }, initFunc, fixed);
- ulrEmbeds.push_back(key_embed);
+ ulrEmbeddings_.push_back(key_embed);
// actual trainable embedding
- initFunc = inits::glorot_uniform;
+ initFunc = inits::glorotUniform();
name = "ulr_embed";
fixed = false;
auto ulr_embed = graph_->param(name, {dimKeys , dimEmb }, initFunc, fixed); // note the reverse dim
- ulrEmbeds.push_back(ulr_embed);
+ ulrEmbeddings_.push_back(ulr_embed);
// init trainable src embedding
name = "ulr_src_embed";
auto ulr_src_embed = graph_->param(name, { dimQueries, dimEmb }, initFunc, fixed);
- ulrEmbeds.push_back(ulr_src_embed);
+ ulrEmbeddings_.push_back(ulr_src_embed);
// ulr transformation matrix
//initFunc = inits::eye(1.f); // identity matrix - is it ok to init wiht identity or shall we make this to the fixed case only
if (trainTrans) {
- initFunc = inits::glorot_uniform;
+ initFunc = inits::glorotUniform();
fixed = false;
}
else
@@ -235,21 +351,95 @@ ULREmbeddingFactory(Ptr<ExpressionGraph> graph) : Factory(graph) {}
fixed = true;
}
name = "ulr_transform";
- auto ulr_transform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed);
- ulrEmbeds.push_back(ulr_transform);
+ auto ulrTransform = graph_->param(name, { dimUlrEmb, dimUlrEmb }, initFunc, fixed);
+ ulrEmbeddings_.push_back(ulrTransform);
- initFunc = inits::from_value(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only
+ initFunc = inits::fromValue(1.f); // TBD: we should read sharable flags here - 1 means all sharable - 0 means no universal embeddings - should be zero for top freq only
fixed = true;
name = "ulr_shared";
auto share_embed = graph_->param(name, { dimQueries, 1 }, initFunc, fixed);
- ulrEmbeds.push_back(share_embed);
-
+ ulrEmbeddings_.push_back(share_embed);
}
+ }
+
+ std::tuple<Expr/*embeddings*/, Expr/*mask*/> apply(Ptr<data::SubBatch> subBatch) const override final {
+ auto queryEmbed = ulrEmbeddings_[0]; // Q : dimQueries*dimUlrEmb
+ auto keyEmbed = ulrEmbeddings_[1]; // K : dimKeys*dimUlrEmb
+ auto uniEmbed = ulrEmbeddings_[2]; // E : dimQueries*dimEmb
+ auto srcEmbed = ulrEmbeddings_[3]; // I : dimQueries*dimEmb
+ auto ulrTransform = ulrEmbeddings_[4]; // A : dimUlrEmb *dimUlrEmb
+ auto ulrSharable = ulrEmbeddings_[5]; // alpha : dimQueries*1
+ int dimBatch = (int)subBatch->batchSize();
+ int dimEmb = uniEmbed->shape()[-1];
+ int dimWords = (int)subBatch->batchWidth();
+ // D = K.A.QT
+ // dimm(K) = univ_tok_vocab*uni_embed_size
+ // dim A = uni_embed_size*uni_embed_size
+ // dim Q: uni_embed_size * total_merged_vocab_size
+ // dim D = univ_tok_vocab * total_merged_vocab_size
+ // note all above can be precombuted and serialized if A is not trainiable and during decoding (TBD)
+ // here we need to handle the mini-batch
+ // extract raws corresponding to Xs in this minibatch from Q
+ auto embIdx = toWordIndexVector(subBatch->data());
+ auto queryEmbeddings = rows(queryEmbed, embIdx);
+ auto srcEmbeddings = rows(srcEmbed, embIdx); // extract trainable src embeddings
+ auto alpha = rows(ulrSharable, embIdx); // extract sharable flags
+ auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb*dimUlrEmb
+ auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]);
+ qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magnitude with larger embeds sizes
+ auto z = dot(qt, keyEmbed, false, true); // query-key similarity
+ float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
+ z = dropout(z, dropProb);
+ float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
+ // temperature in softmax is to control randomness of predictions
+ // high temperature Softmax outputs are more close to each other
+ // low temperatures the softmax become more similar to "hardmax"
+ auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
+ auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
+ auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
+ auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb });
+ auto graph = ulrEmbeddings_.front()->graph();
+ auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
+ inits::fromVector(subBatch->mask()));
+ batchEmbeddings = dropout(batchEmbeddings, options_->get<float>("dropout", 0.0f), {batchEmbeddings->shape()[-3], 1, 1});
+ return std::make_tuple(batchEmbeddings, batchMask);
+ }
- return ulrEmbeds;
+ Expr apply(const Words& words, const Shape& shape) const override final {
+ return applyIndices(toWordIndexVector(words), shape);
+ }
+
+ Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
+ embIdx; shape;
+ ABORT("not implemented"); // @TODO: implement me
}
};
-typedef Accumulator<EmbeddingFactory> embedding;
-typedef Accumulator<ULREmbeddingFactory> ulr_embedding;
+// --- a few layers with built-in parameters created on the fly, without proper object
+// @TODO: change to a proper layer object
+
+// like affine() but with built-in parameters, activation, and dropout
+static inline
+Expr denseInline(Expr x, std::string prefix, std::string suffix, int outDim, const std::function<Expr(Expr)>& actFn = nullptr, float dropProb = 0.0f)
+{
+ auto graph = x->graph();
+
+ auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorotUniform());
+ auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros());
+
+ x = affine(x, W, b);
+ if (actFn)
+ x = actFn(x);
+ x = dropout(x, dropProb);
+ return x;
+}
+
+static inline
+Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) {
+ int dimModel = x->shape()[-1];
+ auto scale = x->graph()->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones());
+ auto bias = x->graph()->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros());
+ return marian::layerNorm(x, scale, bias, 1e-6f);
+}
+
} // namespace marian
diff --git a/src/layers/guided_alignment.h b/src/layers/guided_alignment.h
index df8607ca..f08d3f09 100755
--- a/src/layers/guided_alignment.h
+++ b/src/layers/guided_alignment.h
@@ -1,53 +1,75 @@
#pragma once
-#include "marian.h"
+#include "layers/loss.h"
+#include "common/logging.h"
namespace marian {
-static inline Expr guidedAlignmentCost(Ptr<ExpressionGraph> graph,
- Ptr<data::CorpusBatch> batch,
- Ptr<Options> options,
- Expr att) {
- int dimBatch = att->shape()[-2];
- int dimSrc = att->shape()[-3];
- int dimTrg = att->shape()[-1];
+static inline RationalLoss guidedAlignmentCost(Ptr<ExpressionGraph> /*graph*/,
+ Ptr<data::CorpusBatch> batch,
+ Ptr<Options> options,
+ Expr attention) { // [beam depth=1, max src length, batch size, tgt length]
- //debug(att, "Attention");
+ std::string guidedLossType = options->get<std::string>("guided-alignment-cost"); // @TODO: change "cost" to "loss"
+ float guidedLossWeight = options->get<float>("guided-alignment-weight");
- auto aln = graph->constant(att->shape(),
- inits::from_vector(batch->getGuidedAlignment()));
-
- //debug(aln, "Alignment");
-
- std::string guidedCostType
- = options->get<std::string>("guided-alignment-cost");
-
- std::string costType = options->get<std::string>("cost-type");
-
- int div = 1;
- if(costType == "ce-mean-words") {
- div = dimBatch * dimSrc * dimTrg;
- } else if(costType == "perplexity") {
- div = dimBatch * dimSrc * dimTrg;
- } else if(costType == "ce-sum") {
- div = 1;
- } else {
- div = dimBatch;
- }
-
- Expr alnCost;
+ const auto& shape = attention->shape(); // [beam depth=1, max src length, batch size, tgt length]
float epsilon = 1e-6f;
- if(guidedCostType == "mse") {
- alnCost = sum(flatten(square(att - aln))) / (float)(2 * div);
- } else if(guidedCostType == "mult") {
- alnCost = -log(sum(flatten(att * aln)) + epsilon) / (float)div;
- } else if(guidedCostType == "ce") {
- alnCost = -sum(flatten(aln * log(att + epsilon))) / (float)div;
+ Expr alignmentLoss; // sum up loss over all attention/alignment positions
+ size_t numLabels;
+ if(guidedLossType == "ce") {
+ // normalizedAlignment is multi-hot, but ce requires normalized probabilities, so need to normalize to P(s|t)
+ auto dimBatch = shape[-2];
+ auto dimTrgWords = shape[-1];
+ auto dimSrcWords = shape[-3];
+ ABORT_IF(shape[-4] != 1, "Guided alignments with beam??");
+ auto normalizedAlignment = batch->getGuidedAlignment(); // [dimSrcWords, dimBatch, dimTrgWords] flattened, matches shape of 'attention'
+ auto srcBatch = batch->front();
+ const auto& srcMask = srcBatch->mask();
+ ABORT_IF(shape.elements() != normalizedAlignment.size(), "Attention-matrix and alignment shapes differ??");
+ ABORT_IF(dimBatch != batch->size() || dimTrgWords != batch->widthTrg() || dimSrcWords != batch->width(), "Attention-matrix and batch shapes differ??");
+ auto locate = [=](size_t s, size_t b, size_t t) { return ((s * dimBatch) + b) * dimTrgWords + t; };
+ for (size_t b = 0; b < dimBatch; b++) {
+ for (size_t t = 0; t < dimTrgWords; t++) {
+ for (size_t s = 0; s < dimSrcWords; s++)
+ ABORT_IF(locate(s, b, t) != batch->locateInGuidedAlignments(b, s, t), "locate() and locateInGuidedAlignments() differ??");
+ // renormalize the alignment such that it sums up to 1
+ float sum = 0;
+ for (size_t s = 0; s < dimSrcWords; s++)
+ sum += srcMask[srcBatch->locate(b, s)] * normalizedAlignment[locate(s, b, t)]; // these values are 0 or 1
+ if (sum != 0 && sum != 1)
+ for (size_t s = 0; s < dimSrcWords; s++)
+ normalizedAlignment[locate(s, b, t)] /= sum;
+ }
+ }
+ auto alignment = constant_like(attention, std::move(normalizedAlignment));
+ alignmentLoss = -sum(flatten(alignment * log(attention + epsilon)));
+ numLabels = batch->back()->batchWords();
+ ABORT_IF(numLabels > shape.elements() / shape[-3], "Num labels of guided alignment cost is off??");
} else {
- ABORT("Unknown alignment cost type");
+ auto alignment = constant_like(attention, batch->getGuidedAlignment());
+ if(guidedLossType == "mse")
+ alignmentLoss = sum(flatten(square(attention - alignment))) / 2.f;
+ else if(guidedLossType == "mult") // @TODO: I don't know what this criterion is for. Can we remove it?
+ alignmentLoss = -log(sum(flatten(attention * alignment)) + epsilon);
+ else
+ ABORT("Unknown alignment cost type: {}", guidedLossType);
+ // every position is a label as they should all agree
+ // @TODO: there should be positional masking here ... on the other hand, positions that are not
+ // in a sentence should always agree (both being 0). Lack of masking affects label count only which is
+ // probably negligible?
+ numLabels = shape.elements();
}
- float guidedScalar = options->get<float>("guided-alignment-weight");
- return guidedScalar * alnCost;
+ // Create label node, also weigh by scalar so labels and cost are in the same domain.
+ // Fractional label counts are OK. But only if combined as "sum".
+ // @TODO: It is ugly to check the multi-loss type here, but doing this right requires
+ // a substantial rewrite of the multi-loss architecture, which is planned anyways.
+ std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
+ if (multiLossType == "sum") // sum of sums
+ return RationalLoss(guidedLossWeight * alignmentLoss, guidedLossWeight * numLabels);
+ else
+ return RationalLoss(guidedLossWeight * alignmentLoss, (float)numLabels);
}
+
} // namespace marian
diff --git a/src/layers/loss.cpp b/src/layers/loss.cpp
index 03b79682..408f6fc7 100755
--- a/src/layers/loss.cpp
+++ b/src/layers/loss.cpp
@@ -2,122 +2,38 @@
namespace marian {
-Ptr<LossBase> LossFactory(Ptr<Options> options, bool inference) {
+// @TODO, simplify this. Currently here for back-compat
+Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference) {
float smoothing = inference ? 0.f : options->get<float>("label-smoothing");
+ float factorWeight = options->get<float>("factor-weight", 1.0f);
std::string costType = options->get<std::string>("cost-type", "ce-mean");
- if(costType == "ce-mean" || costType == "cross-entropy") {
- return New<CrossEntropyMeanLoss>(smoothing);
- } else if(costType == "ce-mean-words") {
- return New<CrossEntropyMeanWordsLoss>(smoothing);
- } else if(costType == "ce-sum") {
- return New<CrossEntropySumLoss>(smoothing);
- } else if(costType == "perplexity") {
- return New<PerplexityLoss>(smoothing);
- } else if(costType == "ce-rescore") {
- return New<CrossEntropyRescoreLoss>(smoothing);
- } else if(costType == "ce-rescore-mean") {
- return New<CrossEntropyRescoreMeanLoss>(smoothing);
- } else { // same as ce-mean
- return New<CrossEntropyMeanLoss>(smoothing);
+ bool unlikelihood = options->get<bool>("unlikelihood-loss", false);
+
+ if(costType == "ce-rescore") { // returns per-batch-item scores (while ce-mean reduces over batch)
+ return New<RescorerLoss>();
+ } else if(unlikelihood) {
+ ABORT_IF(!options->hasAndNotEmpty("data-weighting")
+ && options->get<std::string>("data-weighting-type") != "word",
+ "Unlikelihood loss training requires error annotation in form of per-target-label scores");
+ return New<SequenceUnlikelihoodLoss>(smoothing, factorWeight); // this is a mix of CE-loss and unlikelihood less depending on values given for data-weighting
+ } else { // same as ce-mean --@TODO: better check all allowed values, and fail for invalid ones. E.g. what about ce-sum?
+ return New<CrossEntropyLoss>(smoothing, factorWeight);
}
}
-Expr LossBase::getCrossEntropy(Expr logits,
- Expr indices,
- Expr mask,
- Expr weights) {
- auto ce = cross_entropy(logits, indices);
+// see loss.h for detailed explanations of each class
+Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options) {
+ std::string multiLossType = options->get<std::string>("multi-loss-type", "sum");
+ if(multiLossType == "sum") // sum of sums
+ return New<SumMultiRationalLoss>();
+ else if(multiLossType == "scaled") // sum of scaled sums, first element is reference scale
+ return New<ScaledMultiRationalLoss>();
+ else if(multiLossType == "mean") // sum of means
+ return New<MeanMultiRationalLoss>();
+ else
+ ABORT("Unknown multi-loss-type {}", multiLossType);
- if(smoothing_ > 0) {
- // @TODO: add this to CE kernels instead
- auto ceq = mean(logsoftmax(logits), /*axis=*/ -1);
- ce = (1 - smoothing_) * ce - smoothing_ * ceq;
- }
-
- if(mask)
- ce = ce * mask;
-
- if(weights)
- ce = ce * weights;
-
- return ce;
-}
-
-Expr CrossEntropyMeanLoss::getCost(Expr logits,
- Expr indices,
- Expr mask,
- Expr weights) {
- auto ce = getCrossEntropy(logits, indices, mask, weights);
- // Time axis (words): -3
- // Batch axis (sentences): -2
- // if(weights) {
- // return sum(sum(ce, /*axis =*/ -3) /*axis =*/ -2);
- // / sum(mean(mask * weights, /*axis =*/ -3) /*axis =*/ -2);
- // }
- // else {
- return mean(sum(ce, /*axis =*/ -3), /*axis =*/ -2);
- // }
-}
-
-Expr CrossEntropyMeanWordsLoss::getCost(Expr logits,
- Expr indices,
- Expr mask,
- Expr weights) {
- auto ce = getCrossEntropy(logits, indices, mask, weights);
- // if(weights) {
- // return (sum(sum(ce, /*axis =*/ -3), /*axis =*/ -2)
- // / sum(sum(mask * weights, /*axis =*/ -3), /*axis =*/ -2));
- // }
- // else {
- return sum(sum(ce, /*axis =*/ -3), /*axis =*/ -2) // sum CE over all words in the batch
- / sum(sum(mask, /*axis =*/ -3), /*axis =*/ -2); // divide by number of words (sum over mask)
- // }
-}
-
-Expr CrossEntropySumLoss::getCost(Expr logits,
- Expr indices,
- Expr mask,
- Expr weights) {
- auto ce = getCrossEntropy(logits, indices, mask, weights);
- // if(weights) {
- // return sum(sum(ce, /*axis =*/ -3), /*axis =*/ -2)
- // / mean(mean(mask * weights, /*axis =*/ -3), /*axis =*/ -2);
- // }
- // else {
- return sum(sum(ce, /*axis =*/ -3), /*axis =*/ -2);
- // }
-}
-
-Expr PerplexityLoss::getCost(Expr logits,
- Expr indices,
- Expr mask,
- Expr weights) {
- auto ce = getCrossEntropy(logits, indices, mask, weights);
- // if(weights) {
- // return exp(sum(sum(ce, /*axis =*/ -3), /*axis =*/ -2)
- // / sum(sum(mask * weights, /*axis =*/ -3), /*axis =*/ -2));
- // }
- // else {
- return exp(sum(sum(ce, /*axis =*/ -3), /*axis =*/ -2) // sum CE over all words in the batch
- / sum(sum(mask, /*axis =*/ -3), /*axis =*/ -2)); // divide by number of words (sum over mask)
- // }
-}
-
-Expr CrossEntropyRescoreLoss::getCost(Expr logits,
- Expr indices,
- Expr mask,
- Expr weights) {
- auto ce = getCrossEntropy(logits, indices, mask, weights);
- return -sum(ce, /*axis =*/ -3);
-}
-
-Expr CrossEntropyRescoreMeanLoss::getCost(Expr logits,
- Expr indices,
- Expr mask,
- Expr weights) {
- auto ce = getCrossEntropy(logits, indices, mask, weights);
- // divide by number of words in sentence
- return -sum(ce, /*axis =*/ -3) / sum(mask, /*axis =*/ -3);
+ return nullptr;
}
} // namespace marian
diff --git a/src/layers/loss.h b/src/layers/loss.h
index ebf71147..43e89c1d 100644..100755
--- a/src/layers/loss.h
+++ b/src/layers/loss.h
@@ -1,76 +1,447 @@
#pragma once
-#include "marian.h"
+#include "graph/expression_operators.h"
+#include "layers/generic.h" // for Logits (Frank's factor hack)
+#include "data/types.h"
namespace marian {
-class LossBase {
+
+/**
+ * We represent loss as pair of expressions, where loss_ is usually a sum
+ * of all accumulated loss values per label and count_ is the total number
+ * of labels over which the loss was collected.
+ *
+ * These two values can then be used to represent various cost variants -
+ * for instance label-wise cross-entropy or perplexity. Optimization is
+ * only performed with regard to the summed loss_.
+ *
+ * Since both, loss_ and count_ are dynamic graph nodes they can be further
+ * combined into larger structures. See multi-objective losses below.
+ *
+ * @TODO: This seems also used to hold a pair of (logits, mask)
+ */
+class RationalLoss {
protected:
- float smoothing_;
+ Expr loss_; // numerator
+ Expr count_; // denominator
+
+ RationalLoss() = default; // protected
+
+public:
+ RationalLoss(Expr loss, Expr count)
+ : loss_(loss), count_(count) {}
+
+ RationalLoss(Expr loss, float count)
+ : loss_(loss),
+ count_(constant_like(loss, inits::fromValue(count))) {}
+
+ RationalLoss(const RationalLoss& other)
+ : loss_(other.loss_), count_(other.count_) {}
+
+ virtual ~RationalLoss() = default;
+
+ Expr loss() const { return loss_; }
+
+ // @TODO: remove this function, as it does not add too much value over loss()->get(...)
+ template <typename T>
+ void loss(std::vector<T>& losses) const {
+ ABORT_IF(!loss_, "Loss has not been defined");
+ loss_->val()->get(losses);
+ }
+
+ template <typename T>
+ T loss() const { // this will fail if loss is not a single value
+ ABORT_IF(!loss_, "Loss has not been defined");
+ return loss_->val()->scalar<T>();
+ }
+
+ Expr count() const { return count_; }
+
+ // @TODO: remove this function, as it does not add too much value over count()->get(...)
+ template <typename T>
+ void count(std::vector<T>& labels) const {
+ ABORT_IF(!count_, "Labels have not been defined");
+ count_->val()->get(labels);
+ }
+
+ template <typename T>
+ T count() const { // this will fail if loss is not a single value
+ ABORT_IF(!count_, "Labels have not been defined");
+ return count_->val()->scalar<T>();
+ }
+
+ // @TODO: add a funtion for returning maybe ratio?
+
+ size_t size() const {
+ ABORT_IF(!count_, "Labels have not been defined");
+ return count_->shape().elements();
+ }
+};
+
+/**
+ * POD for accumulating loss values after forward/backward used in
+ * Scheduler for updating statistics. This can only be used after a
+ * successful forward step in a computation graph that owns the assigned
+ * RationalLoss object.
+ */
+struct StaticLoss {
+ float loss; // numerator
+ float count; // denominator
+
+ StaticLoss() : loss(0.f), count(0.f) {}
+
+ StaticLoss(const RationalLoss& dynamic)
+ : loss(dynamic.loss<float>()), count(dynamic.count<float>()) {}
+
+ StaticLoss& operator +=(const StaticLoss& other) {
+ loss = loss + other.loss;
+ count = count + other.count;
+ return *this;
+ }
+
+ void reset() {
+ loss = 0.f;
+ count = 0.f;
+ }
+};
+
+/**
+ * @brief Base class for multi-objective losses
+ * Base class for multi-objective losses which is a list of RationalLoss
+ * but also defines how to accumulate that list into a single RationalLoss
+ */
+class MultiRationalLoss : public RationalLoss {
+protected:
+ std::vector<RationalLoss> partialLosses_;
+
+ /**
+ * @brief Accumulation rule for losses
+ * In the default case this would just be a sum, see SumMultiRationalLoss, but there are
+ * special cases like ScaledMultiRationalLoss (scale other loses according to first label count)
+ * or MeanMultiRationalLoss (sum of means) where the accumulation is more complex.
+ */
+ virtual Expr accumulateLoss(const RationalLoss& current) = 0;
+
+ /**
+ * @brief Accumulation rule for labels
+ * Similar as above, the naive case is summation, but for instance MeanMultiRationalLoss
+ * is including all label counts in the loss hence label counts are always just 1 which is
+ * passed through without summation or other modifications.
+ */
+ virtual Expr accumulateCount(const RationalLoss& current) = 0;
+
+public:
+ MultiRationalLoss() : RationalLoss() {}
+
+ MultiRationalLoss(const RationalLoss& rl) : RationalLoss() {
+ push_back(rl);
+ }
+
+ virtual void push_back(const RationalLoss& current) {
+ loss_ = accumulateLoss(current);
+ count_ = accumulateCount(current);
+ partialLosses_.push_back(current);
+ }
+
+ const RationalLoss& operator[](size_t i) {
+ return partialLosses_[i];
+ }
+
+ auto begin() -> decltype(partialLosses_.begin()) const {
+ return partialLosses_.begin();
+ }
+
+ auto end() -> decltype(partialLosses_.end()) const {
+ return partialLosses_.end();
+ }
+
+ size_t size() const {
+ return partialLosses_.size();
+ }
+
+};
+
+/**
+ * @brief Simple sum of losses.
+ * Using this makes sense when the two loss types are similar in scale and
+ * number of labels. For instance two decoders over similarly sized vocabularies
+ */
+class SumMultiRationalLoss : public MultiRationalLoss {
+private:
+ virtual Expr accumulateLoss(const RationalLoss& current) override {
+ if(loss_)
+ return loss_ + current.loss();
+ else
+ return current.loss();
+ }
+
+ virtual Expr accumulateCount(const RationalLoss& current) override {
+ if(count_)
+ return count_ + current.count();
+ else
+ return current.count();
+ }
public:
- explicit LossBase(float smoothing = 0) : smoothing_(smoothing){};
-
- Expr getCrossEntropy(Expr logits, Expr indices, Expr mask, Expr weights);
- virtual Expr getCost(Expr logits,
- Expr indices,
- Expr mask,
- Expr weights = nullptr)
- = 0;
+ SumMultiRationalLoss() : MultiRationalLoss() {}
+ SumMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};
-/*
- * @brief The cross entropy loss function
+/**
+ * @brief Scaled sum of losses.
+ * This can weigh losses equally by choosing the first loss_0 as a reference
+ * and scaling all remaining losses loss_i by count_0 / count_i. Labels are
+ * summed up by the same rule. By this we simulate a sum of losses at similar
+ * scales. Dividing by scaled label counts yields a value close to an equally
+ * weighted sum of means.
*
- * A sum over words and average over sentences
+ * L = sum_i^N L_i + N/M sum_j^M L_j
+ *
+ * We set labels to N. When reporting L/N this is equvalient to sum of means.
+ * Compare to sum of means below where N is factored into the loss, but labels
+ * are set to 1.
*/
-class CrossEntropyMeanLoss : public LossBase {
+class ScaledMultiRationalLoss : public MultiRationalLoss {
+private:
+ virtual Expr accumulateLoss(const RationalLoss& current) override {
+ if(loss_) {
+ const auto& first = partialLosses_.front();
+ return loss_ + current.loss() * first.count() / current.count(); // scale up/down to match scale of first loss
+ } else {
+ return current.loss(); // first reference loss, keeps to scale with this one
+ }
+ }
+
+ virtual Expr accumulateCount(const RationalLoss& current) override {
+ if(count_) {
+ return count_; // Keep first label count // or: count_ + first.count() / current.count();
+ } else {
+ return current.count(); // This is the first loss
+ }
+ }
+
public:
- explicit CrossEntropyMeanLoss(float smoothing = 0) : LossBase(smoothing){};
- Expr getCost(Expr logits, Expr indices, Expr mask, Expr weights) override;
+ ScaledMultiRationalLoss() : MultiRationalLoss() {}
+ ScaledMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};
-/*
- * @brief The cross entropy loss function as an average over target tokens
+/**
+ * @brief Sum of mean losses.
+ * Not really a rational loss as labels are factored into loss. Contribution of
+ * losses is equal, same as for ScaledMultiRationalLoss, just divided by different
+ * number of labels. See:
+ *
+ * L = (1/N sum_i^N L_i + 1/M sum_j^M L_j) = (sum_i^N L_i + N/M sum_j^M L_j) / N
+ *
+ * We set labels to 1. During reporting, we would see the same numbers, but gradients
+ * are scaled differently which may result in different learning curves.
*/
-class CrossEntropyMeanWordsLoss : public LossBase {
+class MeanMultiRationalLoss : public MultiRationalLoss {
+private:
+ virtual Expr accumulateLoss(const RationalLoss& current) override {
+ if(loss_)
+ return loss_ + current.loss() / current.count();
+ else
+ return current.loss() / current.count();
+ }
+
+ virtual Expr accumulateCount(const RationalLoss& current) override {
+ if(count_)
+ return count_; // keep the existing '1'
+ else
+ return current.count()->graph()->ones({1}, current.loss()->value_type()); // just '1' as labels are factored into loss_
+ }
+
public:
- explicit CrossEntropyMeanWordsLoss(float smoothing = 0)
- : LossBase(smoothing){};
- Expr getCost(Expr logits, Expr indices, Expr mask, Expr weights) override;
+ MeanMultiRationalLoss() : MultiRationalLoss() {}
+ MeanMultiRationalLoss(const RationalLoss& rl) : MultiRationalLoss(rl) {}
};
-/*
- * @brief The cross entropy loss function as a sum over target tokens
+/**
+ * @brief Factory for multi-objective rational loss functions
*/
-class CrossEntropySumLoss : public LossBase {
+Ptr<MultiRationalLoss> newMultiLoss(Ptr<Options> options);
+
+//***********************************************************************************//
+// This needs to be refactored. Currently easiest route for backwards compat, but
+// still feels somewhat hacky.
+
+/**
+ * @brief Computes loss per given groundtruth label and then reduces to RationalLoss
+ */
+class LabelwiseLoss {
+protected:
+ std::vector<int> axes_;
+
+ virtual Expr compute(Logits logits, const Words& labels,
+ Expr mask = nullptr, Expr labelWeights = nullptr) = 0;
+
+ // label counts are available, reduce together with loss to obtain counts
+ RationalLoss reduce(Expr loss, Expr labels) {
+ ABORT_IF(!loss, "Loss has not been computed");
+ ABORT_IF(!labels, "Labels have not been computed");
+
+ Expr lossSum = cast(loss, Type::float32); // accumulate in float32
+ Expr labelsSum = cast(labels, Type::float32); // accumulate in float32
+ for(int i = 0; i < axes_.size(); ++i) {
+ lossSum = sum(lossSum, axes_[i]);
+ labelsSum = sum(labelsSum, axes_[i]);
+ }
+
+ return RationalLoss(lossSum, labelsSum);
+ }
+
+ // label counts are not available, assume every element of tensor corresponds to label count 1
+ RationalLoss reduce(Expr loss) {
+ ABORT_IF(!loss, "Loss has not been computed");
+
+ Expr lossSum = cast(loss, Type::float32);
+ for(int i = 0; i < axes_.size(); ++i)
+ lossSum = sum(lossSum, axes_[i]);
+
+ // reduction factor tells how over how many labels we reduced in total.
+ float reducedLabels = (float)loss->shape().elements() / (float)lossSum->shape().elements();
+ return RationalLoss(lossSum, reducedLabels);
+ }
+
public:
- explicit CrossEntropySumLoss(float smoothing = 0) : LossBase(smoothing){};
- Expr getCost(Expr logits, Expr indices, Expr mask, Expr weights) override;
+ LabelwiseLoss(const std::vector<int>& axes)
+ : axes_(axes) { }
+
+ virtual RationalLoss apply(Logits logits, const Words& labels,
+ Expr mask = nullptr, Expr labelWeights = nullptr) {
+ Expr loss = compute(logits, labels, mask, labelWeights);
+
+ if(mask)
+ return reduce(loss, mask); // mask can be used as element-wise label count with broadcasting
+ else
+ return reduce(loss); // we have no mask, assume all items are labels
+ }
};
-/*
- * @brief The perplexity loss function
+/**
+ * @brief Cross entropy loss across last axis, summed up over batch and time dimensions
*/
-class PerplexityLoss : public LossBase {
+class CrossEntropyLoss : public LabelwiseLoss {
public:
- explicit PerplexityLoss(float smoothing = 0) : LossBase(smoothing){};
- Expr getCost(Expr logits, Expr indices, Expr mask, Expr weights) override;
+ CrossEntropyLoss(float labelSmoothing, float factorWeight)
+ : CrossEntropyLoss(/*axes=*/{-2, -3}, labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
+
+ CrossEntropyLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
+ : LabelwiseLoss(axes), // cross-entropy already reduces over axis -1
+ labelSmoothing_(labelSmoothing), factorWeight_(factorWeight) {}
+
+ virtual ~CrossEntropyLoss() {}
+protected:
+ float labelSmoothing_; // interpolation factor for label smoothing, see below
+ float factorWeight_; // give extra weight to factors
+
+ virtual Expr compute(Logits logits, const Words& labels,
+ Expr mask = nullptr, Expr labelWeights = nullptr) override {
+ // logits may be factored; in that case, the getLoss() function computes one loss for each, and sums them up
+ int inFactor = false;
+ auto ce = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
+ logits = atleast_3d(logits); // we always assuma a time and batch dimension exists.
+ // for bert training or classification the time dimension is lot.
+ // Here safeguard against 2d classifier output, adds 1 on the left, non-op.
+ Expr ce = cast(cross_entropy(logits, indices), Type::float32);
+ if (inFactor && factorWeight_ != 1.0f) {
+ LOG_ONCE(info, "scaling factor losses with weight {}", factorWeight_);
+ ce = ce * factorWeight_;
+ }
+ if (labelSmoothing_ > 0) {
+ // ce = -sum_i y^_i log y_i(h)
+ // with smoothing:
+ // ce' = -sum_i ((1-labelSmoothing_) y^_i + labelSmoothing_/N) log y_i(h)
+ // = -(1-labelSmoothing_) sum_i y^_i log y_i(h) - labelSmoothing_ mean_i log y_i(h)
+ // = (1-labelSmoothing_) ce - labelSmoothing_ mean_i log y_i(h)
+ auto logits32 = cast(logits, Type::float32);
+ auto ceqNeg = mean(logits32, /*axis=*/ -1) - logsumexp(logits32, /*axis=*/ -1);
+ ce = (1 - labelSmoothing_) * ce - labelSmoothing_ * ceqNeg;
+ //ce = ce - labelSmoothing_ * (ce + ceqNeg); // writing it this way saves one op :)
+ inFactor = true;
+ }
+ return ce;
+ });
+
+ if(mask)
+ ce = ce * cast(mask, Type::float32);
+
+ if(labelWeights) {
+ // We currently do not know how to use target factors and word-level label weights together
+ bool wordlevel = labelWeights->shape()[-3] > 1; // Time-dimension is not trivially 1, hence we have word-level weights.
+ ABORT_IF(wordlevel && logits.getNumFactorGroups() > 1, "CE loss with word-level label weights is not implemented for factors");
+ ce = ce * cast(labelWeights, Type::float32);
+ }
+
+ return ce;
+ }
};
-/*
- * @brief The cross entropy loss function that keeps sentence-level costs
+
+/**
+ * @brief Unlikelihood loss across last axis, summed up over batch and time dimensions. This is an
+ * implementation of sequence-level unlikelihood loss from https://arxiv.org/abs/1908.04319.
+ * We rely on word-level label weights where 1 is correct and 0 is marking an error. If there are not
+ * zeros for a sentence it going to be trained with normal CE loss if there is at least one 0 it is going
+ * to flip over to use SUL for that sentence to penalize the selected word.
+ *
+ * SUL is implemented as:
+ * -log(gather(1 - softmax(logits), -1, indices))
+ *
+ * Factors are currently not supported.
*/
-class CrossEntropyRescoreLoss : public LossBase {
+class SequenceUnlikelihoodLoss : public CrossEntropyLoss {
public:
- explicit CrossEntropyRescoreLoss(float smoothing = 0) : LossBase(smoothing){};
- Expr getCost(Expr logits, Expr indices, Expr mask, Expr weights) override;
+ SequenceUnlikelihoodLoss(float labelSmoothing, float factorWeight)
+ : CrossEntropyLoss(labelSmoothing, factorWeight) {} // cross-entropy already reduces over axis -1
+
+ SequenceUnlikelihoodLoss(const std::vector<int>& axes, float labelSmoothing, float factorWeight)
+ : CrossEntropyLoss(axes, labelSmoothing, factorWeight) {}
+
+protected:
+ virtual Expr compute(Logits logits, const Words& labels,
+ Expr mask = nullptr, Expr labelWeights = nullptr) override {
+ auto ce = CrossEntropyLoss::compute(logits, labels, mask, /*labelWeights=*/nullptr); // don't pass label-weights to CE
+ if(!labelWeights)
+ return ce; // for validation, @TODO: maybe put rather abort or LOG_ONCE(warn, ...)?
+
+ // We currently do not know how to use target factors and word-level label weights together
+ ABORT_IF(logits.getNumFactorGroups() > 1, "Unlikelihood loss is not implemented for factors");
+
+ ABORT_IF(!mask, "mask is required"); // @TODO: check this, it seems weights for padding are by default 1, which would make this obsolete.
+ // use label weights, where 1 is GOOD and 0 is BAD. After inversion here, now 1 marks, mask again to eliminate padding (might be obsolete)
+ auto errorMask = (1.f - cast(labelWeights, Type::float32)) * cast(mask, Type::float32);
+
+ auto ceUl = logits.applyLossFunction(labels, [&](Expr logits, Expr indices) {
+ return cast(unlikelihood(logits, indices), Type::float32);
+ });
+
+ // compute if want to use CE or UL. If there are no errors train with CE, otherwise train _only on_ the errors with UL. This is the "mixed" training
+ // schedule from https://arxiv.org/abs/1908.04319. Providing labels with or without error scores we can easily switch between CE and UL.
+ auto onlyCe = eq(sum(errorMask, /*axis=*/-3), 0.f); // [1, 1, dimBatch, 1] - equal 1 if no errors are present
+ ceUl = errorMask * ceUl; // don't use for correct label or padding
+
+ auto cost = onlyCe * ce + (1.f - onlyCe) * ceUl; // ce or unlikelihood part are never simultanously used as cost per batch entry
+
+ return cost;
+ }
};
-class CrossEntropyRescoreMeanLoss : public LossBase {
+
+/**
+ * @brief Cross entropy in rescorer used for computing sentences-level log probabilities
+ */
+class RescorerLoss : public CrossEntropyLoss {
public:
- explicit CrossEntropyRescoreMeanLoss(float smoothing = 0) : LossBase(smoothing){};
- Expr getCost(Expr logits, Expr indices, Expr mask, Expr weights) override;
+ // sentence-wise CE, hence reduce only over time axis
+ // This class differs from CrossEntropy in the different 'axes' setting, and that label smoothing is disabled.
+ RescorerLoss() : CrossEntropyLoss(/*axes=*/{-3} /*time axis*/, /*smoothing=*/0.f, /*factorWeight=*/1.0f) {}
};
-Ptr<LossBase> LossFactory(Ptr<Options> options, bool inference);
+/**
+ * @brief Factory for label-wise loss functions
+ */
+Ptr<LabelwiseLoss> newLoss(Ptr<Options> options, bool inference);
+
} // namespace marian
diff --git a/src/layers/weight.cpp b/src/layers/weight.cpp
index ab7fe072..86d55c2e 100755..100644
--- a/src/layers/weight.cpp
+++ b/src/layers/weight.cpp
@@ -3,7 +3,7 @@
namespace marian {
Ptr<WeightingBase> WeightingFactory(Ptr<Options> options) {
- ABORT_IF(!options->has("data-weighting"),
+ ABORT_IF(!options->hasAndNotEmpty("data-weighting"),
"No data-weighting specified in options");
return New<DataWeighting>(options->get<std::string>("data-weighting-type"));
}
@@ -15,8 +15,16 @@ Expr DataWeighting::getWeights(Ptr<ExpressionGraph> graph,
bool sentenceWeighting = weightingType_ == "sentence";
int dimBatch = (int)batch->size();
int dimWords = sentenceWeighting ? 1 : (int)batch->back()->batchWidth();
+
+ // This would abort anyway in fromVector(...), but has clearer error message
+ // here for this particular case
+ ABORT_IF(batch->getDataWeights().size() != dimWords * dimBatch,
+ "Number of sentence/word-level weights ({}) does not match tensor size ({})",
+ batch->getDataWeights().size(), dimWords * dimBatch);
+
auto weights = graph->constant({1, dimWords, dimBatch, 1},
- inits::from_vector(batch->getDataWeights()));
- return weights;
+ inits::fromVector(batch->getDataWeights()));
+ return weights; // [1, dimWords, dimBatch, 1] in case of word-level weights or
+ // [1, 1, dimBatch, 1] in case of sentence-level weights
}
} // namespace marian
diff --git a/src/layers/weight.h b/src/layers/weight.h
index 787c2376..74cf47d6 100755..100644
--- a/src/layers/weight.h
+++ b/src/layers/weight.h
@@ -17,6 +17,7 @@ public:
virtual void debugWeighting(std::vector<float> /*weightedMask*/,
std::vector<float> /*freqMask*/,
Ptr<data::CorpusBatch> /*batch*/){};
+ virtual ~WeightingBase() {}
};
class DataWeighting : public WeightingBase {
diff --git a/src/layers/word2vec_reader.h b/src/layers/word2vec_reader.h
index a7e85592..4bfc6709 100755..100644
--- a/src/layers/word2vec_reader.h
+++ b/src/layers/word2vec_reader.h
@@ -33,12 +33,12 @@ public:
"Unexpected length of embedding vectors");
// Read embedding vectors into a map
- std::unordered_map<Word, std::vector<float>> word2vec;
+ std::unordered_map<WordIndex, std::vector<float>> word2vec;
while(io::getline(embFile, line)) {
values.clear();
utils::split(line, values);
- Word word = std::stoi(values.front());
+ WordIndex word = std::stoi(values.front());
if(word >= (size_t)dimVoc)
continue;
@@ -54,7 +54,7 @@ public:
embs.reserve(dimVoc * dimEmb);
// Populate output vector with embedding
- for(Word word = 0; word < (Word)dimVoc; ++word) {
+ for(WordIndex word = 0; word < (WordIndex)dimVoc; ++word) {
// For words not occuring in the file use uniform distribution
if(word2vec.find(word) == word2vec.end()) {
auto randVals = randomEmbeddings(dimVoc, dimEmb);
@@ -64,13 +64,14 @@ public:
}
}
+ embs.resize(dimVoc * dimEmb, 0); // @TODO: is it correct to zero out the remaining embeddings?
return embs;
}
private:
std::vector<float> randomEmbeddings(int dimVoc, int dimEmb) {
std::vector<float> values;
- values.reserve(dimEmb);
+ values.resize(dimEmb);
// Glorot numal distribution
float scale = sqrtf(2.0f / (dimVoc + dimEmb));
diff --git a/src/marian.h b/src/marian.h
index 3e34a7b7..3e34a7b7 100755..100644
--- a/src/marian.h
+++ b/src/marian.h
diff --git a/src/microsoft/quicksand.cpp b/src/microsoft/quicksand.cpp
index 962bda18..2f7687aa 100755
--- a/src/microsoft/quicksand.cpp
+++ b/src/microsoft/quicksand.cpp
@@ -1,7 +1,7 @@
#include "quicksand.h"
#include "marian.h"
-#ifdef MKL_FOUND
+#if MKL_FOUND
#include "mkl.h"
#endif
@@ -9,6 +9,12 @@
#include "translator/beam_search.h"
#include "translator/scorers.h"
#include "data/alignment.h"
+#include "data/vocab_base.h"
+#include "tensors/cpu/fbgemm/expression_graph_packable.h"
+
+#if USE_FBGEMM
+#include "fbgemm/Utils.h"
+#endif
namespace marian {
@@ -31,6 +37,18 @@ Ptr<Options> newOptions() {
return New<Options>();
}
+class VocabWrapper : public IVocabWrapper {
+ Ptr<Vocab> pImpl_;
+public:
+ VocabWrapper(Ptr<Vocab> vocab) : pImpl_(vocab) {}
+ virtual ~VocabWrapper() {}
+ WordIndex encode(const std::string& word) const override { return (*pImpl_)[word].toWordIndex(); }
+ std::string decode(WordIndex id) const override { return (*pImpl_)[Word::fromWordIndex(id)]; }
+ size_t size() const override { return pImpl_->size(); }
+ void transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const override { pImpl_->transcodeToShortlistInPlace(ptr, num); }
+ Ptr<Vocab> getVocab() const { return pImpl_; }
+};
+
class BeamSearchDecoder : public IBeamSearchDecoder {
private:
Ptr<ExpressionGraph> graph_;
@@ -38,20 +56,26 @@ private:
std::vector<Ptr<Scorer>> scorers_;
+ std::vector<Ptr<Vocab>> vocabs_;
+
public:
BeamSearchDecoder(Ptr<Options> options,
const std::vector<const void*>& ptrs,
- Word eos)
- : IBeamSearchDecoder(options, ptrs, eos) {
-
+ const std::vector<Ptr<IVocabWrapper>>& vocabs)
+ : IBeamSearchDecoder(options, ptrs) {
+
+ // copy the vocabs
+ for (auto vi : vocabs)
+ vocabs_.push_back(std::dynamic_pointer_cast<VocabWrapper>(vi)->getVocab());
+
// setting 16-bit optimization to false for now. Re-enable with better caching or pre-computation
- graph_ = New<ExpressionGraph>(/*inference=*/true, /*optimize=*/false);
+ graph_ = New<ExpressionGraph>(/*inference=*/true);
DeviceId deviceId{0, DeviceType::cpu};
device_ = New<cpu::WrappedDevice>(deviceId);
graph_->setDevice(deviceId, device_);
-#ifdef MKL_FOUND
+#if MKL_FOUND
mkl_set_num_threads(options->get<int>("mkl-threads", 1));
#endif
@@ -70,9 +94,9 @@ public:
modelOpts->merge(options_);
modelOpts->merge(config);
- std::cerr << modelOpts->str() << std::flush;
+ std::cerr << modelOpts->asYamlString() << std::flush; // @TODO: take a look at why this is even here.
- auto encdec = models::from_options(modelOpts, models::usage::translation);
+ auto encdec = models::createModelFromOptions(modelOpts, models::usage::translation);
if(io::isBin(models[i]) && ptrs_[i] != nullptr) {
// if file ends in *.bin and has been mapped by QuickSAND
@@ -94,7 +118,7 @@ public:
QSNBestBatch decode(const QSBatch& qsBatch,
size_t maxLength,
- const std::unordered_set<Word>& shortlist) override {
+ const std::unordered_set<WordIndex>& shortlist) override {
if(shortlist.size() > 0) {
auto shortListGen = New<data::FakeShortlistGenerator>(shortlist);
for(auto scorer : scorers_)
@@ -103,40 +127,40 @@ public:
// form source batch, by interleaving the words over sentences in the batch, and setting the mask
size_t batchSize = qsBatch.size();
- auto subBatch = New<data::SubBatch>(batchSize, maxLength, nullptr);
+ auto subBatch = New<data::SubBatch>(batchSize, maxLength, vocabs_[0]);
for(size_t i = 0; i < maxLength; ++i) {
for(size_t j = 0; j < batchSize; ++j) {
const auto& sent = qsBatch[j];
if(i < sent.size()) {
size_t idx = i * batchSize + j;
- subBatch->data()[idx] = (unsigned int)sent[i];
+ subBatch->data()[idx] = marian::Word::fromWordIndex(sent[i]);
subBatch->mask()[idx] = 1;
}
}
}
- std::vector<Ptr<data::SubBatch>> subBatches;
- subBatches.push_back(subBatch);
+ auto tgtSubBatch = New<data::SubBatch>(batchSize, 0, vocabs_[1]); // only holds a vocab, but data is dummy
+ std::vector<Ptr<data::SubBatch>> subBatches{ subBatch, tgtSubBatch };
std::vector<size_t> sentIds(batchSize, 0);
auto batch = New<data::CorpusBatch>(subBatches);
batch->setSentenceIds(sentIds);
// decode
- auto search = New<BeamSearch>(options_, scorers_, eos_);
+ auto search = New<BeamSearch>(options_, scorers_, vocabs_[1]);
Histories histories = search->search(graph_, batch);
// convert to QuickSAND format
QSNBestBatch qsNbestBatch;
for(const auto& history : histories) { // loop over batch entries
QSNBest qsNbest;
- NBestList nbestHyps = history->NBest(SIZE_MAX); // request as many N as we have
+ NBestList nbestHyps = history->nBest(SIZE_MAX); // request as many N as we have
for (const Result& result : nbestHyps) { // loop over N-best entries
// get hypothesis word sequence and normalized sentence score
auto words = std::get<0>(result);
auto score = std::get<2>(result);
// determine alignment if present
AlignmentSets alignmentSets;
- if (options_->has("alignment"))
+ if (options_->hasAndNotEmpty("alignment"))
{
float alignmentThreshold;
auto alignment = options_->get<std::string>("alignment"); // @TODO: this logic now exists three times in Marian
@@ -147,14 +171,14 @@ public:
else
alignmentThreshold = std::max(std::stof(alignment), 0.f);
auto hyp = std::get<1>(result);
- data::WordAlignment align = data::ConvertSoftAlignToHardAlign(hyp->TracebackAlignment(), alignmentThreshold);
+ data::WordAlignment align = data::ConvertSoftAlignToHardAlign(hyp->tracebackAlignment(), alignmentThreshold);
// convert to QuickSAND format
alignmentSets.resize(words.size());
- for (const auto& p : align) // @TODO: Does the feature_model param max_alignment_links apply here?
- alignmentSets[p.tgtPos].insert({p.srcPos, p.prob});
+ for (const auto& p : align)
+ alignmentSets[p.tgtPos].insert({p.srcPos, p.prob}); // [trgPos] -> {(srcPos, P(srcPos|trgPos))}
}
// form hypothesis to return
- qsNbest.emplace_back(words, std::move(alignmentSets), score);
+ qsNbest.emplace_back(toWordIndexVector(words), std::move(alignmentSets), score);
}
qsNbestBatch.push_back(qsNbest);
}
@@ -165,8 +189,88 @@ public:
Ptr<IBeamSearchDecoder> newDecoder(Ptr<Options> options,
const std::vector<const void*>& ptrs,
- Word eos) {
- return New<BeamSearchDecoder>(options, ptrs, eos);
+ const std::vector<Ptr<IVocabWrapper>>& vocabs,
+ WordIndex eosDummy) { // @TODO: remove this parameter
+ marian::setThrowExceptionOnAbort(true); // globally defined to throw now
+ ABORT_IF(marian::Word::fromWordIndex(eosDummy) != std::dynamic_pointer_cast<VocabWrapper>(vocabs[1])->getVocab()->getEosId(), "Inconsistent eos vs. vocabs_[1]");
+
+ return New<BeamSearchDecoder>(options, ptrs, vocabs/*, eos*/);
+}
+
+std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocabPaths) {
+ std::vector<Ptr<IVocabWrapper>> res(vocabPaths.size());
+ for (size_t i = 0; i < vocabPaths.size(); i++) {
+ if (i > 0 && vocabPaths[i] == vocabPaths[i-1]) {
+ res[i] = res[i-1];
+ LOG(info, "[data] Input {} sharing vocabulary with input {}", i, i-1);
+ }
+ else {
+ auto vocab = New<Vocab>(New<Options>(), i); // (empty options, since they are only used for creating vocabs)
+ auto size = vocab->load(vocabPaths[i]);
+ LOG(info, "[data] Loaded vocabulary size for input {} of size {}", i, size);
+ res[i] = New<VocabWrapper>(vocab);
+ }
+ }
+ return res;
+}
+
+// query CPU AVX version
+DecoderCpuAvxVersion getCpuAvxVersion() {
+#if USE_FBGEMM
+ // Default value is AVX
+ DecoderCpuAvxVersion cpuAvxVer = DecoderCpuAvxVersion::AVX;
+ if (fbgemm::fbgemmHasAvx512Support())
+ cpuAvxVer = DecoderCpuAvxVersion::AVX512;
+ else if (fbgemm::fbgemmHasAvx2Support())
+ cpuAvxVer = DecoderCpuAvxVersion::AVX2;
+
+ return cpuAvxVer;
+#else
+ // Default value is AVX
+ return DecoderCpuAvxVersion::AVX;
+#endif
+}
+
+DecoderCpuAvxVersion parseCpuAvxVersion(std::string name) {
+ if (name == "avx") {
+ return DecoderCpuAvxVersion::AVX;
+ } else if (name == "avx2") {
+ return DecoderCpuAvxVersion::AVX2;
+ } else if (name == "avx512") {
+ return DecoderCpuAvxVersion::AVX512;
+ } else {
+ ABORT("Unknown CPU Instruction Set: {}", name);
+ return DecoderCpuAvxVersion::AVX;
+ }
+}
+
+// @TODO: clean-up this code and unify with marian-conv. The targetPrec parameter is not clear enought etc.
+bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec) {
+ std::cout << "Converting from: " << inputFile << ", to: " << outputFile << std::endl;
+
+ YAML::Node config;
+ std::stringstream configStr;
+ marian::io::getYamlFromModel(config, "special:model.yml", inputFile);
+ configStr << config;
+
+ auto graph = New<ExpressionGraphPackable>();
+ graph->setDevice(CPU0);
+ graph->getBackend()->setOptimized(false);
+
+ graph->load(inputFile);
+ graph->forward();
+ auto saveGemmType = Type::float32;
+ if (targetPrec == 16)
+ saveGemmType = Type::packed16;
+ else if (targetPrec == 8)
+ saveGemmType = Type::packed8avx2; // We currently use avx2 by default.
+
+ // added a flag if the weights needs to be packed or not
+ graph->packAndSave(outputFile, configStr.str(), saveGemmType);
+
+ std::cout << "Conversion Finished." << std::endl;
+
+ return true;
}
} // namespace quicksand
diff --git a/src/microsoft/quicksand.h b/src/microsoft/quicksand.h
index 93308fe9..87de1948 100755
--- a/src/microsoft/quicksand.h
+++ b/src/microsoft/quicksand.h
@@ -16,35 +16,49 @@ class Options;
namespace quicksand {
typedef uint32_t IndexType;
-typedef IndexType Word;
-typedef std::vector<Word> Words;
-typedef std::vector<Words> QSBatch;
-typedef std::vector<std::set<std::pair<size_t, float>>> AlignmentSets; // [tgtPos] -> set of (srcPos, score)
+typedef IndexType WordIndex;
+typedef std::vector<WordIndex> WordIndices;
+typedef std::vector<WordIndices> QSBatch;
+typedef std::vector<std::set<std::pair<size_t, float>>> AlignmentSets; // [trgPos] -> set of (srcPos, P(srcPos|trgPos))
-typedef std::tuple<Words, AlignmentSets, float> QSSentenceWithProb;
+typedef std::tuple<WordIndices, AlignmentSets, float> QSSentenceWithProb;
typedef std::vector<QSSentenceWithProb> QSNBest;
typedef std::vector<QSNBest> QSNBestBatch;
+enum class DecoderCpuAvxVersion {
+ AVX,
+ AVX2,
+ AVX512
+};
+
Ptr<Options> newOptions();
template <class T>
void set(Ptr<Options> options, const std::string& key, const T& value);
+class IVocabWrapper {
+public:
+ virtual WordIndex encode(const std::string& word) const = 0;
+ virtual std::string decode(WordIndex id) const = 0;
+ virtual size_t size() const = 0;
+ virtual void transcodeToShortlistInPlace(WordIndex* ptr, size_t num) const = 0;
+};
+
class IBeamSearchDecoder {
protected:
Ptr<Options> options_;
std::vector<const void*> ptrs_;
- Word eos_;
public:
IBeamSearchDecoder(Ptr<Options> options,
- const std::vector<const void*>& ptrs,
- Word eos)
- : options_(options), ptrs_(ptrs), eos_(eos) {}
+ const std::vector<const void*>& ptrs)
+ : options_(options), ptrs_(ptrs) {}
+
+ virtual ~IBeamSearchDecoder() {}
virtual QSNBestBatch decode(const QSBatch& qsBatch,
size_t maxLength,
- const std::unordered_set<Word>& shortlist)
+ const std::unordered_set<WordIndex>& shortlist)
= 0;
virtual void setWorkspace(uint8_t* data, size_t size) = 0;
@@ -52,7 +66,17 @@ public:
Ptr<IBeamSearchDecoder> newDecoder(Ptr<Options> options,
const std::vector<const void*>& ptrs,
- Word eos);
+ const std::vector<Ptr<IVocabWrapper>>& vocabs,
+ WordIndex eos/*dummy --@TODO: remove*/);
+
+// load src and tgt vocabs
+std::vector<Ptr<IVocabWrapper>> loadVocabs(const std::vector<std::string>& vocabPaths);
+
+// query CPU AVX version
+DecoderCpuAvxVersion getCpuAvxVersion();
+DecoderCpuAvxVersion parseCpuAvxVersion(std::string name);
+
+bool convertModel(std::string inputFile, std::string outputFile, int32_t targetPrec);
} // namespace quicksand
} // namespace marian
diff --git a/src/models/amun.h b/src/models/amun.h
index 35b65206..811d130e 100755
--- a/src/models/amun.h
+++ b/src/models/amun.h
@@ -8,30 +8,30 @@ namespace marian {
class Amun : public EncoderDecoder {
public:
- Amun(Ptr<Options> options) : EncoderDecoder(options) {
+ Amun(Ptr<ExpressionGraph> graph, Ptr<Options> options) : EncoderDecoder(graph, options) {
ABORT_IF(opt<int>("enc-depth") > 1,
- "--type amun does not currently support multiple encoder "
+ "--type amun does not support multiple encoder "
"layers, use --type s2s");
ABORT_IF(opt<int>("enc-cell-depth") > 1,
- "--type amun does not currently support stacked encoder "
+ "--type amun does not support stacked encoder "
"cells, use --type s2s");
ABORT_IF(opt<bool>("skip"),
- "--type amun does not currently support skip connections, "
+ "--type amun does not support skip connections, "
"use --type s2s");
ABORT_IF(opt<int>("dec-depth") > 1,
- "--type amun does not currently support multiple decoder "
+ "--type amun does not support multiple decoder "
"layers, use --type s2s");
ABORT_IF(opt<int>("dec-cell-base-depth") != 2,
- "--type amun does not currently support multiple decoder "
+ "--type amun does not support multiple decoder "
"base cells, use --type s2s");
ABORT_IF(opt<int>("dec-cell-high-depth") > 1,
- "--type amun does not currently support multiple decoder "
+ "--type amun does not support multiple decoder "
"high cells, use --type s2s");
ABORT_IF(opt<std::string>("enc-cell") != "gru",
- "--type amun does not currently support other rnn cells than gru, "
+ "--type amun does not support other rnn cells than gru, "
"use --type s2s");
ABORT_IF(opt<std::string>("dec-cell") != "gru",
- "--type amun does not currently support other rnn cells than gru, "
+ "--type amun does not support other rnn cells than gru, "
"use --type s2s");
}
@@ -92,15 +92,20 @@ public:
LOG(info, "Loading model from {}", name);
// load items from .npz file
auto ioItems = io::loadItems(name);
- // map names and remove a dummy matrix 'decoder_c_tt' from items to avoid creating isolated node
+ // map names and remove a dummy matrices
for(auto it = ioItems.begin(); it != ioItems.end();) {
if(it->name == "decoder_c_tt") {
it = ioItems.erase(it);
+ } else if(it->name == "uidx") {
+ it = ioItems.erase(it);
+ } else if(it->name == "history_errs") {
+ it = ioItems.erase(it);
+ } else {
+ auto pair = nameMap.find(it->name);
+ if(pair != nameMap.end())
+ it->name = pair->second;
+ it++;
}
- auto pair = nameMap.find(it->name);
- if(pair != nameMap.end())
- it->name = pair->second;
- ++it;
}
// load items into the graph
graph->load(ioItems);
@@ -185,16 +190,24 @@ private:
void createAmunConfig(const std::string& name) {
Config::YamlNode amun;
auto vocabs = options_->get<std::vector<std::string>>("vocabs");
- amun["source-vocab"] = vocabs[0];
- amun["target-vocab"] = vocabs[1];
- amun["devices"] = options_->get<std::vector<size_t>>("devices");
- amun["normalize"] = opt<float>("normalize") > 0;
- amun["beam-size"] = opt<size_t>("beam-size");
- amun["relative-paths"] = false;
- amun["scorers"]["F0"]["path"] = name;
+ if(options_->get<bool>("relative-paths")) {
+ amun["relative-paths"] = true;
+ auto dirPath = filesystem::Path{name}.parentPath();
+ amun["source-vocab"] = filesystem::relative(filesystem::Path{vocabs[0]}, dirPath).string();
+ amun["target-vocab"] = filesystem::relative(filesystem::Path{vocabs[1]}, dirPath).string();
+ amun["scorers"]["F0"]["path"] = filesystem::Path{name}.filename().string();
+ } else {
+ amun["relative-paths"] = false;
+ amun["source-vocab"] = vocabs[0];
+ amun["target-vocab"] = vocabs[1];
+ amun["scorers"]["F0"]["path"] = name;
+ }
+
amun["scorers"]["F0"]["type"] = "Nematus";
amun["weights"]["F0"] = 1.0f;
+ amun["normalize"] = opt<float>("normalize") > 0;
+ amun["beam-size"] = opt<size_t>("beam-size");
io::OutputFileStream out(name + ".amun.yml");
out << amun;
diff --git a/src/models/bert.h b/src/models/bert.h
new file mode 100755
index 00000000..51427457
--- /dev/null
+++ b/src/models/bert.h
@@ -0,0 +1,375 @@
+#pragma once
+
+#include "data/corpus_base.h"
+#include "models/encoder_classifier.h"
+#include "models/transformer.h" // @BUGBUG: transformer.h is large and was meant to be compiled separately
+#include "data/rng_engine.h"
+
+namespace marian {
+
+/**
+ * This file contains nearly all BERT-related code and adds BERT-funtionality
+ * on top of existing classes like TansformerEncoder and Classifier.
+ */
+
+namespace data {
+
+/**
+ * BERT-specific mini-batch that computes masking for Masked LM training.
+ * Expects symbols [MASK], [SEP], [CLS] to be present in vocabularies unless
+ * other symbols are specified in the config.
+ *
+ * This takes a normal CorpusBatch and extends it with additional data. Luckily
+ * all the BERT-functionality can be inferred from a CorpusBatch alone.
+ */
+class BertBatch : public CorpusBatch {
+private:
+ std::vector<IndexType> maskedPositions_;
+ Words maskedWords_;
+ std::vector<IndexType> sentenceIndices_;
+
+ std::string maskSymbol_;
+ std::string sepSymbol_;
+ std::string clsSymbol_;
+
+ // Selects a random word from the vocabulary
+ std::unique_ptr<std::uniform_int_distribution<WordIndex>> randomWord_;
+
+ // Selects a random integer between 0 and 99
+ std::unique_ptr<std::uniform_real_distribution<float>> randomPercent_;
+
+ // Word ids of words that should not be masked, e.g. separators, padding
+ std::unordered_set<Word> dontMask_;
+
+ // Masking function, i.e. replaces a chosen word with either
+ // a [MASK] symbol, itself or a random word
+ Word maskOut(Word word, Word mask, std::mt19937& engine) {
+ auto subBatch = subBatches_.front();
+
+ // @TODO: turn those threshold into parameters, adjustable from command line
+ float r = (*randomPercent_)(engine);
+ if (r < 0.1f) { // for 10% of cases return same word
+ return word;
+ } else if (r < 0.2f) { // for 10% return random word
+ Word randWord = Word::fromWordIndex((*randomWord_)(engine));
+ if(dontMask_.count(randWord) > 0) // some words, e.g. [CLS] or </s>, may not be used as random words
+ return mask; // for those, return the mask symbol instead
+ else
+ return randWord; // else return the random word
+ } else { // for 80% of words apply mask symbol
+ return mask;
+ }
+ }
+
+public:
+
+ // Takes a corpus batch, random engine (for deterministic behavior) and the masking percentage.
+ // Also sets special vocabulary items given on command line.
+ BertBatch(Ptr<CorpusBatch> batch,
+ std::mt19937& engine,
+ float maskFraction,
+ const std::string& maskSymbol,
+ const std::string& sepSymbol,
+ const std::string& clsSymbol,
+ int dimTypeVocab)
+ : CorpusBatch(*batch),
+ maskSymbol_(maskSymbol), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
+
+ // BERT expects a textual first stream and a second stream with class labels
+ auto subBatch = subBatches_.front();
+ const auto& vocab = *subBatch->vocab();
+
+ // Initialize to sample random vocab id
+ randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));
+
+ // Intialize to sample random percentage
+ randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));
+
+ auto& words = subBatch->data();
+
+ // Get word id of special symbols
+ Word maskId = vocab[maskSymbol_];
+ Word clsId = vocab[clsSymbol_];
+ Word sepId = vocab[sepSymbol_];
+
+ ABORT_IF(maskId == vocab.getUnkId(),
+ "BERT masking symbol {} not found in vocabulary", maskSymbol_);
+
+ ABORT_IF(sepId == vocab.getUnkId(),
+ "BERT separator symbol {} not found in vocabulary", sepSymbol_);
+
+ ABORT_IF(clsId == vocab.getUnkId(),
+ "BERT class symbol {} not found in vocabulary", clsSymbol_);
+
+ dontMask_.insert(clsId); // don't mask class token
+ dontMask_.insert(sepId); // don't mask separator token
+ dontMask_.insert(vocab.getEosId()); // don't mask </s>
+ // it's ok to mask <unk>
+
+ std::vector<int> selected;
+ selected.reserve(words.size());
+ for(int i = 0; i < words.size(); ++i) // collect words among which we will mask
+ if(dontMask_.count(words[i]) == 0) // do not add indices of special words
+ selected.push_back(i);
+ std::shuffle(selected.begin(), selected.end(), engine); // randomize positions
+ selected.resize((size_t)std::ceil(selected.size() * maskFraction)); // select first x percent from shuffled indices
+
+ for(int i : selected) {
+ maskedPositions_.push_back(i); // where is the original word?
+ maskedWords_.push_back(words[i]); // what is the original word?
+ words[i] = maskOut(words[i], maskId, engine); // mask that position
+ }
+
+ annotateSentenceIndices(dimTypeVocab);
+ }
+
+ BertBatch(Ptr<CorpusBatch> batch,
+ const std::string& sepSymbol,
+ const std::string& clsSymbol,
+ int dimTypeVocab)
+ : CorpusBatch(*batch),
+ maskSymbol_("dummy"), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
+ annotateSentenceIndices(dimTypeVocab);
+ }
+
+ void annotateSentenceIndices(int dimTypeVocab) {
+ // BERT expects a textual first stream and a second stream with class labels
+ auto subBatch = subBatches_.front();
+ const auto& vocab = *subBatch->vocab();
+ auto& words = subBatch->data();
+
+ // Get word id of special symbols
+ Word sepId = vocab[sepSymbol_];
+ ABORT_IF(sepId == vocab.getUnkId(),
+ "BERT separator symbol {} not found in vocabulary", sepSymbol_);
+
+ int dimBatch = (int)subBatch->batchSize();
+ int dimWords = (int)subBatch->batchWidth();
+
+ const size_t maxSentPos = dimTypeVocab;
+
+ // create indices for BERT sentence embeddings A and B
+ sentenceIndices_.resize(words.size()); // each word is either in sentence A or B
+ std::vector<IndexType> sentPos(dimBatch, 0); // initialize each batch entry with being A [0]
+ for(int i = 0; i < dimWords; ++i) { // advance word-wise
+ for(int j = 0; j < dimBatch; ++j) { // scan batch-wise
+ int k = i * dimBatch + j;
+ sentenceIndices_[k] = sentPos[j]; // set to current sentence position for batch entry, max position 1.
+ if(words[k] == sepId && sentPos[j] < maxSentPos) { // if current word is a separator and not beyond range
+ sentPos[j]++; // then increase sentence position for batch entry (to B [1])
+ }
+ }
+ }
+ }
+
+ const std::vector<IndexType>& bertMaskedPositions() { return maskedPositions_; }
+ const Words& bertMaskedWords() { return maskedWords_; }
+ const std::vector<IndexType>& bertSentenceIndices() { return sentenceIndices_; }
+};
+
+}
+
+/**
+ * BERT-specific version of EncoderClassifier, mostly here to automatically convert a
+ * CorpusBatch to BertBatch.
+ */
+class BertEncoderClassifier : public EncoderClassifier, public data::RNGEngine { // @TODO: this random engine is not being serialized right now
+public:
+ BertEncoderClassifier(Ptr<Options> options)
+ : EncoderClassifier(options) {}
+
+ std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
+ std::string modelType = opt<std::string>("type");
+ int dimTypeVocab = opt<int>("bert-type-vocab-size");
+
+ // intercept batch and annotate with BERT-specific concepts
+ Ptr<data::BertBatch> bertBatch;
+ if(modelType == "bert") { // full BERT pre-training
+ bertBatch = New<data::BertBatch>(batch,
+ eng_,
+ opt<float>("bert-masking-fraction", 0.15f), // 15% by default according to paper
+ opt<std::string>("bert-mask-symbol"),
+ opt<std::string>("bert-sep-symbol"),
+ opt<std::string>("bert-class-symbol"),
+ dimTypeVocab);
+ } else if(modelType == "bert-classifier") { // we are probably fine-tuning a BERT model for a classification task
+ bertBatch = New<data::BertBatch>(batch,
+ opt<std::string>("bert-sep-symbol"),
+ opt<std::string>("bert-class-symbol"),
+ dimTypeVocab); // only annotate sentence separators
+ } else {
+ ABORT("Unknown BERT-style model: {}", modelType);
+ }
+
+ return EncoderClassifier::apply(graph, bertBatch, clearGraph);
+ }
+
+ // for externally created BertBatch for instance in BertValidator
+ std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::BertBatch> bertBatch, bool clearGraph) {
+ return EncoderClassifier::apply(graph, bertBatch, clearGraph);
+ }
+};
+
+/**
+ * BERT-specific modifications to EncoderTransformer
+ * Actually all that is needed is to intercept the creation of special embeddings,
+ * here sentence embeddings for sentence A and B.
+ * @BUGBUG: transformer.h was meant to be compiled separately. I.e., one cannot derive from it.
+ * Is there a way to maybe instead include a reference in here, instead of deriving from it?
+ */
+class BertEncoder : public EncoderTransformer {
+ using EncoderTransformer::EncoderTransformer;
+public:
+ Expr addSentenceEmbeddings(Expr embeddings,
+ Ptr<data::CorpusBatch> batch,
+ bool learnedPosEmbeddings) const {
+ Ptr<data::BertBatch> bertBatch = std::dynamic_pointer_cast<data::BertBatch>(batch);
+ ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training or fine-tuning");
+
+ int dimEmb = embeddings->shape()[-1];
+ int dimBatch = embeddings->shape()[-2];
+ int dimWords = embeddings->shape()[-3];
+
+ int dimTypeVocab = opt<int>("bert-type-vocab-size", 2);
+
+ Expr signal;
+ if(learnedPosEmbeddings) {
+ auto sentenceEmbeddings = embedding()
+ ("prefix", "Wtype")
+ ("dimVocab", dimTypeVocab) // sentence A or sentence B
+ ("dimEmb", dimEmb)
+ .construct(graph_);
+ signal = sentenceEmbeddings->applyIndices(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb});
+ } else {
+ // @TODO: factory for positional embeddings?
+ // constant sinusoidal position embeddings, no backprob
+ auto sentenceEmbeddingsExpr = graph_->constant({2, dimEmb}, inits::sinusoidalPositionEmbeddings(0));
+ signal = rows(sentenceEmbeddingsExpr, bertBatch->bertSentenceIndices());
+ signal = reshape(signal, {dimWords, dimBatch, dimEmb});
+ }
+
+ return embeddings + signal;
+ }
+
+ virtual Expr addSpecialEmbeddings(Expr input, int start = 0, Ptr<data::CorpusBatch> batch = nullptr) const override {
+ bool trainPosEmbeddings = opt<bool>("transformer-train-position-embeddings", true);
+ bool trainTypeEmbeddings = opt<bool>("bert-train-type-embeddings", true);
+ input = addPositionalEmbeddings(input, start, trainPosEmbeddings);
+ input = addSentenceEmbeddings(input, batch, trainTypeEmbeddings);
+ return input;
+ }
+};
+
+/**
+ * BERT-specific classifier
+ * Can be used for next sentence prediction task or other fine-tuned down-stream tasks
+ * Does not actually need a BertBatch, works with CorpusBatch.
+ *
+ * @TODO: This is in fact so generic that we might move it out of here as the typical classifier implementation
+ */
+class BertClassifier : public ClassifierBase {
+ using ClassifierBase::ClassifierBase;
+public:
+ Ptr<ClassifierState> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
+ ABORT_IF(encoderStates.size() != 1, "Currently we only support a single encoder BERT model");
+
+ auto context = encoderStates[0]->getContext();
+ auto classEmbeddings = slice(context, /*axis=*/-3, /*i=*/0); // [CLS] symbol is first symbol in each sequence
+
+ int dimModel = classEmbeddings->shape()[-1];
+ int dimTrgCls = opt<std::vector<int>>("dim-vocabs")[batchIndex_]; // Target vocab is used as class labels
+
+ auto output = mlp::mlp() //
+ .push_back(mlp::dense() //
+ ("prefix", prefix_ + "_ff_logit_l1") //
+ ("dim", dimModel) //
+ ("activation", (int)mlp::act::tanh)) // @TODO: do we actually need this?
+ .push_back(mlp::output() //
+ ("dim", dimTrgCls)) //
+ ("prefix", prefix_ + "_ff_logit_l2") //
+ .construct(graph);
+
+ auto logits = output->apply(classEmbeddings); // class logits for each batch entry
+
+ auto state = New<ClassifierState>();
+ state->setLogProbs(logits);
+
+ // Filled externally, for BERT these are NextSentence prediction labels
+ const auto& classLabels = (*batch)[batchIndex_]->data();
+ state->setTargetWords(classLabels);
+
+ return state;
+ }
+
+ virtual void clear() override {}
+};
+
+/**
+ * This is a model that pretrains BERT for classification.
+ * This is also a Classifier, but compared to the BertClassifier above needs the BERT-specific information from BertBatch
+ * as this is self-generating its labels from the source. Labels are dynamically created as complements of the masking process.
+ */
+class BertMaskedLM : public ClassifierBase {
+ using ClassifierBase::ClassifierBase;
+public:
+ Ptr<ClassifierState> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
+ Ptr<data::BertBatch> bertBatch = std::dynamic_pointer_cast<data::BertBatch>(batch);
+
+ ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training");
+ ABORT_IF(encoderStates.size() != 1, "Currently we only support a single encoder BERT model");
+
+ auto context = encoderStates[0]->getContext();
+
+ auto bertMaskedPositions = graph->indices(bertBatch->bertMaskedPositions()); // positions in batch of masked entries
+ const auto& bertMaskedWords = bertBatch->bertMaskedWords(); // vocab ids of entries that have been masked
+
+ int dimModel = context->shape()[-1];
+ int dimBatch = context->shape()[-2];
+ int dimTime = context->shape()[-3];
+
+ auto maskedContext = rows(reshape(context, {dimBatch * dimTime, dimModel}), bertMaskedPositions); // subselect stuff that has actually been masked out
+
+ int dimVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
+
+ auto layer1 = mlp::mlp()
+ .push_back(mlp::dense()
+ ("prefix", prefix_ + "_ff_logit_l1")
+ ("dim", dimModel))
+ .construct(graph);
+
+ auto intermediate = layer1->apply(maskedContext);
+
+ std::string activationType = opt<std::string>("transformer-ffn-activation");
+ if(activationType == "relu")
+ intermediate = relu(intermediate);
+ else if(activationType == "swish")
+ intermediate = swish(intermediate);
+ else if(activationType == "gelu")
+ intermediate = gelu(intermediate);
+ else
+ ABORT("Activation function {} not supported in BERT masked LM", activationType);
+
+ auto gamma = graph->param(prefix_ + "_ff_ln_scale", {1, dimModel}, inits::ones());
+ auto beta = graph->param(prefix_ + "_ff_ln_bias", {1, dimModel}, inits::zeros());
+ intermediate = layerNorm(intermediate, gamma, beta);
+
+ auto layer2 = mlp::mlp()
+ .push_back(mlp::output(
+ "prefix", prefix_ + "_ff_logit_l2",
+ "dim", dimVoc)
+ .tieTransposed("Wemb"))
+ .construct(graph);
+
+ auto logits = layer2->apply(intermediate); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab dim]
+
+ auto state = New<ClassifierState>();
+ state->setLogProbs(logits);
+ state->setTargetWords(bertMaskedWords);
+
+ return state;
+ }
+
+ virtual void clear() override {}
+};
+
+}
diff --git a/src/models/char_s2s.h b/src/models/char_s2s.h
index 6d5d1db1..3b9bb2fa 100644..100755
--- a/src/models/char_s2s.h
+++ b/src/models/char_s2s.h
@@ -8,24 +8,15 @@
namespace marian {
class CharS2SEncoder : public EncoderS2S {
+ using EncoderS2S::EncoderS2S;
+
public:
- CharS2SEncoder(Ptr<Options> options) : EncoderS2S(options) {}
-
virtual Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch) override {
- auto embeddings = buildSourceEmbeddings(graph);
-
+ graph_ = graph;
// select embeddings that occur in the batch
- Expr batchEmbeddings, batchMask;
- std::tie(batchEmbeddings, batchMask)
- = EncoderBase::lookup(graph, embeddings, batch);
-
- // apply dropout over source words
- float dropProb = inference_ ? 0 : opt<float>("dropout-src");
- if(dropProb) {
- int srcWords = batchEmbeddings->shape()[-3];
- batchEmbeddings = dropout(batchEmbeddings, dropProb, {srcWords, 1, 1});
- }
+ Expr batchEmbeddings, batchMask; std::tie
+ (batchEmbeddings, batchMask) = getEmbeddingLayer()->apply(batch->front());
int dimEmb = opt<int>("dim-emb");
auto convSizes = options_->get<std::vector<int>>("char-conv-filters-num");
@@ -67,7 +58,7 @@ protected:
}
size_t dimWords = strided.size() / dimBatch;
auto stridedMask
- = graph->constant({(int)dimWords, (int)dimBatch, 1}, inits::from_vector(strided));
+ = graph->constant({(int)dimWords, (int)dimBatch, 1}, inits::fromVector(strided));
return stridedMask;
}
};
diff --git a/src/models/classifier.h b/src/models/classifier.h
new file mode 100755
index 00000000..9faa907e
--- /dev/null
+++ b/src/models/classifier.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include "marian.h"
+#include "models/states.h"
+#include "layers/constructors.h"
+#include "layers/factory.h"
+
+namespace marian {
+
+/**
+ * Simple base class for Classifiers to be used in EncoderClassifier framework
+ * Currently only implementations are in bert.h
+ */
+class ClassifierBase :public LayerBase {
+ using LayerBase::LayerBase;
+protected:
+ Ptr<Options> options_;
+ const std::string prefix_{"classifier"};
+ const bool inference_{false};
+ const size_t batchIndex_{0};
+
+public:
+ ClassifierBase(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+ : LayerBase(graph, options),
+ prefix_(options->get<std::string>("prefix", "classifier")),
+ inference_(options->get<bool>("inference", false)),
+ batchIndex_(options->get<size_t>("index", 1)) {} // assume that training input has batch index 0 and labels has 1
+
+ virtual ~ClassifierBase() {}
+
+ virtual Ptr<ClassifierState> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, const std::vector<Ptr<EncoderState>>&) = 0;
+
+ template <typename T>
+ T opt(const std::string& key) const {
+ return options_->get<T>(key);
+ }
+
+ // Should be used to clear any batch-wise temporary objects if present
+ virtual void clear() = 0;
+};
+
+} \ No newline at end of file
diff --git a/src/models/costs.h b/src/models/costs.h
index d38a445e..2a66e549 100755
--- a/src/models/costs.h
+++ b/src/models/costs.h
@@ -5,42 +5,56 @@
#include "layers/loss.h"
#include "layers/weight.h"
#include "models/encoder_decoder.h"
+#include "models/encoder_classifier.h"
namespace marian {
namespace models {
-class CostBase {
+// @TODO: this whole file is an unholy mess and needs to be refactored.
+// Using MultiRationalLoss is a first improvement, but we can probably
+// unify classifier and decoder costs. Also rethink step-wise cost.
+
+// @TODO: inheritance and polymorphism is used here in a rather unclear way.
+// E.g. returns Ptr<MultiRationalLoss> which should be Ptr<RationalLoss>?
+// Other functions return RationalLoss directly without Ptr<...>, but also
+// they do not need polymorphism here.
+
+class ICost {
public:
- virtual Expr apply(Ptr<ModelBase> model,
- Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true)
- = 0;
+ virtual Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
+ Ptr<ExpressionGraph> graph, // @TODO: why needed? Can it be gotten from model?
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) = 0;
+ virtual ~ICost() {}
};
-class EncoderDecoderCE : public CostBase {
+class EncoderDecoderCECost : public ICost {
protected:
Ptr<Options> options_;
- bool inference_{false};
- bool toBeWeighted_{false};
- Ptr<LossBase> loss_;
+ const bool inference_{false};
+ /*const*/ bool toBeWeighted_{false};
+
+ // @TODO: single loss seems wrong
+ Ptr<LabelwiseLoss> loss_;
Ptr<WeightingBase> weighter_;
public:
- EncoderDecoderCE(Ptr<Options> options)
+ EncoderDecoderCECost(Ptr<Options> options)
: options_(options), inference_(options->get<bool>("inference", false)) {
- loss_ = LossFactory(options_, inference_);
+ loss_ = newLoss(options_, inference_);
toBeWeighted_
- = (options_->has("data-weighting") && !inference_)
- || (options_->has("dynamic-weighting")
- && options_->get<bool>("dynamic-weighting") && !inference_);
+ = (options_->hasAndNotEmpty("data-weighting") && !inference_)
+ || (options_->has("dynamic-weighting") && options_->get<bool>("dynamic-weighting")
+ && !inference_);
if(toBeWeighted_)
weighter_ = WeightingFactory(options_);
}
- Expr apply(Ptr<ModelBase> model,
+ virtual ~EncoderDecoderCECost() {}
+
+ Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph = true) override {
@@ -53,35 +67,81 @@ public:
if(toBeWeighted_)
weights = weighter_->getWeights(graph, corpusBatch);
- Expr cost;
- cost = loss_->getCost(state->getLogProbs(),
- state->getTargetIndices(),
- state->getTargetMask(),
- weights);
+ // multi-objective training
+ Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_);
+
+ // @TODO: adapt to multi-objective training with multiple decoders
+ auto partialLoss = loss_->apply(state->getLogProbs(),
+ state->getTargetWords(),
+ state->getTargetMask(),
+ weights);
+ multiLoss->push_back(partialLoss);
if(options_->get("guided-alignment", std::string("none")) != "none" && !inference_) {
- auto alignments = encdec->getDecoders()[0]->getAlignments();
- ABORT_IF(alignments.empty(), "Model does not seem to support alignments");
+ auto attentionVectors = encdec->getDecoders()[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
+ ABORT_IF(attentionVectors.empty(), "Model does not seem to support alignments");
+
+ auto attention = concatenate(attentionVectors, /*axis =*/ -1);
+
+ auto alignmentLoss = guidedAlignmentCost(graph, corpusBatch, options_, attention);
+ multiLoss->push_back(alignmentLoss);
+ }
+
+ return multiLoss;
+ }
+};
+
+// Wraps an EncoderClassifier so it can produce a cost from raw logits. @TODO: Needs refactoring
+class EncoderClassifierCECost : public ICost {
+protected:
+ Ptr<Options> options_;
+ const bool inference_{false};
- auto att = concatenate(alignments, /*axis =*/ -1);
+ // @TODO: single loss seems wrong, especially since we support multiple objectives here,
+ // also not sure this needs to be a member at all.
+ Ptr<LabelwiseLoss> loss_;
- return cost + guidedAlignmentCost(graph, corpusBatch, options_, att);
- } else {
- return cost;
+public:
+ EncoderClassifierCECost(Ptr<Options> options)
+ : options_(options), inference_(options->get<bool>("inference", false)) {
+ loss_ = newLoss(options_, inference_);
+ }
+
+ Ptr<MultiRationalLoss> apply(Ptr<IModel> model,
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
+
+ auto enccls = std::static_pointer_cast<EncoderClassifier>(model);
+ auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
+
+ auto states = enccls->apply(graph, corpusBatch, clearGraph);
+
+ // multi-objective training
+ Ptr<MultiRationalLoss> multiLoss = newMultiLoss(options_);
+ for(int i = 0; i < states.size(); ++i) {
+ auto partialLoss = loss_->apply(Logits(states[i]->getLogProbs()),
+ states[i]->getTargetWords(),
+ /*mask=*/nullptr,
+ /*weights=*/nullptr);
+ multiLoss->push_back(partialLoss);
}
+ return multiLoss;
}
};
-class Trainer : public ModelBase {
+class Trainer : public ICriterionFunction {
protected:
- Ptr<ModelBase> model_;
- Ptr<CostBase> cost_;
+ Ptr<IModel> model_;
+ Ptr<ICost> cost_;
public:
- Trainer(Ptr<ModelBase> model, Ptr<CostBase> cost)
+ Trainer(Ptr<IModel> model, Ptr<ICost> cost)
: model_(model), cost_(cost) {}
- Ptr<ModelBase> getModel() { return model_; }
+ virtual ~Trainer() {}
+
+ Ptr<IModel> getModel() { return model_; }
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
@@ -95,31 +155,74 @@ public:
model_->save(graph, name, saveTranslatorConfig);
}
- virtual Expr build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
+ virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
return cost_->apply(model_, graph, batch, clearGraph);
};
virtual void clear(Ptr<ExpressionGraph> graph) override { model_->clear(graph); };
};
-typedef Trainer Scorer;
+class ILogProb {
+public:
+ virtual Logits apply(Ptr<IModel> model,
+ Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) = 0;
+};
+
+// @TODO: Name 'scorer' is ambiguous: Does it compute scores for all classes, or the loss value for the ground truth?
+// Beam search uses it for the former meaning, while 'marian score' and validation in the latter.
+// This class is for the former use. The latter is done using Trainer.
+class Scorer : public IModel {
+protected:
+ Ptr<IModel> model_;
+ Ptr<ILogProb> logProb_;
+
+public:
+ Scorer(Ptr<IModel> model, Ptr<ILogProb> cost)
+ : model_(model), logProb_(cost) {}
+
+ virtual ~Scorer(){}
+
+ Ptr<IModel> getModel() { return model_; }
+
+ virtual void load(Ptr<ExpressionGraph> graph,
+ const std::string& name,
+ bool markedReloaded = true) override {
+ model_->load(graph, name, markedReloaded);
+ };
+
+ virtual void save(Ptr<ExpressionGraph> graph,
+ const std::string& name,
+ bool saveTranslatorConfig = false) override {
+ model_->save(graph, name, saveTranslatorConfig);
+ }
+
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
+ return logProb_->apply(model_, graph, batch, clearGraph);
+ };
+
+ virtual void clear(Ptr<ExpressionGraph> graph) override { model_->clear(graph); };
+};
-class CostStep {
+class ILogProbStep {
public:
+ // @BUGBUG: This is not a function application. Rather, it updates 'state' in-place.
+ // Suggest to call it updateState, and not return the state object.
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) = 0;
};
-class LogSoftmaxStep : public CostStep {
+class LogSoftmaxStep : public ILogProbStep {
public:
+ virtual ~LogSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
// decoder needs normalized probabilities (note: skipped if beam 1 and --skip-cost)
- auto logits = state->getLogProbs();
-
- auto logprobs = logsoftmax(logits);
-
- state->setLogProbs(logprobs);
+ state->setLogProbs(state->getLogProbs().applyUnaryFunction(logsoftmax));
+ // @TODO: This is becoming more and more opaque ^^. Can we simplify this?
return state;
}
};
@@ -127,28 +230,30 @@ public:
// Gumbel-max noising for sampling during beam-search
// Seems to work well enough with beam-size=1. Turn on
// with --output-sampling during translation with marian-decoder
-class GumbelSoftmaxStep : public CostStep {
+class GumbelSoftmaxStep : public ILogProbStep {
public:
+ virtual ~GumbelSoftmaxStep() {}
virtual Ptr<DecoderState> apply(Ptr<DecoderState> state) override {
- auto logits = state->getLogProbs();
-
- auto logprobs = logsoftmax(logits + constant_like(logits, inits::gumbel));
-
- state->setLogProbs(logprobs);
+ state->setLogProbs(state->getLogProbs().applyUnaryFunctions(
+ [](Expr logits){ // lemma gets gumbelled
+ return logsoftmax(logits + constant_like(logits, inits::gumbel()));
+ },
+ logsoftmax)); // factors don't
return state;
}
};
-// class to wrap an EncoderDecoderBase and a CostStep that are executed in sequence,
-// wrapped again in the EncoderDecoderBase interface
+// class to wrap an IEncoderDecoder and a ILogProbStep that are executed in sequence,
+// wrapped again in the IEncoderDecoder interface
// @TODO: seems we are conflating an interface defition with its implementation?
-class Stepwise : public EncoderDecoderBase {
+// @TODO: needs a better name. Stepwise is an adjective. Classes are things=nouns. StepwiseWhat?
+class Stepwise : public IEncoderDecoder {
protected:
- Ptr<EncoderDecoderBase> encdec_;
- Ptr<CostStep> cost_;
+ Ptr<IEncoderDecoder> encdec_;
+ Ptr<ILogProbStep> cost_;
public:
- Stepwise(Ptr<EncoderDecoderBase> encdec, Ptr<CostStep> cost)
+ Stepwise(Ptr<IEncoderDecoder> encdec, Ptr<ILogProbStep> cost)
: encdec_(encdec), cost_(cost) {}
virtual void load(Ptr<ExpressionGraph> graph,
@@ -171,9 +276,9 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override { encdec_->clear(graph); }
- virtual Expr build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override {
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
return build(graph, corpusBatch, clearGraph);
}
@@ -185,26 +290,24 @@ public:
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
- const std::vector<IndexType>& hypIndices,
- const std::vector<IndexType>& embIndices,
- int dimBatch,
+ const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const Words& words, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) override {
- auto nextState = encdec_->step(
- graph, state, hypIndices, embIndices, dimBatch, beamSize);
+ auto nextState = encdec_->step(graph, state, hypIndices, words, batchIndices, beamSize);
return cost_->apply(nextState);
}
- virtual Expr build(Ptr<ExpressionGraph> /*graph*/,
- Ptr<data::CorpusBatch> /*batch*/,
- bool /*clearGraph*/ = true) override {
+ virtual Logits build(Ptr<ExpressionGraph> /*graph*/,
+ Ptr<data::CorpusBatch> /*batch*/,
+ bool /*clearGraph*/ = true) override {
ABORT("Wrong wrapper. Use models::Trainer or models::Scorer");
- return nullptr;
}
virtual Ptr<Options> getOptions() override { return encdec_->getOptions(); };
virtual void setShortlistGenerator(
- Ptr<data::ShortlistGenerator> shortlistGenerator) override {
+ Ptr<const data::ShortlistGenerator> shortlistGenerator) override {
encdec_->setShortlistGenerator(shortlistGenerator);
};
@@ -215,21 +318,5 @@ public:
virtual data::SoftAlignment getAlignment() override { return encdec_->getAlignment(); }
};
-inline Ptr<ModelBase> add_cost(Ptr<EncoderDecoder> encdec,
- Ptr<Options> options) {
- switch(options->get<usage>("usage", usage::raw)) {
- case usage::training:
- return New<Trainer>(encdec, New<EncoderDecoderCE>(options));
- case usage::scoring:
- return New<Scorer>(encdec, New<EncoderDecoderCE>(options));
- case usage::translation:
- if(options->get<bool>("output-sampling", false))
- return New<Stepwise>(encdec, New<GumbelSoftmaxStep>());
- else
- return New<Stepwise>(encdec, New<LogSoftmaxStep>());
- case usage::raw:
- default: return encdec;
- }
-}
} // namespace models
} // namespace marian
diff --git a/src/models/decoder.h b/src/models/decoder.h
index 39e56e1f..725573c1 100755
--- a/src/models/decoder.h
+++ b/src/models/decoder.h
@@ -4,25 +4,20 @@
#include "states.h"
#include "data/shortlist.h"
+#include "layers/constructors.h"
#include "layers/generic.h"
namespace marian {
-class DecoderBase {
+class DecoderBase : public EncoderDecoderLayerBase {
protected:
- Ptr<Options> options_;
- std::string prefix_{"decoder"};
- bool inference_{false};
- size_t batchIndex_{1};
-
Ptr<data::Shortlist> shortlist_;
public:
- DecoderBase(Ptr<Options> options)
- : options_(options),
- prefix_(options->get<std::string>("prefix", "decoder")),
- inference_(options->get<bool>("inference", false)),
- batchIndex_(options->get<size_t>("index", 1)) {}
+ DecoderBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) :
+ EncoderDecoderLayerBase(graph, options, "decoder", /*batchIndex=*/1,
+ options->get<float>("dropout-trg", 0.0f),
+ options->get<bool>("embedding-fix-trg", false)) {}
virtual Ptr<DecoderState> startState(Ptr<ExpressionGraph>,
Ptr<data::CorpusBatch> batch,
@@ -34,103 +29,56 @@ public:
virtual void embeddingsFromBatch(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
Ptr<data::CorpusBatch> batch) {
-
- int dimVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
- int dimEmb = opt<int>("dim-emb");
-
- auto yEmbFactory = embedding(graph) //
- ("dimVocab", dimVoc) //
- ("dimEmb", dimEmb);
-
- if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
- yEmbFactory("prefix", "Wemb");
- else
- yEmbFactory("prefix", prefix_ + "_Wemb");
-
- if(options_->has("embedding-fix-trg"))
- yEmbFactory("fixed", opt<bool>("embedding-fix-trg"));
-
- if(options_->has("embedding-vectors")) {
- auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
- yEmbFactory("embFile", embFiles[batchIndex_]) //
- ("normalization", opt<bool>("embedding-normalization"));
- }
-
- auto yEmb = yEmbFactory.construct();
+ graph_ = graph;
auto subBatch = (*batch)[batchIndex_];
- int dimBatch = (int)subBatch->batchSize();
- int dimWords = (int)subBatch->batchWidth();
- auto chosenEmbeddings = rows(yEmb, subBatch->data());
+ Expr y, yMask; std::tie
+ (y, yMask) = getEmbeddingLayer()->apply(subBatch);
- auto y
- = reshape(chosenEmbeddings, {dimWords, dimBatch, opt<int>("dim-emb")});
+ // @TODO: during training there is currently no code path that leads to using a shortlist
+#if 0
+ const Words& data =
+ /*if*/ (shortlist_) ?
+ shortlist_->mappedIndices()
+ /*else*/ :
+ subBatch->data();
+#endif
- auto yMask = graph->constant({dimWords, dimBatch, 1},
- inits::from_vector(subBatch->mask()));
-
- Expr yData;
- if(shortlist_) {
- yData = graph->indices(shortlist_->mappedIndices());
- } else {
- yData = graph->indices(subBatch->data());
- }
+ ABORT_IF(shortlist_, "How did a shortlist make it into training?");
auto yShifted = shift(y, {1, 0, 0});
- state->setTargetEmbeddings(yShifted);
+ state->setTargetHistoryEmbeddings(yShifted);
state->setTargetMask(yMask);
- state->setTargetIndices(yData);
+
+ const Words& data = subBatch->data();
+ state->setTargetWords(data);
}
virtual void embeddingsFromPrediction(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
- const std::vector<IndexType>& embIdx,
+ const Words& words,
int dimBatch,
int dimBeam) {
- int dimTrgEmb = opt<int>("dim-emb");
- int dimTrgVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
-
- // embeddings are loaded from model during translation, no fixing required
- auto yEmbFactory = embedding(graph) //
- ("dimVocab", dimTrgVoc) //
- ("dimEmb", dimTrgEmb);
-
- if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
- yEmbFactory("prefix", "Wemb");
- else
- yEmbFactory("prefix", prefix_ + "_Wemb");
-
- auto yEmb = yEmbFactory.construct();
-
+ graph_ = graph;
+ auto embeddingLayer = getEmbeddingLayer();
Expr selectedEmbs;
- if(embIdx.empty()) {
- selectedEmbs = graph->constant({1, 1, dimBatch, dimTrgEmb}, inits::zeros);
- } else {
- selectedEmbs = rows(yEmb, embIdx);
- selectedEmbs = reshape(selectedEmbs, {dimBeam, 1, dimBatch, dimTrgEmb});
- }
- state->setTargetEmbeddings(selectedEmbs);
+ int dimEmb = opt<int>("dim-emb");
+ if(words.empty())
+ selectedEmbs = graph_->constant({1, 1, dimBatch, dimEmb}, inits::zeros());
+ else
+ selectedEmbs = embeddingLayer->apply(words, {dimBeam, 1, dimBatch, dimEmb});
+ state->setTargetHistoryEmbeddings(selectedEmbs);
}
- virtual const std::vector<Expr> getAlignments(int /*i*/ = 0) { return {}; };
+ virtual const std::vector<Expr> getAlignments(int /*i*/ = 0) { return {}; }; // [tgt index][beam depth, max src length, batch size, 1]
virtual Ptr<data::Shortlist> getShortlist() { return shortlist_; }
virtual void setShortlist(Ptr<data::Shortlist> shortlist) {
shortlist_ = shortlist;
}
- template <typename T>
- T opt(const std::string& key) const {
- return options_->get<T>(key);
- }
-
- template <typename T>
- T opt(const std::string& key, const T& def) {
- return options_->get<T>(key, def);
- }
-
virtual void clear() = 0;
};
diff --git a/src/models/encoder.h b/src/models/encoder.h
index 6d1ee852..61bb0d88 100755
--- a/src/models/encoder.h
+++ b/src/models/encoder.h
@@ -5,88 +5,15 @@
namespace marian {
-class EncoderBase {
-protected:
- Ptr<Options> options_;
- std::string prefix_{"encoder"};
- bool inference_{false};
- size_t batchIndex_{0};
-
- // @TODO: This used to be virtual, but is never overridden.
- // virtual
- std::tuple<Expr, Expr> lookup(Ptr<ExpressionGraph> graph,
- Expr srcEmbeddings,
- Ptr<data::CorpusBatch> batch) const {
- auto subBatch = (*batch)[batchIndex_];
- int dimBatch = (int)subBatch->batchSize();
- int dimEmb = srcEmbeddings->shape()[-1];
- int dimWords = (int)subBatch->batchWidth();
- auto chosenEmbeddings = rows(srcEmbeddings, subBatch->data());
- auto batchEmbeddings = reshape(chosenEmbeddings, { dimWords, dimBatch, dimEmb });
- auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
- inits::from_vector(subBatch->mask()));
-
- return std::make_tuple(batchEmbeddings, batchMask);
- }
-
- std::tuple<Expr, Expr> ulrLookup(Ptr<ExpressionGraph> graph,
- std::vector<Expr> urlEmbeddings,
- Ptr<data::CorpusBatch> batch) const {
- auto subBatch = (*batch)[batchIndex_];
- // is their a better way to do this?
- assert(urlEmbeddings.size() == 6);
- auto queryEmbed = urlEmbeddings[0]; //Q : dimQueries*dimUlrEmb
- auto keyEmbed = urlEmbeddings[1]; // K : dimKeys*dimUlrEmb
- auto uniEmbed = urlEmbeddings[2]; // E : dimQueries*dimEmb
- auto srcEmbed = urlEmbeddings[3]; // I : dimQueries*dimEmb
- auto ulrTransform = urlEmbeddings[4]; //A : dimUlrEmb *dimUlrEmb
- auto ulrSharable = urlEmbeddings[5]; //alpha : dimQueries*1
- int dimBatch = (int)subBatch->batchSize();
- int dimEmb = uniEmbed->shape()[-1];
- int dimWords = (int)subBatch->batchWidth();
- // D = K.A.QT
- // dimm(K) = univ_tok_vocab*uni_embed_size
- // dim A = uni_embed_size*uni_embed_size
- // dim Q: uni_embed_size * total_merged_vocab_size
- // dim D = univ_tok_vocab * total_merged_vocab_size
- // note all above can be precombuted and serialized if A is not trainiabale and during decoding (TBD)
- // here we need to handle the mini-batch
- // extract raws corresponding to Xs in this mini batch from Q
- auto queryEmbeddings = rows(queryEmbed, subBatch->data());
- auto srcEmbeddings = rows(srcEmbed, subBatch->data()); // extract trainable src embeddings
- auto alpha = rows(ulrSharable, subBatch->data()); // extract sharable flags
- auto qt = dot(queryEmbeddings, ulrTransform, false, false); //A: transform embeddings based on similarity A : dimUlrEmb *dimUlrEmb
- auto sqrtDim=std::sqrt((float)queryEmbeddings->shape()[-1]);
- qt = qt/sqrtDim; // normalize accordin to embed size to avoid dot prodcut growing large in magintude with larger embeds sizes
- auto z = dot(qt, keyEmbed, false, true); // query-key similarity
- float dropProb = this->options_->get<float>("ulr-dropout", 0.0f); // default no dropout
- z = dropout(z, dropProb);
- float tau = this->options_->get<float>("ulr-softmax-temperature", 1.0f); // default no temperature
- // temperature in softmax is to control randomness of predictions
- // high temperature Softmax outputs are more close to each other
- // low temperatures the softmax become more similar to "hardmax"
- auto weights = softmax(z / tau); // assume default is dim=-1, what about temprature? - scaler ??
- auto chosenEmbeddings = dot(weights, uniEmbed); // AVERAGE
- auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings; // this should be elementwise broadcast
- auto batchEmbeddings = reshape(chosenEmbeddings_mix, { dimWords, dimBatch, dimEmb });
- auto batchMask = graph->constant({ dimWords, dimBatch, 1 },
- inits::from_vector(subBatch->mask()));
- return std::make_tuple(batchEmbeddings, batchMask);
- }
-public:
- EncoderBase(Ptr<Options> options)
- : options_(options),
- prefix_(options->get<std::string>("prefix", "encoder")),
- inference_(options->get<bool>("inference", false)),
- batchIndex_(options->get<size_t>("index", 0)) {}
-
- virtual Ptr<EncoderState> build(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>)
- = 0;
-
- template <typename T>
- T opt(const std::string& key) const {
- return options_->get<T>(key);
- }
+class EncoderBase : public EncoderDecoderLayerBase {
+public:
+ EncoderBase(Ptr<ExpressionGraph> graph, Ptr<Options> options) :
+ EncoderDecoderLayerBase(graph, options, "encoder", /*batchIndex=*/0,
+ options->get<float>("dropout-src", 0.0f),
+ options->get<bool>("embedding-fix-src", false)) {}
+
+ // @TODO: turn into an interface. Also see if we can get rid of the graph parameter.
+ virtual Ptr<EncoderState> build(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>) = 0;
virtual void clear() = 0;
};
diff --git a/src/models/encoder_classifier.h b/src/models/encoder_classifier.h
new file mode 100644
index 00000000..b551205d
--- /dev/null
+++ b/src/models/encoder_classifier.h
@@ -0,0 +1,227 @@
+#pragma once
+
+#include "marian.h"
+
+#include "models/encoder.h"
+#include "models/classifier.h"
+#include "models/model_base.h"
+#include "models/states.h"
+
+namespace marian {
+
+/**
+ * Combines sequence encoders with generic classifiers
+ * Can be used to train sequence classifiers like language detection, BERT-next-sentence-prediction etc.
+ * Already has support for multi-objective training.
+ *
+ * @TODO: this should probably be unified somehow with EncoderDecoder which could allow for deocder/classifier
+ * multi-objective training.
+ */
+class EncoderClassifierBase : public models::IModel {
+public:
+ virtual ~EncoderClassifierBase() {}
+
+ virtual void load(Ptr<ExpressionGraph> graph,
+ const std::string& name,
+ bool markedReloaded = true) override
+ = 0;
+
+ virtual void mmap(Ptr<ExpressionGraph> graph,
+ const void* ptr,
+ bool markedReloaded = true)
+ = 0;
+
+ virtual void save(Ptr<ExpressionGraph> graph,
+ const std::string& name,
+ bool saveTranslatorConfig = false) override
+ = 0;
+
+ virtual void clear(Ptr<ExpressionGraph> graph) override = 0;
+
+ virtual std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph>, Ptr<data::CorpusBatch>, bool) = 0;
+
+
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override = 0;
+
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::CorpusBatch> batch,
+ bool clearGraph = true) = 0;
+
+ virtual Ptr<Options> getOptions() = 0;
+};
+
+class EncoderClassifier : public EncoderClassifierBase {
+protected:
+ Ptr<Options> options_;
+
+ std::string prefix_;
+
+ std::vector<Ptr<EncoderBase>> encoders_;
+ std::vector<Ptr<ClassifierBase>> classifiers_;
+
+ bool inference_{false};
+
+ std::set<std::string> modelFeatures_;
+
+ Config::YamlNode getModelParameters() {
+ Config::YamlNode modelParams;
+ auto clone = options_->cloneToYamlNode();
+ for(auto& key : modelFeatures_)
+ modelParams[key] = clone[key];
+
+ if(options_->has("original-type"))
+ modelParams["type"] = clone["original-type"];
+
+ modelParams["version"] = buildVersion();
+ return modelParams;
+ }
+
+ std::string getModelParametersAsString() {
+ auto yaml = getModelParameters();
+ YAML::Emitter out;
+ cli::OutputYaml(yaml, out);
+ return std::string(out.c_str());
+ }
+
+public:
+ typedef data::Corpus dataset_type;
+
+ // @TODO: lots of code-duplication with EncoderDecoder
+ EncoderClassifier(Ptr<Options> options)
+ : options_(options),
+ prefix_(options->get<std::string>("prefix", "")),
+ inference_(options->get<bool>("inference", false)) {
+ modelFeatures_ = {"type",
+ "dim-vocabs",
+ "dim-emb",
+ "dim-rnn",
+ "enc-cell",
+ "enc-type",
+ "enc-cell-depth",
+ "enc-depth",
+ "dec-depth",
+ "dec-cell",
+ "dec-cell-base-depth",
+ "dec-cell-high-depth",
+ "skip",
+ "layer-normalization",
+ "right-left",
+ "input-types",
+ "special-vocab",
+ "tied-embeddings",
+ "tied-embeddings-src",
+ "tied-embeddings-all"};
+
+ modelFeatures_.insert("transformer-heads");
+ modelFeatures_.insert("transformer-no-projection");
+ modelFeatures_.insert("transformer-dim-ffn");
+ modelFeatures_.insert("transformer-ffn-depth");
+ modelFeatures_.insert("transformer-ffn-activation");
+ modelFeatures_.insert("transformer-dim-aan");
+ modelFeatures_.insert("transformer-aan-depth");
+ modelFeatures_.insert("transformer-aan-activation");
+ modelFeatures_.insert("transformer-aan-nogate");
+ modelFeatures_.insert("transformer-preprocess");
+ modelFeatures_.insert("transformer-postprocess");
+ modelFeatures_.insert("transformer-postprocess-emb");
+ modelFeatures_.insert("transformer-decoder-autoreg");
+ modelFeatures_.insert("transformer-tied-layers");
+ modelFeatures_.insert("transformer-guided-alignment-layer");
+ modelFeatures_.insert("transformer-train-position-embeddings");
+
+ modelFeatures_.insert("bert-train-type-embeddings");
+ modelFeatures_.insert("bert-type-vocab-size");
+
+ modelFeatures_.insert("ulr");
+ modelFeatures_.insert("ulr-trainable-transformation");
+ modelFeatures_.insert("ulr-dim-emb");
+ modelFeatures_.insert("lemma-dim-emb");
+ }
+
+ virtual Ptr<Options> getOptions() override { return options_; }
+
+ std::vector<Ptr<EncoderBase>>& getEncoders() { return encoders_; }
+ std::vector<Ptr<ClassifierBase>>& getClassifiers() { return classifiers_; }
+
+ void push_back(Ptr<EncoderBase> encoder) { encoders_.push_back(encoder); }
+ void push_back(Ptr<ClassifierBase> classifier) { classifiers_.push_back(classifier); }
+
+ void load(Ptr<ExpressionGraph> graph,
+ const std::string& name,
+ bool markedReloaded) override {
+ graph->load(name, markedReloaded && !opt<bool>("ignore-model-config", false));
+ }
+
+ void mmap(Ptr<ExpressionGraph> graph,
+ const void* ptr,
+ bool markedReloaded) override {
+ graph->mmap(ptr, markedReloaded && !opt<bool>("ignore-model-config", false));
+ }
+
+ void save(Ptr<ExpressionGraph> graph,
+ const std::string& name,
+ bool /*saveModelConfig*/) override {
+ LOG(info, "Saving model weights and runtime parameters to {}", name);
+ graph->save(name , getModelParametersAsString());
+ }
+
+ void clear(Ptr<ExpressionGraph> graph) override {
+ graph->clear();
+
+ for(auto& enc : encoders_)
+ enc->clear();
+ for(auto& cls : classifiers_)
+ cls->clear();
+ }
+
+ template <typename T>
+ T opt(const std::string& key) {
+ return options_->get<T>(key);
+ }
+
+ template <typename T>
+ T opt(const std::string& key, const T& def) {
+ return options_->get<T>(key, def);
+ }
+
+ template <typename T>
+ void set(std::string key, T value) {
+ options_->set(key, value);
+ }
+
+ /*********************************************************************/
+
+ virtual std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
+ if(clearGraph)
+ clear(graph);
+
+ std::vector<Ptr<EncoderState>> encoderStates;
+ for(auto& encoder : encoders_)
+ encoderStates.push_back(encoder->build(graph, batch));
+
+ std::vector<Ptr<ClassifierState>> classifierStates;
+ for(auto& classifier : classifiers_)
+ classifierStates.push_back(classifier->apply(graph, batch, encoderStates));
+
+ return classifierStates;
+ }
+
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::CorpusBatch> batch,
+ bool clearGraph = true) override {
+ auto states = apply(graph, batch, clearGraph);
+ // returns raw logits
+ return Logits(states[0]->getLogProbs());
+ }
+
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override {
+ auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
+ return build(graph, corpusBatch, clearGraph);
+ }
+};
+
+} // namespace marian
diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp
index 3edb7b33..09856a9c 100755
--- a/src/models/encoder_decoder.cpp
+++ b/src/models/encoder_decoder.cpp
@@ -1,32 +1,39 @@
-#include "encoder_decoder.h"
+#include "models/encoder_decoder.h"
#include "common/cli_helper.h"
+#include "common/filesystem.h"
#include "common/version.h"
namespace marian {
-EncoderDecoder::EncoderDecoder(Ptr<Options> options)
- : options_(options),
+EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+ : LayerBase(graph, options),
prefix_(options->get<std::string>("prefix", "")),
inference_(options->get<bool>("inference", false)) {
- modelFeatures_ = {"type",
- "dim-vocabs",
- "dim-emb",
- "dim-rnn",
- "enc-cell",
- "enc-type",
- "enc-cell-depth",
- "enc-depth",
- "dec-depth",
- "dec-cell",
- "dec-cell-base-depth",
- "dec-cell-high-depth",
- "skip",
- "layer-normalization",
- "right-left",
- "special-vocab",
- "tied-embeddings",
- "tied-embeddings-src",
- "tied-embeddings-all"};
+
+ std::vector<std::string> encoderDecoderModelFeatures =
+ {"type",
+ "dim-vocabs",
+ "dim-emb",
+ "dim-rnn",
+ "enc-cell",
+ "enc-type",
+ "enc-cell-depth",
+ "enc-depth",
+ "dec-depth",
+ "dec-cell",
+ "dec-cell-base-depth",
+ "dec-cell-high-depth",
+ "skip",
+ "layer-normalization",
+ "right-left",
+ "input-types",
+ "special-vocab",
+ "tied-embeddings",
+ "tied-embeddings-src",
+ "tied-embeddings-all"};
+
+ for(auto feature : encoderDecoderModelFeatures)
+ modelFeatures_.insert(feature);
modelFeatures_.insert("transformer-heads");
modelFeatures_.insert("transformer-no-projection");
@@ -43,6 +50,15 @@ EncoderDecoder::EncoderDecoder(Ptr<Options> options)
modelFeatures_.insert("transformer-decoder-autoreg");
modelFeatures_.insert("transformer-tied-layers");
modelFeatures_.insert("transformer-guided-alignment-layer");
+ modelFeatures_.insert("transformer-train-position-embeddings");
+
+ modelFeatures_.insert("bert-train-type-embeddings");
+ modelFeatures_.insert("bert-type-vocab-size");
+
+ modelFeatures_.insert("ulr");
+ modelFeatures_.insert("ulr-trainable-transformation");
+ modelFeatures_.insert("ulr-dim-emb");
+ modelFeatures_.insert("lemma-dim-emb");
}
std::vector<Ptr<EncoderBase>>& EncoderDecoder::getEncoders() {
@@ -63,18 +79,40 @@ void EncoderDecoder::push_back(Ptr<DecoderBase> decoder) {
void EncoderDecoder::createDecoderConfig(const std::string& name) {
Config::YamlNode decoder;
- decoder["models"] = std::vector<std::string>({name});
- decoder["vocabs"] = options_->get<std::vector<std::string>>("vocabs");
+
+ if(options_->get<bool>("relative-paths")) {
+ decoder["relative-paths"] = true;
+ // we can safely use a bare model file name here, because the config file is created in the same
+ // directory as the model file
+ auto modelFileName = filesystem::Path{name}.filename().string();
+ decoder["models"] = std::vector<std::string>({modelFileName});
+
+ // create relative paths to vocabs with regard to saved model checkpoint
+ auto dirPath = filesystem::Path{name}.parentPath();
+ std::vector<std::string> relativeVocabs;
+ const auto& vocabs = options_->get<std::vector<std::string>>("vocabs");
+ std::transform(
+ vocabs.begin(),
+ vocabs.end(),
+ std::back_inserter(relativeVocabs),
+ [&](const std::string& p) -> std::string {
+ return filesystem::relative(filesystem::Path{p}, dirPath).string();
+ });
+
+ decoder["vocabs"] = relativeVocabs;
+ } else {
+ decoder["relative-paths"] = false;
+ decoder["models"] = std::vector<std::string>({name});
+ decoder["vocabs"] = options_->get<std::vector<std::string>>("vocabs");
+ }
+
decoder["beam-size"] = opt<size_t>("beam-size");
decoder["normalize"] = opt<float>("normalize");
decoder["word-penalty"] = opt<float>("word-penalty");
decoder["mini-batch"] = opt<size_t>("valid-mini-batch");
decoder["maxi-batch"] = opt<size_t>("valid-mini-batch") > 1 ? 100 : 1;
- decoder["maxi-batch-sort"]
- = opt<size_t>("valid-mini-batch") > 1 ? "src" : "none";
-
- decoder["relative-paths"] = false;
+ decoder["maxi-batch-sort"] = opt<size_t>("valid-mini-batch") > 1 ? "src" : "none";
io::OutputFileStream out(name + ".decoder.yml");
out << decoder;
@@ -82,11 +120,12 @@ void EncoderDecoder::createDecoderConfig(const std::string& name) {
Config::YamlNode EncoderDecoder::getModelParameters() {
Config::YamlNode modelParams;
+ auto clone = options_->cloneToYamlNode();
for(auto& key : modelFeatures_)
- modelParams[key] = options_->getYaml()[key];
+ modelParams[key] = clone[key];
if(options_->has("original-type"))
- modelParams["type"] = options_->getYaml()["original-type"];
+ modelParams["type"] = clone["original-type"];
modelParams["version"] = buildVersion();
return modelParams;
@@ -149,18 +188,17 @@ Ptr<DecoderState> EncoderDecoder::startState(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> EncoderDecoder::step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
- const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
- const std::vector<IndexType>& embIndices, // [beamIndex * activeBatchSize + batchIndex]
- int dimBatch,
+ const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const Words& words, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) {
// create updated state that reflects reordering and dropping of hypotheses
- state = hypIndices.empty() ? state : state->select(hypIndices, beamSize);
+ state = hypIndices.empty() ? state : state->select(hypIndices, batchIndices, beamSize);
- // Fill stte with embeddings based on last prediction
- decoders_[0]->embeddingsFromPrediction(
- graph, state, embIndices, dimBatch, beamSize);
+ // Fill state with embeddings based on last prediction
+ decoders_[0]->embeddingsFromPrediction(graph, state, words, (int) batchIndices.size(), beamSize);
auto nextState = decoders_[0]->step(graph, state);
-
+
return nextState;
}
@@ -177,12 +215,12 @@ Ptr<DecoderState> EncoderDecoder::stepAll(Ptr<ExpressionGraph> graph,
decoders_[0]->embeddingsFromBatch(graph, state, batch);
auto nextState = decoders_[0]->step(graph, state);
nextState->setTargetMask(state->getTargetMask());
- nextState->setTargetIndices(state->getTargetIndices());
+ nextState->setTargetWords(state->getTargetWords());
return nextState;
}
-Expr EncoderDecoder::build(Ptr<ExpressionGraph> graph,
+Logits EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph) {
auto state = stepAll(graph, batch, clearGraph);
@@ -191,7 +229,7 @@ Expr EncoderDecoder::build(Ptr<ExpressionGraph> graph,
return state->getLogProbs();
}
-Expr EncoderDecoder::build(Ptr<ExpressionGraph> graph,
+Logits EncoderDecoder::build(Ptr<ExpressionGraph> graph,
Ptr<data::Batch> batch,
bool clearGraph) {
auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
diff --git a/src/models/encoder_decoder.h b/src/models/encoder_decoder.h
index 5818c237..92c1647f 100644..100755
--- a/src/models/encoder_decoder.h
+++ b/src/models/encoder_decoder.h
@@ -2,15 +2,16 @@
#include "marian.h"
-#include "decoder.h"
-#include "encoder.h"
-#include "model_base.h"
-#include "states.h"
+#include "models/decoder.h"
+#include "models/encoder.h"
+#include "models/model_base.h"
+#include "models/states.h"
namespace marian {
-class EncoderDecoderBase : public models::ModelBase {
+class IEncoderDecoder : public models::IModel {
public:
+ virtual ~IEncoderDecoder() {}
virtual void load(Ptr<ExpressionGraph> graph,
const std::string& name,
bool markedReloaded = true) override
@@ -28,32 +29,29 @@ public:
virtual void clear(Ptr<ExpressionGraph> graph) override = 0;
- virtual Expr build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override
- = 0;
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override = 0;
+
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::CorpusBatch> batch,
+ bool clearGraph = true) = 0;
virtual Ptr<DecoderState> startState(Ptr<ExpressionGraph> graph,
- Ptr<data::CorpusBatch> batch)
- = 0;
+ Ptr<data::CorpusBatch> batch) = 0;
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
- const std::vector<IndexType>& hypIndices,
- const std::vector<IndexType>& embIndices,
- int dimBatch,
+ const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const Words& words, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize)
= 0;
- virtual Expr build(Ptr<ExpressionGraph> graph,
- Ptr<data::CorpusBatch> batch,
- bool clearGraph = true)
- = 0;
-
virtual Ptr<Options> getOptions() = 0;
virtual void setShortlistGenerator(
- Ptr<data::ShortlistGenerator> shortlistGenerator)
+ Ptr<const data::ShortlistGenerator> shortlistGenerator)
= 0;
virtual Ptr<data::Shortlist> getShortlist() = 0;
@@ -61,18 +59,16 @@ public:
virtual data::SoftAlignment getAlignment() = 0;
};
-class EncoderDecoder : public EncoderDecoderBase {
+class EncoderDecoder : public IEncoderDecoder, public LayerBase {
protected:
- Ptr<Options> options_;
- Ptr<data::ShortlistGenerator> shortlistGenerator_;
+ Ptr<const data::ShortlistGenerator> shortlistGenerator_;
- std::string prefix_;
+ const std::string prefix_;
+ const bool inference_{ false };
std::vector<Ptr<EncoderBase>> encoders_;
std::vector<Ptr<DecoderBase>> decoders_;
- bool inference_{false};
-
std::set<std::string> modelFeatures_;
Config::YamlNode getModelParameters();
@@ -83,7 +79,7 @@ protected:
public:
typedef data::Corpus dataset_type;
- EncoderDecoder(Ptr<Options> options);
+ EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options);
virtual Ptr<Options> getOptions() override { return options_; }
@@ -125,7 +121,7 @@ public:
}
virtual void setShortlistGenerator(
- Ptr<data::ShortlistGenerator> shortlistGenerator) override {
+ Ptr<const data::ShortlistGenerator> shortlistGenerator) override {
shortlistGenerator_ = shortlistGenerator;
};
@@ -133,13 +129,15 @@ public:
return decoders_[0]->getShortlist();
};
+ // convert alignment tensors that live GPU-side into a CPU-side vector of vectors
virtual data::SoftAlignment getAlignment() override {
- data::SoftAlignment aligns;
- for(auto aln : decoders_[0]->getAlignments()) {
- aligns.push_back({});
- aln->val()->get(aligns.back());
+ data::SoftAlignment softAlignments;
+ auto alignments = decoders_[0]->getAlignments(); // [tgt index][beam depth, max src length, batch size, 1]
+ for(auto alignment : alignments) { // [beam depth, max src length, batch size, 1]
+ softAlignments.push_back({});
+ alignment->val()->get(softAlignments.back());
}
- return aligns;
+ return softAlignments; // [tgt index][beam depth * max src length * batch size]
};
/*********************************************************************/
@@ -150,21 +148,21 @@ public:
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state,
const std::vector<IndexType>& hypIndices,
- const std::vector<IndexType>& embIndices,
- int dimBatch,
+ const Words& words,
+ const std::vector<IndexType>& batchIndices,
int beamSize) override;
virtual Ptr<DecoderState> stepAll(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
bool clearGraph = true);
- virtual Expr build(Ptr<ExpressionGraph> graph,
- Ptr<data::CorpusBatch> batch,
- bool clearGraph = true) override;
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::CorpusBatch> batch,
+ bool clearGraph = true) override;
- virtual Expr build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true) override;
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true) override;
};
} // namespace marian
diff --git a/src/models/model_base.h b/src/models/model_base.h
index 47039841..5f76b380 100644
--- a/src/models/model_base.h
+++ b/src/models/model_base.h
@@ -2,6 +2,8 @@
#include <string>
#include "marian.h"
+#include "layers/loss.h"
+#include "layers/generic.h"
namespace marian {
namespace models {
@@ -15,7 +17,8 @@ YAML_REGISTER_TYPE(marian::models::usage, int)
namespace marian {
namespace models {
-class ModelBase {
+// model = input -> predictions
+class IModel {
public:
virtual void load(Ptr<ExpressionGraph>,
const std::string&,
@@ -26,9 +29,32 @@ public:
bool saveTranslatorConfig = false)
= 0;
- virtual Expr build(Ptr<ExpressionGraph> graph,
- Ptr<data::Batch> batch,
- bool clearGraph = true)
+ virtual Logits build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true)
+ = 0;
+
+ virtual void clear(Ptr<ExpressionGraph> graph) = 0;
+};
+
+// criterion = (input, reference) -> loss
+// @TODO: Is there a better name?
+class ICriterionFunction {
+public:
+ virtual ~ICriterionFunction() {}
+
+ virtual void load(Ptr<ExpressionGraph>,
+ const std::string&,
+ bool markReloaded = true)
+ = 0;
+ virtual void save(Ptr<ExpressionGraph>,
+ const std::string&,
+ bool saveTranslatorConfig = false)
+ = 0;
+
+ virtual Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph,
+ Ptr<data::Batch> batch,
+ bool clearGraph = true)
= 0;
virtual void clear(Ptr<ExpressionGraph> graph) = 0;
diff --git a/src/models/model_factory.cpp b/src/models/model_factory.cpp
index d42f07c8..3d9e1e27 100644..100755
--- a/src/models/model_factory.cpp
+++ b/src/models/model_factory.cpp
@@ -1,7 +1,9 @@
#include "marian.h"
-#include "models/encoder_decoder.h"
#include "models/model_factory.h"
+#include "models/encoder_decoder.h"
+#include "models/encoder_classifier.h"
+#include "models/bert.h"
#include "models/costs.h"
@@ -24,97 +26,132 @@
namespace marian {
namespace models {
-Ptr<EncoderBase> EncoderFactory::construct() {
+Ptr<EncoderBase> EncoderFactory::construct(Ptr<ExpressionGraph> graph) {
if(options_->get<std::string>("type") == "s2s")
- return New<EncoderS2S>(options_);
+ return New<EncoderS2S>(graph, options_);
#ifdef CUDNN
if(options_->get<std::string>("type") == "char-s2s")
- return New<CharS2SEncoder>(options_);
+ return New<CharS2SEncoder>(graph, options_);
#endif
if(options_->get<std::string>("type") == "transformer")
- // return New<EncoderTransformer>(options_);
- return NewEncoderTransformer(options_);
+ return NewEncoderTransformer(graph, options_);
+
+ if(options_->get<std::string>("type") == "bert-encoder")
+ return New<BertEncoder>(graph, options_);
ABORT("Unknown encoder type");
}
-Ptr<DecoderBase> DecoderFactory::construct() {
+Ptr<DecoderBase> DecoderFactory::construct(Ptr<ExpressionGraph> graph) {
if(options_->get<std::string>("type") == "s2s")
- return New<DecoderS2S>(options_);
+ return New<DecoderS2S>(graph, options_);
if(options_->get<std::string>("type") == "transformer")
- // return New<DecoderTransformer>(options_);
- return NewDecoderTransformer(options_);
+ return NewDecoderTransformer(graph, options_);
ABORT("Unknown decoder type");
}
-Ptr<ModelBase> EncoderDecoderFactory::construct() {
- Ptr<EncoderDecoder> encdec;
+Ptr<ClassifierBase> ClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
+ if(options_->get<std::string>("type") == "bert-masked-lm")
+ return New<BertMaskedLM>(graph, options_);
+ else if(options_->get<std::string>("type") == "bert-classifier")
+ return New<BertClassifier>(graph, options_);
+ else
+ ABORT("Unknown classifier type");
+}
+Ptr<IModel> EncoderDecoderFactory::construct(Ptr<ExpressionGraph> graph) {
+ Ptr<EncoderDecoder> encdec;
if(options_->get<std::string>("type") == "amun")
- encdec = New<Amun>(options_);
- if(options_->get<std::string>("type") == "nematus")
- encdec = New<Nematus>(options_);
-
- if(!encdec)
- encdec = New<EncoderDecoder>(options_);
+ encdec = New<Amun>(graph, options_);
+ else if(options_->get<std::string>("type") == "nematus")
+ encdec = New<Nematus>(graph, options_);
+ else
+ encdec = New<EncoderDecoder>(graph, options_);
for(auto& ef : encoders_)
- encdec->push_back(ef(options_).construct());
+ encdec->push_back(ef(options_).construct(graph));
for(auto& df : decoders_)
- encdec->push_back(df(options_).construct());
+ encdec->push_back(df(options_).construct(graph));
+
+ return encdec;
+}
+
+Ptr<IModel> EncoderClassifierFactory::construct(Ptr<ExpressionGraph> graph) {
+ Ptr<EncoderClassifier> enccls;
+ if(options_->get<std::string>("type") == "bert")
+ enccls = New<BertEncoderClassifier>(options_);
+ else if(options_->get<std::string>("type") == "bert-classifier")
+ enccls = New<BertEncoderClassifier>(options_);
+ else
+ enccls = New<EncoderClassifier>(options_);
+
+ for(auto& ef : encoders_)
+ enccls->push_back(ef(options_).construct(graph));
+
+ for(auto& cf : classifiers_)
+ enccls->push_back(cf(options_).construct(graph));
- return add_cost(encdec, options_);
+ return enccls;
}
-Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
+Ptr<IModel> createBaseModelByType(std::string type, usage use, Ptr<Options> options) {
+ Ptr<ExpressionGraph> graph = nullptr; // graph unknown at this stage
// clang-format off
if(type == "s2s" || type == "amun" || type == "nematus") {
- return models::encoder_decoder()(options)
- ("usage", use)
- ("original-type", type)
- .push_back(models::encoder()("type", "s2s"))
- .push_back(models::decoder()("type", "s2s"))
- .construct();
+ return models::encoder_decoder(options->with(
+ "usage", use,
+ "original-type", type))
+ .push_back(models::encoder()("type", "s2s"))
+ .push_back(models::decoder()("type", "s2s"))
+ .construct(graph);
}
- if(type == "transformer") {
- return models::encoder_decoder()(options)
- ("usage", use)
+ else if(type == "transformer") {
+#if 1
+ auto newOptions = options->with("usage", use);
+ auto res = New<EncoderDecoder>(graph, newOptions);
+ res->push_back(New<EncoderTransformer>(graph, newOptions->with("type", "transformer")));
+ res->push_back(New<DecoderTransformer>(graph, newOptions->with("type", "transformer")));
+ return res;
+#else
+ return models::encoder_decoder(options->with(
+ "usage", use))
.push_back(models::encoder()("type", "transformer"))
.push_back(models::decoder()("type", "transformer"))
- .construct();
+ .construct(graph);
+#endif
}
- if(type == "transformer_s2s") {
+ else if(type == "transformer_s2s") {
return models::encoder_decoder()(options)
("usage", use)
("original-type", type)
- .push_back(models::encoder()("type", "transformer"))
- .push_back(models::decoder()("type", "s2s"))
- .construct();
+ .push_back(models::encoder()("type", "transformer"))
+ .push_back(models::decoder()("type", "s2s"))
+ .construct(graph);
}
- if(type == "lm") {
+ else if(type == "lm") {
auto idx = options->has("index") ? options->get<size_t>("index") : 0;
std::vector<int> dimVocabs = options->get<std::vector<int>>("dim-vocabs");
int vocab = dimVocabs[0];
dimVocabs.resize(idx + 1);
std::fill(dimVocabs.begin(), dimVocabs.end(), vocab);
- return models::encoder_decoder()(options)
- ("usage", use)
- ("type", "s2s")
- ("original-type", type)
- .push_back(models::decoder()
- ("index", idx)
- ("dim-vocabs", dimVocabs))
- .construct();
+ return models::encoder_decoder(options->with(
+ "usage", use,
+ "type", "s2s",
+ "original-type", type))
+ .push_back(models::decoder()
+ ("index", idx)
+ ("dim-vocabs", dimVocabs))
+ .construct(graph);
}
- if(type == "multi-s2s") {
+ else if(type == "multi-s2s") {
size_t numEncoders = 2;
auto ms2sFactory = models::encoder_decoder()(options)
("usage", use)
@@ -128,10 +165,10 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
ms2sFactory.push_back(models::decoder()("index", numEncoders));
- return ms2sFactory.construct();
+ return ms2sFactory.construct(graph);
}
- if(type == "shared-multi-s2s") {
+ else if(type == "shared-multi-s2s") {
size_t numEncoders = 2;
auto ms2sFactory = models::encoder_decoder()(options)
("usage", use)
@@ -145,10 +182,10 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
ms2sFactory.push_back(models::decoder()("index", numEncoders));
- return ms2sFactory.construct();
+ return ms2sFactory.construct(graph);
}
- if(type == "multi-transformer") {
+ else if(type == "multi-transformer") {
size_t numEncoders = 2;
auto mtransFactory = models::encoder_decoder()(options)
("usage", use)
@@ -161,10 +198,10 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
}
mtransFactory.push_back(models::decoder()("index", numEncoders));
- return mtransFactory.construct();
+ return mtransFactory.construct(graph);
}
- if(type == "shared-multi-transformer") {
+ else if(type == "shared-multi-transformer") {
size_t numEncoders = 2;
auto mtransFactory = models::encoder_decoder()(options)
("usage", use)
@@ -177,10 +214,10 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
}
mtransFactory.push_back(models::decoder()("index", numEncoders));
- return mtransFactory.construct();
+ return mtransFactory.construct(graph);
}
- if(type == "lm-transformer") {
+ else if(type == "lm-transformer") {
auto idx = options->has("index") ? options->get<size_t>("index") : 0;
std::vector<int> dimVocabs = options->get<std::vector<int>>("dim-vocabs");
int vocab = dimVocabs[0];
@@ -191,54 +228,120 @@ Ptr<ModelBase> by_type(std::string type, usage use, Ptr<Options> options) {
("usage", use)
("type", "transformer")
("original-type", type)
- .push_back(models::decoder()
- ("index", idx)
- ("dim-vocabs", dimVocabs))
- .construct();
+ .push_back(models::decoder()
+ ("index", idx)
+ ("dim-vocabs", dimVocabs))
+ .construct(graph);
}
-#ifdef COMPILE_EXAMPLES
- // @TODO: examples should be compiled optionally
- if(type == "mnist-ffnn") {
- auto mnist = New<MnistFeedForwardNet>(options);
- if(use == usage::scoring)
- return New<Scorer>(mnist, New<MNISTLogsoftmax>());
- else if(use == usage::training)
- return New<Trainer>(mnist, New<MNISTCrossEntropyCost>());
- else
- return mnist;
+ else if(type == "bert") { // for full BERT training
+ return models::encoder_classifier()(options) //
+ ("original-type", "bert") // so we can query this
+ ("usage", use) //
+ .push_back(models::encoder() //
+ ("type", "bert-encoder") // close to original transformer encoder
+ ("index", 0)) //
+ .push_back(models::classifier() //
+ ("prefix", "masked-lm") // prefix for parameter names
+ ("type", "bert-masked-lm") //
+ ("index", 0)) // multi-task learning with MaskedLM
+ .push_back(models::classifier() //
+ ("prefix", "next-sentence") // prefix for parameter names
+ ("type", "bert-classifier") //
+ ("index", 1)) // next sentence prediction
+ .construct(graph);
}
-#endif
+ else if(type == "bert-classifier") { // for BERT fine-tuning on non-BERT classification task
+ return models::encoder_classifier()(options) //
+ ("original-type", "bert-classifier") // so we can query this if needed
+ ("usage", use) //
+ .push_back(models::encoder() //
+ ("type", "bert-encoder") //
+ ("index", 0)) // close to original transformer encoder
+ .push_back(models::classifier() //
+ ("type", "bert-classifier") //
+ ("index", 1)) // next sentence prediction
+ .construct(graph);
+ }
+
+#ifdef COMPILE_EXAMPLES
+ else if(type == "mnist-ffnn")
+ return New<MnistFeedForwardNet>(options);
+#endif
#ifdef CUDNN
#ifdef COMPILE_EXAMPLES
- if(type == "mnist-lenet") {
- auto mnist = New<MnistLeNet>(options);
- if(use == usage::scoring)
- return New<Scorer>(mnist, New<MNISTLogsoftmax>());
- else if(use == usage::training)
- return New<Trainer>(mnist, New<MNISTCrossEntropyCost>());
- else
- return mnist;
- }
+ else if(type == "mnist-lenet")
+ return New<MnistLeNet>(options);
#endif
- if(type == "char-s2s") {
+ else if(type == "char-s2s") {
return models::encoder_decoder()(options)
("usage", use)
("original-type", type)
- .push_back(models::encoder()("type", "char-s2s"))
- .push_back(models::decoder()("type", "s2s"))
- .construct();
+ .push_back(models::encoder()("type", "char-s2s"))
+ .push_back(models::decoder()("type", "s2s"))
+ .construct(graph);
}
#endif
// clang-format on
- ABORT("Unknown model type: {}", type);
+ else
+ ABORT("Unknown model type: {}", type);
+}
+
+Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage use) {
+ std::string type = options->get<std::string>("type");
+ auto baseModel = createBaseModelByType(type, use, options);
+
+ // add (log)softmax if requested
+ if (use == usage::translation) {
+ if(std::dynamic_pointer_cast<EncoderDecoder>(baseModel)) {
+ if(options->get<bool>("output-sampling", false))
+ return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<GumbelSoftmaxStep>());
+ else
+ return New<Stepwise>(std::dynamic_pointer_cast<EncoderDecoder>(baseModel), New<LogSoftmaxStep>());
+ }
+#ifdef COMPILE_EXAMPLES
+ // note: 'usage::translation' here means 'inference'
+ else if (std::dynamic_pointer_cast<MnistFeedForwardNet>(baseModel))
+ return New<Scorer>(baseModel, New<MNISTLogsoftmax>());
+#ifdef CUDNN
+ else if (std::dynamic_pointer_cast<MnistLeNet>(baseModel))
+ return New<Scorer>(baseModel, New<MNISTLogsoftmax>());
+#endif
+#endif
+ else
+ ABORT("'usage' parameter 'translation' cannot be applied to model type: {}", type);
+ }
+ else if (use == usage::raw)
+ return baseModel;
+ else
+ ABORT("'Usage' parameter must be 'translation' or 'raw'");
}
-Ptr<ModelBase> from_options(Ptr<Options> options, usage use) {
+Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options, usage use) {
std::string type = options->get<std::string>("type");
- return by_type(type, use, options);
+ auto baseModel = createBaseModelByType(type, use, options);
+
+ // add cost function
+ ABORT_IF(use != usage::training && use != usage::scoring, "'Usage' parameter must be 'training' or 'scoring'");
+ // note: usage::scoring means "score the loss function", hence it uses a Trainer (not Scorer, which is for decoding)
+ // @TODO: Should we define a new class that does not compute gradients?
+ if (std::dynamic_pointer_cast<EncoderDecoder>(baseModel))
+ return New<Trainer>(baseModel, New<EncoderDecoderCECost>(options));
+ else if (std::dynamic_pointer_cast<EncoderClassifier>(baseModel))
+ return New<Trainer>(baseModel, New<EncoderClassifierCECost>(options));
+#ifdef COMPILE_EXAMPLES
+ // @TODO: examples should be compiled optionally
+ else if (std::dynamic_pointer_cast<MnistFeedForwardNet>(baseModel))
+ return New<Trainer>(baseModel, New<MNISTCrossEntropyCost>());
+#ifdef CUDNN
+ else if (std::dynamic_pointer_cast<MnistLeNet>(baseModel))
+ return New<Trainer>(baseModel, New<MNISTCrossEntropyCost>());
+#endif
+#endif
+ else
+ ABORT("Criterion function unknown for model type: {}", type);
}
} // namespace models
diff --git a/src/models/model_factory.h b/src/models/model_factory.h
index 2ec7fe75..1840df8f 100644..100755
--- a/src/models/model_factory.h
+++ b/src/models/model_factory.h
@@ -4,37 +4,42 @@
#include "layers/factory.h"
#include "models/encoder_decoder.h"
+#include "models/encoder_classifier.h"
namespace marian {
namespace models {
class EncoderFactory : public Factory {
+ using Factory::Factory;
public:
- EncoderFactory(Ptr<ExpressionGraph> graph = nullptr) : Factory(graph) {}
-
- virtual Ptr<EncoderBase> construct();
+ virtual Ptr<EncoderBase> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<EncoderFactory> encoder;
class DecoderFactory : public Factory {
+ using Factory::Factory;
public:
- DecoderFactory(Ptr<ExpressionGraph> graph = nullptr) : Factory(graph) {}
-
- virtual Ptr<DecoderBase> construct();
+ virtual Ptr<DecoderBase> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<DecoderFactory> decoder;
+class ClassifierFactory : public Factory {
+ using Factory::Factory;
+public:
+ virtual Ptr<ClassifierBase> construct(Ptr<ExpressionGraph> graph);
+};
+
+typedef Accumulator<ClassifierFactory> classifier;
+
class EncoderDecoderFactory : public Factory {
+ using Factory::Factory;
private:
std::vector<encoder> encoders_;
std::vector<decoder> decoders_;
public:
- EncoderDecoderFactory(Ptr<ExpressionGraph> graph = nullptr)
- : Factory(graph) {}
-
Accumulator<EncoderDecoderFactory> push_back(encoder enc) {
encoders_.push_back(enc);
return Accumulator<EncoderDecoderFactory>(*this);
@@ -45,15 +50,37 @@ public:
return Accumulator<EncoderDecoderFactory>(*this);
}
- virtual Ptr<ModelBase> construct();
+ virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
};
typedef Accumulator<EncoderDecoderFactory> encoder_decoder;
-Ptr<ModelBase> by_type(std::string type, usage, Ptr<Options> options);
+class EncoderClassifierFactory : public Factory {
+ using Factory::Factory;
+private:
+ std::vector<encoder> encoders_;
+ std::vector<classifier> classifiers_;
+
+public:
+ Accumulator<EncoderClassifierFactory> push_back(encoder enc) {
+ encoders_.push_back(enc);
+ return Accumulator<EncoderClassifierFactory>(*this);
+ }
+
+ Accumulator<EncoderClassifierFactory> push_back(classifier cls) {
+ classifiers_.push_back(cls);
+ return Accumulator<EncoderClassifierFactory>(*this);
+ }
+
+ virtual Ptr<IModel> construct(Ptr<ExpressionGraph> graph);
+};
+
+typedef Accumulator<EncoderClassifierFactory> encoder_classifier;
+
+Ptr<IModel> createBaseModelByType(std::string type, usage, Ptr<Options> options);
-Ptr<ModelBase> from_options(Ptr<Options> options, usage);
+Ptr<IModel> createModelFromOptions(Ptr<Options> options, usage);
-Ptr<ModelBase> from_config(Ptr<Config> config, usage);
+Ptr<ICriterionFunction> createCriterionFunctionFromOptions(Ptr<Options> options, usage);
} // namespace models
} // namespace marian
diff --git a/src/models/model_task.h b/src/models/model_task.h
index 932fba62..96dfadd0 100644
--- a/src/models/model_task.h
+++ b/src/models/model_task.h
@@ -5,11 +5,12 @@
namespace marian {
struct ModelTask {
+ virtual ~ModelTask() {}
virtual void run() = 0;
};
struct ModelServiceTask {
- virtual void init() = 0;
+ virtual ~ModelServiceTask() {}
virtual std::string run(const std::string&) = 0;
};
} // namespace marian
diff --git a/src/models/nematus.h b/src/models/nematus.h
index 88e9854b..3c0a549e 100755
--- a/src/models/nematus.h
+++ b/src/models/nematus.h
@@ -8,17 +8,16 @@ namespace marian {
class Nematus : public EncoderDecoder {
public:
- template <class... Args>
- Nematus(Ptr<Options> options) : EncoderDecoder(options), nameMap_(createNameMap()) {
+ Nematus(Ptr<ExpressionGraph> graph, Ptr<Options> options) : EncoderDecoder(graph, options), nameMap_(createNameMap()) {
ABORT_IF(options_->get<std::string>("enc-type") != "bidirectional",
- "--type nematus does not currently support other encoder "
+ "--type nematus does not support other encoder "
"type than bidirectional, use --type s2s");
ABORT_IF(options_->get<std::string>("enc-cell") != "gru-nematus",
- "--type nematus does not currently support other rnn cells "
+ "--type nematus does not support other rnn cells "
"than gru-nematus, use --type s2s");
ABORT_IF(options_->get<std::string>("dec-cell") != "gru-nematus",
- "--type nematus does not currently support other rnn cells "
+ "--type nematus does not support other rnn cells "
"than gru-nematus, use --type s2s");
ABORT_IF(options_->get<int>("dec-cell-high-depth") > 1,
@@ -36,11 +35,16 @@ public:
for(auto it = ioItems.begin(); it != ioItems.end();) {
if(it->name == "decoder_c_tt") {
it = ioItems.erase(it);
+ } else if(it->name == "uidx") {
+ it = ioItems.erase(it);
+ } else if(it->name == "history_errs") {
+ it = ioItems.erase(it);
+ } else {
+ auto pair = nameMap_.find(it->name);
+ if(pair != nameMap_.end())
+ it->name = pair->second;
+ it++;
}
- auto pair = nameMap_.find(it->name);
- if(pair != nameMap_.end())
- it->name = pair->second;
- ++it;
}
// load items into the graph
graph->load(ioItems);
diff --git a/src/models/s2s.h b/src/models/s2s.h
index 92ba9a7d..8d12f96e 100755
--- a/src/models/s2s.h
+++ b/src/models/s2s.h
@@ -9,7 +9,9 @@
namespace marian {
class EncoderS2S : public EncoderBase {
+ using EncoderBase::EncoderBase;
public:
+ virtual ~EncoderS2S() {}
Expr applyEncoderRNN(Ptr<ExpressionGraph> graph,
Expr embeddings,
Expr mask,
@@ -34,7 +36,7 @@ public:
float dropoutRnn = inference_ ? 0 : opt<float>("dropout-rnn");
- auto rnnFw = rnn::rnn(graph) //
+ auto rnnFw = rnn::rnn() //
("type", opt<std::string>("enc-cell")) //
("direction", (int)forward) //
("dimInput", embeddings->shape()[-1]) //
@@ -44,7 +46,7 @@ public:
("skip", opt<bool>("skip"));
for(int i = 1; i <= first; ++i) {
- auto stacked = rnn::stacked_cell(graph);
+ auto stacked = rnn::stacked_cell();
for(int j = 1; j <= opt<int>("enc-cell-depth"); ++j) {
std::string paramPrefix = prefix_ + "_bi";
if(i > 1)
@@ -53,14 +55,14 @@ public:
paramPrefix += "_cell" + std::to_string(j);
bool transition = (j > 1);
- stacked.push_back(rnn::cell(graph) //
+ stacked.push_back(rnn::cell() //
("prefix", paramPrefix) //
("transition", transition));
}
rnnFw.push_back(stacked);
}
- auto rnnBw = rnn::rnn(graph) //
+ auto rnnBw = rnn::rnn() //
("type", opt<std::string>("enc-cell")) //
("direction", (int)backward) //
("dimInput", embeddings->shape()[-1]) //
@@ -70,7 +72,7 @@ public:
("skip", opt<bool>("skip"));
for(int i = 1; i <= first; ++i) {
- auto stacked = rnn::stacked_cell(graph);
+ auto stacked = rnn::stacked_cell();
for(int j = 1; j <= opt<int>("enc-cell-depth"); ++j) {
std::string paramPrefix = prefix_ + "_bi_r";
if(i > 1)
@@ -79,15 +81,15 @@ public:
paramPrefix += "_cell" + std::to_string(j);
bool transition = (j > 1);
- stacked.push_back(rnn::cell(graph) //
+ stacked.push_back(rnn::cell() //
("prefix", paramPrefix) //
("transition", transition));
}
rnnBw.push_back(stacked);
}
- auto context = concatenate({rnnFw->transduce(embeddings, mask),
- rnnBw->transduce(embeddings, mask)},
+ auto context = concatenate({rnnFw.construct(graph)->transduce(embeddings, mask),
+ rnnBw.construct(graph)->transduce(embeddings, mask)},
/*axis =*/ -1);
if(second > 0) {
@@ -95,7 +97,7 @@ public:
// previous bidirectional RNN through multiple layers
// construct RNN first
- auto rnnUni = rnn::rnn(graph) //
+ auto rnnUni = rnn::rnn() //
("type", opt<std::string>("enc-cell")) //
("dimInput", 2 * opt<int>("dim-rnn")) //
("dimState", opt<int>("dim-rnn")) //
@@ -104,68 +106,32 @@ public:
("skip", opt<bool>("skip"));
for(int i = first + 1; i <= second + first; ++i) {
- auto stacked = rnn::stacked_cell(graph);
+ auto stacked = rnn::stacked_cell();
for(int j = 1; j <= opt<int>("enc-cell-depth"); ++j) {
std::string paramPrefix = prefix_ + "_l" + std::to_string(i) + "_cell"
+ std::to_string(j);
- stacked.push_back(rnn::cell(graph)("prefix", paramPrefix));
+ stacked.push_back(rnn::cell()("prefix", paramPrefix));
}
rnnUni.push_back(stacked);
}
// transduce context to new context
- context = rnnUni->transduce(context);
+ context = rnnUni.construct(graph)->transduce(context);
}
return context;
}
- Expr buildSourceEmbeddings(Ptr<ExpressionGraph> graph) {
- // create source embeddings
- int dimVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
- int dimEmb = opt<int>("dim-emb");
-
- auto embFactory = embedding(graph) //
- ("dimVocab", dimVoc) //
- ("dimEmb", dimEmb);
-
- if(opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
- embFactory("prefix", "Wemb");
- else
- embFactory("prefix", prefix_ + "_Wemb");
-
- if(options_->has("embedding-fix-src"))
- embFactory("fixed", opt<bool>("embedding-fix-src"));
-
- if(options_->has("embedding-vectors")) {
- auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
- embFactory //
- ("embFile", embFiles[batchIndex_]) //
- ("normalization", opt<bool>("embedding-normalization"));
- }
-
- return embFactory.construct();
- }
-
- EncoderS2S(Ptr<Options> options) : EncoderBase(options) {}
+ //EncoderS2S(Ptr<Options> options) : EncoderBase(options) {}
virtual Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch) override {
- auto embeddings = buildSourceEmbeddings(graph);
-
+ graph_ = graph;
// select embeddings that occur in the batch
- Expr batchEmbeddings, batchMask;
- std::tie(batchEmbeddings, batchMask)
- = EncoderBase::lookup(graph, embeddings, batch);
-
- // apply dropout over source words
- float dropProb = inference_ ? 0 : opt<float>("dropout-src");
- if(dropProb) {
- int srcWords = batchEmbeddings->shape()[-3];
- batchEmbeddings = dropout(batchEmbeddings, dropProb, {srcWords, 1, 1});
- }
+ Expr batchEmbeddings, batchMask; std::tie
+ (batchEmbeddings, batchMask) = getEmbeddingLayer()->apply((*batch)[batchIndex_]);
Expr context = applyEncoderRNN(
- graph, batchEmbeddings, batchMask, opt<std::string>("enc-type"));
+ graph_, batchEmbeddings, batchMask, opt<std::string>("enc-type"));
return New<EncoderState>(context, batchMask, batch);
}
@@ -174,14 +140,18 @@ public:
};
class DecoderS2S : public DecoderBase {
+ using DecoderBase::DecoderBase;
private:
Ptr<rnn::RNN> rnn_;
Ptr<mlp::MLP> output_;
+ int lastDimBatch_{-1}; // monitor dimBatch to take into account batch-pruning during decoding
+ // may require to lazily rebuild decoder RNN
Ptr<rnn::RNN> constructDecoderRNN(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state) {
float dropoutRnn = inference_ ? 0 : opt<float>("dropout-rnn");
- auto rnn = rnn::rnn(graph) //
+
+ auto rnn = rnn::rnn() //
("type", opt<std::string>("dec-cell")) //
("dimInput", opt<int>("dim-emb")) //
("dimState", opt<int>("dim-rnn")) //
@@ -197,11 +167,11 @@ private:
size_t decoderHighDepth = opt<size_t>("dec-cell-high-depth");
// setting up conditional (transitional) cell
- auto baseCell = rnn::stacked_cell(graph);
+ auto baseCell = rnn::stacked_cell();
for(size_t i = 1; i <= decoderBaseDepth; ++i) {
bool transition = (i > 2);
auto paramPrefix = prefix_ + "_cell" + std::to_string(i);
- baseCell.push_back(rnn::cell(graph) //
+ baseCell.push_back(rnn::cell() //
("prefix", paramPrefix) //
("final", i > 1) //
("transition", transition));
@@ -213,8 +183,7 @@ private:
auto encState = state->getEncoderStates()[k];
- baseCell.push_back(
- rnn::attention(graph)("prefix", attPrefix).set_state(encState));
+ baseCell.push_back(rnn::attention()("prefix", attPrefix).set_state(encState));
}
}
}
@@ -224,24 +193,22 @@ private:
// Add more cells to RNN (stacked RNN)
for(size_t i = 2; i <= decoderLayers; ++i) {
// deep transition
- auto highCell = rnn::stacked_cell(graph);
+ auto highCell = rnn::stacked_cell();
for(size_t j = 1; j <= decoderHighDepth; j++) {
auto paramPrefix
= prefix_ + "_l" + std::to_string(i) + "_cell" + std::to_string(j);
- highCell.push_back(rnn::cell(graph)("prefix", paramPrefix));
+ highCell.push_back(rnn::cell()("prefix", paramPrefix));
}
// Add cell to RNN (more layers)
rnn.push_back(highCell);
}
- return rnn.construct();
+ return rnn.construct(graph);
}
public:
- DecoderS2S(Ptr<Options> options) : DecoderBase(options) {}
-
virtual Ptr<DecoderState> startState(
Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
@@ -258,43 +225,48 @@ public:
Expr start;
if(!meanContexts.empty()) {
// apply single layer network to mean to map into decoder space
- auto mlp = mlp::mlp(graph).push_back(
- mlp::dense(graph) //
+ auto mlp = mlp::mlp().push_back(
+ mlp::dense() //
("prefix", prefix_ + "_ff_state") //
("dim", opt<int>("dim-rnn")) //
("activation", (int)mlp::act::tanh) //
("layer-normalization", opt<bool>("layer-normalization")) //
("nematus-normalization",
- options_->has("original-type")
+ options_->has("original-type")
&& opt<std::string>("original-type") == "nematus") //
- );
+ )
+ .construct(graph);
start = mlp->apply(meanContexts);
} else {
int dimBatch = (int)batch->size();
int dimRnn = opt<int>("dim-rnn");
- start = graph->constant({dimBatch, dimRnn}, inits::zeros);
+ start = graph->constant({dimBatch, dimRnn}, inits::zeros());
}
rnn::States startStates(opt<size_t>("dec-depth"), {start, start});
- return New<DecoderState>(startStates, nullptr, encStates, batch);
+ return New<DecoderState>(startStates, Logits(), encStates, batch);
}
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state) override {
- auto embeddings = state->getTargetEmbeddings();
-
- // dropout target words
- float dropoutTrg = inference_ ? 0 : opt<float>("dropout-trg");
- if(dropoutTrg) {
- int trgWords = embeddings->shape()[-3];
- embeddings = dropout(embeddings, dropoutTrg, {trgWords, 1, 1});
- }
-
- if(!rnn_)
- rnn_ = constructDecoderRNN(graph, state);
+ auto embeddings = state->getTargetHistoryEmbeddings();
+
+ // The batch dimension of the inputs can change due to batch-pruning, in that case
+ // cached elements need to be rebuilt, in this case the mapped encoder context in the
+ // attention mechanism of the decoder RNN.
+ int currDimBatch = embeddings->shape()[-2];
+ if(!rnn_ || lastDimBatch_ != currDimBatch) // if currDimBatch is different, rebuild the cached RNN
+ rnn_ = constructDecoderRNN(graph, state); // @TODO: add a member to encoder/decoder state `bool batchDimChanged()`
+ lastDimBatch_ = currDimBatch;
+ // Also @TODO: maybe implement a Cached(build, updateIf) that runs a check and rebuild if required
+ // at dereferecing :
+ // rnn_ = Cached<decltype(constructDecoderRNN(graph, state))>(
+ // /*build=*/[]{ return constructDecoderRNN(graph, state); },
+ // /*updateIf=*/[]{ return state->batchDimChanged() });
+ // rnn_->transduce(...);
// apply RNN to embeddings, initialized with encoder context mapped into
// decoder space
@@ -322,10 +294,11 @@ public:
if(!output_) {
// construct deep output multi-layer network layer-wise
- auto hidden = mlp::dense(graph) //
+
+ auto hidden = mlp::dense() //
("prefix", prefix_ + "_ff_logit_l1") //
("dim", opt<int>("dim-emb")) //
- ("activation", mlp::act::tanh) //
+ ("activation", (int)mlp::act::tanh) //
("layer-normalization", opt<bool>("layer-normalization")) //
("nematus-normalization",
options_->has("original-type")
@@ -333,7 +306,7 @@ public:
int dimTrgVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
- auto last = mlp::output(graph) //
+ auto last = mlp::output() //
("prefix", prefix_ + "_ff_logit_l2") //
("dim", dimTrgVoc);
@@ -341,29 +314,31 @@ public:
std::string tiedPrefix = prefix_ + "_Wemb";
if(opt<bool>("tied-embeddings-all") || opt<bool>("tied-embeddings-src"))
tiedPrefix = "Wemb";
- last.tie_transposed("W", tiedPrefix);
+ last.tieTransposed(tiedPrefix);
}
-
- if(shortlist_)
- last.set_shortlist(shortlist_);
+ last("vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_]); // for factored outputs
+ last("lemma-dim-emb", opt<int>("lemma-dim-emb", 0)); // for factored outputs
// assemble layers into MLP and apply to embeddings, decoder context and
// aligned source context
- output_ = mlp::mlp(graph) //
+ output_ = mlp::mlp() //
.push_back(hidden) //
.push_back(last)
- .construct();
+ .construct(graph);
}
- Expr logits;
+ if (shortlist_)
+ output_->setShortlist(shortlist_);
+
+ Logits logits;
if(alignedContext)
- logits = output_->apply(embeddings, decoderContext, alignedContext);
+ logits = output_->applyAsLogits({embeddings, decoderContext, alignedContext});
else
- logits = output_->apply(embeddings, decoderContext);
+ logits = output_->applyAsLogits({embeddings, decoderContext});
// return unormalized(!) probabilities
auto nextState = New<DecoderState>(
- decoderStates, logits, state->getEncoderStates(), state->getBatch());
+ decoderStates, logits, state->getEncoderStates(), state->getBatch());
// Advance current target token position by one
nextState->setPosition(state->getPosition() + 1);
@@ -379,7 +354,8 @@ public:
void clear() override {
rnn_ = nullptr;
- output_ = nullptr;
+ if (output_)
+ output_->clear();
}
};
} // namespace marian
diff --git a/src/models/states.h b/src/models/states.h
index 96461734..c2f9ee05 100755
--- a/src/models/states.h
+++ b/src/models/states.h
@@ -1,6 +1,7 @@
#pragma once
#include "marian.h"
+#include "layers/generic.h" // @HACK: for factored embeddings only so far
#include "rnn/types.h"
namespace marian {
@@ -8,7 +9,7 @@ namespace marian {
class EncoderState {
private:
Expr context_;
- Expr mask_;
+ Expr mask_; // [beam depth=1, max length, batch size, vector dim=1] source mask
Ptr<data::CorpusBatch> batch_;
public:
@@ -16,73 +17,81 @@ public:
: context_(context), mask_(mask), batch_(batch) {}
EncoderState() {}
+ virtual ~EncoderState() {}
- virtual Expr getContext() { return context_; }
- virtual Expr getAttended() { return context_; }
- virtual Expr getMask() { return mask_; }
+ virtual Expr getContext() const { return context_; }
+ virtual Expr getAttended() const { return context_; }
+ virtual Expr getMask() const { return mask_; } // source batch mask; may have additional positions suppressed
virtual const Words& getSourceWords() {
return batch_->front()->data();
}
+
+ // Sub-select active batch entries from encoder context and context mask
+ Ptr<EncoderState> select(const std::vector<IndexType>& batchIndices) { // [batchIndex] indices of active batch entries
+ // Dimension -2 is OK for both, RNN and Transformer models as the encoder context in Transformer gets transposed to the same dimension layout
+ return New<EncoderState>(index_select(context_, -2, batchIndices), index_select(mask_, -2, batchIndices), batch_);
+ }
};
class DecoderState {
protected:
rnn::States states_; // states of individual decoder layers
- Expr logProbs_;
+ Logits logProbs_;
std::vector<Ptr<EncoderState>> encStates_;
Ptr<data::CorpusBatch> batch_;
- Expr targetEmbeddings_;
+ Expr targetHistoryEmbeddings_; // decoder history (teacher-forced or from decoding), embedded
Expr targetMask_;
- Expr targetIndices_;
+ Words targetWords_; // target labels
// Keep track of current target token position during translation
size_t position_{0};
public:
DecoderState(const rnn::States& states,
- Expr logProbs,
+ Logits logProbs,
const std::vector<Ptr<EncoderState>>& encStates,
Ptr<data::CorpusBatch> batch)
: states_(states), logProbs_(logProbs), encStates_(encStates), batch_(batch) {}
+ virtual ~DecoderState() {}
// @TODO: Do we need all these to be virtual?
virtual const std::vector<Ptr<EncoderState>>& getEncoderStates() const {
return encStates_;
}
- virtual Expr getLogProbs() const { return logProbs_; }
- virtual void setLogProbs(Expr logProbs) { logProbs_ = logProbs; }
+ virtual Logits getLogProbs() const { return logProbs_; }
+ virtual void setLogProbs(Logits logProbs) { logProbs_ = logProbs; }
// @TODO: should this be a constructor? Then derived classes can call this without the New<> in the loop
- virtual Ptr<DecoderState> select(const std::vector<IndexType>& selIdx,
+ virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) const {
+
+ std::vector<Ptr<EncoderState>> newEncStates;
+ for(auto& es : encStates_)
+ // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
+ newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
+
+ // hypindices matches batchIndices in terms of batch dimension, so we only need hypIndices
auto selectedState = New<DecoderState>(
- states_.select(selIdx, beamSize, /*isBatchMajor=*/false), logProbs_, encStates_, batch_);
+ states_.select(hypIndices, beamSize, /*isBatchMajor=*/false), logProbs_, newEncStates, batch_);
- // Set positon of new state based on the target token position of current
- // state
+ // Set positon of new state based on the target token position of current state
selectedState->setPosition(getPosition());
return selectedState;
}
virtual const rnn::States& getStates() const { return states_; }
- virtual Expr getTargetEmbeddings() const { return targetEmbeddings_; };
+ virtual Expr getTargetHistoryEmbeddings() const { return targetHistoryEmbeddings_; };
+ virtual void setTargetHistoryEmbeddings(Expr targetHistoryEmbeddings) { targetHistoryEmbeddings_ = targetHistoryEmbeddings; }
- virtual void setTargetEmbeddings(Expr targetEmbeddings) {
- targetEmbeddings_ = targetEmbeddings;
- }
-
- virtual Expr getTargetIndices() const { return targetIndices_; };
-
- virtual void setTargetIndices(Expr targetIndices) {
- targetIndices_ = targetIndices;
- }
+ virtual const Words& getTargetWords() const { return targetWords_; };
+ virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
virtual Expr getTargetMask() const { return targetMask_; };
-
virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }
virtual const Words& getSourceWords() const {
@@ -99,4 +108,31 @@ public:
virtual void blacklist(Expr /*totalCosts*/, Ptr<data::CorpusBatch> /*batch*/) {}
};
+
+/**
+ * Classifier output based on DecoderState
+ * @TODO: should be unified with DecoderState or not be used at all as Classifier do not really have stateful output.
+ */
+class ClassifierState {
+private:
+ Expr logProbs_;
+ std::vector<Ptr<EncoderState>> encStates_;
+ Ptr<data::CorpusBatch> batch_;
+
+ Expr targetMask_;
+ Words targetWords_;
+
+public:
+ virtual ~ClassifierState() {}
+ virtual Expr getLogProbs() const { return logProbs_; }
+ virtual void setLogProbs(Expr logProbs) { logProbs_ = logProbs; }
+
+ virtual const Words& getTargetWords() const { return targetWords_; };
+ virtual void setTargetWords(const Words& targetWords) { targetWords_ = targetWords; }
+
+ virtual Expr getTargetMask() const { return targetMask_; };
+
+ virtual void setTargetMask(Expr targetMask) { targetMask_ = targetMask; }
+};
+
} // namespace marian
diff --git a/src/models/transformer.h b/src/models/transformer.h
index c30d06c8..84f1cd65 100755
--- a/src/models/transformer.h
+++ b/src/models/transformer.h
@@ -6,7 +6,6 @@
#include "marian.h"
#include "layers/constructors.h"
-#include "layers/factory.h"
#include "models/decoder.h"
#include "models/encoder.h"
#include "models/states.h"
@@ -23,50 +22,79 @@ namespace marian {
template<class EncoderOrDecoderBase>
class Transformer : public EncoderOrDecoderBase {
typedef EncoderOrDecoderBase Base;
+ using Base::Base;
protected:
- using Base::options_; using Base::inference_;
- std::unordered_map<std::string, Expr> cache_;
+ using Base::options_; using Base::inference_; using Base::batchIndex_; using Base::graph_;
+ std::unordered_map<std::string, Expr> cache_; // caching transformation of the encoder that should not be created again
// attention weights produced by step()
// If enabled, it is set once per batch during training, and once per step during translation.
// It can be accessed by getAlignments(). @TODO: move into a state or return-value object
std::vector<Expr> alignments_; // [max tgt len or 1][beam depth, max src length, batch size, 1]
- template <typename T> T opt(const std::string& key) const { Ptr<Options> options = options_; return options->get<T>(key); } // need to duplicate, since somehow using Base::opt is not working
- // FIXME: that separate options assignment is weird
+ // @TODO: make this go away
+ template <typename T>
+ T opt(const char* const key) const { Ptr<Options> options = options_; return options->get<T>(key); }
- template <typename T> T opt(const std::string& key, const T& def) const { Ptr<Options> options = options_; if (options->has(key)) return options->get<T>(key); else return def; }
+ template <typename T>
+ T opt(const std::string& key) const { return opt<T>(key.c_str()); }
- Ptr<ExpressionGraph> graph_;
+ template <typename T>
+ T opt(const char* const key, const T& def) const { Ptr<Options> options = options_; return options->get<T>(key, def); }
-public:
- Transformer(Ptr<Options> options)
- : EncoderOrDecoderBase(options) {
- }
+ template <typename T>
+ T opt(const std::string& key, const T& def) const { opt<T>(key.c_str(), def); }
+public:
static Expr transposeTimeBatch(Expr input) { return transpose(input, {0, 2, 1, 3}); }
- Expr addPositionalEmbeddings(Expr input, int start = 0) const {
+ Expr addPositionalEmbeddings(Expr input, int start = 0, bool trainPosEmbeddings = false) const {
int dimEmb = input->shape()[-1];
int dimWords = input->shape()[-3];
- float num_timescales = (float)dimEmb / 2;
- float log_timescale_increment = std::log(10000.f) / (num_timescales - 1.f);
+ Expr embeddings = input;
- std::vector<float> vPos(dimEmb * dimWords, 0);
- for(int p = start; p < dimWords + start; ++p) {
- for(int i = 0; i < num_timescales; ++i) {
- float v = p * std::exp(i * -log_timescale_increment);
- vPos[(p - start) * dimEmb + i] = std::sin(v);
- vPos[(p - start) * dimEmb + (int)num_timescales + i] = std::cos(v); // @TODO: is int vs. float correct for num_timescales?
- }
+ if(trainPosEmbeddings) {
+ int maxLength = opt<int>("max-length");
+
+ // Hack for translating with length longer than trained embeddings
+ // We check if the embedding matrix "Wpos" already exist so we can
+ // check the number of positions in that loaded parameter.
+ // We then have to restict the maximum length to the maximum positon
+ // and positions beyond this will be the maximum position.
+ Expr seenEmb = graph_->get("Wpos");
+ int numPos = seenEmb ? seenEmb->shape()[-2] : maxLength;
+
+ auto embeddingLayer = embedding(
+ "prefix", "Wpos", // share positional embeddings across all encoders/decorders
+ "dimVocab", numPos,
+ "dimEmb", dimEmb)
+ .construct(graph_);
+
+ // fill with increasing numbers until current length or maxPos
+ std::vector<IndexType> positions(dimWords, numPos - 1);
+ for(int i = 0; i < std::min(dimWords, numPos); ++i)
+ positions[i] = i;
+
+ auto signal = embeddingLayer->applyIndices(positions, {dimWords, 1, dimEmb});
+ embeddings = embeddings + signal;
+ } else {
+ // @TODO : test if embeddings should be scaled when trainable
+ // according to paper embeddings are scaled up by \sqrt(d_m)
+ embeddings = std::sqrt((float)dimEmb) * embeddings; // embeddings were initialized to unit length; so norms will be in order of sqrt(dimEmb)
+
+ auto signal = graph_->constant({dimWords, 1, dimEmb},
+ inits::sinusoidalPositionEmbeddings(start));
+ embeddings = embeddings + signal;
}
- // shared across batch entries
- auto signal
- = graph_->constant({dimWords, 1, dimEmb}, inits::from_vector(vPos));
- return input + signal;
+ return embeddings;
+ }
+
+ virtual Expr addSpecialEmbeddings(Expr input, int start = 0, Ptr<data::CorpusBatch> /*batch*/ = nullptr) const {
+ bool trainPosEmbeddings = opt<bool>("transformer-train-positions", false);
+ return addPositionalEmbeddings(input, start, trainPosEmbeddings);
}
Expr triangleMask(int length) const {
@@ -75,13 +103,14 @@ public:
for(int i = 0; i < length; ++i)
for(int j = 0; j <= i; ++j)
vMask[i * length + j] = 1.f;
- return graph_->constant({1, length, length}, inits::from_vector(vMask));
+ return graph_->constant({1, length, length}, inits::fromVector(vMask));
}
// convert multiplicative 1/0 mask to additive 0/-inf log mask, and transpose to match result of bdot() op in Attention()
static Expr transposedLogMask(Expr mask) { // mask: [-4: beam depth=1, -3: batch size, -2: vector dim=1, -1: max length]
auto ms = mask->shape();
- mask = (1 - mask) * -99999999.f;
+ float maskFactor = std::max(NumericLimits<float>(mask->value_type()).lowest / 2.f, -99999999.f); // to make sure we do not overflow for fp16
+ mask = (1 - mask) * maskFactor;
return reshape(mask, {ms[-3], 1, ms[-2], ms[-1]}); // [-4: batch size, -3: num heads broadcast=1, -2: max length broadcast=1, -1: max length]
}
@@ -113,30 +142,6 @@ public:
return reshape(output, {dimBeam, dimBatch, dimSteps, dimModel});
}
- // like affine() but with built-in parameters, activation, and dropout
- static inline
- Expr dense(Expr x, std::string prefix, std::string suffix, int outDim, const std::function<Expr(Expr)>& actFn = nullptr, float dropProb = 0.0f)
- {
- auto graph = x->graph();
-
- auto W = graph->param(prefix + "_W" + suffix, { x->shape()[-1], outDim }, inits::glorot_uniform);
- auto b = graph->param(prefix + "_b" + suffix, { 1, outDim }, inits::zeros);
-
- x = affine(x, W, b);
- if (actFn)
- x = actFn(x);
- if (dropProb)
- x = dropout(x, dropProb);
- return x;
- }
-
- Expr layerNorm(Expr x, std::string prefix, std::string suffix = std::string()) const {
- int dimModel = x->shape()[-1];
- auto scale = graph_->param(prefix + "_ln_scale" + suffix, { 1, dimModel }, inits::ones);
- auto bias = graph_->param(prefix + "_ln_bias" + suffix, { 1, dimModel }, inits::zeros);
- return marian::layerNorm(x, scale, bias, 1e-6f);
- }
-
Expr preProcess(std::string prefix, std::string ops, Expr input, float dropProb = 0.0f) const {
auto output = input;
for(auto op : ops) {
@@ -164,7 +169,7 @@ public:
// highway connection
else if(op == 'h') {
int dimModel = input->shape()[-1];
- auto t = dense(prevInput, prefix, /*suffix=*/"h", dimModel);
+ auto t = denseInline(prevInput, prefix, /*suffix=*/"h", dimModel);
output = highway(output, prevInput, t);
}
// layer normalization
@@ -178,23 +183,23 @@ public:
void collectOneHead(Expr weights, int dimBeam) {
// select first head, this is arbitrary as the choice does not really matter
- auto head0 = select(weights, std::vector<IndexType>({0}), -3); // @TODO: implement an index() or slice() operator and use that
+ auto head0 = slice(weights, -3, 0);
int dimBatchBeam = head0->shape()[-4];
- int srcWords = head0->shape()[-1];
- int trgWords = head0->shape()[-2];
+ int srcWords = head0->shape()[-1]; // (max) length of src sequence
+ int trgWords = head0->shape()[-2]; // (max) length of trg sequence, or 1 in decoding
int dimBatch = dimBatchBeam / dimBeam;
// reshape and transpose to match the format guided_alignment expects
head0 = reshape(head0, {dimBeam, dimBatch, trgWords, srcWords});
- head0 = transpose(head0, {0, 3, 1, 2}); // [-4: beam depth, -3: max src length, -2: batch size, -1: max tgt length]
+ head0 = transpose(head0, {0, 3, 1, 2}); // [beam depth, max src length, batch size, max tgt length]
// save only last alignment set. For training this will be all alignments,
// for translation only the last one. Also split alignments by target words.
// @TODO: make splitting obsolete
alignments_.clear();
- for(int i = 0; i < trgWords; ++i) {
- alignments_.push_back(select(head0, std::vector<IndexType>({(IndexType)i}), -1)); // [tgt index][-4: beam depth, -3: max src length, -2: batch size, -1: 1]
+ for(int i = 0; i < trgWords; ++i) { // loop over all trg positions. In decoding, there is only one.
+ alignments_.push_back(slice(head0, -1, i)); // [tgt index][beam depth, max src length, batch size, 1] P(src pos|trg pos, beam index, batch index)
}
}
@@ -221,17 +226,16 @@ public:
// take softmax along src sequence axis (-1)
auto weights = softmax(z); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: max src length]
-
+
if(saveAttentionWeights)
collectOneHead(weights, dimBeam);
// optional dropout for attention weights
- float dropProb
- = inference_ ? 0 : opt<float>("transformer-dropout-attention");
- weights = dropout(weights, dropProb);
+ weights = dropout(weights, inference_ ? 0 : opt<float>("transformer-dropout-attention"));
// apply attention weights to values
auto output = bdot(weights, v); // [-4: beam depth * batch size, -3: num heads, -2: max tgt length, -1: split vector dim]
+
return output;
}
@@ -246,8 +250,8 @@ public:
bool saveAttentionWeights = false) {
int dimModel = q->shape()[-1];
// @TODO: good opportunity to implement auto-batching here or do something manually?
- auto Wq = graph_->param(prefix + "_Wq", {dimModel, dimModel}, inits::glorot_uniform);
- auto bq = graph_->param(prefix + "_bq", { 1, dimModel}, inits::zeros);
+ auto Wq = graph_->param(prefix + "_Wq", {dimModel, dimModel}, inits::glorotUniform());
+ auto bq = graph_->param(prefix + "_bq", { 1, dimModel}, inits::zeros());
auto qh = affine(q, Wq, bq);
qh = SplitHeads(qh, dimHeads); // [-4: beam depth * batch size, -3: num heads, -2: max length, -1: split vector dim]
@@ -255,28 +259,32 @@ public:
// Caching transformation of the encoder that should not be created again.
// @TODO: set this automatically by memoizing encoder context and
// memoization propagation (short-term)
- if (!cache || (cache && cache_.count(prefix + "_keys") == 0)) {
- auto Wk = graph_->param(prefix + "_Wk", {dimModel, dimModel}, inits::glorot_uniform);
- auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros);
+ if (cache // if caching
+ && cache_.count(prefix + "_keys") > 0 // and the keys expression has been seen
+ && cache_[prefix + "_keys"]->shape().elements() == keys->shape().elements()) { // and the underlying element size did not change
+ kh = cache_[prefix + "_keys"]; // then return cached tensor
+ }
+ else {
+ auto Wk = graph_->param(prefix + "_Wk", {dimModel, dimModel}, inits::glorotUniform());
+ auto bk = graph_->param(prefix + "_bk", {1, dimModel}, inits::zeros());
- kh = affine(keys,Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
+ kh = affine(keys, Wk, bk); // [-4: beam depth, -3: batch size, -2: max length, -1: vector dim]
kh = SplitHeads(kh, dimHeads); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
cache_[prefix + "_keys"] = kh;
}
- else {
- kh = cache_[prefix + "_keys"];
- }
Expr vh;
- if (!cache || (cache && cache_.count(prefix + "_values") == 0)) {
- auto Wv = graph_->param(prefix + "_Wv", {dimModel, dimModel}, inits::glorot_uniform);
- auto bv = graph_->param(prefix + "_bv", {1, dimModel}, inits::zeros);
+ if (cache
+ && cache_.count(prefix + "_values") > 0
+ && cache_[prefix + "_values"]->shape().elements() == values->shape().elements()) {
+ vh = cache_[prefix + "_values"];
+ } else {
+ auto Wv = graph_->param(prefix + "_Wv", {dimModel, dimModel}, inits::glorotUniform());
+ auto bv = graph_->param(prefix + "_bv", {1, dimModel}, inits::zeros());
vh = affine(values, Wv, bv); // [-4: batch size, -3: num heads, -2: max length, -1: split vector dim]
vh = SplitHeads(vh, dimHeads);
cache_[prefix + "_values"] = vh;
- } else {
- vh = cache_[prefix + "_values"];
}
int dimBeam = q->shape()[-4];
@@ -292,8 +300,8 @@ public:
bool project = !opt<bool>("transformer-no-projection");
if(project || dimAtt != dimOut) {
auto Wo
- = graph_->param(prefix + "_Wo", {dimAtt, dimOut}, inits::glorot_uniform);
- auto bo = graph_->param(prefix + "_bo", {1, dimOut}, inits::zeros);
+ = graph_->param(prefix + "_Wo", {dimAtt, dimOut}, inits::glorotUniform());
+ auto bo = graph_->param(prefix + "_bo", {1, dimOut}, inits::zeros());
output = affine(output, Wo, bo);
}
@@ -317,7 +325,7 @@ public:
// multi-head self-attention over previous input
output = MultiHead(prefix, dimModel, heads, output, keys, values, mask, cache, saveAttentionWeights);
-
+
auto opsPost = opt<std::string>("transformer-postprocess");
output = postProcess(prefix + "_Wo", opsPost, output, input, dropProb);
@@ -349,6 +357,8 @@ public:
return (ActivationFunction*)relu;
else if (actName == "swish")
return (ActivationFunction*)swish;
+ else if (actName == "gelu")
+ return (ActivationFunction*)gelu;
ABORT("Invalid activation name '{}'", actName);
}
@@ -369,8 +379,8 @@ public:
// the stack of FF layers
for(int i = 1; i < depthFfn; ++i)
- output = dense(output, prefix, /*suffix=*/std::to_string(i), dimFfn, actFn, ffnDropProb);
- output = dense(output, prefix, /*suffix=*/std::to_string(depthFfn), dimModel);
+ output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, actFn, ffnDropProb);
+ output = denseInline(output, prefix, /*suffix=*/std::to_string(depthFfn), dimModel);
auto opsPost = opt<std::string>("transformer-postprocess");
output
@@ -397,14 +407,14 @@ public:
// the stack of AAN layers
for(int i = 1; i < depthAan; ++i)
- y = dense(y, prefix, /*suffix=*/std::to_string(i), dimAan, actFn, aanDropProb);
+ y = denseInline(y, prefix, /*suffix=*/std::to_string(i), dimAan, actFn, aanDropProb);
if(y->shape()[-1] != dimModel) // bring it back to the desired dimension if needed
- y = dense(y, prefix, std::to_string(depthAan), dimModel);
+ y = denseInline(y, prefix, std::to_string(depthAan), dimModel);
bool noGate = opt<bool>("transformer-aan-nogate");
if(!noGate) {
- auto gi = dense(x, prefix, /*suffix=*/"i", dimModel, (ActivationFunction*)sigmoid);
- auto gf = dense(y, prefix, /*suffix=*/"f", dimModel, (ActivationFunction*)sigmoid);
+ auto gi = denseInline(x, prefix, /*suffix=*/"i", dimModel, (ActivationFunction*)sigmoid);
+ auto gf = denseInline(y, prefix, /*suffix=*/"f", dimModel, (ActivationFunction*)sigmoid);
y = gi * x + gf * y;
}
@@ -440,7 +450,8 @@ public:
return LayerAAN(prefix, input, output);
}
- Expr DecoderLayerRNN(rnn::State& decoderState,
+ Expr DecoderLayerRNN(std::unordered_map<std::string, Ptr<rnn::RNN>>& perLayerRnn, // @TODO: rewrite this whole organically grown mess
+ rnn::State& decoderState,
const rnn::State& prevDecoderState,
std::string prefix,
Expr input,
@@ -448,15 +459,18 @@ public:
int /*startPos*/) const {
float dropoutRnn = inference_ ? 0.f : opt<float>("dropout-rnn");
- auto rnn = rnn::rnn(graph_) //
- ("type", opt<std::string>("dec-cell")) //
- ("prefix", prefix) //
- ("dimInput", opt<int>("dim-emb")) //
- ("dimState", opt<int>("dim-emb")) //
- ("dropout", dropoutRnn) //
- ("layer-normalization", opt<bool>("layer-normalization")) //
- .push_back(rnn::cell(graph_)) //
- .construct();
+ if(!perLayerRnn[prefix]) // lazily create and cache RNNs in the decoder to avoid costly recreation @TODO: turn this into class members
+ perLayerRnn[prefix] = rnn::rnn(
+ "type", opt<std::string>("dec-cell"),
+ "prefix", prefix,
+ "dimInput", opt<int>("dim-emb"),
+ "dimState", opt<int>("dim-emb"),
+ "dropout", dropoutRnn,
+ "layer-normalization", opt<bool>("layer-normalization"))
+ .push_back(rnn::cell())
+ .construct(graph_);
+
+ auto rnn = perLayerRnn[prefix];
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
auto opsPre = opt<std::string>("transformer-preprocess");
@@ -475,102 +489,52 @@ public:
};
class EncoderTransformer : public Transformer<EncoderBase> {
+ typedef Transformer<EncoderBase> Base;
+ using Base::Base;
public:
- EncoderTransformer(Ptr<Options> options) : Transformer(options) {}
-
- // returns the embedding matrix based on options
- // and based on batchIndex_.
-
- std::vector<Expr> ULREmbeddings() const {
- // standard encoder word embeddings
- int dimSrcVoc = opt<std::vector<int>>("dim-vocabs")[0]; //ULR multi-lingual src
- int dimTgtVoc = opt<std::vector<int>>("dim-vocabs")[1]; //ULR monon tgt
- int dimEmb = opt<int>("dim-emb");
- int dimUlrEmb = opt<int>("ulr-dim-emb");
- auto embFactory = ulr_embedding(graph_)("dimSrcVoc", dimSrcVoc)("dimTgtVoc", dimTgtVoc)
- ("dimUlrEmb", dimUlrEmb)("dimEmb", dimEmb)
- ("ulrTrainTransform", opt<bool>("ulr-trainable-transformation"))
- ("ulrQueryFile", opt<std::string>("ulr-query-vectors"))
- ("ulrKeysFile", opt<std::string>("ulr-keys-vectors"));
- return embFactory.construct();
- }
-
- Expr wordEmbeddings(size_t subBatchIndex) const {
- // standard encoder word embeddings
- int dimVoc = opt<std::vector<int>>("dim-vocabs")[subBatchIndex];
- int dimEmb = opt<int>("dim-emb");
- auto embFactory = embedding(graph_)("dimVocab", dimVoc)("dimEmb", dimEmb);
- if (opt<bool>("tied-embeddings-src") || opt<bool>("tied-embeddings-all"))
- embFactory("prefix", "Wemb");
- else
- embFactory("prefix", prefix_ + "_Wemb");
- if (options_->has("embedding-fix-src"))
- embFactory("fixed", opt<bool>("embedding-fix-src"));
- if (options_->has("embedding-vectors")) {
- auto embFiles = opt<std::vector<std::string>>("embedding-vectors");
- embFactory("embFile", embFiles[subBatchIndex])
- ("normalization", opt<bool>("embedding-normalization"));
- }
- return embFactory.construct();
- }
-
- Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
- Ptr<data::CorpusBatch> batch) override {
+ virtual Ptr<EncoderState> build(Ptr<ExpressionGraph> graph,
+ Ptr<data::CorpusBatch> batch) override {
graph_ = graph;
return apply(batch);
}
Ptr<EncoderState> apply(Ptr<data::CorpusBatch> batch) {
- int dimEmb = opt<int>("dim-emb");
int dimBatch = (int)batch->size();
int dimSrcWords = (int)(*batch)[batchIndex_]->batchWidth();
// create the embedding matrix, considering tying and some other options
// embed the source words in the batch
Expr batchEmbeddings, batchMask;
- if (options_->has("ulr") && options_->get<bool>("ulr") == true) {
- auto embeddings = ULREmbeddings(); // embedding uses ULR
- std::tie(batchEmbeddings, batchMask)
- = EncoderBase::ulrLookup(graph_, embeddings, batch);
- }
- else
- {
- auto embeddings = wordEmbeddings(batchIndex_);
- std::tie(batchEmbeddings, batchMask)
- = EncoderBase::lookup(graph_, embeddings, batch);
- }
- // apply dropout over source words
- float dropoutSrc = inference_ ? 0 : opt<float>("dropout-src");
- if(dropoutSrc) {
- int srcWords = batchEmbeddings->shape()[-3];
- batchEmbeddings = dropout(batchEmbeddings, dropoutSrc, {srcWords, 1, 1});
- }
- // according to paper embeddings are scaled up by \sqrt(d_m)
- auto scaledEmbeddings = std::sqrt((float)dimEmb) * batchEmbeddings;
- scaledEmbeddings = addPositionalEmbeddings(scaledEmbeddings);
+
+ auto embeddingLayer = getEmbeddingLayer(opt<bool>("ulr", false));
+ std::tie(batchEmbeddings, batchMask) = embeddingLayer->apply((*batch)[batchIndex_]);
+ batchEmbeddings = addSpecialEmbeddings(batchEmbeddings, /*start=*/0, batch);
+
// reorganize batch and timestep
- scaledEmbeddings = atleast_nd(scaledEmbeddings, 4);
- batchMask = atleast_nd(batchMask, 4);
- auto layer = transposeTimeBatch(scaledEmbeddings); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
- auto layerMask
- = reshape(transposeTimeBatch(batchMask), {1, dimBatch, 1, dimSrcWords}); // [-4: beam depth=1, -3: batch size, -2: vector dim=1, -1: max length]
+ batchEmbeddings = atleast_nd(batchEmbeddings, 4); // [beam depth=1, max length, batch size, vector dim]
+ batchMask = atleast_nd(batchMask, 4); // [beam depth=1, max length, batch size, vector dim=1]
- auto opsEmb = opt<std::string>("transformer-postprocess-emb");
+ auto layer = transposeTimeBatch(batchEmbeddings); // [beam depth=1, batch size, max length, vector dim]
+ auto layerMask = transposeTimeBatch(batchMask); // [beam depth=1, batch size, max length, vector dim=1]
+ auto opsEmb = opt<std::string>("transformer-postprocess-emb");
float dropProb = inference_ ? 0 : opt<float>("transformer-dropout");
layer = preProcess(prefix_ + "_emb", opsEmb, layer, dropProb);
- layerMask = transposedLogMask(layerMask); // [-4: batch size, -3: 1, -2: vector dim=1, -1: max length]
+ // LayerAttention expects mask in a different layout
+ layerMask = reshape(layerMask, {1, dimBatch, 1, dimSrcWords}); // [1, batch size, 1, max length]
+ layerMask = transposedLogMask(layerMask); // [batch size, num heads broadcast=1, max length broadcast=1, max length]
// apply encoder layers
+ // This is the Transformer Encoder stack.
auto encDepth = opt<int>("enc-depth");
for(int i = 1; i <= encDepth; ++i) {
layer = LayerAttention(prefix_ + "_l" + std::to_string(i) + "_self",
layer, // query
layer, // keys
layer, // values
- layerMask);
-
+ layerMask); // [batch size, num heads broadcast=1, max length broadcast=1, max length]
layer = LayerFFN(prefix_ + "_l" + std::to_string(i) + "_ffn", layer);
+ checkpoint(layer); // sets a manually specified checkpoint if gradient checkpointing is enabled, does nothing otherwise.
}
// restore organization of batch and time steps. This is currently required
@@ -581,21 +545,30 @@ public:
return New<EncoderState>(context, batchMask, batch);
}
- void clear() override {}
+ virtual void clear() override {}
};
class TransformerState : public DecoderState {
public:
TransformerState(const rnn::States& states,
- Expr logProbs,
+ Logits logProbs,
const std::vector<Ptr<EncoderState>>& encStates,
Ptr<data::CorpusBatch> batch)
: DecoderState(states, logProbs, encStates, batch) {}
- virtual Ptr<DecoderState> select(const std::vector<IndexType>& selIdx,
+ virtual Ptr<DecoderState> select(const std::vector<IndexType>& hypIndices, // [beamIndex * activeBatchSize + batchIndex]
+ const std::vector<IndexType>& batchIndices, // [batchIndex]
int beamSize) const override {
+
+ // @TODO: code duplication with DecoderState only because of isBatchMajor=true, should rather be a contructor argument of DecoderState?
+
+ std::vector<Ptr<EncoderState>> newEncStates;
+ for(auto& es : encStates_)
+ // If the size of the batch dimension of the encoder state context changed, subselect the correct batch entries
+ newEncStates.push_back(es->getContext()->shape()[-2] == batchIndices.size() ? es : es->select(batchIndices));
+
// Create hypothesis-selected state based on current state and hyp indices
- auto selectedState = New<TransformerState>(states_.select(selIdx, beamSize, /*isBatchMajor=*/true), logProbs_, encStates_, batch_);
+ auto selectedState = New<TransformerState>(states_.select(hypIndices, beamSize, /*isBatchMajor=*/true), logProbs_, newEncStates, batch_);
// Set the same target token position as the current state
// @TODO: This is the same as in base function.
@@ -605,42 +578,37 @@ public:
};
class DecoderTransformer : public Transformer<DecoderBase> {
+ typedef Transformer<DecoderBase> Base;
+ using Base::Base;
private:
- Ptr<mlp::MLP> output_;
+ Ptr<mlp::Output> output_;
+
+ // This caches RNN objects to avoid reconstruction between batches or deocoding steps.
+ // To be removed after refactoring of transformer.h
+ std::unordered_map<std::string, Ptr<rnn::RNN>> perLayerRnn_;
private:
- void LazyCreateOutputLayer()
+ // @TODO: move this out for sharing with other models
+ void lazyCreateOutputLayer()
{
if(output_) // create it lazily
return;
int dimTrgVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];
- auto layerOut = mlp::output(graph_) //
- ("prefix", prefix_ + "_ff_logit_out") //
- ("dim", dimTrgVoc);
+ auto outputFactory = mlp::OutputFactory(
+ "prefix", prefix_ + "_ff_logit_out",
+ "dim", dimTrgVoc,
+ "vocab", opt<std::vector<std::string>>("vocabs")[batchIndex_], // for factored outputs
+ "lemma-dim-emb", opt<int>("lemma-dim-emb", 0)); // for factored outputs
- if(opt<bool>("tied-embeddings") || opt<bool>("tied-embeddings-all")) {
- std::string tiedPrefix = prefix_ + "_Wemb";
- if(opt<bool>("tied-embeddings-all") || opt<bool>("tied-embeddings-src"))
- tiedPrefix = "Wemb";
- layerOut.tie_transposed("W", tiedPrefix);
- }
+ if(opt<bool>("tied-embeddings") || opt<bool>("tied-embeddings-all"))
+ outputFactory.tieTransposed(opt<bool>("tied-embeddings-all") || opt<bool>("tied-embeddings-src") ? "Wemb" : prefix_ + "_Wemb");
- if(shortlist_)
- layerOut.set_shortlist(shortlist_);
-
- // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab dim]
- // assemble layers into MLP and apply to embeddings, decoder context and
- // aligned source context
- output_ = mlp::mlp(graph_) //
- .push_back(layerOut) //
- .construct();
+ output_ = std::dynamic_pointer_cast<mlp::Output>(outputFactory.construct(graph_)); // (construct() returns only the underlying interface)
}
public:
- DecoderTransformer(Ptr<Options> options) : Transformer(options) {}
-
virtual Ptr<DecoderState> startState(
Ptr<ExpressionGraph> graph,
Ptr<data::CorpusBatch> batch,
@@ -652,54 +620,41 @@ public:
int dimBatch = (int)batch->size();
int dim = opt<int>("dim-emb");
- auto start = graph->constant({1, 1, dimBatch, dim}, inits::zeros);
+ auto start = graph->constant({1, 1, dimBatch, dim}, inits::zeros());
rnn::States startStates(opt<size_t>("dec-depth"), {start, start});
// don't use TransformerState for RNN layers
- return New<DecoderState>(startStates, nullptr, encStates, batch);
+ return New<DecoderState>(startStates, Logits(), encStates, batch);
}
else {
rnn::States startStates;
- return New<TransformerState>(startStates, nullptr, encStates, batch);
+ return New<TransformerState>(startStates, Logits(), encStates, batch);
}
}
virtual Ptr<DecoderState> step(Ptr<ExpressionGraph> graph,
Ptr<DecoderState> state) override {
ABORT_IF(graph != graph_, "An inconsistent graph parameter was passed to step()");
- LazyCreateOutputLayer();
+ lazyCreateOutputLayer();
return step(state);
}
Ptr<DecoderState> step(Ptr<DecoderState> state) {
- auto embeddings = state->getTargetEmbeddings(); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vector dim]
- auto decoderMask = state->getTargetMask(); // [max length, batch size, 1] --this is a hypothesis
-
- // dropout target words
- float dropoutTrg = inference_ ? 0 : opt<float>("dropout-trg");
- if(dropoutTrg) {
- int trgWords = embeddings->shape()[-3];
- embeddings = dropout(embeddings, dropoutTrg, {trgWords, 1, 1});
- }
+ auto embeddings = state->getTargetHistoryEmbeddings(); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vector dim]
+ auto decoderMask = state->getTargetMask(); // [max length, batch size, 1] --this is a hypothesis
//************************************************************************//
- int dimEmb = embeddings->shape()[-1];
int dimBeam = 1;
if(embeddings->shape().size() > 3)
dimBeam = embeddings->shape()[-4];
- // according to paper embeddings are scaled by \sqrt(d_m)
- auto scaledEmbeddings = std::sqrt((float)dimEmb) * embeddings;
-
// set current target token position during decoding or training. At training
// this should be 0. During translation the current length of the translation.
// Used for position embeddings and creating new decoder states.
int startPos = (int)state->getPosition();
- scaledEmbeddings
- = addPositionalEmbeddings(scaledEmbeddings, startPos);
-
+ auto scaledEmbeddings = addSpecialEmbeddings(embeddings, startPos);
scaledEmbeddings = atleast_nd(scaledEmbeddings, 4);
// reorganize batch and timestep
@@ -722,25 +677,33 @@ public:
std::vector<Expr> encoderContexts;
std::vector<Expr> encoderMasks;
-
for(auto encoderState : state->getEncoderStates()) {
- auto encoderContext = encoderState->getContext();
- auto encoderMask = encoderState->getMask();
+ auto encoderContext = encoderState->getContext(); // encoder output
+ auto encoderMask = encoderState->getMask(); // note: may differ from Encoder self-attention mask in that additional positions are banned for cross-attention
+ encoderMask = atleast_nd(encoderMask, 4);
- encoderContext = transposeTimeBatch(encoderContext); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
+ encoderContext = transposeTimeBatch(encoderContext); // [beam depth=1, batch size, max length, vector dim]
+ encoderMask = transposeTimeBatch(encoderMask); // [beam depth=1, max length, batch size, vector dim=1]
int dimSrcWords = encoderContext->shape()[-2];
- //int dims = encoderMask->shape().size();
- encoderMask = atleast_nd(encoderMask, 4);
- encoderMask = reshape(transposeTimeBatch(encoderMask),
- {1, dimBatch, 1, dimSrcWords});
- encoderMask = transposedLogMask(encoderMask);
+ // This would happen if something goes wrong during batch pruning.
+ ABORT_IF(encoderContext->shape()[-3] != dimBatch,
+ "Context and query batch dimension do not match {} != {}",
+ encoderContext->shape()[-3],
+ dimBatch);
+
+ // LayerAttention expects mask in a different layout
+ encoderMask = reshape(encoderMask, { 1, dimBatch, 1, dimSrcWords }); // [1, batch size, 1, max length]
+ encoderMask = transposedLogMask(encoderMask); // [batch size, num heads broadcast=1, max length broadcast=1, max length]
if(dimBeam > 1)
encoderMask = repeat(encoderMask, dimBeam, /*axis=*/ -4);
encoderContexts.push_back(encoderContext);
encoderMasks.push_back(encoderMask);
+
+ checkpoint(encoderContext);
+ checkpoint(encoderMask);
}
rnn::States prevDecoderStates = state->getStates();
@@ -771,11 +734,13 @@ public:
else if(layerType == "average-attention")
query = DecoderLayerAAN(decoderState, prevDecoderState, prefix_ + "_l" + layerNo + "_aan", query, selfMask, startPos);
else if(layerType == "rnn")
- query = DecoderLayerRNN(decoderState, prevDecoderState, prefix_ + "_l" + layerNo + "_rnn", query, selfMask, startPos);
+ query = DecoderLayerRNN(perLayerRnn_, decoderState, prevDecoderState, prefix_ + "_l" + layerNo + "_rnn", query, selfMask, startPos);
else
ABORT("Unknown auto-regressive layer type in transformer decoder {}",
layerType);
+ checkpoint(query);
+
// source-target attention
// Iterate over multiple encoders and simply stack the attention blocks
if(encoderContexts.size() > 0) {
@@ -789,7 +754,7 @@ public:
// decoding or scoring return the attention weights of one head of the last layer.
// @TODO: maybe allow to return average or max over all heads?
bool saveAttentionWeights = false;
- if(j == 0 && (options_->get("guided-alignment", std::string("none")) != "none" || options_->has("alignment"))) {
+ if(j == 0 && (options_->get("guided-alignment", std::string("none")) != "none" || options_->hasAndNotEmpty("alignment"))) {
size_t attLayer = decDepth - 1;
std::string gaStr = options_->get<std::string>("transformer-guided-alignment-layer", "last");
if(gaStr != "last")
@@ -812,10 +777,14 @@ public:
}
}
+ checkpoint(query);
+
// remember decoder state
decoderStates.push_back(decoderState);
query = LayerFFN(prefix_ + "_l" + layerNo + "_ffn", query); // [-4: beam depth=1, -3: batch size, -2: max length, -1: vector dim]
+
+ checkpoint(query);
}
auto decoderContext = transposeTimeBatch(query); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vector dim]
@@ -823,16 +792,18 @@ public:
//************************************************************************//
// final feed-forward layer (output)
- Expr logits = output_->apply(decoderContext); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab dim]
-
+ if(shortlist_)
+ output_->setShortlist(shortlist_);
+ auto logits = output_->applyAsLogits(decoderContext); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab or shortlist dim]
+
// return unormalized(!) probabilities
Ptr<DecoderState> nextState;
if (opt<std::string>("transformer-decoder-autoreg", "self-attention") == "rnn") {
nextState = New<DecoderState>(
- decoderStates, logits, state->getEncoderStates(), state->getBatch());
+ decoderStates, logits, state->getEncoderStates(), state->getBatch());
} else {
nextState = New<TransformerState>(
- decoderStates, logits, state->getEncoderStates(), state->getBatch());
+ decoderStates, logits, state->getEncoderStates(), state->getBatch());
}
nextState->setPosition(state->getPosition() + 1);
return nextState;
@@ -841,27 +812,21 @@ public:
// helper function for guided alignment
// @TODO: const vector<> seems wrong. Either make it non-const or a const& (more efficient but dangerous)
virtual const std::vector<Expr> getAlignments(int /*i*/ = 0) override {
- return alignments_;
+ return alignments_; // [tgt index][beam depth, max src length, batch size, 1]
}
void clear() override {
- output_ = nullptr;
+ if (output_)
+ output_->clear();
cache_.clear();
alignments_.clear();
+ perLayerRnn_.clear(); // this needs to be cleared between batches.
+ // @TODO: figure out how to detect stale nodes i.e. nodes that are referenced,
+ // but where underlying memory has been deallocated by dropping all tensors
+ // from a TensorAllocator object. This can happen during ExpressionGraph::clear()
}
};
-// factory functions
-Ptr<EncoderBase> NewEncoderTransformer(Ptr<Options> options)
-{
- return New<EncoderTransformer>(options);
-}
-
-Ptr<DecoderBase> NewDecoderTransformer(Ptr<Options> options)
-{
- return New<DecoderTransformer>(options);
-}
-
// clang-format on
} // namespace marian
diff --git a/src/models/transformer_factory.h b/src/models/transformer_factory.h
index aa31e4d1..b282d819 100755
--- a/src/models/transformer_factory.h
+++ b/src/models/transformer_factory.h
@@ -1,14 +1,12 @@
+// @TODO: rename to transformer.h eventually. This is not a Factory as in factory.h.
#pragma once
#include "marian.h"
#include "models/decoder.h"
#include "models/encoder.h"
-//#include "models/states.h"
-//#include "layers/constructors.h"
-//#include "layers/factory.h"
namespace marian {
-Ptr<EncoderBase> NewEncoderTransformer(Ptr<Options> options);
-Ptr<DecoderBase> NewDecoderTransformer(Ptr<Options> options);
+Ptr<EncoderBase> NewEncoderTransformer(Ptr<ExpressionGraph> graph, Ptr<Options> options);
+Ptr<DecoderBase> NewDecoderTransformer(Ptr<ExpressionGraph> graph, Ptr<Options> options);
} // namespace marian
diff --git a/src/models/transformer_stub.cpp b/src/models/transformer_stub.cpp
index 420b7781..871ee009 100644..100755
--- a/src/models/transformer_stub.cpp
+++ b/src/models/transformer_stub.cpp
@@ -1,4 +1,14 @@
-// TODO: This is a wrapper around transformer.h. We kept the .H name to minimize confusing git, until this is code-reviewed.
-// This is meant to speed-up builds, and to support Ctrl-F7 to rebuild.
-
-#include "models/transformer.h"
+#include "models/transformer.h"
+
+namespace marian {
+// factory functions
+Ptr<EncoderBase> NewEncoderTransformer(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+{
+ return New<EncoderTransformer>(graph, options);
+}
+
+Ptr<DecoderBase> NewDecoderTransformer(Ptr<ExpressionGraph> graph, Ptr<Options> options)
+{
+ return New<DecoderTransformer>(graph, options);
+}
+} // namespace marian
diff --git a/src/optimizers/clippers.cpp b/src/optimizers/clippers.cpp
index 2507ccc4..7f3b8119 100644
--- a/src/optimizers/clippers.cpp
+++ b/src/optimizers/clippers.cpp
@@ -11,7 +11,7 @@ void Elementwise::clip(Tensor t) {
void Norm::clip(Tensor t) {
using namespace functional;
- float l2Norm = L2Norm(t);
+ float l2Norm = L2Norm(t, nullptr); // @TODO: this is a placeholder for a memory allocator, will be replaced with better version in a PR or two.
if(l2Norm >= c_)
Element(_1 = (c_ / l2Norm) * _1, t);
}
diff --git a/src/optimizers/clippers.h b/src/optimizers/clippers.h
index f953c2ec..67a50004 100644
--- a/src/optimizers/clippers.h
+++ b/src/optimizers/clippers.h
@@ -16,6 +16,7 @@ namespace marian {
class ClipperBase {
public:
virtual void clip(Tensor) = 0;
+ virtual ~ClipperBase() {}
};
typedef std::shared_ptr<ClipperBase> ClipperPtr;
diff --git a/src/optimizers/optimizers.cpp b/src/optimizers/optimizers.cpp
index d7a58be3..083f94c0 100755
--- a/src/optimizers/optimizers.cpp
+++ b/src/optimizers/optimizers.cpp
@@ -2,10 +2,12 @@
#include "common/io.h"
#include "tensors/tensor_operators.h"
+#include <array>
namespace marian {
-void Sgd::updateImpl(Tensor params, Tensor grads) {
+void Sgd::updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t refMBWords) {
+ actualMBSize, refMBWords; // (no correction for base update needed beyond using ce-sum)
using namespace functional;
Element(_1 -= eta_ * _2,
params,
@@ -14,9 +16,10 @@ void Sgd::updateImpl(Tensor params, Tensor grads) {
params->getBackend()->synchronize();
}
-// Aagrad
+// Adagrad
-void Adagrad::updateImpl(Tensor params, Tensor grads) {
+void Adagrad::updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t refMBWords) {
+ ABORT_IF(actualMBSize != refMBWords, "Adagrad does not support rational hyper-parameter adjustment");
if(!alloc_)
alloc_ = New<TensorAllocator>(params->getBackend());
@@ -52,7 +55,6 @@ void Adagrad::load(const std::string& name,
std::vector<float> vGt;
- // @TODO: use new IO
auto items = io::loadItems(name);
for(auto item : items) {
// get the size of gt_
@@ -61,8 +63,7 @@ void Adagrad::load(const std::string& name,
// extract data into vectors
if(item.name == "adagrad_gt") {
vGt.resize(totalSize);
- std::copy(
- (float*)item.data(), (float*)item.data() + totalSize, vGt.begin());
+ std::copy((float*)item.data(), ((float*)item.data()) + totalSize, vGt.begin());
}
}
if(vGt.empty()) {
@@ -76,7 +77,7 @@ void Adagrad::load(const std::string& name,
if(!opt->gt_) {
if(!opt->alloc_)
opt->alloc_ = New<TensorAllocator>(backends[localDeviceIndex]);
- auto size = end-begin;
+ auto size = end - begin;
opt->alloc_->reserveExact(sizeof(float) * size);
opt->alloc_->allocate(opt->gt_, {1, (int)size});
}
@@ -108,8 +109,7 @@ void Adagrad::save(const std::string& name,
item.shape = Shape({1, (int)vGt.size()});
item.type = Type::float32;
item.bytes.resize(vGt.size() * sizeOf(item.type));
- std::copy(
- (char*)vGt.data(), (char*)vGt.data() + vGt.size(), item.bytes.begin());
+ std::copy((char*)vGt.data(), (char*)(vGt.data() + vGt.size()), item.bytes.begin());
io::saveItems(name, {item});
}
@@ -121,7 +121,8 @@ void Adagrad::resetStats() {
// Adam
-void Adam::updateImpl(Tensor params, Tensor grads) {
+void Adam::updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t refMBWords) {
+ // lazy allocation
if(!alloc_)
alloc_ = New<TensorAllocator>(params->getBackend());
@@ -130,29 +131,40 @@ void Adam::updateImpl(Tensor params, Tensor grads) {
alloc_->reserveExact(2 * params->memory()->size());
alloc_->allocate(mt_, {1, elements});
mt_->set(0.f);
-
alloc_->allocate(vt_, {1, elements});
vt_->set(0.f);
}
- t_++;
- float denom1 = 1 - (float)std::pow(beta1_, t_);
- float denom2 = 1 - (float)std::pow(beta2_, t_);
+ double T = (double)actualMBSize;
+ double Tref = (double)refMBWords;
- using namespace functional;
-
- Element(_1 = (beta1_ * _1) + ((1 - beta1_) * _2), mt_, grads);
- Element(_1 = (beta2_ * _1) + ((1 - beta2_) * (_2 * _2)), vt_, grads);
+ // adjust for minibatch-size changes if Adam parameters are given a reference size (else do nothing)
+ double eta = eta_ * (T/Tref);
+ double beta1 = beta1_;
+ double beta2 = beta2_;
+ double decay = w_ ;
- Element(_1 -= eta_ // learning-rate: x_t = x_{t-1} - \eta * (...)
- * ((_2 / denom1) // 1st moment: m_{t-1}
- / (sqrt(_3 / denom2) + eps_) // 2nd moment: \sqrt(v_{t-1})
- + w_ * _1), // weight-decay: w * x_{t-1}
- params,
- mt_,
- vt_);
+ // denominators. At steady state: =1. This recursion does the same as the Adam beta correction term.
+ denom1_ = (beta1 * denom1_) + (1 - beta1); // momentum smoothing
+ denom2_ = (beta2 * denom2_) + (1 - beta2); // RMS normalization
- params->getBackend()->synchronize();
+ // numerators. Divide by T to convert ce-sum gradient to avg gradient.
+ using namespace functional;
+ Element(_1 = ((float)beta1 * _1) + float((1 - beta1) / T ) * _2, mt_, grads); // momentum smoothing. At steady state: =smoothed avg gradient
+ Element(_1 = ((float)beta2 * _1) + float((1 - beta2) / T / T) * (_2 * _2), vt_, grads); // RMS normalization. At steady state: =mean square of the avg gradients
+
+ // apply Adam normalization
+ float etaf = (float)eta, denom1f = (float)denom1_, denom2f = (float)denom2_, decayf = (float)decay; // (get casts out of Element expression for readability)
+ Element(_1 -= etaf // learning-rate: x_t = x_{t-1} - \eta * (...)
+ * (( ( _2 / denom1f) // momentum-smoothed per-sample gradient: m_{t-1}
+ / (sqrt(_3 / denom2f) + eps_)) // normalize by RMS: \sqrt(v_{t-1})
+ + decayf * _1), // weight-decay: w * x_{t-1}
+ params, // =_1
+ mt_, // =_2
+ vt_ // =_3
+ );
+
+ params->getBackend()->synchronize(); // @TODO: This should not be in here. Maybe in the wrapper. Why is it needed at all?
}
void Adam::load(const std::string& name,
@@ -168,6 +180,7 @@ void Adam::load(const std::string& name,
std::vector<float> vMt;
std::vector<float> vVt;
+ std::array<double, 2> vDenoms;
auto items = io::loadItems(name);
for(auto item : items) {
@@ -178,12 +191,18 @@ void Adam::load(const std::string& name,
if(item.name == "adam_mt") {
vMt.resize(totalSize);
std::copy(
- (float*)item.data(), (float*)item.data() + totalSize, vMt.begin());
+ (float*)item.data(), ((float*)item.data()) + totalSize, vMt.begin());
}
- if(item.name == "adam_vt") {
+ else if(item.name == "adam_vt") {
vVt.resize(totalSize);
std::copy(
- (float*)item.data(), (float*)item.data() + totalSize, vVt.begin());
+ (float*)item.data(), ((float*)item.data()) + totalSize, vVt.begin());
+ }
+ else if(item.name == "adam_denoms") {
+ ABORT_IF(totalSize != 2, "adam_denoms should have 2 entries");
+ std::copy(
+ (double*)item.data(), ((double*)item.data()) + totalSize, vDenoms.begin());
+ // Back compat note: Old files lacked "adam_denoms". For those, vDenoms will remain 0, which reproduces the old behavior.
}
}
if(vMt.empty() || vVt.empty()) {
@@ -212,6 +231,9 @@ void Adam::load(const std::string& name,
auto opt = std::dynamic_pointer_cast<Adam>(opts[id]);
opt->vt_->set(std::vector<float>(begin, end));
});
+
+ denom1_ = vDenoms[0];
+ denom2_ = vDenoms[1];
//LOG(info, "done loading Adam params");
}
@@ -248,7 +270,7 @@ void Adam::save(const std::string& name,
itemMt.type = Type::float32;
itemMt.bytes.resize(vMt.size() * sizeOf(itemMt.type));
std::copy(
- (char*)vMt.data(), (char*)vMt.data() + vMt.size(), itemMt.bytes.begin());
+ (char*)vMt.data(), (char*)(vMt.data() + vMt.size()), itemMt.bytes.begin());
io::Item itemVt;
itemVt.name = "adam_vt";
@@ -256,9 +278,19 @@ void Adam::save(const std::string& name,
itemVt.type = Type::float32;
itemVt.bytes.resize(vVt.size() * sizeOf(itemVt.type));
std::copy(
- (char*)vVt.data(), (char*)vVt.data() + vVt.size(), itemVt.bytes.begin());
+ (char*)vVt.data(), (char*)(vVt.data() + vVt.size()), itemVt.bytes.begin());
+
+ // @TODO: this pattern is duplicated several times; refactor it
+ std::array<double, 2> vDenoms{denom1_, denom2_};
+ io::Item itemDenoms;
+ itemDenoms.name = "adam_denoms";
+ itemDenoms.shape = Shape({1, (int)vDenoms.size()});
+ itemDenoms.type = Type::float64;
+ itemDenoms.bytes.resize(vDenoms.size() * sizeOf(itemDenoms.type));
+ std::copy(
+ (char*)vDenoms.data(), (char*)(vDenoms.data() + vDenoms.size()), itemDenoms.bytes.begin());
- io::saveItems(name, {itemMt, itemVt});
+ io::saveItems(name, {itemMt, itemVt, itemDenoms});
}
void Adam::resetStats() {
@@ -267,29 +299,32 @@ void Adam::resetStats() {
if(vt_)
vt_->set(0.f);
+
+ denom1_ = 0; // @BUGBUG: or 1 or refMBWords if so specified. Fix once we have proper parameterization for that.
+ denom2_ = 0;
}
Ptr<OptimizerBase> Optimizer(Ptr<Options> options) {
- float lrate = (float)options->get<double>("learn-rate"); // @TODO: should this be <float>?
- auto params = options->has("optimizer-params")
- ? options->get<std::vector<float>>("optimizer-params")
- : std::vector<float>({});
+ float lrate = options->get<float>("learn-rate");
+ auto params = options->get<std::vector<float>>("optimizer-params", std::vector<float>({}));
+ // adjust hyper-parameters as if our MB size (in target labels) was this value
+ size_t refMBWordsParam = options->get<size_t>("mini-batch-words-ref");
Ptr<ClipperBase> clipper = nullptr;
- float clipNorm = (float)options->get<double>("clip-norm"); // @TODO: should this be <float>?
+ float clipNorm = options->get<float>("clip-norm");
if(clipNorm > 0)
- clipper = Clipper<Norm>(clipNorm);
+ clipper = Clipper<Norm>(clipNorm); // @BUGBUG: this is not scaling by number of labels?
auto opt = options->get<std::string>("optimizer");
if(opt == "sgd") {
- return Optimizer<Sgd>(lrate, clipper, params);
+ return Optimizer<Sgd>(lrate, refMBWordsParam, clipper, params);
} else if(opt == "adagrad") {
- return Optimizer<Adagrad>(lrate, clipper, params);
+ return Optimizer<Adagrad>(lrate, refMBWordsParam, clipper, params);
} else if(opt == "adam") {
- return Optimizer<Adam>(lrate, clipper, params);
+ return Optimizer<Adam>(lrate, refMBWordsParam, clipper, params);
} else {
- ABORT("Unknown optimizer: {}", opt);
+ ABORT("Unknown optimizer kind: {}", opt);
}
}
} // namespace marian
diff --git a/src/optimizers/optimizers.h b/src/optimizers/optimizers.h
index 9b83afb6..d10af919 100755..100644
--- a/src/optimizers/optimizers.h
+++ b/src/optimizers/optimizers.h
@@ -18,22 +18,42 @@ namespace marian {
*/
class OptimizerBase : public TrainingObserver {
public:
- OptimizerBase(float eta, Ptr<ClipperBase> clipper = nullptr)
- : eta_(eta), clipper_(clipper) {}
+ OptimizerBase(float eta, size_t refMBWordsParam, Ptr<ClipperBase> clipper)
+ : eta_(eta), refMBWordsParam_(refMBWordsParam), clipper_(clipper) {
+
+ // automatic learning-rate adjustment
+ // If users provide, in addition to the hyper-parameters, a reference minibatch size,
+ // that these hyper-parameters were originally tuned for, then the learning-rate gets
+ // adjusted accordingly. Note: Requires user to also use ce-sum criterion.
+ if (refMBWordsParam_ != 0)
+ LOG(info, "[optimizers] Learning rate gets automatically adjusted as if minibatch size was {}", refMBWordsParam_);
+ }
+
+ virtual ~OptimizerBase() {}
- void update(Ptr<ExpressionGraph> graph) {
+ static constexpr size_t mbSizeNotProvided = SIZE_MAX;
+
+ void update(Ptr<ExpressionGraph> graph, size_t mbSize = mbSizeNotProvided) {
Tensor p = graph->params()->vals();
Tensor g = graph->params()->grads();
- update(p, g);
+ update(p, g, mbSize);
}
- void update(Tensor params, Tensor grads) {
+ void update(Tensor params, Tensor grads, size_t mbSize = mbSizeNotProvided) {
if(clipper_)
- clipper_->clip(grads);
-
- // In case we want to add a multiply factor to our learning rate
- updateImpl(params, grads);
+ clipper_->clip(grads); //@BUGBUG: take into account actual mini-batch size since gradients are not normalized
+
+ size_t refMBWords = refMBWordsParam_;
+ if (refMBWords == 0) { // optimizer not configured to use hyper-parameter auto-adjustment
+ refMBWords = mbSize = 1; // neutral settings that keep the standard behavior
+ }
+ else { // optimizer is configured to auto-adjust hyper-parameters
+ ABORT_IF(mbSize == mbSizeNotProvided, "Using rational optimizer auto-adjustment with trainer that does not provide MB size");
+ // note: this behavior is only meaningful if using the ce-sum criterion
+ }
+
+ updateImpl(params, grads, mbSize, refMBWords);
}
virtual void init(TrainingState& state) override {
@@ -58,7 +78,7 @@ public:
resetStats();
}
- void setParams(const std::vector<float>& params) { parseParams(params); }
+ virtual void setParams(const std::vector<float>& params) = 0;
typedef std::function<void(size_t /*localDeviceIndex*/,
std::vector<float>::const_iterator /*begin*/,
@@ -78,12 +98,13 @@ public:
bool /*isMainProcess*/ = true) {}
protected:
- virtual void updateImpl(Tensor params, Tensor grads) = 0;
- virtual void parseParams(const std::vector<float>& params) = 0;
+ virtual void updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t refMBWords) = 0;
virtual void resetStats() = 0;
// Learning rate
float eta_;
+ // Reference MB size. This enables automatic adjustment of optimizer hyper-parameters to MB size.
+ size_t refMBWordsParam_{0}; // 0 means no adjustment
// Clip gradient norm
Ptr<ClipperBase> clipper_;
};
@@ -93,13 +114,13 @@ protected:
*/
class Sgd : public OptimizerBase {
public:
- Sgd(float eta, Ptr<ClipperBase> clipper = nullptr)
- : OptimizerBase(eta, clipper) {}
-
+ Sgd(float eta, size_t refMBWordsParam = 0, Ptr<ClipperBase> clipper = nullptr)
+ : OptimizerBase(eta, refMBWordsParam, clipper) {}
+ virtual ~Sgd() {}
+ virtual void setParams(const std::vector<float>& /*params*/) override {}
private:
- void updateImpl(Tensor params, Tensor grads) override;
+ void updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t refMBWords) override;
- virtual void parseParams(const std::vector<float>& /*params*/) override {}
virtual void resetStats() override {}
};
@@ -110,8 +131,8 @@ private:
*/
class Adagrad : public OptimizerBase {
public:
- Adagrad(float eta, Ptr<ClipperBase> clipper = nullptr)
- : OptimizerBase(eta, clipper) {}
+ Adagrad(float eta, size_t refMBWordsParam = 0, Ptr<ClipperBase> clipper = nullptr)
+ : OptimizerBase(eta, refMBWordsParam, clipper) {}
void load(const std::string& name,
const std::vector<Ptr<OptimizerBase>>& opts,
@@ -122,15 +143,15 @@ public:
const GatherStateFunc& gatherFn,
bool /*isMainProcess*/ = true) override;
-private:
- void updateImpl(Tensor params, Tensor grads) override;
- void resetStats() override;
-
- void parseParams(const std::vector<float>& params) override {
+ void setParams(const std::vector<float>& params) override {
if(params.size() > 0)
eps_ = params[0];
}
+private:
+ void updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t refMBWords) override;
+ void resetStats() override;
+
float eps_ = 1e-8f;
Ptr<TensorAllocator> alloc_;
Tensor gt_;
@@ -140,11 +161,13 @@ private:
* @brief Adam optimizer
*
* https://arxiv.org/pdf/1412.6980v8.pdf
+ *
+ * with Frank's modifications for automatic hyper-parameter adjustment.
*/
class Adam : public OptimizerBase {
public:
- Adam(float eta, Ptr<ClipperBase> clipper = nullptr)
- : OptimizerBase(eta, clipper), t_(0) {}
+ Adam(float eta, size_t refMBWordsParam = 0, Ptr<ClipperBase> clipper = nullptr)
+ : OptimizerBase(eta, refMBWordsParam, clipper) {}
void load(const std::string& name,
const std::vector<Ptr<OptimizerBase>>& opts,
@@ -156,10 +179,12 @@ public:
bool isMainProcess = true) override;
private:
- void updateImpl(Tensor params, Tensor grads) override;
+ void updateImpl(Tensor params, Tensor grads, size_t actualMBSize, size_t refMBWords) override;
void resetStats() override;
- virtual void parseParams(const std::vector<float>& params) override {
+ // Adam parameters:
+ // [beta1, beta2, eps, w, refMBWords]
+ virtual void setParams(const std::vector<float>& params) override {
if(params.size() > 0)
beta1_ = params[0];
if(params.size() > 1)
@@ -169,25 +194,30 @@ private:
// weighted decay for AdamW, to be explored, disabled by default
if(params.size() > 3)
- w_ = params[3];
+ w_ = params[3]; // default (disabled): 0
}
+ // hyper-parameters
float beta1_ = 0.9f;
float beta2_ = 0.999f;
float eps_ = 1e-8f;
float w_ = 0.0f;
- size_t t_;
+ // CPU-side running accumulators
+ double denom1_ = 0;
+ double denom2_ = 0;
+
+ // GPU-side running accumulators
Ptr<TensorAllocator> alloc_;
Tensor mt_;
Tensor vt_;
};
template <class Algorithm>
-Ptr<OptimizerBase> Optimizer(float eta,
+Ptr<OptimizerBase> Optimizer(float eta, size_t refMBWordsParam = 0,
Ptr<ClipperBase> clipper = nullptr,
std::vector<float> params = {}) {
- auto opt = Ptr<OptimizerBase>(new Algorithm(eta, clipper));
+ auto opt = Ptr<OptimizerBase>(new Algorithm(eta, refMBWordsParam, clipper));
opt->setParams(params);
return opt;
}
diff --git a/src/rescorer/rescorer.h b/src/rescorer/rescorer.h
index fa456856..a201f3fb 100644..100755
--- a/src/rescorer/rescorer.h
+++ b/src/rescorer/rescorer.h
@@ -19,23 +19,23 @@ using namespace data;
class Rescorer {
private:
- Ptr<models::ModelBase> builder_;
+ Ptr<models::ICriterionFunction> builder_;
public:
Rescorer(Ptr<Options> options)
- : builder_(models::from_options(options, models::usage::scoring)) {}
+ : builder_(models::createCriterionFunctionFromOptions(options, models::usage::scoring)) {}
void load(Ptr<ExpressionGraph> graph, const std::string& modelFile) {
builder_->load(graph, modelFile);
}
- Expr build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
+ Ptr<RationalLoss> build(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
return builder_->build(graph, batch);
}
data::SoftAlignment getAlignment() {
- auto model = std::static_pointer_cast<models::Scorer>(builder_)->getModel();
- return std::static_pointer_cast<EncoderDecoderBase>(model)->getAlignment();
+ auto model = std::static_pointer_cast<models::Trainer>(builder_)->getModel();
+ return std::static_pointer_cast<IEncoderDecoder>(model)->getAlignment();
}
};
@@ -49,15 +49,15 @@ private:
public:
Rescore(Ptr<Options> options) : options_(options) {
- ABORT_IF(options_->has("summary") && options_->has("alignment"),
+ ABORT_IF(options_->hasAndNotEmpty("summary") && options_->hasAndNotEmpty("alignment"),
"Alignments can not be produced with summarized score");
- ABORT_IF(options_->has("summary") && options_->get<bool>("normalize"),
+ ABORT_IF(options_->hasAndNotEmpty("summary") && options_->get<bool>("normalize"),
"Normalization by length cannot be used with summary scores");
options_->set("inference", true);
- // @TODO: make normalize here a float and pass into loss to compute the same way as in decoding
- options_->set("cost-type", options_->get<bool>("normalize") ? "ce-rescore-mean" : "ce-rescore");
+ options_->set("shuffle", "none");
+ options_->set("cost-type", "ce-rescore"); // indicates that to keep separate per-batch-item scoresForSummary
if(options_->get<bool>("n-best"))
corpus_ = New<CorpusNBest>(options_);
@@ -68,8 +68,16 @@ public:
auto devices = Config::getDevices(options_);
for(auto device : devices) {
- auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
+ auto graph = New<ExpressionGraph>(true);
+
+ auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
+ graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
+ graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
+ if (device.type == DeviceType::cpu) {
+ graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
+ }
+
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
}
@@ -92,20 +100,20 @@ public:
LOG(info, "Scoring");
auto batchGenerator = New<BatchGenerator<CorpusBase>>(corpus_, options_);
- batchGenerator->prepare(false);
+ batchGenerator->prepare();
Ptr<ScoreCollector> output = options_->get<bool>("n-best")
? std::static_pointer_cast<ScoreCollector>(
New<ScoreCollectorNBest>(options_))
: New<ScoreCollector>(options_);
- std::string alignment = options_->get<std::string>("alignment", "");
- bool summarize = options_->has("summary");
+ auto alignment = options_->get<std::string>("alignment", "");
+ auto summary = options_->get<std::string>("summary", "");
+ bool summarize = !summary.empty();
+ // @TODO: make normalize here a float and pass into loss to compute the same way as in decoding
bool normalize = options_->get<bool>("normalize");
- std::string summary = summarize ? options_->get<std::string>("summary") : "cross-entropy";
-
- float sumCost = 0;
+ float sumLoss = 0;
size_t sumWords = 0;
size_t sumSamples = 0;
size_t batchId = 0;
@@ -115,7 +123,7 @@ public:
ThreadPool pool(graphs_.size(), graphs_.size());
for(auto batch : *batchGenerator) {
- auto task = [=, &sumCost, &sumWords, &sumSamples, &smutex](size_t id) {
+ auto task = [=, &sumLoss, &sumWords, &sumSamples, &smutex](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
thread_local Ptr<Model> builder;
@@ -126,29 +134,41 @@ public:
// @TODO: normalize by length as in normalize
// Once we have Frank's concept of ce-sum with sample size by words we will return a pair
- // here which will make it trivial to report all variants.
- auto costNode = builder->build(graph, batch);
+ // here which will make it trivial to report all variants.
+ auto dynamicLoss = builder->build(graph, batch);
graph->forward();
- std::vector<float> scores;
- costNode->val()->get(scores);
+ // get loss
+ std::vector<float> scoresForSummary;
+ dynamicLoss->loss(scoresForSummary);
+ std::vector<float> sentScores(scoresForSummary); // if '--normalize' then report scoresForSummary length-normalized
+ if (normalize) {
+ std::vector<float> sentLengths;
+ dynamicLoss->count(sentLengths);
+ for (size_t i = 0; i < scoresForSummary.size(); i++) {
+ if (sentScores[i] != 0) // (avoid 0/0)
+ sentScores[i] /= (sentLengths.size() == 1 ? sentLengths[0] : sentLengths[i]); // emulate broadcasting semantics
+ }
+ }
// soft alignments for each sentence in the batch
- std::vector<data::SoftAlignment> aligns(batch->size());
+ std::vector<data::SoftAlignment> aligns(batch->size()); // @TODO: do this resize inside getAlignmentsForBatch()
if(!alignment.empty()) {
getAlignmentsForBatch(builder->getAlignment(), batch, aligns);
}
std::unique_lock<std::mutex> lock(smutex);
- for(auto s : scores)
- sumCost += s;
+ for(auto s : scoresForSummary)
+ sumLoss += s;
sumWords += batch->back()->batchWords();
sumSamples += batch->size();
if(!summarize) {
for(size_t i = 0; i < batch->size(); ++i) {
- output->Write((long)batch->getSentenceIds()[i], scores[i], aligns[i]);
+ output->Write((long)batch->getSentenceIds()[i],
+ -1.f * sentScores[i], // report logProb while score is CE, hence negate
+ aligns[i]);
}
}
@@ -168,22 +188,22 @@ public:
}
if(normalize) {
- LOG(info, "Total normalized log probs {} : Total sentences {} : Total words {}", sumCost, sumSamples, sumWords);
+ LOG(info, "Total normalized log probs {} : Total sentences {} : Total words {}", sumLoss, sumSamples, sumWords);
LOG(warn, "Sum of normalized log probs is a sum of averages");
} else {
- LOG(info, "Total log probs {} : Total sentences {} : Total words {}", sumCost, sumSamples, sumWords);
+ LOG(info, "Total log probs {} : Total sentences {} : Total words {}", sumLoss, sumSamples, sumWords);
}
- if(summarize) {
+ if(summarize) { // @TODO: use one function from loss
float cost = 0;
if(summary == "perplexity")
- cost = std::exp(-(float)sumCost / (float)sumWords);
+ cost = std::exp(sumLoss / (float)sumWords);
else if(summary == "ce-sum")
- cost = -sumCost;
+ cost = sumLoss;
else if(summary == "ce-mean-words")
- cost = -(float)sumCost / (float)sumWords;
+ cost = sumLoss / (float)sumWords;
else
- cost = -sumCost / sumSamples;
+ cost = sumLoss / sumSamples;
LOG(info, "Reporting {} summary", summary);
std::cout << cost << std::endl;
diff --git a/src/rescorer/score_collector.cpp b/src/rescorer/score_collector.cpp
index ac118a6a..1577feba 100644
--- a/src/rescorer/score_collector.cpp
+++ b/src/rescorer/score_collector.cpp
@@ -13,7 +13,7 @@ ScoreCollector::ScoreCollector(const Ptr<Options>& options)
alignmentThreshold_(getAlignmentThreshold(alignment_)) {
if(options->get<std::string>("output") == "stdout")
- outStrm_.reset(new io::OutputFileStream(std::cout));
+ outStrm_.reset(new std::ostream(std::cout.rdbuf()));
else
outStrm_.reset(new io::OutputFileStream(options->get<std::string>("output")));
}
@@ -113,8 +113,7 @@ std::string ScoreCollectorNBest::addToNBest(const std::string nbest,
const std::string feature,
float score,
const data::SoftAlignment& align) {
- std::vector<std::string> fields;
- utils::split(nbest, fields, "|||");
+ auto fields = utils::split(nbest, "|||");
std::stringstream ss;
if(!alignment_.empty() && !align.empty())
ss << " " << getAlignment(align) << " |||";
diff --git a/src/rescorer/score_collector.h b/src/rescorer/score_collector.h
index e09681ca..38ba7a08 100644
--- a/src/rescorer/score_collector.h
+++ b/src/rescorer/score_collector.h
@@ -13,6 +13,7 @@ namespace marian {
class ScoreCollector {
public:
ScoreCollector(const Ptr<Options>& options);
+ virtual ~ScoreCollector() {}
virtual void Write(long id, const std::string& message);
virtual void Write(long id,
@@ -21,7 +22,7 @@ public:
protected:
long nextId_{0};
- UPtr<io::OutputFileStream> outStrm_;
+ UPtr<std::ostream> outStrm_;
std::mutex mutex_;
typedef std::map<long, std::string> Outputs;
diff --git a/src/rnn/attention.h b/src/rnn/attention.h
index cbb16a34..6b30cb55 100644..100755
--- a/src/rnn/attention.h
+++ b/src/rnn/attention.h
@@ -9,6 +9,8 @@ namespace rnn {
Expr attOps(Expr va, Expr context, Expr state);
+// Attitive attention used in RNN cells.
+// @TODO: come up with common framework for attention in RNNs and Transformer.
class GlobalAttention : public CellInput {
private:
Expr Wa_, ba_, Ua_, va_;
@@ -50,34 +52,33 @@ public:
Wa_ = graph->param(prefix + "_W_comb_att",
{dimDecState, dimEncState},
- inits::glorot_uniform);
+ inits::glorotUniform());
Ua_ = graph->param(
- prefix + "_Wc_att", {dimEncState, dimEncState}, inits::glorot_uniform);
+ prefix + "_Wc_att", {dimEncState, dimEncState}, inits::glorotUniform());
va_ = graph->param(
- prefix + "_U_att", {dimEncState, 1}, inits::glorot_uniform);
- ba_ = graph->param(prefix + "_b_att", {1, dimEncState}, inits::zeros);
+ prefix + "_U_att", {dimEncState, 1}, inits::glorotUniform());
+ ba_ = graph->param(prefix + "_b_att", {1, dimEncState}, inits::zeros());
if(dropout_ > 0.0f) {
- dropMaskContext_ = graph->dropout(dropout_, {1, dimEncState});
- dropMaskState_ = graph->dropout(dropout_, {1, dimDecState});
+ dropMaskContext_ = graph->dropoutMask(dropout_, {1, dimEncState});
+ dropMaskState_ = graph->dropoutMask(dropout_, {1, dimDecState});
}
- if(dropMaskContext_)
- contextDropped_ = dropout(contextDropped_, dropMaskContext_);
+ contextDropped_ = dropout(contextDropped_, dropMaskContext_);
if(layerNorm_) {
if(nematusNorm_) {
// instead of gammaContext_
Wc_att_lns_ = graph->param(
- prefix + "_Wc_att_lns", {1, dimEncState}, inits::from_value(1.f));
+ prefix + "_Wc_att_lns", {1, dimEncState}, inits::fromValue(1.f));
Wc_att_lnb_ = graph->param(
- prefix + "_Wc_att_lnb", {1, dimEncState}, inits::zeros);
+ prefix + "_Wc_att_lnb", {1, dimEncState}, inits::zeros());
// instead of gammaState_
W_comb_att_lns_ = graph->param(prefix + "_W_comb_att_lns",
{1, dimEncState},
- inits::from_value(1.f));
+ inits::fromValue(1.f));
W_comb_att_lnb_ = graph->param(
- prefix + "_W_comb_att_lnb", {1, dimEncState}, inits::zeros);
+ prefix + "_W_comb_att_lnb", {1, dimEncState}, inits::zeros());
mappedContext_ = layerNorm(affine(contextDropped_, Ua_, ba_),
Wc_att_lns_,
@@ -85,9 +86,9 @@ public:
NEMATUS_LN_EPS);
} else {
gammaContext_ = graph->param(
- prefix + "_att_gamma1", {1, dimEncState}, inits::from_value(1.0));
+ prefix + "_att_gamma1", {1, dimEncState}, inits::fromValue(1.0));
gammaState_ = graph->param(
- prefix + "_att_gamma2", {1, dimEncState}, inits::from_value(1.0));
+ prefix + "_att_gamma2", {1, dimEncState}, inits::fromValue(1.0));
mappedContext_
= layerNorm(dot(contextDropped_, Ua_), gammaContext_, ba_);
@@ -113,8 +114,7 @@ public:
if(recState->shape().size() > 3)
dimBeam = recState->shape()[-4];
- if(dropMaskState_)
- recState = dropout(recState, dropMaskState_);
+ recState = dropout(recState, dropMaskState_);
auto mappedState = dot(recState, Wa_);
if(layerNorm_) {
diff --git a/src/rnn/attention_constructors.h b/src/rnn/attention_constructors.h
index 9fd1e966..a878f57f 100644
--- a/src/rnn/attention_constructors.h
+++ b/src/rnn/attention_constructors.h
@@ -15,11 +15,11 @@ protected:
Ptr<EncoderState> state_;
public:
- AttentionFactory(Ptr<ExpressionGraph> graph) : InputFactory(graph) {}
+// AttentionFactory(Ptr<ExpressionGraph> graph) : InputFactory(graph) {}
- Ptr<CellInput> construct() override {
+ Ptr<CellInput> construct(Ptr<ExpressionGraph> graph) override {
ABORT_IF(!state_, "EncoderState not set");
- return New<Attention>(graph_, options_, state_);
+ return New<Attention>(graph, options_, state_);
}
Accumulator<AttentionFactory> set_state(Ptr<EncoderState> state) {
diff --git a/src/rnn/cells.cpp b/src/rnn/cells.cpp
index e9467544..3d7eee21 100644
--- a/src/rnn/cells.cpp
+++ b/src/rnn/cells.cpp
@@ -43,6 +43,23 @@ struct GRUFastNodeOp : public NaryNodeOp {
const std::string type() override { return "GRU-ops"; }
const std::string color() override { return "yellow"; }
+
+ virtual size_t hash() override {
+ size_t seed = NaryNodeOp::hash();
+ util::hash_combine(seed, final_);
+ return seed;
+ }
+
+ virtual bool equal(Expr node) override {
+ if(!NaryNodeOp::equal(node))
+ return false;
+ auto cnode = std::dynamic_pointer_cast<GRUFastNodeOp>(node);
+ if(!cnode)
+ return false;
+ if(final_ != cnode->final_)
+ return false;
+ return true;
+ }
};
Expr gruOps(const std::vector<Expr>& nodes, bool final) {
diff --git a/src/rnn/cells.h b/src/rnn/cells.h
index 12db6ea3..9fbc8852 100644..100755
--- a/src/rnn/cells.h
+++ b/src/rnn/cells.h
@@ -36,26 +36,26 @@ public:
dropout_ = options_->get<float>("dropout", 0);
U_ = graph->param(
- prefix + "_U", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_U", {dimState, dimState}, inits::glorotUniform());
if(dimInput)
W_ = graph->param(
- prefix + "_W", {dimInput, dimState}, inits::glorot_uniform);
+ prefix + "_W", {dimInput, dimState}, inits::glorotUniform());
- b_ = graph->param(prefix + "_b", {1, dimState}, inits::zeros);
+ b_ = graph->param(prefix + "_b", {1, dimState}, inits::zeros());
if(dropout_ > 0.0f) {
if(dimInput)
- dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
- dropMaskS_ = graph->dropout(dropout_, {1, dimState});
+ dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
+ dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput)
gamma1_ = graph->param(
- prefix + "_gamma1", {1, 3 * dimState}, inits::from_value(1.f));
+ prefix + "_gamma1", {1, 3 * dimState}, inits::fromValue(1.f));
gamma2_ = graph->param(
- prefix + "_gamma2", {1, 3 * dimState}, inits::from_value(1.f));
+ prefix + "_gamma2", {1, 3 * dimState}, inits::fromValue(1.f));
}
}
@@ -72,8 +72,7 @@ public:
else
input = inputs.front();
- if(dropMaskX_)
- input = dropout(input, dropMaskX_);
+ input = dropout(input, dropMaskX_);
auto xW = dot(input, W_);
@@ -87,8 +86,7 @@ public:
Expr recState = state.output;
auto stateDropped = recState;
- if(dropMaskS_)
- stateDropped = dropout(recState, dropMaskS_);
+ stateDropped = dropout(recState, dropMaskS_);
auto sU = dot(stateDropped, U_);
if(layerNorm_)
sU = layerNorm(sU, gamma2_);
@@ -133,20 +131,20 @@ public:
if(dimInput)
W_ = graph->param(
- prefix + "_W", {dimInput, dimState}, inits::glorot_uniform);
+ prefix + "_W", {dimInput, dimState}, inits::glorotUniform());
- b_ = graph->param(prefix + "_b", {1, dimState}, inits::zeros);
+ b_ = graph->param(prefix + "_b", {1, dimState}, inits::zeros());
if(dropout_ > 0.0f) {
if(dimInput)
- dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
- dropMaskS_ = graph->dropout(dropout_, {1, dimState});
+ dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
+ dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput)
- gamma1_ = graph->param(prefix + "_gamma1", {1, dimState}, inits::ones);
- gamma2_ = graph->param(prefix + "_gamma2", {1, dimState}, inits::ones);
+ gamma1_ = graph->param(prefix + "_gamma1", {1, dimState}, inits::ones());
+ gamma2_ = graph->param(prefix + "_gamma2", {1, dimState}, inits::ones());
}
}
@@ -163,8 +161,7 @@ public:
else
input = inputs.front();
- if(dropMaskX_)
- input = dropout(input, dropMaskX_);
+ input = dropout(input, dropMaskX_);
auto xW = dot(input, W_);
@@ -178,8 +175,7 @@ public:
Expr recState = state.output;
auto stateDropped = recState;
- if(dropMaskS_)
- stateDropped = dropout(recState, dropMaskS_);
+ stateDropped = dropout(recState, dropMaskS_);
auto sU = dot(stateDropped, U_);
if(layerNorm_)
sU = layerNorm(sU, gamma2_);
@@ -229,43 +225,43 @@ public:
final_ = opt<bool>("final", false);
auto U = graph->param(
- prefix + "_U", {dimState, 2 * dimState}, inits::glorot_uniform);
+ prefix + "_U", {dimState, 2 * dimState}, inits::glorotUniform());
auto Ux = graph->param(
- prefix + "_Ux", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Ux", {dimState, dimState}, inits::glorotUniform());
U_ = concatenate({U, Ux}, /*axis =*/ -1);
if(dimInput > 0) {
auto W = graph->param(
- prefix + "_W", {dimInput, 2 * dimState}, inits::glorot_uniform);
+ prefix + "_W", {dimInput, 2 * dimState}, inits::glorotUniform());
auto Wx = graph->param(
- prefix + "_Wx", {dimInput, dimState}, inits::glorot_uniform);
+ prefix + "_Wx", {dimInput, dimState}, inits::glorotUniform());
W_ = concatenate({W, Wx}, /*axis =*/ -1);
}
- auto b = graph->param(prefix + "_b", {1, 2 * dimState}, inits::zeros);
- auto bx = graph->param(prefix + "_bx", {1, dimState}, inits::zeros);
+ auto b = graph->param(prefix + "_b", {1, 2 * dimState}, inits::zeros());
+ auto bx = graph->param(prefix + "_bx", {1, dimState}, inits::zeros());
b_ = concatenate({b, bx}, /*axis =*/ -1);
// @TODO use this and adjust Amun model type saving and loading
// U_ = graph->param(prefix + "_U", {dimState, 3 * dimState},
- // (Expr a) : UnaryNodeOp(a)inits::glorot_uniform);
+ // (Expr a) : UnaryNodeOp(a)inits::glorotUniform());
// W_ = graph->param(prefix + "_W", {dimInput, 3 * dimState},
- // (Expr a) : UnaryNodeOp(a)inits::glorot_uniform);
+ // (Expr a) : UnaryNodeOp(a)inits::glorotUniform());
// b_ = graph->param(prefix + "_b", {1, 3 * dimState},
- // (Expr a) : UnaryNodeOp(a)inits::zeros);
+ // (Expr a) : UnaryNodeOp(a)inits::zeros());
if(dropout_ > 0.0f) {
if(dimInput)
- dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
- dropMaskS_ = graph->dropout(dropout_, {1, dimState});
+ dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
+ dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput)
gamma1_ = graph->param(
- prefix + "_gamma1", {1, 3 * dimState}, inits::from_value(1.f));
+ prefix + "_gamma1", {1, 3 * dimState}, inits::fromValue(1.f));
gamma2_ = graph->param(
- prefix + "_gamma2", {1, 3 * dimState}, inits::from_value(1.f));
+ prefix + "_gamma2", {1, 3 * dimState}, inits::fromValue(1.f));
}
}
@@ -284,8 +280,7 @@ public:
else
input = inputs[0];
- if(dropMaskX_)
- input = dropout(input, dropMaskX_);
+ input = dropout(input, dropMaskX_);
auto xW = dot(input, W_);
if(layerNorm_)
@@ -299,8 +294,7 @@ public:
Expr mask = nullptr) override {
auto stateOrig = state.output;
auto stateDropped = stateOrig;
- if(dropMaskS_)
- stateDropped = dropout(stateOrig, dropMaskS_);
+ stateDropped = dropout(stateOrig, dropMaskS_);
auto sU = dot(stateDropped, U_);
if(layerNorm_)
@@ -309,7 +303,7 @@ public:
Expr xW;
if(xWs.empty()) {
if(!fakeInput_ || fakeInput_->shape() != sU->shape())
- fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros);
+ fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros());
xW = fakeInput_;
} else {
xW = xWs.front();
@@ -376,9 +370,9 @@ public:
final_ = opt<bool>("final", false);
auto U = graph->param(
- prefix + "_U", {dimState, 2 * dimState}, inits::glorot_uniform);
+ prefix + "_U", {dimState, 2 * dimState}, inits::glorotUniform());
auto Ux = graph->param(
- prefix + "_Ux", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Ux", {dimState, dimState}, inits::glorotUniform());
if(layerNorm_) {
U_ = U;
@@ -389,9 +383,9 @@ public:
if(dimInput > 0) {
auto W = graph->param(
- prefix + "_W", {dimInput, 2 * dimState}, inits::glorot_uniform);
+ prefix + "_W", {dimInput, 2 * dimState}, inits::glorotUniform());
auto Wx = graph->param(
- prefix + "_Wx", {dimInput, dimState}, inits::glorot_uniform);
+ prefix + "_Wx", {dimInput, dimState}, inits::glorotUniform());
if(layerNorm_) {
W_ = W;
Wx_ = Wx;
@@ -400,8 +394,8 @@ public:
}
}
- auto b = graph->param(prefix + "_b", {1, 2 * dimState}, inits::zeros);
- auto bx = graph->param(prefix + "_bx", {1, dimState}, inits::zeros);
+ auto b = graph->param(prefix + "_b", {1, 2 * dimState}, inits::zeros());
+ auto bx = graph->param(prefix + "_bx", {1, dimState}, inits::zeros());
if(layerNorm_) {
b_ = b;
@@ -409,10 +403,10 @@ public:
// in specific cases we need to pass bx to the kernel
if(encoder_ && transition_) {
- auto b0 = graph->constant({1, 2 * dimState}, inits::zeros);
+ auto b0 = graph->constant({1, 2 * dimState}, inits::zeros());
bbx_ = concatenate({b0, bx}, /*axis =*/ -1);
} else {
- bbx_ = graph->constant({1, 3 * dimState}, inits::zeros);
+ bbx_ = graph->constant({1, 3 * dimState}, inits::zeros());
}
} else {
bbx_ = concatenate({b, bx}, /*axis =*/ -1);
@@ -420,28 +414,28 @@ public:
if(dropout_ > 0.0f) {
if(dimInput)
- dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
- dropMaskS_ = graph->dropout(dropout_, {1, dimState});
+ dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
+ dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput) {
W_lns_ = graph->param(
- prefix + "_W_lns", {1, 2 * dimState}, inits::from_value(1.f));
+ prefix + "_W_lns", {1, 2 * dimState}, inits::fromValue(1.f));
W_lnb_
- = graph->param(prefix + "_W_lnb", {1, 2 * dimState}, inits::zeros);
+ = graph->param(prefix + "_W_lnb", {1, 2 * dimState}, inits::zeros());
Wx_lns_ = graph->param(
- prefix + "_Wx_lns", {1, 1 * dimState}, inits::from_value(1.f));
+ prefix + "_Wx_lns", {1, 1 * dimState}, inits::fromValue(1.f));
Wx_lnb_
- = graph->param(prefix + "_Wx_lnb", {1, 1 * dimState}, inits::zeros);
+ = graph->param(prefix + "_Wx_lnb", {1, 1 * dimState}, inits::zeros());
}
U_lns_ = graph->param(
- prefix + "_U_lns", {1, 2 * dimState}, inits::from_value(1.f));
- U_lnb_ = graph->param(prefix + "_U_lnb", {1, 2 * dimState}, inits::zeros);
+ prefix + "_U_lns", {1, 2 * dimState}, inits::fromValue(1.f));
+ U_lnb_ = graph->param(prefix + "_U_lnb", {1, 2 * dimState}, inits::zeros());
Ux_lns_ = graph->param(
- prefix + "_Ux_lns", {1, 1 * dimState}, inits::from_value(1.f));
+ prefix + "_Ux_lns", {1, 1 * dimState}, inits::fromValue(1.f));
Ux_lnb_
- = graph->param(prefix + "_Ux_lnb", {1, 1 * dimState}, inits::zeros);
+ = graph->param(prefix + "_Ux_lnb", {1, 1 * dimState}, inits::zeros());
}
}
@@ -460,8 +454,7 @@ public:
else
input = inputs[0];
- if(dropMaskX_)
- input = dropout(input, dropMaskX_);
+ input = dropout(input, dropMaskX_);
Expr xW;
if(layerNorm_) {
@@ -494,8 +487,7 @@ public:
auto stateOrig = state.output;
auto stateDropped = stateOrig;
- if(dropMaskS_)
- stateDropped = dropout(stateOrig, dropMaskS_);
+ stateDropped = dropout(stateOrig, dropMaskS_);
Expr sU;
if(layerNorm_) {
@@ -530,7 +522,7 @@ public:
Expr xW;
if(transition_) {
if(!fakeInput_ || fakeInput_->shape() != sU->shape())
- fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros);
+ fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros());
xW = fakeInput_;
} else {
xW = xWs.front();
@@ -575,25 +567,25 @@ public:
dropout_ = opt<float>("dropout", 0);
U_ = graph->param(
- prefix + "_U", {dimState, 4 * dimState}, inits::glorot_uniform);
+ prefix + "_U", {dimState, 4 * dimState}, inits::glorotUniform());
if(dimInput)
W_ = graph->param(
- prefix + "_W", {dimInput, 4 * dimState}, inits::glorot_uniform);
+ prefix + "_W", {dimInput, 4 * dimState}, inits::glorotUniform());
- b_ = graph->param(prefix + "_b", {1, 4 * dimState}, inits::zeros);
+ b_ = graph->param(prefix + "_b", {1, 4 * dimState}, inits::zeros());
if(dropout_ > 0.0f) {
if(dimInput)
- dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
- dropMaskS_ = graph->dropout(dropout_, {1, dimState});
+ dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
+ dropMaskS_ = graph->dropoutMask(dropout_, {1, dimState});
}
if(layerNorm_) {
if(dimInput)
gamma1_ = graph->param(
- prefix + "_gamma1", {1, 4 * dimState}, inits::from_value(1.f));
+ prefix + "_gamma1", {1, 4 * dimState}, inits::fromValue(1.f));
gamma2_ = graph->param(
- prefix + "_gamma2", {1, 4 * dimState}, inits::from_value(1.f));
+ prefix + "_gamma2", {1, 4 * dimState}, inits::fromValue(1.f));
}
}
@@ -612,8 +604,7 @@ public:
} else
input = inputs.front();
- if(dropMaskX_)
- input = dropout(input, dropMaskX_);
+ input = dropout(input, dropMaskX_);
auto xW = dot(input, W_);
@@ -630,8 +621,7 @@ public:
auto cellState = state.cell;
auto recStateDropped = recState;
- if(dropMaskS_)
- recStateDropped = dropout(recState, dropMaskS_);
+ recStateDropped = dropout(recState, dropMaskS_);
auto sU = dot(recStateDropped, U_);
@@ -641,7 +631,7 @@ public:
Expr xW;
if(xWs.empty()) {
if(!fakeInput_ || fakeInput_->shape() != sU->shape())
- fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros);
+ fakeInput_ = sU->graph()->constant(sU->shape(), inits::zeros());
xW = fakeInput_;
} else {
xW = xWs.front();
@@ -677,17 +667,17 @@ public:
std::string prefix = options->get<std::string>("prefix");
Um_ = graph->param(
- prefix + "_Um", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Um", {dimState, dimState}, inits::glorotUniform());
Wm_ = graph->param(
- prefix + "_Wm", {dimInput, dimState}, inits::glorot_uniform);
- bm_ = graph->param(prefix + "_bm", {1, dimState}, inits::zeros);
- bwm_ = graph->param(prefix + "_bwm", {1, dimState}, inits::zeros);
+ prefix + "_Wm", {dimInput, dimState}, inits::glorotUniform());
+ bm_ = graph->param(prefix + "_bm", {1, dimState}, inits::zeros());
+ bwm_ = graph->param(prefix + "_bwm", {1, dimState}, inits::zeros());
if(CellType::layerNorm_) {
gamma1m_ = graph->param(
- prefix + "_gamma1m", {1, dimState}, inits::from_value(1.f));
+ prefix + "_gamma1m", {1, dimState}, inits::fromValue(1.f));
gamma2m_ = graph->param(
- prefix + "_gamma2m", {1, dimState}, inits::from_value(1.f));
+ prefix + "_gamma2m", {1, dimState}, inits::fromValue(1.f));
}
}
@@ -746,28 +736,28 @@ public:
std::string prefix = options->get<std::string>("prefix");
Uf_ = graph->param(
- prefix + "_Uf", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Uf", {dimState, dimState}, inits::glorotUniform());
Wf_ = graph->param(
- prefix + "_Wf", {dimInput, dimState}, inits::glorot_uniform);
- bf_ = graph->param(prefix + "_bf", {1, dimState}, inits::zeros);
+ prefix + "_Wf", {dimInput, dimState}, inits::glorotUniform());
+ bf_ = graph->param(prefix + "_bf", {1, dimState}, inits::zeros());
Ui_ = graph->param(
- prefix + "_Ui", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Ui", {dimState, dimState}, inits::glorotUniform());
Wi_ = graph->param(
- prefix + "_Wi", {dimInput, dimState}, inits::glorot_uniform);
- bi_ = graph->param(prefix + "_bi", {1, dimState}, inits::zeros);
+ prefix + "_Wi", {dimInput, dimState}, inits::glorotUniform());
+ bi_ = graph->param(prefix + "_bi", {1, dimState}, inits::zeros());
Uc_ = graph->param(
- prefix + "_Uc", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Uc", {dimState, dimState}, inits::glorotUniform());
Wc_ = graph->param(
- prefix + "_Wc", {dimInput, dimState}, inits::glorot_uniform);
- bc_ = graph->param(prefix + "_bc", {1, dimState}, inits::zeros);
+ prefix + "_Wc", {dimInput, dimState}, inits::glorotUniform());
+ bc_ = graph->param(prefix + "_bc", {1, dimState}, inits::zeros());
Uo_ = graph->param(
- prefix + "_Uo", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Uo", {dimState, dimState}, inits::glorotUniform());
Wo_ = graph->param(
- prefix + "_Wo", {dimInput, dimState}, inits::glorot_uniform);
- bo_ = graph->param(prefix + "_bo", {1, dimState}, inits::zeros);
+ prefix + "_Wo", {dimInput, dimState}, inits::glorotUniform());
+ bo_ = graph->param(prefix + "_bo", {1, dimState}, inits::zeros());
}
State apply(std::vector<Expr> inputs, State state, Expr mask = nullptr) {
@@ -828,28 +818,28 @@ public:
std::string prefix = options->get<std::string>("prefix");
auto Uf = graph->param(
- prefix + "_Uf", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Uf", {dimState, dimState}, inits::glorotUniform());
auto Wf = graph->param(
- prefix + "_Wf", {dimInput, dimState}, inits::glorot_uniform);
- auto bf = graph->param(prefix + "_bf", {1, dimState}, inits::zeros);
+ prefix + "_Wf", {dimInput, dimState}, inits::glorotUniform());
+ auto bf = graph->param(prefix + "_bf", {1, dimState}, inits::zeros());
auto Ui = graph->param(
- prefix + "_Ui", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Ui", {dimState, dimState}, inits::glorotUniform());
auto Wi = graph->param(
- prefix + "_Wi", {dimInput, dimState}, inits::glorot_uniform);
- auto bi = graph->param(prefix + "_bi", {1, dimState}, inits::zeros);
+ prefix + "_Wi", {dimInput, dimState}, inits::glorotUniform());
+ auto bi = graph->param(prefix + "_bi", {1, dimState}, inits::zeros());
auto Uc = graph->param(
- prefix + "_Uc", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Uc", {dimState, dimState}, inits::glorotUniform());
auto Wc = graph->param(
- prefix + "_Wc", {dimInput, dimState}, inits::glorot_uniform);
- auto bc = graph->param(prefix + "_bc", {1, dimState}, inits::zeros);
+ prefix + "_Wc", {dimInput, dimState}, inits::glorotUniform());
+ auto bc = graph->param(prefix + "_bc", {1, dimState}, inits::zeros());
auto Uo = graph->param(
- prefix + "_Uo", {dimState, dimState}, inits::glorot_uniform);
+ prefix + "_Uo", {dimState, dimState}, inits::glorotUniform());
auto Wo = graph->param(
- prefix + "_Wo", {dimInput, dimState}, inits::glorot_uniform);
- auto bo = graph->param(prefix + "_bo", {1, dimState}, inits::zeros);
+ prefix + "_Wo", {dimInput, dimState}, inits::glorotUniform());
+ auto bo = graph->param(prefix + "_bo", {1, dimState}, inits::zeros());
U_ = concatenate({Uf, Ui, Uc, Uo}, /*axis =*/ -1);
W_ = concatenate({Wf, Wi, Wc, Wo}, /*axis =*/ -1);
@@ -919,25 +909,25 @@ public:
layerNorm_ = opt<bool>("layer-normalization", false);
W_ = graph->param(
- prefix + "_W", {dimInput, dimInput}, inits::glorot_uniform);
+ prefix + "_W", {dimInput, dimInput}, inits::glorotUniform());
Wf_ = graph->param(
- prefix + "_Wf", {dimInput, dimInput}, inits::glorot_uniform);
- bf_ = graph->param(prefix + "_bf", {1, dimInput}, inits::zeros);
+ prefix + "_Wf", {dimInput, dimInput}, inits::glorotUniform());
+ bf_ = graph->param(prefix + "_bf", {1, dimInput}, inits::zeros());
Wr_ = graph->param(
- prefix + "_Wr", {dimInput, dimInput}, inits::glorot_uniform);
- br_ = graph->param(prefix + "_br", {1, dimInput}, inits::zeros);
+ prefix + "_Wr", {dimInput, dimInput}, inits::glorotUniform());
+ br_ = graph->param(prefix + "_br", {1, dimInput}, inits::zeros());
if(dropout_ > 0.0f) {
- dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
+ dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
}
if(layerNorm_) {
if(dimInput)
- gamma_ = graph->param(prefix + "_gamma", {1, dimState}, inits::ones);
- gammar_ = graph->param(prefix + "_gammar", {1, dimState}, inits::ones);
- gammaf_ = graph->param(prefix + "_gammaf", {1, dimState}, inits::ones);
+ gamma_ = graph->param(prefix + "_gamma", {1, dimState}, inits::ones());
+ gammar_ = graph->param(prefix + "_gammar", {1, dimState}, inits::ones());
+ gammaf_ = graph->param(prefix + "_gammaf", {1, dimState}, inits::ones());
}
}
@@ -954,7 +944,7 @@ public:
else
input = inputs.front();
- auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input;
+ auto inputDropped = dropout(input, dropMaskX_);
Expr x, f, r;
if(layerNorm_) {
@@ -1014,20 +1004,20 @@ public:
layerNorm_ = opt<bool>("layer-normalization", false);
W_ = graph->param(
- prefix + "_W", {dimInput, dimInput}, inits::glorot_uniform);
+ prefix + "_W", {dimInput, dimInput}, inits::glorotUniform());
Wf_ = graph->param(
- prefix + "_Wf", {dimInput, dimInput}, inits::glorot_uniform);
- bf_ = graph->param(prefix + "_bf", {1, dimInput}, inits::zeros);
+ prefix + "_Wf", {dimInput, dimInput}, inits::glorotUniform());
+ bf_ = graph->param(prefix + "_bf", {1, dimInput}, inits::zeros());
if(dropout_ > 0.0f) {
- dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
+ dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
}
if(layerNorm_) {
if(dimInput)
- gamma_ = graph->param(prefix + "_gamma", {1, dimState}, inits::ones);
- gammaf_ = graph->param(prefix + "_gammaf", {1, dimState}, inits::ones);
+ gamma_ = graph->param(prefix + "_gamma", {1, dimState}, inits::ones());
+ gammaf_ = graph->param(prefix + "_gammaf", {1, dimState}, inits::ones());
}
}
@@ -1044,7 +1034,7 @@ public:
else
input = inputs.front();
- auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input;
+ auto inputDropped = dropout(input, dropMaskX_);
Expr x, f;
if(layerNorm_) {
@@ -1095,16 +1085,16 @@ public:
// W_ = graph->param(prefix + "_W",
// {dimInput, dimInput},
-// inits::glorot_uniform);
+// inits::glorotUniform());
// Wf_ = graph->param(prefix + "_Wf",
// {dimInput, dimInput},
-// inits::glorot_uniform);
+// inits::glorotUniform());
// bf_ = graph->param(
-// prefix + "_bf", {1, dimInput}, inits::zeros);
+// prefix + "_bf", {1, dimInput}, inits::zeros());
// if(dropout_ > 0.0f) {
-// dropMaskX_ = graph->dropout(dropout_, {1, dimInput});
+// dropMaskX_ = graph->dropoutMask(dropout_, {1, dimInput});
// }
// }
@@ -1121,7 +1111,7 @@ public:
// else
// input = inputs.front();
-// auto inputDropped = dropMaskX_ ? dropout(input, dropMaskX_) : input;
+// auto inputDropped = dropout(input, dropMaskX_);
// auto x = dot(inputDropped, W_);
// auto f = affine(inputDropped, Wf_, bf_);
diff --git a/src/rnn/constructors.h b/src/rnn/constructors.h
index ec0e1bd8..beb1fce1 100755
--- a/src/rnn/constructors.h
+++ b/src/rnn/constructors.h
@@ -7,27 +7,25 @@
namespace marian {
namespace rnn {
-struct StackableFactory : public Factory {
- StackableFactory(Ptr<ExpressionGraph> graph) : Factory(graph) {}
- StackableFactory(const StackableFactory&) = default;
- StackableFactory(StackableFactory&&) = default;
-
- virtual ~StackableFactory() {}
-
- template <typename Cast>
- inline Ptr<Cast> as() {
- return std::dynamic_pointer_cast<Cast>(shared_from_this());
- }
-
- template <typename Cast>
- inline bool is() {
- return as<Cast>() != nullptr;
- }
-};
+typedef Factory StackableFactory;
+//struct StackableFactory : public Factory {StackableFactory
+// using Factory::Factory;
+//
+// virtual ~StackableFactory() {}
+//
+// template <typename Cast>
+// inline Ptr<Cast> as() {
+// return std::dynamic_pointer_cast<Cast>(shared_from_this());
+// }
+//
+// template <typename Cast>
+// inline bool is() {
+// return as<Cast>() != nullptr;
+// }
+//};
struct InputFactory : public StackableFactory {
- InputFactory(Ptr<ExpressionGraph> graph) : StackableFactory(graph) {}
- virtual Ptr<CellInput> construct() = 0;
+ virtual Ptr<CellInput> construct(Ptr<ExpressionGraph> graph) = 0;
};
class CellFactory : public StackableFactory {
@@ -35,44 +33,42 @@ protected:
std::vector<std::function<Expr(Ptr<rnn::RNN>)>> inputs_;
public:
- CellFactory(Ptr<ExpressionGraph> graph) : StackableFactory(graph) {}
-
- virtual Ptr<Cell> construct() {
+ virtual Ptr<Cell> construct(Ptr<ExpressionGraph> graph) {
std::string type = options_->get<std::string>("type");
if(type == "gru") {
- auto cell = New<GRU>(graph_, options_);
+ auto cell = New<GRU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "gru-nematus") {
- auto cell = New<GRUNematus>(graph_, options_);
+ auto cell = New<GRUNematus>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "lstm") {
- auto cell = New<LSTM>(graph_, options_);
+ auto cell = New<LSTM>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "mlstm") {
- auto cell = New<MLSTM>(graph_, options_);
+ auto cell = New<MLSTM>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "mgru") {
- auto cell = New<MGRU>(graph_, options_);
+ auto cell = New<MGRU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "tanh") {
- auto cell = New<Tanh>(graph_, options_);
+ auto cell = New<Tanh>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "relu") {
- auto cell = New<ReLU>(graph_, options_);
+ auto cell = New<ReLU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "sru") {
- auto cell = New<SRU>(graph_, options_);
+ auto cell = New<SRU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else if(type == "ssru") {
- auto cell = New<SSRU>(graph_, options_);
+ auto cell = New<SSRU>(graph, options_);
cell->setLazyInputs(inputs_);
return cell;
} else {
@@ -81,7 +77,7 @@ public:
}
CellFactory clone() {
- CellFactory aClone(graph_);
+ CellFactory aClone;
aClone.options_->merge(options_);
aClone.inputs_ = inputs_;
return aClone;
@@ -103,10 +99,8 @@ protected:
std::vector<Ptr<StackableFactory>> stackableFactories_;
public:
- StackedCellFactory(Ptr<ExpressionGraph> graph) : CellFactory(graph) {}
-
- Ptr<Cell> construct() override {
- auto stacked = New<StackedCell>(graph_, options_);
+ Ptr<Cell> construct(Ptr<ExpressionGraph> graph) override {
+ auto stacked = New<StackedCell>(graph, options_);
int lastDimInput = options_->get<int>("dimInput");
@@ -115,20 +109,20 @@ public:
if(sf->is<CellFactory>()) {
auto cellFactory = sf->as<CellFactory>();
- cellFactory->getOptions()->merge(options_);
+ cellFactory->mergeOpts(options_);
- sf->getOptions()->set("dimInput", lastDimInput);
+ sf->setOpt("dimInput", lastDimInput);
lastDimInput = 0;
if(i == 0)
for(auto f : inputs_)
cellFactory->add_input(f);
- stacked->push_back(cellFactory->construct());
+ stacked->push_back(cellFactory->construct(graph));
} else {
auto inputFactory = sf->as<InputFactory>();
- inputFactory->getOptions()->merge(options_);
- auto input = inputFactory->construct();
+ inputFactory->mergeOpts(options_);
+ auto input = inputFactory->construct(graph);
stacked->push_back(input);
lastDimInput += input->dimOutput();
}
@@ -146,49 +140,46 @@ public:
typedef Accumulator<StackedCellFactory> stacked_cell;
class RNNFactory : public Factory {
+ using Factory::Factory;
protected:
std::vector<Ptr<CellFactory>> layerFactories_;
public:
- RNNFactory(Ptr<ExpressionGraph> graph) : Factory(graph) {}
-
- Ptr<RNN> construct() {
- auto rnn = New<RNN>(graph_, options_);
+ Ptr<RNN> construct(Ptr<ExpressionGraph> graph) {
+ auto rnn = New<RNN>(graph, options_);
for(size_t i = 0; i < layerFactories_.size(); ++i) {
auto lf = layerFactories_[i];
- lf->getOptions()->merge(options_);
+ lf->mergeOpts(options_);
if(i > 0) {
int dimInput
- = layerFactories_[i - 1]->getOptions()->get<int>("dimState")
- + lf->getOptions()->get<int>("dimInputExtra", 0);
+ = layerFactories_[i - 1]->opt<int>("dimState")
+ + lf->opt<int>("dimInputExtra", 0);
- lf->getOptions()->set("dimInput", dimInput);
+ lf->setOpt("dimInput", dimInput);
}
if((rnn::dir)opt<int>("direction", (int)rnn::dir::forward)
== rnn::dir::alternating_forward) {
if(i % 2 == 0)
- lf->getOptions()->set("direction", (int)rnn::dir::forward);
+ lf->setOpt("direction", (int)rnn::dir::forward);
else
- lf->getOptions()->set("direction", (int)rnn::dir::backward);
+ lf->setOpt("direction", (int)rnn::dir::backward);
}
if((rnn::dir)opt<int>("direction", (int)rnn::dir::forward)
== rnn::dir::alternating_backward) {
if(i % 2 == 1)
- lf->getOptions()->set("direction", (int)rnn::dir::forward);
+ lf->setOpt("direction", (int)rnn::dir::forward);
else
- lf->getOptions()->set("direction", (int)rnn::dir::backward);
+ lf->setOpt("direction", (int)rnn::dir::backward);
}
- rnn->push_back(lf->construct());
+ rnn->push_back(lf->construct(graph));
}
return rnn;
}
- Ptr<RNN> operator->() { return construct(); }
-
template <class F>
Accumulator<RNNFactory> push_back(const F& f) {
layerFactories_.push_back(New<F>(f));
@@ -196,7 +187,7 @@ public:
}
RNNFactory clone() {
- RNNFactory aClone(graph_);
+ RNNFactory aClone;
aClone.options_->merge(options_);
for(auto lf : layerFactories_)
aClone.push_back(lf->clone());
diff --git a/src/rnn/rnn.h b/src/rnn/rnn.h
index 74a535d3..4efc569c 100755..100644
--- a/src/rnn/rnn.h
+++ b/src/rnn/rnn.h
@@ -35,7 +35,7 @@ protected:
public:
BaseRNN(Ptr<ExpressionGraph> graph, Ptr<Options> options)
: graph_(graph), options_(options) {}
-
+ virtual ~BaseRNN() {}
virtual Expr transduce(Expr, Expr = nullptr) = 0;
virtual Expr transduce(Expr, State, Expr = nullptr) = 0;
virtual Expr transduce(Expr, States, Expr = nullptr) = 0;
@@ -75,11 +75,11 @@ private:
std::vector<Expr> steps(xWs.size());
std::transform(xWs.begin(), xWs.end(), steps.begin(), [j](Expr e) {
- return step(e, j, -3);
+ return slice(e, -3, j);
});
if(mask)
- state = cell_->applyState(steps, state, step(mask, j, -3));
+ state = cell_->applyState(steps, state, slice(mask, -3, j));
else
state = cell_->applyState(steps, state);
@@ -113,6 +113,7 @@ private:
public:
friend RNN;
+ virtual ~SingleLayerRNN() {}
// @TODO: benchmark whether this concatenation is a good idea
virtual Expr transduce(Expr input, Expr mask = nullptr) override {
diff --git a/src/rnn/types.h b/src/rnn/types.h
index 672c600b..47424b75 100755..100644
--- a/src/rnn/types.h
+++ b/src/rnn/types.h
@@ -18,7 +18,7 @@ struct State {
select(cell, selIdx, beamSize, isBatchMajor) };
}
-private:
+ // this function is also called by Logits
static Expr select(Expr sel, // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN)
const std::vector<IndexType>& selIdx, // [beamIndex * activeBatchSize + batchIndex]
int beamSize, bool isBatchMajor)
@@ -34,6 +34,7 @@ private:
ABORT_IF(dimTime != 1 && !isBatchMajor, "unexpected time extent for RNN state"); // (the reshape()/rows() trick won't work in this case)
int numCols = isBatchMajor ? dimDepth * dimTime : dimDepth;
+ // @TODO: Can this complex operation be more easily written using index_select()?
sel = reshape(sel, { sel->shape().elements() / numCols, numCols }); // [beamSize * dimBatch, dimDepth] or [beamSize * dimBatch, dimTime * dimDepth]
sel = rows(sel, selIdx);
sel = reshape(sel, { beamSize, isBatchMajor ? dimBatch : dimTime, isBatchMajor ? dimTime : dimBatch, dimDepth });
diff --git a/src/tensors/allocator.h b/src/tensors/allocator.h
index 1e402ba6..9dc44f58 100755..100644
--- a/src/tensors/allocator.h
+++ b/src/tensors/allocator.h
@@ -88,14 +88,10 @@ private:
bool throw_{false};
std::set<Gap> gaps_;
- std::unordered_map<uint8_t*, Ptr<MemoryPiece>> allocated_;
-
- size_t align(size_t size) {
- return (size_t)(ceil(size / (float)alignment_) * alignment_);
- }
+ std::unordered_map<uint8_t*, MemoryPiece::PtrType> allocated_;
void grow(size_t add) {
- add = align(add);
+ add = alignedSize(add);
uint8_t* oldData = device_->data();
size_t oldSize = device_->size();
@@ -109,7 +105,7 @@ private:
gap.size()));
insertGap(Gap(device_->data() + oldSize, add));
- std::unordered_map<uint8_t*, Ptr<MemoryPiece>> oldAllocated;
+ std::unordered_map<uint8_t*, MemoryPiece::PtrType> oldAllocated;
allocated_.swap(oldAllocated);
for(auto it : oldAllocated) {
uint8_t* newPtr = device_->data() + std::distance(oldData, it.first);
@@ -119,13 +115,15 @@ private:
}
Gap getGap(size_t size) {
- size = align(size);
+ size = alignedSize(size);
auto it = std::lower_bound(gaps_.begin(), gaps_.end(), Gap(nullptr, size));
if(throw_ && it == gaps_.end()) {
+ //ABORT("Trying to allocate {}, but only {} available.", available_, size);
throw AllocationException(available_, size);
}
+ // @TODO: compact memory before re-allocation attempt, maybe by left shifting memory over currently largest gap
while(it == gaps_.end()) {
grow(step_);
it = std::lower_bound(gaps_.begin(), gaps_.end(), Gap(nullptr, size));
@@ -177,10 +175,14 @@ public:
reserve(bytes);
}
+ size_t alignedSize(size_t size) {
+ return (size_t)(ceil(size / (double)alignment_) * alignment_);
+ }
+
void throwAtReallocation(bool throwRealloc) { throw_ = throwRealloc; }
void reserve(size_t bytes) {
- bytes = align(bytes);
+ bytes = alignedSize(bytes);
if(bytes > 0)
device_->reserve(bytes);
clear();
@@ -188,22 +190,16 @@ public:
template <typename T>
size_t capacity(size_t num) {
- return align(num * sizeof(T));
- }
-
- size_t capacity(size_t num, Type type) { return align(num * sizeOf(type)); }
-
- Ptr<MemoryPiece> alloc(size_t num, Type type) {
- return alloc(num * sizeOf(type));
+ return alignedSize(num * sizeof(T));
}
template <typename T>
- Ptr<MemoryPiece> alloc(size_t num) {
+ MemoryPiece::PtrType alloc(size_t num) {
return alloc(capacity<T>(num));
}
- Ptr<MemoryPiece> alloc(size_t bytes) {
- bytes = align(bytes);
+ MemoryPiece::PtrType alloc(size_t bytes) {
+ bytes = alignedSize(bytes);
Gap gap = getGap(bytes);
if(gap.size() > bytes) {
@@ -211,13 +207,13 @@ public:
}
auto ptr = gap.data();
- auto mp = New<MemoryPiece>(ptr, bytes);
+ auto mp = MemoryPiece::New(ptr, bytes);
allocated_[ptr] = mp;
return mp;
}
bool free(uint8_t* ptr, size_t bytes) {
- bytes = align(bytes);
+ bytes = alignedSize(bytes);
ABORT_IF(ptr == 0, "Double free?");
@@ -233,7 +229,7 @@ public:
return false;
}
- bool free(Ptr<MemoryPiece> mp) {
+ bool free(MemoryPiece::PtrType mp) {
if(free(mp->data(), mp->size())) {
mp->set(nullptr, 0);
return true;
@@ -248,8 +244,8 @@ public:
insertGap({device_->data(), device_->size()}, false);
}
- Ptr<MemoryPiece> memory() {
- return New<MemoryPiece>(device_->data(), device_->size());
+ MemoryPiece::PtrType memory() {
+ return MemoryPiece::New(device_->data(), device_->size());
}
size_t size() { return device_->size(); }
diff --git a/src/tensors/backend.h b/src/tensors/backend.h
index 36e537dc..ce8a5e60 100644
--- a/src/tensors/backend.h
+++ b/src/tensors/backend.h
@@ -16,10 +16,8 @@ protected:
public:
Backend(DeviceId deviceId, size_t seed)
- : deviceId_(deviceId),
- seed_(seed),
- randomGenerator_(createRandomGenerator(seed, deviceId)) {}
-
+ : deviceId_(deviceId), seed_(seed), randomGenerator_(createRandomGenerator(seed, deviceId)) {}
+ virtual ~Backend() {};
virtual DeviceId getDeviceId() { return deviceId_; };
virtual Ptr<RandomGenerator> getRandomGenerator() { return randomGenerator_; }
@@ -29,6 +27,11 @@ public:
virtual void setClip(float clipValue) { clipValue_ = clipValue; }
float getClip() { return clipValue_; }
+
+ // for CPU, sets to use optimized code for inference.
+ // for GPU, this is invalid. for gpu, isOptimized() function always returns false.
+ virtual void setOptimized(bool optimize) = 0;
+ virtual bool isOptimized() = 0;
};
Ptr<Backend> BackendByDeviceId(DeviceId deviceId, size_t seed);
diff --git a/src/tensors/cpu/add.h b/src/tensors/cpu/add.h
index 38a0684d..4bae5bb5 100755..100644
--- a/src/tensors/cpu/add.h
+++ b/src/tensors/cpu/add.h
@@ -15,8 +15,8 @@ namespace marian {
namespace cpu {
-template <size_t K, class Functor>
-void gAddGeneric(Functor functor,
+template <size_t K, class Functor, class AggFunctor>
+void gAggregateGeneric(Functor functor, float aggInit, AggFunctor aggFunctor,
const functional::Shape full,
functional::Tensor<float> out,
functional::Array<functional::Tensor<float>, K> ins,
@@ -34,16 +34,16 @@ void gAddGeneric(Functor functor,
functional::Array<int, N> dims;
for(int index = 0; index < outLength; ++index) {
if(same) {
- out[index] += functional::apply(functor, ins, index) * scale;
+ out[index] = aggFunctor(out[index], functional::apply(functor, ins, index) * scale);
} else {
out.shape().dims(index, dims);
- out[index] += functional::loops(functor, ins, len, dims) * scale;
+ out[index] = aggFunctor(out[index], functional::loops(functor, aggInit, aggFunctor, ins, len, dims) * scale);
}
}
}
-template <size_t K, class Functor>
-void gAddEqual(Functor functor,
+template <size_t K, class Functor, class AggFunctor>
+void gAggregateEqual(Functor functor, AggFunctor aggFunctor,
functional::Tensor<float> out,
functional::Array<functional::Tensor<float>, K> ins,
float scale,
@@ -61,12 +61,12 @@ void gAddEqual(Functor functor,
indices[i] = ins[i].shape().bindex(dims);
}
- out[index] += functional::apply(functor, ins, indices) * scale;
+ out[index] = aggFunctor(out[index], functional::apply(functor, ins, indices) * scale);
}
}
-template <size_t K, class Functor>
-void gAddReduce(Functor functor,
+template <size_t K, class Functor, class AggFunctor>
+void gAggregateReduce(Functor functor, float aggInit, AggFunctor aggFunctor,
const functional::Shape full,
functional::Tensor<float> out,
functional::Array<functional::Tensor<float>, K> ins,
@@ -79,10 +79,10 @@ void gAddReduce(Functor functor,
same = same && ins[i].shape().elements() == full.elements();
for(int j = 0; j < rows; ++j) {
- float sum = 0;
+ float colSum = aggInit;
if(same) {
for(int id = 0; id < cols; ++id)
- sum += functional::apply(functor, ins, j * cols + id);
+ colSum = aggFunctor(colSum, functional::apply(functor, ins, j * cols + id));
} else {
functional::Array<int, functional::Shape::size()> dims;
for(int id = 0; id < cols; ++id) {
@@ -90,15 +90,15 @@ void gAddReduce(Functor functor,
functional::Array<int, K> indices;
for(size_t i = 0; i < K; ++i)
indices[i] = ins[i].shape().bindex(dims);
- sum += functional::apply(functor, ins, indices);
+ colSum = aggFunctor(colSum, functional::apply(functor, ins, indices));
}
}
- out[j] += sum * scale;
+ out[j] = aggFunctor(out[j], colSum * scale);
}
}
-template <class Functor, class... Tensors>
-void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
+template <class Functor, class AggFunctor, class... Tensors>
+void Aggregate(Functor functor, float aggInit, AggFunctor aggFunctor, float scale, marian::Tensor out, Tensors... tensors) {
auto full = marian::Shape::broadcast({out, tensors...});
//int length = out->shape().elements();
@@ -111,15 +111,16 @@ void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
if(full.back() != 1 && out->shape().back() == 1) {
//size_t m = full.elements() / length;
//size_t k = full.back();
- cpu::gAddReduce(functor, full, gOut, gIns, scale);
+ cpu::gAggregateReduce(functor, aggInit, aggFunctor, full, gOut, gIns, scale);
} else if(out->shape() == full) {
bool broadcast = false;
for(size_t i = 0; i < K; ++i)
broadcast = broadcast || gOut.shape() != gIns[i].shape();
- cpu::gAddEqual(functor, gOut, gIns, scale, broadcast);
+ cpu::gAggregateEqual(functor, aggFunctor, gOut, gIns, scale, broadcast);
} else {
- cpu::gAddGeneric(functor, full, gOut, gIns, scale);
+ cpu::gAggregateGeneric(functor, aggInit, aggFunctor, full, gOut, gIns, scale);
}
}
+
} // namespace cpu
} // namespace marian
diff --git a/src/tensors/cpu/backend.h b/src/tensors/cpu/backend.h
index 7264902e..74bbf808 100644
--- a/src/tensors/cpu/backend.h
+++ b/src/tensors/cpu/backend.h
@@ -10,10 +10,17 @@ namespace marian {
namespace cpu {
class Backend : public marian::Backend {
+protected:
+ bool optimized_{false};
+
public:
Backend(DeviceId deviceId, size_t seed) : marian::Backend(deviceId, seed) {}
void setDevice() override {}
void synchronize() override {}
+
+ // for CPU & inference only, sets to use optimized code for inference. Does nothing for GPU.
+ void setOptimized(bool optimize) override { optimized_ = optimize; }
+ bool isOptimized() override { return optimized_; }
};
} // namespace cpu
} // namespace marian
diff --git a/src/tensors/cpu/device.cpp b/src/tensors/cpu/device.cpp
index 04a79ae6..40bb558b 100644
--- a/src/tensors/cpu/device.cpp
+++ b/src/tensors/cpu/device.cpp
@@ -8,46 +8,53 @@
namespace marian {
namespace cpu {
-
-Device::~Device() {
- free(data_);
- data_ = nullptr;
- size_ = 0;
-}
+namespace {
// allocate function for tensor reserve() below.
-// Needed for AVX512, while not available on all compilers. It seems clang
-// does not have aligned_alloc for all cstlib versions. If AVX512 is not used
-// a simple malloc is probably fine.
-// Should generate a runtime error otherwise as we have a check in the AVX512
-// functions which tests for alignment.
-#ifdef _WIN32
-#define MALLOC(size) _aligned_malloc(size, alignment_)
-#elif __GNUC__
-#define MALLOC(size) aligned_alloc(alignment_, size)
-#else
-#define MALLOC(size) malloc(size)
-#endif
+// Alignment is needed because we use AVX512 and AVX2 vectors. We should fail if we can't allocate aligned memory.
#ifdef _WIN32
-#define FREE(ptr) _aligned_free(ptr)
+void *genericMalloc(size_t alignment, size_t size) {
+ void *ret = _aligned_malloc(size, alignment);
+ ABORT_IF(!ret, "Failed to allocate memory on CPU");
+ return ret;
+}
+void genericFree(void *ptr) {
+ _aligned_free(ptr);
+}
#else
-#define FREE(ptr) free(ptr)
+// Linux and OS X. There is no fallback to malloc because we need it to be aligned.
+void *genericMalloc(size_t alignment, size_t size) {
+ // On macos, aligned_alloc is available only on c++17
+ // Furthermore, it requires that the memory requested is an exact multiple of the alignment, otherwise it fails.
+ // posix_memalign is available both Mac (Since 2016) and Linux and in both gcc and clang
+ void *result;
+ // Error could be detected by return value or just remaining nullptr.
+ ABORT_IF(posix_memalign(&result, alignment, size), "Failed to allocate memory on CPU");
+ return result;
+}
+void genericFree(void *ptr) {
+ free(ptr);
+}
#endif
+} // namespace
+
+Device::~Device() {
+ genericFree(data_);
+}
+
void Device::reserve(size_t size) {
size = align(size);
ABORT_IF(size < size_ || size == 0,
"New size must be larger than old size and larger than 0");
+ uint8_t *temp = static_cast<uint8_t*>(genericMalloc(alignment_, size));
if(data_) {
- uint8_t *temp = static_cast<uint8_t*>(MALLOC(size));
std::copy(data_, data_ + size_, temp);
- FREE(data_);
- data_ = temp;
- } else {
- data_ = static_cast<uint8_t*>(MALLOC(size));
+ genericFree(data_);
}
+ data_ = temp;
size_ = size;
}
} // namespace cpu
diff --git a/src/tensors/cpu/element.h b/src/tensors/cpu/element.h
index ca554112..d98a0f9b 100644
--- a/src/tensors/cpu/element.h
+++ b/src/tensors/cpu/element.h
@@ -17,13 +17,16 @@ namespace cpu {
// on strides, it correctly broadcasts to all dimensions without additional
// computation.
// Compiler optimizes this to single construct with nested(?) loops.
+
+namespace F = marian::functional;
+
template <size_t I = 0>
struct E {
- template <size_t K, class Functor>
+ template <size_t numArg, class Functor, typename ElementType>
static inline void element(
const Functor& functor,
- functional::Array<functional::Tensor<float>, K>& tensors,
- functional::Array<int, K> indices) {
+ F::Array<F::Tensor<ElementType>, numArg>& tensors,
+ F::Array<int, numArg> indices) {
const auto& shape = tensors[0].shape();
// loop over outer-most dimension
@@ -34,8 +37,11 @@ struct E {
// increase index for current dimension by stride or 0 if broadcasting.
// bstride(i) is look-up value, either equal to stride if the
// corresponding dim is larger 1 or 0 if the dim is 1.
- for(size_t k = 0; k < K; ++k)
+ for(size_t k = 0; k < numArg; ++k) {
+ //int stride = tensors[k].shape().stride(I);
+ //indices[k] += stride == 1 ? 0 : stride;
indices[k] += tensors[k].shape().bstride(I);
+ }
}
}
};
@@ -43,31 +49,81 @@ struct E {
// specialization for inner-most single element (recursive stopping criterion)
// using const reference for indices here to avoid copying. No loop.
template <>
-struct E<functional::Shape::size()> {
- template <size_t K, class Functor>
+struct E<F::Shape::size()> {
+ template <size_t numArg, class Functor, typename ElementType>
static inline void element(
const Functor& functor,
- functional::Array<functional::Tensor<float>, K>& tensors,
- const functional::Array<int, K>& indices) {
+ F::Array<F::Tensor<ElementType>, numArg>& tensors,
+ const F::Array<int, numArg>& indices) {
// just apply the function for all indexed elements across all tensors
- tensors[0][indices[0]] = functional::apply(functor, tensors, indices);
+ // @TODO: use converting operator[] on tensor
+ tensors[0].data()[indices[0]] = F::apply(functor, tensors, indices);
}
};
-// main call to function executing element-wise operation
-template <class Functor, class... Tensors>
-void Element(const Functor& functor, marian::Tensor out, Tensors... tensors) {
- constexpr size_t K = sizeof...(tensors) + 1;
- functional::Array<functional::Tensor<float>, K> gTensors = {out, tensors...};
+template <typename ElementType, class Functor, class... Tensors>
+void element(const Functor& functor, marian::Tensor out, Tensors... tensors) {
- // create and initialize indices to 0
- functional::Array<int, K> indices;
+ // Number of input tensors + 1 (output tensor)
+ constexpr size_t argNum = sizeof...(tensors) + 1;
+ // create and initialize indices to 0, one index per tensor
+ F::Array<int, argNum> indices;
indices.fill(0);
// call elementwise operation going from outer-most dimension
// to inner-most element.
+ F::Array<F::Tensor<ElementType>, argNum> gTensors = {out, tensors...};
E<0>::element(functor, gTensors, indices);
}
+// Dispatch elementwise functions with float element type based on number of
+// elements. If dividable by 8 and AVX2 is available (TODO: check this?) use
+// AVX2 specific intrinsics. Similar for 4 and AVX. TODO: Add AVX512 support.
+template <class Functor, class... Tensors>
+void elementFloat(const Functor& functor, marian::Tensor out, Tensors... tensors) {
+#ifndef __CUDACC__
+ std::vector<marian::Tensor> ts({tensors...});
+ bool div8 = true;
+ bool div4 = true;
+
+ if(out->shape()[-1] % 8 != 0)
+ div8 = false;
+ if(out->shape()[-1] % 4 != 0)
+ div4 = false;
+ for(auto t : ts) {
+ if(t->shape()[-1] % 8 != 0)
+ div8 = false;
+ if(t->shape()[-1] % 4 != 0)
+ div4 = false;
+ }
+
+ if(div8) {
+ // std::cerr << "8: " << functor.to_string() << std::endl;
+#ifdef __AVX__
+ element<float32x8>(functor, out, tensors...);
+ return;
+#endif
+ }
+
+ if(div4) {
+ // std::cerr << "4: " << functor.to_string() << std::endl;
+ element<float32x4>(functor, out, tensors...);
+ return;
+ }
+#endif
+ // std::cerr << "1: " << functor.to_string() << std::endl;
+ element<float>(functor, out, tensors...);
+}
+
+// main call to function executing element-wise operation
+template <class Functor, class... Tensors>
+void Element(const Functor& functor, marian::Tensor out, Tensors... tensors) {
+ switch(out->type()) {
+ case Type::float32: elementFloat(functor, out, tensors...); break;
+ //case Type::uint32: element<uint32_t>(functor, out, tensors...); break;
+ default: ABORT("Unsupported type for element-wise operation: {}", out->type()); break;
+ }
+}
+
} // namespace cpu
} // namespace marian
diff --git a/src/tensors/cpu/fbgemm/expanded_gemm.h b/src/tensors/cpu/fbgemm/expanded_gemm.h
new file mode 100644
index 00000000..32cc6b12
--- /dev/null
+++ b/src/tensors/cpu/fbgemm/expanded_gemm.h
@@ -0,0 +1,407 @@
+#pragma once
+
+#include "graph/node.h"
+#include "packed_gemm.h"
+#include "tensors/cpu/sharp/int_gemm.h"
+
+#if USE_FBGEMM
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-variable"
+#endif
+
+#include "3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h"
+
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
+#endif
+
+using namespace fbgemm;
+// @TODO: don't use using namespace ...; in header files. Just don't. [UG]
+#endif // USE_FBGEMM
+
+namespace marian {
+namespace cpu {
+namespace variant {
+
+// Enumeration for the Matrix used in pack functions
+// A matrix - 0, B matrix - 1
+enum class PackMatrix : uint8_t {
+ A = 0x00,
+ B = 0x01
+};
+
+// Pack a matrix (fp16) into cache utilization efficient way (block format) together with quantization into fp16
+// PackMatrix packMat_: the type of packed matrix - A or B matrix
+// bool transpose_: transpose
+// int nrow_: the number of rows
+// int ncol_: the number of columns
+// int kernel_ncol_blocks_: the number of column blocks
+// int brow_: the number of rows in a block
+// int bcol_: the number of columns in a block
+// int last_brow_: the number of rows in the last block
+// int nbrow_: row index in a block
+// int nbcol_: column index in a block
+// uint64_t packsize_: the size of the packed matrix
+// (the number of fp16 elements + padding (1024) + extra temporary memory (256))
+struct FbgemmPacked16PackNodeOp : public UnaryNodeOp {
+ PackMatrix packMat_;
+ bool transpose_;
+ int nrow_;
+ int ncol_;
+ int kernel_ncol_blocks_;
+ int brow_;
+ int bcol_;
+ int last_brow_;
+ int nbrow_;
+ int nbcol_;
+ uint64_t packsize_;
+
+ FbgemmPacked16PackNodeOp(Expr a, PackMatrix packMat, bool transpose, float clipValue)
+ : UnaryNodeOp(a, newShape(a, transpose), Type::uint8),
+ packMat_(packMat),
+ transpose_(transpose) {
+ if(packMat != PackMatrix::B)
+ ABORT("Only prepacking of B (weight matrix) is supported");
+ if(clipValue != 0)
+ ABORT("Clipping is not supported");
+ if(!memoize_)
+ ABORT("Only constant weight node can be packed");
+ }
+
+ NodeOps forwardOps() override {
+#if USE_FBGEMM
+ return {NodeOp(fbgemmPacked16Pack(val_,
+ child(0)->val()->data(),
+ transpose_,
+ nrow_,
+ ncol_,
+ kernel_ncol_blocks_,
+ brow_,
+ bcol_,
+ last_brow_,
+ nbrow_,
+ nbcol_,
+ packsize_))
+ };
+#else // USE_FBGEMM
+ ABORT("FbgemmPacked16PackNodeOp can only be used with FBGEMM enabled.");
+ return { NodeOp(0) };
+#endif // USE_FBGEMM
+ }
+
+ NodeOps backwardOps() override {
+ ABORT("FbgemmPacked16PackNodeOp only available for inference");
+ return {NodeOp(0)};
+ }
+
+ const std::string type() override { return "packMatFp16"; }
+
+ Shape newShape(Expr MAYBE_UNUSED a, bool MAYBE_UNUSED transpose) {
+#if USE_FBGEMM
+ auto shapeMat = a->shape();
+ // Should be 2D - weight matrix
+ ABORT_IF(shapeMat.size() != 2,
+ "Weight Matrix should be 2D");
+ fbgemmPacked16PackInfo(shapeMat,
+ transpose,
+ nrow_,
+ ncol_,
+ kernel_ncol_blocks_,
+ brow_,
+ bcol_,
+ last_brow_,
+ nbrow_,
+ nbcol_,
+ packsize_);
+
+ Shape outShape({(int)packsize_});
+ return outShape;
+#else
+ ABORT("Packed GEMM requires a build with USE_FBGEMM enabled");
+ return Shape();
+#endif // USE_FBGEMM
+ }
+};
+ ;
+// Pack a matrix (int8) into cache utilization efficient way (block format) together with quantization into int8
+// PackMatrix packMat_: the type of packed matrix - A or B matrix
+// marian::Type packType_: the type the input matrix is packed - packed8avx2 or packed8avx512
+// bool transpose_: transpose
+// int nrow_: the number of rows
+// int ncol_: the number of columns
+// uint64_t packsize_: the size of the packed matrix
+// (the size of int8 packed B from fbgemm:PackAWithQuantRowOffset + quantization scale, offset and zero point)
+
+struct FbgemmPacked8PackNodeOp : public UnaryNodeOp {
+ PackMatrix packMat_;
+ marian::Type packType_;
+ bool transpose_;
+ int nrow_;
+ int ncol_;
+ uint64_t packsize_;
+
+ FbgemmPacked8PackNodeOp(Expr a,
+ PackMatrix packMat,
+ marian::Type packType,
+ bool transpose,
+ float clipValue)
+ : UnaryNodeOp(a, newShape(a, transpose), Type::uint8),
+ packMat_(packMat),
+ packType_(packType),
+ transpose_(transpose) {
+ if(packMat != PackMatrix::B)
+ ABORT("Only prepacking of B (weight matrix) is supported");
+ if(clipValue != 0)
+ ABORT("Clipping is not supported");
+ if(!memoize_)
+ ABORT("Only constant weight node can be packed");
+ }
+
+ NodeOps forwardOps() override {
+#if USE_FBGEMM
+ return {NodeOp(fbgemmPacked8Pack(val_,
+ child(0)->val()->data(),
+ packType_,
+ transpose_,
+ nrow_,
+ ncol_,
+ packsize_))
+ };
+#else // USE_FBGEMM
+ ABORT("FbgemmPacked8PackNodeOp can only be used with FBGEMM enabled.");
+ return { NodeOp(0) };
+#endif // USE_FBGEMM
+ }
+
+ NodeOps backwardOps() override {
+ ABORT("FbgemmPacked8PackNodeOp only available for inference");
+ return {NodeOp(0)};
+ }
+
+ const std::string type() override { return "packMatInt8"; }
+
+#if USE_FBGEMM
+ Shape newShape(Expr a, bool transpose) {
+ fbgemmPacked8PackInfo(a->shape(), packType_, transpose, nrow_, ncol_, packsize_);
+ Shape outShape({(int)packsize_});
+ return outShape;
+ }
+#else
+ Shape newShape(Expr /*a*/, bool /*transpose*/) {
+ ABORT("Packed GEMM requires a build with USE_FBGEMM enabled");
+ return Shape();
+ }
+#endif // USE_FBGEMM
+};
+
+
+// Affine transform (matrix multiplication) using packed B matrix
+// float scalar_: scalar multiplier
+// size_t m_: the number of rows in A and C
+// size_t n_: the number of columns in B and C
+// size_t k_: the number of columns in A and the number of rows in C
+// bool transA_: transpose A
+// bool transB_: transpose B
+class FbgemmPacked16AffineNodeOp : public NaryNodeOp {
+private:
+ size_t m_;
+ size_t n_;
+ size_t k_;
+ bool transA_;
+ bool transB_;
+
+public:
+ FbgemmPacked16AffineNodeOp(const std::vector<Expr>& nodes, Shape bShape, bool transA, bool transB, float /*scalar*/)
+ : NaryNodeOp(nodes, newShape(nodes[0], bShape, transA, transB), Type::float32)/*, scalar_(scalar)*/ {
+ transA_ = transA;
+ transB_ = transB;
+ m_ = nodes[0]->shape().elements() / nodes[0]->shape()[-1];
+ k_ = nodes[0]->shape().back();
+ if(transA)
+ std::swap(m_, k_);
+
+ size_t l = bShape.elements() / bShape[-1];
+ n_ = bShape[-1];
+ if(transB)
+ std::swap(l, n_);
+ }
+
+ Shape newShape(Expr a, Shape bShape, bool transA, bool transB) {
+ auto shapeA = a->shape();
+ if(transA) {
+ shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
+ shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
+ }
+
+ auto shapeB = bShape;
+ if(transB) {
+ shapeB.set(shapeB.size() - 2, bShape[shapeB.size() - 1]);
+ shapeB.set(shapeB.size() - 1, bShape[shapeB.size() - 2]);
+ }
+
+ Shape outShape = shapeA;
+ outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
+ ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
+ "Matrix product requires inner dimensions to match");
+ return outShape;
+ }
+
+ NodeOps forwardOps() override {
+#if USE_FBGEMM
+ return {
+ NodeOp(fbgemmPacked16Gemm(val_,
+ child(0)->val(),
+ child(1)->val(),
+ children().size() > 2 ? child(2)->val() : nullptr, // pass only if it has a bias
+ m_,
+ n_,
+ transA_))
+ };
+#else // USE_FBGEMM
+ ABORT("FbgemmPacked16AffineNodeOp can only be used with FBGEMM enabled.");
+ return { NodeOp(0) };
+#endif // USE_FBGEMM
+ }
+
+ NodeOps backwardOps() override {
+ ABORT("Only used for inference");
+ return {NodeOp(0)};
+ }
+
+ const std::string type() override { return "gemmPacked16"; }
+};
+
+// Affine transform (matrix multiplication) using packed B matrix
+// Especially, this gemm performs quantized gemms in 8-bit integers.
+// float scalar_: scalar multiplier
+// size_t m_: the number of rows in A and C
+// size_t n_: the number of columns in B and C
+// size_t k_: the number of columns in A and the number of rows in C
+// bool transA_: transpose A
+// bool transB_: transpose B
+class FbgemmPacked8AffineNodeOp : public NaryNodeOp {
+private:
+ size_t m_;
+ size_t n_;
+ size_t k_;
+ bool transA_;
+ bool transB_;
+
+public:
+ FbgemmPacked8AffineNodeOp(const std::vector<Expr>& nodes, Shape bShape, bool transA, bool transB, float /*scalar*/)
+ : NaryNodeOp(nodes, newShape(nodes[0], bShape, transA, transB), Type::float32)/*, scalar_(scalar) */ {
+ transA_ = transA;
+ transB_ = transB;
+ m_ = nodes[0]->shape().elements() / nodes[0]->shape()[-1];
+ k_ = nodes[0]->shape().back();
+ if(transA)
+ std::swap(m_, k_);
+
+ size_t l = bShape.elements() / bShape[-1];
+ n_ = bShape[-1];
+ if(transB)
+ std::swap(l, n_);
+ }
+
+ Shape newShape(Expr a, Shape bShape, bool transA, bool transB) {
+ auto shapeA = a->shape();
+ if(transA) {
+ shapeA.set(shapeA.size() - 2, a->shape()[shapeA.size() - 1]);
+ shapeA.set(shapeA.size() - 1, a->shape()[shapeA.size() - 2]);
+ }
+
+ auto shapeB = bShape;
+ if(transB) {
+ shapeB.set(shapeB.size() - 2, bShape[shapeB.size() - 1]);
+ shapeB.set(shapeB.size() - 1, bShape[shapeB.size() - 2]);
+ }
+
+ Shape outShape = shapeA;
+ outShape.set(outShape.size() - 1, shapeB[shapeB.size() - 1]);
+ ABORT_IF(shapeA[shapeA.size() - 1] != shapeB[shapeB.size() - 2],
+ "Matrix product requires inner dimensions to match");
+ return outShape;
+ }
+
+ NodeOps forwardOps() override {
+ NodeOps nodeOps;
+#if USE_FBGEMM
+ // Do addBias only if it has a bias term
+ if (children().size() > 2) {
+ nodeOps = { NodeOp(fbgemmPacked8Gemm(val_,
+ child(0)->val(),
+ child(1)->val(),
+ m_,
+ n_,
+ k_,
+ transA_,
+ transB_);
+ marian::cpu::int16::AddBias(val_, child(2)->val())) };
+ } else {
+ nodeOps = { NodeOp(fbgemmPacked8Gemm(val_,
+ child(0)->val(),
+ child(1)->val(),
+ m_,
+ n_,
+ k_,
+ transA_,
+ transB_)) };
+ }
+#else // USE_FBGEMM
+ ABORT("FbgemmPacked8AffineNodeOp can only be used with FBGEMM enabled.");
+#endif // USE_FBGEMM
+
+ return nodeOps;
+ }
+
+ NodeOps backwardOps() override {
+ ABORT("Only used for inference");
+ return {NodeOp(0)};
+ }
+
+ const std::string type() override { return "gemmPacked8"; }
+};
+
+static inline Expr affine(Expr a, Expr b, Shape bShape, Expr c, bool transA, bool transB, float scalar) {
+ std::vector<Expr> nodes = {a, b, c};
+ Type elementType = b->value_type();
+
+ if (elementType == Type::packed16)
+ return Expression<FbgemmPacked16AffineNodeOp>(nodes, bShape, transA, transB, scalar);
+ else if (isPacked(elementType) && sizeOf(elementType) == 1)
+ return Expression<FbgemmPacked8AffineNodeOp>(nodes, bShape, transA, transB, scalar);
+ else {
+ ABORT("Only int8 and fp16 are available. {}", elementType);
+ return nullptr;
+ }
+}
+
+static inline Expr pack(Type elementType, Expr a, PackMatrix packMat, bool transpose, float clipValue) {
+ if (elementType == Type::packed16)
+ return Expression<FbgemmPacked16PackNodeOp>(a, packMat, transpose, clipValue);
+ else if (isPacked(elementType) && sizeOf(elementType) == 1)
+ return Expression<FbgemmPacked8PackNodeOp>(a, packMat, elementType, transpose, clipValue);
+ else {
+ ABORT("Only int8 and fp16 are available. {}", elementType);
+ return nullptr;
+ }
+}
+
+static inline Expr dot(Expr a, Expr b, Shape bShape, bool transA, bool transB, float scalar) {
+ std::vector<Expr> nodes = {a, b};
+ Type elementType = b->value_type();
+
+ if (elementType == Type::packed16)
+ return Expression<FbgemmPacked16AffineNodeOp>(nodes, bShape, transA, transB, scalar);
+ else if (isPacked(elementType) && sizeOf(elementType) == 1)
+ return Expression<FbgemmPacked8AffineNodeOp>(nodes, bShape, transA, transB, scalar);
+ else {
+ ABORT("Only int8 and fp16 are available. {}", elementType);
+ return nullptr;
+ }
+}
+
+} // namespace variant
+} // namespace cpu
+} // namespace marian
diff --git a/src/tensors/cpu/fbgemm/expression_graph_packable.h b/src/tensors/cpu/fbgemm/expression_graph_packable.h
new file mode 100644
index 00000000..743b7c8c
--- /dev/null
+++ b/src/tensors/cpu/fbgemm/expression_graph_packable.h
@@ -0,0 +1,153 @@
+#pragma once
+
+#include "graph/expression_graph.h"
+#include "packed_gemm.h"
+
+namespace marian {
+
+// When FBGEMM based packed GEMM is used, some weight matrices need to be packed offline.
+// The decision which weights can be packed or not should be done walking through the graph.
+// This requires some more changes, but we temporarily do this just by name ("_W") of the weights.
+// And, this introduces a low level packed_gemm.h apis interact with high level graph class.
+// So, we make a subclass of ExpressionGraph and put those immature codes in this class.
+// We will improve this in the near future.
+class ExpressionGraphPackable : public ExpressionGraph {
+public:
+ ExpressionGraphPackable()
+ : ExpressionGraph( /* inference = */ true) {} // Packable expression graph only supports inference
+
+ virtual ~ExpressionGraphPackable() {}
+
+ // Convert model weights into packed format and save to IO items.
+ // @TODO: review this
+ void packAndSave(const std::string& name, const std::string& meta, Type gemmElementType = Type::float32, Type saveElementType = Type::float32) {
+ std::vector<io::Item> ioItems;
+
+ // sorted by name in std::map
+ for (auto p : params()->getMap()) {
+ std::string pName = p.first;
+
+ if (!namespace_.empty()) {
+ if (pName.substr(0, namespace_.size() + 2) == namespace_ + "::")
+ pName = pName.substr(namespace_.size() + 2);
+ }
+
+ Tensor val = p.second->val();
+
+ // save as packed format
+ // @TODO Hardcoded to find packable weights - all the weights used for affine op (fp16), all the weights used for affine op and dot op (int8)
+ if ((gemmElementType == Type::packed8avx2 || gemmElementType == Type::packed8avx512)
+ && (pName.find("_W") == pName.length() - 3 || pName.find("_W") == pName.length() - 2)) {
+ #if USE_FBGEMM
+ using namespace marian::cpu::variant;
+ // packing information - size
+ int nrow;
+ int ncol;
+ uint64_t packsize;
+
+ fbgemmPacked8PackInfo(val->shape(),
+ gemmElementType,
+ pName.find("Wemb") != std::string::npos,
+ nrow,
+ ncol,
+ packsize);
+
+ auto allocator = New<TensorAllocator>(getBackend());
+
+ // buffer tensor to save packed matrix
+ Tensor packedTensor;
+ allocator->allocate(packedTensor, { 1, (int32_t)packsize }, Type::uint8);
+
+ //Pack B matrix into int8
+ fbgemmPacked8Pack(packedTensor,
+ val->data(),
+ gemmElementType,
+ pName.find("Wemb") != std::string::npos,
+ nrow,
+ ncol,
+ packsize);
+ io::Item item;
+ item.name = pName;
+ item.shape = val->shape();
+ item.type = gemmElementType;
+
+ // Use the actual memory as this will be aligned and padded.
+ // When memory mapping this is required. Shape keeps track of
+ // tensor size. Saving to *.npz will cut to size.
+ auto mem = packedTensor->memory();
+ item.bytes.resize(mem->size());
+ copy(backend_, mem->data<char>(), mem->data<char>() + mem->size(), item.bytes.data());
+
+ ioItems.emplace_back(std::move(item));
+#else
+ ABORT("Packed type {} only supported when compiled with -DUSE_FBGEMM=on", gemmElementType);
+#endif
+ } else if (gemmElementType == Type::packed16 && pName.find("_W") == pName.length() - 3) {
+#if USE_FBGEMM
+ using namespace marian::cpu::variant;
+
+ // packing information
+ int nrow, ncol, kernel_ncol_blocks, brow, bcol, last_brow, nbrow, nbcol;
+ uint64_t packsize;
+
+ fbgemmPacked16PackInfo(val->shape(),
+ false,
+ nrow,
+ ncol,
+ kernel_ncol_blocks,
+ brow,
+ bcol,
+ last_brow,
+ nbrow,
+ nbcol,
+ packsize);
+
+ auto allocator = New<TensorAllocator>(getBackend());
+
+ Tensor packedTensor;
+ allocator->allocate(packedTensor, { 1, (int32_t)packsize }, Type::uint8);
+
+ // fbgemmPacked16Pack
+ fbgemmPacked16Pack(packedTensor,
+ val->data(),
+ false,
+ nrow,
+ ncol,
+ kernel_ncol_blocks,
+ brow,
+ bcol,
+ last_brow,
+ nbrow,
+ nbcol,
+ packsize);
+ io::Item item;
+ item.name = pName;
+ item.shape = val->shape();
+ item.type = gemmElementType;
+
+ // Use the actual memory as this will be aligned and padded.
+ // When memory mapping this is required. Shape keeps track of
+ // tensor size. Saving to *.npz will cut to size.
+ auto mem = packedTensor->memory();
+ item.bytes.resize(mem->size());
+ copy(backend_, mem->data<char>(), mem->data<char>() + mem->size(), item.bytes.data());
+
+ ioItems.emplace_back(std::move(item));
+#else
+ ABORT("Packed type {} only supported when compiled with -DUSE_FBGEMM=on", gemmElementType);
+#endif
+ } else {
+ io::Item item;
+ val->get(item, pName);
+ item.convert(saveElementType);
+ ioItems.emplace_back(std::move(item));
+ }
+ }
+
+ if (!meta.empty())
+ io::addMetaToItems(meta, "special:model.yml", ioItems);
+ io::saveItems(name, ioItems);
+ }
+};
+
+} // namespace marian \ No newline at end of file
diff --git a/src/tensors/cpu/fbgemm/packed_gemm.cpp b/src/tensors/cpu/fbgemm/packed_gemm.cpp
new file mode 100644
index 00000000..a98d5e4a
--- /dev/null
+++ b/src/tensors/cpu/fbgemm/packed_gemm.cpp
@@ -0,0 +1,550 @@
+#include "packed_gemm.h"
+#include "tensors/tensor_allocator.h"
+#include "tensors/tensor_operators.h"
+
+#include <emmintrin.h>
+#include <immintrin.h>
+#include <tmmintrin.h>
+#include <xmmintrin.h>
+#include <cassert>
+#include <cstddef>
+#include <unordered_map>
+//#include <chrono>
+
+#if USE_FBGEMM
+#ifdef _MSC_VER
+#pragma warning(disable: 4505) // 'fbgemmAlignedAlloc' in fbgemm.h: unreferenced local function has been removed (missing 'static inline')
+#pragma warning(disable: 4251) // 'fbgemm::CompressedSparseColumn::colptr_': class 'std::vector<int,std::allocator<_Ty>>' needs to have dll-interface to be used by clients of class 'fbgemm::CompressedSparseColumn'
+#pragma warning(disable: 4661) // 'fbgemm::PackMatrix<fbgemm::PackBMatrix<int8_t,int32_t>,int8_t,int32_t>::PackMatrix(int32_t,int32_t,inpType *,int,const fbgemm::BlockingFactors *)': no suitable definition provided for explicit template instantiation request
+#pragma warning(disable: 4244) // fbgemm\quantutils.h(51): warning C4244: 'return': conversion from 'const _Ty' to 'T2', possible loss of data
+#pragma warning(disable: 4717) // 'fbgemm::PackMatrix<fbgemm::PackAWithQuantRowOffset<unsigned char,int>,unsigned char,int>::isThisLastKBlock': recursive on all control paths, function will cause runtime stack overflow
+// the following does not work; need to manually disable them in Linker options
+//#pragma comment(linker, "/ignore:4049") // locally defined symbol ...asmjit... imported
+//#pragma comment(linker, "/ignore:4217") // locally defined symbol ...asmjit... imported
+#endif
+
+#ifdef __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-variable"
+#endif
+#include "3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h"
+#include "3rd_party/fbgemm/include/fbgemm/QuantUtils.h"
+#include "3rd_party/fbgemm/include/fbgemm/Fbgemm.h"
+#ifdef __GNUC__
+#pragma GCC diagnostic pop
+#endif
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#if MKL_FOUND
+#include <mkl.h>
+#include <mkl_types.h>
+#endif
+
+using namespace fbgemm;
+#endif // USE_FBGEMM
+
+namespace marian {
+namespace cpu {
+namespace variant { // Variants of GEMM implementations
+
+#if USE_FBGEMM
+// initialize with a dummy
+// When this class is instantiated,
+// the actual packing operation is happening. If we create this instance every time we call GEMM,
+// we are doing packing every time and very slow.
+// In Caffe2, the operator is stateful and hold an instance of this.
+// But, we don't have any logic for this in marian. We can only cache a tensor (which means a memory chunk).
+// So, for now, we keep the packed memory on our own 1D tensor, then when we call GEMM,
+// we just reuse this instance again and again by replacing the class members (including memory pointer). Eventually,
+// I will add a new constructor to the class in FBGEMM which accepts
+// pre - allocated and pre - packed memory as a parameter.After it's done,
+// this temporary buffer will be removed.
+// When constructing this dummy buffer, ones are used for all the parameters to allocate minimum amount of memory.
+//
+// In a multi marian instance setting (as a dynamic library),
+// different marian instances should not share this variable.
+static thread_local PackedGemmMatrixFP16 packedPlaceholder(1, 1, 1, 1, 1, 1, 1, 1);
+
+// Copied code from fbgemm. It's padding required from some kernel in FBGEMM
+// Verbatim - 'required by sw pipelined kernels'
+// https://github.com/marian-nmt/FBGEMM/blob/master/include/fbgemm/FbgemmFP16.h#L109
+const int PACK16_PADDING = 1024;
+
+// This is a memory space to store auxiliary variables for FBGEMM (e.g. block row, block column, kernel_ncol_blocks and etc.)
+const int PACK16_SPECIALMEM = 256;
+
+// This is copied from FBGEMM code
+// A better way?
+// will be removed, when FBGEMM api is changed
+// blocked row-major format address arithmetic
+/**
+ * Returns the memory address in the packed (block formatted) matrix array of a specific element
+ * indexed by the original non-packed array.
+ *
+ * @param r_ row index in the original matrix
+ * @param c_ column index in the original matrix
+ * @param brow_ row wide block index
+ * @param bcol_ column wide block index
+ * @param nbrow_ number of blocks in row
+ * @param nbcol_ number of blocks in column
+ * @param last_brow_ row number of the last block
+ */
+inline uint64_t addr(const int r_,
+ const int c_,
+ const int brow_,
+ const int bcol_,
+ const int nbrow_,
+ const int nbcol_,
+ const int last_brow_) {
+ uint64_t r = (uint64_t)r_;
+ uint64_t c = (uint64_t)c_;
+
+ uint64_t block_row_id = r / brow_;
+ uint64_t brow_offset = (block_row_id * nbcol_) * (brow_ * bcol_);
+ uint64_t block_col_id = c / bcol_;
+ uint64_t bcol_offset
+ = block_col_id * ((block_row_id != nbrow_ - 1) ? (brow_ * bcol_) : (last_brow_ * bcol_));
+ uint64_t block_offset = brow_offset + bcol_offset;
+ uint64_t inblock_offset = r % brow_ * bcol_ + c % bcol_;
+
+ uint64_t index = block_offset + inblock_offset;
+ return index;
+}
+
+// Memory blocking factors (parameters) for packing into AVX2 int8
+static const fbgemm::BlockingFactors Packed8Avx2BlockingFactors = {
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NR,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NR_MIN,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MCB,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NCB
+};
+
+// Memory blocking factors (parameters) for packing into AVX512 int8
+static const fbgemm::BlockingFactors Packed8Avx512BlockingFactors = {
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR_MIN,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MCB,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB,
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NCB
+};
+
+// This function returns the correct blocking factors structure for given packing type.
+inline const fbgemm::BlockingFactors* getBlockingFactors(marian::Type packType) {
+ if(packType == Type::packed8avx2) {
+ return &Packed8Avx2BlockingFactors;
+ } else if(packType == Type::packed8avx512) {
+ return &Packed8Avx512BlockingFactors;
+ } else {
+ ABORT("Only avx2 and avx512 instruction sets are supported for int8. {}", packType);
+ }
+}
+
+void fbgemmPacked16PackInfo(const marian::Shape& shape,
+ const bool transpose,
+ uint64_t& packsize) {
+ int nrow, ncol, kernel_ncol_blocks, brow = 512, bcol, last_brow, nbrow, nbcol;
+ fbgemmPacked16PackInfo(shape, transpose, nrow, ncol, kernel_ncol_blocks, brow, bcol, last_brow, nbrow, nbcol, packsize);
+}
+
+void fbgemmPacked16PackInfo(const marian::Shape& shape,
+ const bool transpose,
+ int& nrow,
+ int& ncol,
+ int& kernel_ncol_blocks,
+ int& brow,
+ int& bcol,
+ int& last_brow,
+ int& nbrow,
+ int& nbcol,
+ uint64_t& packsize) {
+ nrow = transpose ? shape[1] : shape[0];
+ ncol = transpose ? shape[0] : shape[1];
+ kernel_ncol_blocks = 2;
+ brow = 512;
+ bcol = 8 * kernel_ncol_blocks;
+ last_brow = nrow % brow == 0 ? brow : nrow % brow;
+ nbrow = nrow % brow == 0 ? nrow / brow : (nrow + brow) / brow;
+ nbcol = ncol % bcol == 0 ? ncol / bcol : (ncol + bcol) / bcol;
+ ABORT_IF(ncol % bcol != 0, "ncol (number of columns) should be multiple of 16. {}", ncol);
+ packsize = ((nbrow * brow) * (nbcol * bcol)) * sizeof(fbgemm::float16) + PACK16_PADDING
+ + PACK16_SPECIALMEM;
+}
+
+void fbgemmPacked8PackInfo(const marian::Shape& shape,
+ const marian::Type packType,
+ const bool transpose,
+ int& nrow,
+ int& ncol,
+ uint64_t& packsize) {
+ // Should be 2D - weight matrix
+ ABORT_IF(shape.size() != 2,
+ "Weight Matrix should be 2D");
+ nrow = transpose ? shape[1] : shape[0];
+ ncol = transpose ? shape[0] : shape[1];
+
+ const fbgemm::BlockingFactors* params = getBlockingFactors(packType);
+
+ packsize = fbgemm::PackMatrix<fbgemm::PackBMatrix<int8_t>, int8_t>::packedBufferSize(
+ transpose ? shape[1] : shape[0],
+ transpose ? shape[0] : shape[1], params);
+ // add extra space for storing some other variables specific to B matrix
+ // quantization sacles: 1 per column and float
+ // quantization offset: 1 per column and int32
+ // column offsets: 1 per column and int32
+ packsize += ncol * (sizeof(float) + sizeof(int32_t) + sizeof(int32_t));
+}
+
+// This function computes the offset values for each column which are used for compensating the remainders of quantized values
+// More detailed math is avilable in the FBGEMM's blog - https://engineering.fb.com/ml-applications/fbgemm/
+inline void col_offsets_with_zero_pt_s8acc32(
+ bool transpose,
+ int K,
+ int N,
+ const int8_t* Bint8,
+ const int32_t* B_zero_point,
+ int32_t* col_offsets,
+ int ncols_per_quant_group) {
+ for (int n = 0; n < N; ++n) {
+ int32_t sum = 0;
+ for (int k = 0; k < K; ++k) {
+ sum += transpose ? Bint8[k + n * K] : Bint8[k * N + n];
+ }
+ col_offsets[n] = sum - B_zero_point[n / ncols_per_quant_group] * K;
+ }
+}
+
+void fbgemmPacked16Pack(marian::Tensor out,
+ const float* inData, // Packing is only available for 2D weight matrix in Marian. Otherwise, it's aborted in expanded_gemm.h.
+ const bool transpose,
+ const int nrow,
+ const int ncol,
+ const int kernel_ncol_blocks,
+ const int brow,
+ const int bcol,
+ const int last_brow,
+ const int nbrow,
+ const int nbcol,
+ const uint64_t packsize) {
+ // initialize memory
+ uint8_t* outmemorg = out->data<uint8_t>();
+ for(auto i = 0; i < packsize; i++) {
+ outmemorg[i] = 0;
+ }
+ // save the other auxiliary variables
+ uint64_t* auxmemsize = (uint64_t*)outmemorg;
+ auxmemsize[0] = packsize;
+ // save FBGEMM related parameters into the header of the allocated memory by marian
+ int32_t header[8];
+ header[0] = nrow;
+ header[1] = ncol;
+ header[2] = kernel_ncol_blocks;
+ header[3] = brow;
+ header[4] = bcol;
+ header[5] = last_brow;
+ header[6] = nbrow;
+ header[7] = nbcol;
+ memcpy(auxmemsize + 1, header, sizeof(header));
+ // cast to float16
+ fbgemm::float16* outmem = (fbgemm::float16*)(outmemorg + 256);
+ fbgemm::float16* dummy = new fbgemm::float16;
+ // pack the matrix
+ for(int i = 0; i < nrow; i++) {
+ for(int j = 0; j < ncol; j++) {
+ outmem[addr(i, j, brow, bcol, nbrow, nbcol, last_brow)]
+ = tconv(!transpose ? inData[i * ncol + j] : inData[i + nrow * j], *dummy);
+ }
+ }
+ delete dummy;
+}
+
+void fbgemmPacked8Pack(marian::Tensor out,
+ const float* inData,
+ const marian::Type packType,
+ const bool transpose,
+ const int nrow,
+ const int ncol,
+ const uint64_t packsize) {
+ int k = nrow;
+ int n = ncol;
+ int len = k * n;
+
+ // 1. collect stats for each column
+ float* bqScale = new float[n];
+ int32_t* bqZeropoint = new int32_t[n];
+
+ const float* data = inData;
+ float val = 0;
+
+ if (transpose) {
+ for (int jj = 0; jj < n; jj++) {
+ float min = std::numeric_limits<float>::max(), max = std::numeric_limits<float>::min();
+ double mean = 0, sqrsum = 0;
+ for (int ii = 0; ii < k; ii++) {
+ val = data[jj * k + ii];
+ mean += val;
+ sqrsum += val * val;
+ }
+ mean /= k;
+ sqrsum /= k;
+ sqrsum -= mean * mean;
+ sqrsum = sqrt(sqrsum);
+
+ min = (float)(mean - 7.0f*sqrsum);
+ max = (float)(mean + 7.0f*sqrsum);
+ bqScale[jj] = (max - min) / 255;
+ bqZeropoint[jj] = (int32_t)(127 - max / bqScale[jj]);
+ }
+ } else {
+ for (int jj = 0; jj < n; jj++) {
+ float min = std::numeric_limits<float>::max(), max = std::numeric_limits<float>::min();
+ double mean = 0, sqrsum = 0;
+ for (int ii = 0; ii < k; ii++) {
+ val = data[jj + ii * n];
+ mean += val;
+ sqrsum += val * val;
+ }
+ mean /= k;
+ sqrsum /= k;
+ sqrsum -= mean * mean;
+ sqrsum = sqrt(sqrsum);
+
+ min = (float)(mean - 7.0f*sqrsum);
+ max = (float)(mean + 7.0f*sqrsum);
+ bqScale[jj] = (max - min) / 255;
+ bqZeropoint[jj] = (int32_t)(127 - max / bqScale[jj]);
+ }
+ }
+
+ // 2. quantize
+ int8_t* quantized = 0;
+#ifdef _MSC_VER
+ quantized = (int8_t*)_aligned_malloc(len, 256);
+#else
+ int result = posix_memalign((void**)&quantized, 256, len); result;
+ assert(result == 0);
+#endif
+ for (int jj = 0; jj < n; jj++) {
+ TensorQuantizationParams bQuantParam;
+ bQuantParam.scale = bqScale[jj];
+ bQuantParam.zero_point = bqZeropoint[jj];
+ bQuantParam.precision = 8;
+
+ if (transpose)
+ fbgemm::Quantize<int8_t>(data + jj * k, quantized + jj * k, k, bQuantParam);
+ else {
+ for (int ii = 0; ii < k; ii++) {
+ quantized[ii*n + jj] = fbgemm::Quantize<int8_t>(data[ii*n + jj], bQuantParam);
+ }
+ }
+ }
+
+ // 3. compute column offsets
+ int32_t* col_offsets = new int32_t[n];
+ col_offsets_with_zero_pt_s8acc32(transpose, k, n, quantized, bqZeropoint, col_offsets, 1);
+
+
+ int8_t* packedbuf = out->data<int8_t>();
+ for(auto i = 0; i < packsize; i++) {
+ packedbuf[i] = 0;
+ }
+
+ // 4. packing
+ const fbgemm::BlockingFactors* params = getBlockingFactors(packType);
+
+ PackBMatrix<int8_t> packedBN(
+ transpose ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
+ nrow, ncol, quantized, transpose ? nrow : ncol, packedbuf, 1, params);
+
+ // copy quantization scale
+ memcpy(packedbuf + (packsize - n * (sizeof(float) + sizeof(int32_t) + sizeof(int32_t))), bqScale, n * sizeof(float));
+ // copy quantization offset
+ memcpy(packedbuf + (packsize - n * (sizeof(int32_t) + sizeof(int32_t))), bqZeropoint, n * sizeof(int32_t));
+ // copy column offsets to the memory
+ memcpy(packedbuf + (packsize - n * sizeof(int32_t)), col_offsets, n * sizeof(int32_t));
+
+#ifdef _MSC_VER
+ _aligned_free(quantized);
+#else
+ free(quantized);
+#endif
+ delete[] col_offsets;
+ delete[] bqScale;
+ delete[] bqZeropoint;
+}
+
+// GEMM operation on the packed B matrix
+// C: output matrix
+// A: A matrix
+// B: B matrix (packed)
+// m: the number of rows in A and C
+// n: the number of columns in B and C
+// transA: transpose of A matrix
+// B is already packed. So, we don't need transB
+void fbgemmPacked16Gemm(marian::Tensor C,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ const marian::Tensor bias,
+ const size_t m,
+ const size_t n,
+ const int transA) {
+ // row major
+ // keep the original mem
+ fbgemm::float16* pmat = packedPlaceholder.pmat_;
+ // retreive aux fields from the memory
+ uint64_t* packedmemSize = (uint64_t*)B->data();
+ packedPlaceholder.size_ = packedmemSize[0];
+ int32_t header[8];
+ memcpy(header, packedmemSize + 1, sizeof(header));
+ packedPlaceholder.nrow_ = header[0];
+ packedPlaceholder.ncol_ = header[1];
+ packedPlaceholder.kernel_ncol_blocks_ = header[2];
+ packedPlaceholder.brow_ = header[3];
+ packedPlaceholder.bcol_ = header[4];
+ packedPlaceholder.last_brow_ = header[5];
+ packedPlaceholder.nbrow_ = header[6];
+ packedPlaceholder.nbcol_ = header[7];
+
+ // packed matrix
+ packedPlaceholder.pmat_ = (fbgemm::float16*)(B->data<uint8_t>() + 256);
+
+ if(bias != nullptr) {
+#if MKL_FOUND
+ for(int i = 0; i < m; ++i) {
+ mkl_somatcopy('R', 'N', 1, n, 1, bias->data(), n, C->data() + n * i, n);
+ }
+#else
+ for(int i = 0; i < m; ++i) {
+ std::copy(bias->data(), bias->data() + n, C->data() + n * i);
+ }
+#endif
+ }
+
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ {
+#ifdef _OPENMP
+ int num_threads = omp_get_num_threads();
+ int tid = omp_get_thread_num();
+#else
+ int num_threads = 1;
+ int tid = 0;
+#endif
+ fbgemm::cblas_gemm_compute(transA ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
+ (int)m,
+ A->data(),
+ packedPlaceholder,
+ bias != nullptr ? 1.0f : 0.0f,
+ C->data(),
+ tid,
+ num_threads);
+ }
+
+ // return back the original mem
+ packedPlaceholder.pmat_ = pmat;
+}
+
+// GEMM operation on the packed B matrix in 8 bit integers
+// C: output matrix
+// A: A matrix
+// B: B matrix (packed)
+// m: the number of rows in A and C
+// n: the number of columns in B and C
+// k: the number of columns in A and the number of rows in B
+// transA: whether A matrix is transposed or not
+// transB: whether B matrix is transposed or not
+void fbgemmPacked8Gemm(marian::Tensor C,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ const size_t m,
+ const size_t n,
+ const size_t k,
+ const int transA,
+ const int transB) {
+ // pack type
+ marian::Type packType = B->type();
+
+ const fbgemm::BlockingFactors* params = getBlockingFactors(packType);
+
+ if((packType == Type::packed8avx2 && fbgemmHasAvx512Support())
+ || (packType == Type::packed8avx512 && !fbgemmHasAvx512Support())) {
+ ABORT("FBGEMM doesn't allow to use {} packing order on {} CPUs",
+ packType == Type::packed8avx2 ? "AVX2" : "AVX512",
+ fbgemmHasAvx512Support() ? "AVX512" : "AVX2");
+ }
+
+ // compute range to quantize A (activations) - (min/max quantization)
+ float min_est = std::numeric_limits<float>::max(), max_est = std::numeric_limits<float>::min();
+
+ int elem = A->shape().elements();
+ float* data = A->data();
+ // AVX based find min/max
+ FindMinMax(data, &min_est, &max_est, elem);
+
+ float ascale = (max_est - min_est) / 255;
+ int32_t azeropoint = (int32_t)(255 - max_est / ascale);
+
+ std::vector<int32_t> row_offset_buf(PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize());
+ PackAWithQuantRowOffset<uint8_t> packAN(
+ transA ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
+ (int32_t)(transA ? k : m),
+ (int32_t)(transA ? m : k),
+ A->data(),
+ (int32_t)(transA ? m : k),
+ nullptr, /*buffer for packed matrix*/
+ ascale,
+ azeropoint,
+ 1, /*groups*/
+ row_offset_buf.data(),
+ params);
+
+ // packed matrix size of B
+ int bPackSize = PackMatrix<PackBMatrix<int8_t>, int8_t>::packedBufferSize((int32_t)k, (int32_t)n);
+
+ // retrieve B matrix
+ int8_t* bdata = B->data<int8_t>();
+ float* bqScale = new float[n];
+ memcpy(bqScale, bdata + bPackSize, n * sizeof(float));
+
+ int32_t* bqZeropoint = new int32_t[n];
+ memcpy(bqZeropoint, bdata + bPackSize + n * sizeof(float), n * sizeof(int32_t));
+
+ int32_t* col_offsets = new int32_t[n];
+ memcpy(col_offsets, bdata + bPackSize + n * (sizeof(float) + sizeof(int32_t)), n * sizeof(int32_t));
+
+ DoNothing<float, float> doNothingObj{};
+ ReQuantizeForFloat<false, QuantizationGranularity::OUT_CHANNEL> outputProcObj(
+ doNothingObj,
+ ascale,
+ bqScale,
+ azeropoint,
+ bqZeropoint,
+ packAN.getRowOffsetBuffer(),
+ col_offsets,
+ nullptr,
+ (std::uint32_t) n);
+
+ PackBMatrix<int8_t> repackedBN(
+ transB ? matrix_op_t::Transpose : matrix_op_t::NoTranspose, (int32_t) k, (int32_t) n, bdata, (int32_t) (transB ? k : n), 1, params);
+
+ // gemm computation
+ fbgemmPacked(packAN, repackedBN, C->data(), (int32_t*)C->data(), (int32_t) n, outputProcObj, 0, 1, params);
+
+ delete[] col_offsets;
+ delete[] bqZeropoint;
+ delete[] bqScale;
+}
+
+#endif // USE_FBGEMM
+
+} // namespace variant
+} // namespace cpu
+} // namespace marian
diff --git a/src/tensors/cpu/fbgemm/packed_gemm.h b/src/tensors/cpu/fbgemm/packed_gemm.h
new file mode 100644
index 00000000..d0a63ea9
--- /dev/null
+++ b/src/tensors/cpu/fbgemm/packed_gemm.h
@@ -0,0 +1,141 @@
+#pragma once
+
+#include "tensors/tensor.h"
+
+namespace marian {
+namespace cpu {
+namespace variant { // Variants of GEMM implementations
+
+// Returns the byte size of packed matrix in fp16. It's calculated by fbgemm's internal logic due to the paddings and different layouts.
+// Packing with fp16 only targets AVX2 instruction sets for now.
+// See '3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h'.
+// shape: shape of the tensor to be packed
+// transpose: the matrix is transposed
+// packsize (out): the size of the packed matrix in byte
+void fbgemmPacked16PackInfo(const marian::Shape& shape,
+ const bool transpose,
+ /*out*/uint64_t& packsize);
+
+// Returns the byte size of packed matrix in fp16. It's calculated by fbgemm's internal logic due to the paddings and different layouts.
+// This function returns some other extra variables
+// Packing with fp16 only targets AVX2 instruction sets for now.
+// See '3rd_party/fbgemm/include/fbgemm/FbgemmFP16.h'.
+// shape: shape of the tensor to be packed
+// transpose: the matrix is transposed
+// nrow (out): the number of rows
+// ncol (out): the number of columns
+// kernel_ncol_blocks (out): the number of column blocks
+// brow (out): the number of rows in a block
+// bcol (out): the number of columns in a block
+// last_brow (out): the number of rows in the last block
+// nbrow (out): row index in a block
+// nbcol (out): column index in a block
+// packsize (out): the size of the packed matrix in byte
+void fbgemmPacked16PackInfo(const marian::Shape& shape,
+ const bool transpose,
+ /*out*/int& nrow,
+ /*out*/int& ncol,
+ /*out*/int& kernel_ncol_blocks,
+ /*out*/int& brow,
+ /*out*/int& bcol,
+ /*out*/int& last_brow,
+ /*out*/int& nbrow,
+ /*out*/int& nbcol,
+ /*out*/uint64_t& packsize); // @TODO: change to size_t where appropriate
+
+// Returns the byte size of packed matrix in int8. It's calculated by fbgemm's internal logic due to the paddings and different layouts.
+// See '3rd_party/fbgemm/src/PackBMatrix.cc'.
+// shape: shape of the tensor to be packed
+// packType: Type to be packed - packed8avx2 or packed8avx512
+// transpose: the matrix is transposed
+// nrow (out): the number of rows
+// ncol (out): the number of columns
+// packsize (out): the size of the packed matrix in byte
+void fbgemmPacked8PackInfo(const marian::Shape& shape,
+ const marian::Type packType,
+ const bool transpose,
+ /*out*/int& nrow,
+ /*out*/int& ncol,
+ /*out*/uint64_t& packsize);
+
+// Pack a matrix (fp16) into cache utilization efficient way (block format) into fp16
+// out: output tensor - packed format
+// inData: input tensor data - pointer of float data
+// transpose: the matrix is transposed
+// nrow: the number of rows
+// ncol: the number of columns
+// kernel_ncol_blocks: the number of column blocks
+// brow: the number of rows in a block
+// bcol: the number of columns in a block
+// last_brow: the number of rows in the last block
+// nbrow: row index in a block
+// nbcol: column index in a block
+// packsize: the size of the packed matrix
+// (the number of fp16 elements + padding (1024) + extra temporary memory (256))
+void fbgemmPacked16Pack(marian::Tensor out,
+ const float* inData,
+ const bool transpose,
+ const int nrow,
+ const int ncol,
+ const int kernel_ncol_blocks,
+ const int brow,
+ const int bcol,
+ const int last_brow,
+ const int nbrow,
+ const int nbcol,
+ const uint64_t packsize); // @TODO: change to size_t where appropriate
+
+// Pack a matrix (int8) into cache utilization efficient way (block format) together with quantization into int8
+// out: output tensor - packed format and quantized into int8
+// inData: input tensor data - pointer of float data
+// packType: Type to be packed - packed8avx2 or packed8avx512
+// transpose: the matrix is transposed
+// nrow: the number of rows
+// ncol: the number of columns
+// packsize: the size of the packed matrix
+// (the size of int8 packed B from fbgemm:PackAWithQuantRowOffset + quantization scale, offset and zero point)
+void fbgemmPacked8Pack(marian::Tensor out,
+ const float* inData,
+ const marian::Type packType,
+ const bool transpose,
+ const int nrow,
+ const int ncol,
+ const uint64_t packsize); // @TODO: change to size_t where appropriate
+
+// GEMM operation on the packed B matrix
+// C: output matrix
+// A: A matrix
+// B: B matrix (packed)
+// m: the number of rows in A and C
+// n: the number of columns in B and C
+// transA: transpose of A matrix
+// B is already packed. So, we don't need transB
+void fbgemmPacked16Gemm(marian::Tensor C,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ const marian::Tensor bias,
+ const size_t m,
+ const size_t n,
+ const int transA = 0);
+
+// GEMM operation on the packed B matrix in 8 bit integers
+// C: output matrix
+// A: A matrix
+// B: B matrix (packed)
+// m: the number of rows in A and C
+// n: the number of columns in B and C
+// k: the number of columns in A and rows in B
+// transA: transpose of A matrix
+// transB: transpose of B matrix
+void fbgemmPacked8Gemm(marian::Tensor C,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ const size_t m,
+ const size_t n,
+ const size_t k,
+ const int transA = 0,
+ const int transB = 0);
+
+} // namespace variant
+} // namespace cpu
+} // namespace marian
diff --git a/src/tensors/cpu/int16.h b/src/tensors/cpu/int16.h
index b63cc91d..abb465b6 100644
--- a/src/tensors/cpu/int16.h
+++ b/src/tensors/cpu/int16.h
@@ -31,7 +31,7 @@ private:
public:
DotNodeOp(Expr a, Expr b, float scalar)
- : NaryNodeOp({a, b}, newShape(a, b)), scalar_(scalar) {}
+ : NaryNodeOp({a, b}, newShape(a, b), Type::float32), scalar_(scalar) {}
Shape newShape(Expr a, Expr b) {
auto shapeA = a->shape();
@@ -66,7 +66,7 @@ private:
public:
AffineNodeOp(const std::vector<Expr>& nodes, float scalar)
- : NaryNodeOp(nodes, newShape(nodes[0], nodes[1])), scalar_(scalar) {}
+ : NaryNodeOp(nodes, newShape(nodes[0], nodes[1]), Type::float32), scalar_(scalar) {}
Shape newShape(Expr a, Expr b) {
auto shapeA = a->shape();
diff --git a/src/tensors/cpu/prod.cpp b/src/tensors/cpu/prod.cpp
index 69923f87..ac13ccee 100755
--- a/src/tensors/cpu/prod.cpp
+++ b/src/tensors/cpu/prod.cpp
@@ -167,5 +167,76 @@ void ProdWithBias(marian::Tensor C,
cpu::int16::AddBias(C, bias);
}
+void CSRProd(marian::Tensor C,
+ Ptr<Allocator> /*allocator*/,
+ const marian::Tensor& S_values,
+ const marian::Tensor& S_indices,
+ const marian::Tensor& S_offsets,
+ const marian::Tensor& D,
+ bool transS,
+ bool swapOperands,
+ float beta) {
+ C, S_values, S_indices, S_offsets, D;
+
+ // Note: The CPU implementation currently only implements what's needed for decoding.
+
+ // interpret tensor dimensions as matrix dimensions
+ const auto& shapeC = C->shape();
+ const auto& shapeD = D->shape();
+ // If swapOperands, S and D are swapped (C = D x S instead of C = S x D).
+ // In that case, in the next 6 lines, please read all dimensions as if they were reversed in order.
+ auto rowsC = shapeC[-(int)swapOperands];
+ auto colsC = shapeC.elements() / rowsC;
+ auto rowsD = shapeD[-(int)swapOperands];
+ auto colsD = shapeD.elements() / rowsD;
+ auto rowsS = transS ? rowsD : rowsC;
+ auto colsS = transS ? rowsC : rowsD;
+ ABORT_IF(colsD != colsC, "Inconsistent outer dimensions in CSR product");
+ if (swapOperands) { // make rowsX actual row dimensions again, likewise colsX
+ std::swap(rowsC, colsC);
+ std::swap(rowsD, colsD);
+ std::swap(rowsS, colsS);
+ }
+ // sparse arrays
+ auto numOffsets = S_offsets->shape().elements() - 1; // -1 since last value is length
+ ABORT_IF(numOffsets != rowsS, "Unexpected number of rows in CSR argument"); numOffsets;
+ ABORT_IF(S_values->shape() != S_indices->shape(), "CSR values and indices must have the same size");
+ if (!transS && !swapOperands) {
+ // C = S * D, where D = CSR matrix
+ const auto* offsets = S_offsets->data<IndexType>();
+ const auto* indices = S_indices->data<IndexType>();
+ const auto* values = S_values->data<float>();
+ const auto* dataD = D->data<float>();
+ auto* dataC = C->data<float>();
+ ABORT_IF(beta != 0 && beta != 1, "cpu::CSRProd only supports beta = 0 or 1");
+ for (size_t i = 0; i < rowsC; i++) {
+ auto add = (beta == 1); // first element: overwrite or add according to beta; subsequent elements: add
+ for (size_t kk = offsets[i]; kk < offsets[i + 1]; kk++) {
+ auto k = indices[kk]; // fetch the non-zero row
+ auto valS = values[kk]; // and the value from that row
+ // This code is written with the hope for good vectorization, and the hope
+ // that adding to memory will be done efficiently by the caching system.
+ if (valS == 1)
+ if (!add)
+ for (size_t j = 0; j < colsC; j++)
+ dataC[i * colsC/*==colsD*/ + j] = dataD[k * colsC/*==colsD*/ + j]; // this is a memcpy()
+ else
+ for (size_t j = 0; j < colsC; j++)
+ dataC[i * colsC/*==colsD*/ + j] += dataD[k * colsC/*==colsD*/ + j]; // this is a contiguous-vector addition
+ else
+ if (!add)
+ for (size_t j = 0; j < colsC; j++)
+ dataC[i * colsC/*==colsD*/ + j] = valS * dataD[k * colsC/*==colsD*/ + j];
+ else
+ for (size_t j = 0; j < colsC; j++)
+ dataC[i * colsC/*==colsD*/ + j] += valS * dataD[k * colsC/*==colsD*/ + j]; // notice the +=
+ add = true; // next iteration will add to existing result
+ }
+ }
+ }
+ else
+ ABORT("CSRProd for transS={}, swapOperands={} is not yet implemented for CPU", transS, swapOperands);
+}
+
} // namespace cpu
} // namespace marian
diff --git a/src/tensors/cpu/sharp/avx_gemm.cpp b/src/tensors/cpu/sharp/avx_gemm.cpp
index c41b73eb..61f75fea 100644
--- a/src/tensors/cpu/sharp/avx_gemm.cpp
+++ b/src/tensors/cpu/sharp/avx_gemm.cpp
@@ -495,6 +495,7 @@ void AVX_MatrixMult8(const __m512i *A,
put.Write(C + (i + 4) * num_B_rows + j, Reduce16to32(sum5, sum6));
put.Write(C + (i + 6) * num_B_rows + j, Reduce16to32(sum7));
}
+ /* fall through */
case 6:
for(int j = 0; j < num_B_rows; j++) {
const __m512i *B_row = B + j * sse_width;
@@ -518,6 +519,7 @@ void AVX_MatrixMult8(const __m512i *A,
put.Write(C + i * num_B_rows + j, Reduce16to32(sum1, sum2, sum3, sum4));
put.Write(C + (i + 4) * num_B_rows + j, Reduce16to32(sum5, sum6));
}
+ /* fall through */
case 5:
for(int j = 0; j < num_B_rows; j++) {
const __m512i *B_row = B + j * sse_width;
@@ -539,6 +541,7 @@ void AVX_MatrixMult8(const __m512i *A,
put.Write(C + i * num_B_rows + j, Reduce16to32(sum1, sum2, sum3, sum4));
put.Write(C + (i + 4) * num_B_rows + j, Reduce16to32(sum5));
}
+ /* fall through */
case 4:
for(int j = 0; j < num_B_rows; j++) {
const __m512i *B_row = B + j * sse_width;
@@ -557,6 +560,7 @@ void AVX_MatrixMult8(const __m512i *A,
}
put.Write(C + i * num_B_rows + j, Reduce16to32(sum1, sum2, sum3, sum4));
}
+ /* fall through */
case 3:
for(int j = 0; j < num_B_rows; j++) {
const __m512i *B_row = B + j * sse_width;
@@ -574,6 +578,7 @@ void AVX_MatrixMult8(const __m512i *A,
put.Write(C + i * num_B_rows + j, Reduce16to32(sum1, sum2));
put.Write(C + (i + 2) * num_B_rows + j, Reduce16to32(sum3));
}
+ /* fall through */
case 2:
for(int j = 0; j < num_B_rows; j++) {
const __m512i *B_row = B + j * sse_width;
@@ -588,6 +593,7 @@ void AVX_MatrixMult8(const __m512i *A,
}
put.Write(C + i * num_B_rows + j, Reduce16to32(sum1, sum2));
}
+ /* fall through */
case 1:
for(int j = 0; j < num_B_rows; j++) {
const __m512i *B_row = B + j * sse_width;
diff --git a/src/tensors/cpu/sharp/int_gemm.cpp b/src/tensors/cpu/sharp/int_gemm.cpp
index e04446bc..cdb7cbf0 100644
--- a/src/tensors/cpu/sharp/int_gemm.cpp
+++ b/src/tensors/cpu/sharp/int_gemm.cpp
@@ -89,17 +89,13 @@ void AddBias(marian::Tensor C, const marian::Tensor Bias) {
const float* x = C->data();
const float* bias = Bias->data();
- int m = C->shape().elements() / C->shape()[-1];
- int n = C->shape()[-1];
-#ifdef __AVX512F__
- int n16 = n & ~15;
-#else
- int n4 = (n / 4) * 4;
-#endif
+ const int m = C->shape().elements() / C->shape()[-1];
+ const int n = C->shape()[-1];
for(int j = 0; j < m; ++j) {
int i = 0;
#ifdef __AVX512F__
+ int n16 = n & ~15;
for(; i < n16; i += 16) {
__m512 ai = _mm512_loadu_ps(x + j * n + i);
__m512 bi = _mm512_loadu_ps(bias + i);
@@ -107,6 +103,7 @@ void AddBias(marian::Tensor C, const marian::Tensor Bias) {
_mm512_storeu_ps(y + j * n + i, yi);
}
#else
+ int n4 = (n / 4) * 4;
for(; i < n4; i += 4) {
__m128 ai = _mm_loadu_ps(x + j * n + i);
__m128 bi = _mm_loadu_ps(bias + i);
diff --git a/src/tensors/cpu/sharp/int_gemm.h b/src/tensors/cpu/sharp/int_gemm.h
index 3ae23156..3ae23156 100755..100644
--- a/src/tensors/cpu/sharp/int_gemm.h
+++ b/src/tensors/cpu/sharp/int_gemm.h
diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp
index a3b6bf9a..ae5eeed5 100755
--- a/src/tensors/cpu/tensor_operators.cpp
+++ b/src/tensors/cpu/tensor_operators.cpp
@@ -5,22 +5,55 @@
#include "tensors/tensor_operators.h"
#include "tensors/cpu/backend.h"
+#include "tensors/allocator.h"
#include "functional/approx.h"
#include "functional/functional.h"
#include "functional/tensor.h"
+#include "functional/operators.h"
+
+#if MKL_FOUND
+#include <mkl.h>
+#endif
namespace marian {
namespace cpu {
-inline float stableSigmoid(float x) {
- if(x >= 0) {
- float z = expf(-x);
- return 1.0f / (1.0f + z);
+ void IsNaN(const Tensor /*in*/, Ptr<Allocator> /*allocator*/, bool& /*isNaN*/, bool& /*isInf*/) {
+ ABORT("Not implemented");
+}
+
+template <typename To, typename From>
+void CopyCastTo(To* out, const From* in, int length) {
+ for(int i = 0; i < length; ++i)
+ out[i] = in[i];
+}
+
+// Casting has been factored into two functions "CopyCastFrom" and
+// "CopyCastTo". This only serves the purpuse to autmatically create
+// the full Carthesian product of possible type cast via template magic.
+// Extending CopyCast and CopyCastFrom with a new branch in the "if" clause
+// adds all possible variants.
+template <typename T>
+void CopyCastFrom(Tensor out, const T* in, int length) {
+ if(out->type() == Type::float32) {
+ CopyCastTo(out->data<float>(), in, length);
+ } else if(out->type() == Type::float16) {
+ CopyCastTo(out->data<float16>(), in, length);
} else {
- float z = expf(x);
- return z / (1.0f + z);
+ ABORT("CopyCastTo to type {} not implemented", out->type());
+ }
+}
+
+// currently useless on the CPU until more types are added
+void CopyCast(Tensor out, const Tensor in) {
+ if(in->type() == Type::float32) {
+ CopyCastFrom(out, in->data<float>(), (int)in->size());
+ } else if(in->type() == Type::float16) {
+ CopyCastFrom(out, in->data<float16>(), (int)in->size());
+ } else {
+ ABORT("CopyCastFrom from type {} not implemented", in->type());
}
}
@@ -181,6 +214,71 @@ void Transpose0213(Tensor out, Tensor in) {
}
}
+// This function is called only when MKL is available.
+#if MKL_FOUND
+// Given a 4D array, transpose (swap) the initial 3 dimensions while keeping the last dimension.
+// e.g. 1234 --> 2134, 1234 --> 3214 (4 is always kept).
+// This is an optimized version for swapping first 3 dimensions
+// assuming the last dimension is large enough to get benefits from vectorized copy.
+//
+// @param out output tensor
+// @param in input tensor
+// @param vAxis target (transposed) axes of each given axes
+template <bool add>
+void TransposeFirst3In4(Tensor out, Tensor in, const std::vector<int>& vAxis) {
+ ABORT_IF(vAxis.size() != 4, "This function handles only 4D arrays.");
+ int innermost = in->shape()[-1];
+
+ int l1 = in->shape()[vAxis[0]];
+ int l2 = in->shape()[vAxis[1]];
+ int l3 = in->shape()[vAxis[2]];
+
+ // find the mapping between the transposed output dimensional indices (oi, oj, ok)
+ // and original input dimensional indices (i, j, k)
+ int oi, oj, ok;
+#pragma omp parallel for
+ for(int k = 0; k < l1; ++k) {
+ int shift = k * l2 * l3;
+ for(int j = 0; j < l2; ++j) {
+ for(int i = 0; i < l3; ++i) {
+ if(vAxis[0] == 0) {
+ if(vAxis[1] == 1) {
+ oi = i; oj = j; ok = k;
+ } else {
+ oi = j; oj = i; ok = k;
+ }
+ } else if(vAxis[0] == 1) {
+ if(vAxis[1] == 0) {
+ oi = i; oj = k; ok = j;
+ } else {
+ oi = j; oj = k; ok = i;
+ }
+ } else {
+ if(vAxis[1] == 0) {
+ oi = k; oj = i; ok = j;
+ } else {
+ oi = k; oj = j; ok = i;
+ }
+ }
+ int src = ok * in->shape()[1] * in->shape()[2] + oj * in->shape()[2] + oi;
+ int dst = l3 * j + shift + i;
+
+ const float* inRow = in->data() + src * innermost;
+ float* outRow = out->data() + dst * innermost;
+
+ if(!add) {
+ mkl_somatcopy('R', 'N', 1, innermost, 1.0f, inRow, innermost, outRow, innermost);
+ } else {
+ for(int ii = 0; ii < innermost; ++ii) {
+ outRow[ii] += inRow[ii];
+ }
+ }
+ }
+ }
+ }
+}
+#endif // MKL_FOUND
+
inline void transpose4x4_SSE(const float* A,
float* B,
const int lda,
@@ -247,16 +345,25 @@ void TransposeGeneric(Tensor out, Tensor in, const std::vector<int>& vAxis) {
gOut.shape().dims(index, oDims);
for(size_t i = 0; i < N; ++i)
pDims[permute[i]] = oDims[i];
+
+ // @TODO: where does this change come from?
+ int inIndex = gIn.shape().index(pDims);
+
+ // @TODO: use internal conversion instead of raw indices
if(add)
- gOut[index] += gIn[pDims];
+ gOut.data()[index] += gIn.data()[inIndex];
else
- gOut[index] = gIn[pDims];
+ gOut.data()[index] = gIn.data()[inIndex];
}
}
void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis) {
if(vAxis == std::vector<int>({0, 2, 1, 3}))
Transpose0213<false>(out, in);
+#if MKL_FOUND
+ else if(vAxis.size() == 4 && vAxis[3] == 3)
+ TransposeFirst3In4<false>(out, in, vAxis);
+#endif // MKL_FOUND
else if(vAxis == std::vector<int>({1, 0}) && in->shape()[-1] % 16 == 0
&& in->shape()[-2] % 16 == 0)
Transpose10(out, in);
@@ -271,64 +378,120 @@ void TransposeNDGrad(Tensor out, Tensor in, const std::vector<int>& vAxis) {
TransposeGeneric<true>(out, in, vAxis);
}
+template <typename ElementType>
void Softmax(Tensor out, Tensor in) {
- float* pOut = out->data();
- const float* pIn = in->data();
+ using namespace functional;
+ functional::Tensor<ElementType> fout = out;
+ const functional::Tensor<ElementType> fin = in;
- int rows = out->shape().elements() / out->shape().back();
- int cols = out->shape().back();
+ ElementType* pOut = fout.data();
+ const ElementType* pIn = fin.data();
+
+ int rows = fout.shape().elements() / fout.shape().back();
+ int cols = fout.shape().back();
for(int j = 0; j < rows; ++j) {
- float* so = pOut + j * cols;
- const float* sp = pIn + j * cols;
+ ElementType* so = pOut + j * cols;
+ const ElementType* sp = pIn + j * cols;
- float max = sp[0];
- for(int i = 1; i < cols; ++i)
- max = std::max(max, sp[i]);
+ ElementType max = sp[0];
+ for(int i = 1; i < cols; ++i) {
+ max = Ops<ElementType>::max(max, sp[i]);
+ }
- float sum = 0.f;
+ // if ElementType is a complex type, e.g. float32x8, find the max of these 8 values
+ typename Ops<ElementType>::Single maxs = Ops<ElementType>::maxReduce(max);
+
+ ElementType sum = 0.f;
for(int i = 0; i < cols; ++i) {
- float ex = expf(sp[i] - max);
+ ElementType ex = Ops<ElementType>::exp(Ops<ElementType>::sub(sp[i], maxs));
+ sum = Ops<ElementType>::add(sum, ex);
so[i] = ex;
- sum += ex;
}
+ // if ElementType is a complex type, e.g. float32x8, sum these 8 values
+ typename Ops<ElementType>::Single sums = Ops<ElementType>::sumReduce(sum);
+
for(int i = 0; i < cols; ++i) {
- so[i] /= sum;
+ so[i] = Ops<ElementType>::div(so[i], sums);
}
}
}
+
+void Softmax(Tensor out, Tensor in) {
+ matchOrAbort<float>(out->type());
+ matchOrAbort<float>(in->type());
+
+#ifdef __AVX__
+ if(out->shape()[-1] % 8 == 0) {
+ Softmax<float32x8>(out, in);
+ return;
+ }
+#endif
+ if(out->shape()[-1] % 4 == 0) {
+ Softmax<float32x4>(out, in);
+ } else {
+ Softmax<float>(out, in);
+ }
+}
+
+
+template <typename ElementType>
void LogSoftmax(Tensor out, Tensor in) {
- float* pOut = out->data();
- const float* pIn = in->data();
- int rows = out->shape().elements() / out->shape().back();
- int cols = out->shape().back();
+ using namespace functional;
+ functional::Tensor<ElementType> fout = out;
+ const functional::Tensor<ElementType> fin = in;
+
+ ElementType* pOut = fout.data();
+ const ElementType* pIn = fin.data();
+
+ int rows = fout.shape().elements() / fout.shape().back();
+ int cols = fout.shape().back();
for(int j = 0; j < rows; ++j) {
- float* so = pOut + j * cols;
- const float* sp = pIn + j * cols;
+ ElementType* so = pOut + j * cols;
+ const ElementType* sp = pIn + j * cols;
- float max = sp[0];
+ ElementType max = sp[0];
for(int i = 1; i < cols; ++i) {
- max = std::max(max, sp[i]);
+ max = Ops<ElementType>::max(max, sp[i]);
}
+ typename Ops<ElementType>::Single maxs = Ops<ElementType>::maxReduce(max); // global maximum
- float sum = 0.f;
+ ElementType sum = 0.f;
for(int i = 0; i < cols; ++i) {
- float sm = sp[i] - max;
- float ex = expf(sm);
+ ElementType sm = Ops<ElementType>::sub(sp[i], maxs);
+ sum = Ops<ElementType>::add(sum, Ops<ElementType>::exp(sm));
so[i] = sm;
- sum += ex;
}
+ typename Ops<ElementType>::Single sums = Ops<ElementType>::sumReduce(sum); // global sum
+ ElementType logSum = Ops<ElementType>::log(sums); // broadcasts Single to ElementType
for(int i = 0; i < cols; ++i) {
- so[i] -= logf(sum);
+ so[i] = Ops<ElementType>::sub(so[i], logSum);
}
}
}
+void LogSoftmax(Tensor out, Tensor in) {
+ matchOrAbort<float>(out->type());
+ matchOrAbort<float>(in->type());
+
+#ifdef __AVX__
+ if(out->shape()[-1] % 8 == 0) {
+ LogSoftmax<float32x8>(out, in);
+ return;
+ }
+#endif
+ if(out->shape()[-1] % 4 == 0) {
+ LogSoftmax<float32x4>(out, in);
+ } else {
+ LogSoftmax<float>(out, in);
+ }
+}
+
// @TODO: Remove remaining underscores in CPU kernels
void SoftmaxGrad(Tensor grad_, Tensor adj_, Tensor val_) {
int rows = grad_->shape().elements() / grad_->shape()[-1];
@@ -387,6 +550,7 @@ void CopyRows(Tensor out_,
size_t cols = in_->shape()[-1];
size_t rows = indices->size();
+ // note: may also be applied to IndexType; works by luck. Fix with fp16
float* out = out_->data();
const float* in = in_->data();
@@ -394,7 +558,7 @@ void CopyRows(Tensor out_,
for(size_t j = 0; j < rows; ++j) {
size_t dst = j;
- // @TODO: consider moving type checking to this function
+ // @TODO: consider moving type checking to this function
// instead of matchOrAbort above
size_t src = (size_t)indices->data<IndexType>()[j];
@@ -480,6 +644,37 @@ void PasteCols(Tensor out_,
}
}
+// Optimized version of Select for axis=2
+// @TODO: make this generally fast without this special version
+void SelectAxis2(Tensor out,
+ const Tensor in,
+ const Tensor indices) {
+
+ matchOrAbort<IndexType>(indices->type());
+
+ functional::Shape outShape = out->shape();
+ functional::Shape inShape = in->shape();
+
+ auto idxData = indices->data<IndexType>();
+ auto odata = out->data();
+ const auto idata = in->data();
+
+ int size = outShape[3];
+
+ for(int k = 0; k < outShape[0]; ++k) {
+ for(int j = 0; j < outShape[1]; ++j) {
+ int outOffset = k * j * outShape[2] * size + j * outShape[2] * size;
+ int inOffset = k * j * inShape[2] * size + j * inShape[2] * size;
+ for(int i = 0; i < outShape[2]; ++i) {
+ auto idx = idxData[i];
+ int outIndex = outOffset + i * size;
+ int inIndex = inOffset + idx * size;
+ std::copy(idata + inIndex, idata + inIndex + size, odata + outIndex);
+ }
+ }
+ }
+}
+
void Select(Tensor out,
const Tensor in,
const Tensor indices,
@@ -489,19 +684,23 @@ void Select(Tensor out,
// @TODO: make this efficient
functional::Shape outShape = out->shape();
- functional::Shape inShape = in->shape();
+ functional::Shape inShape = in->shape();
+ functional::Shape idxShape = indices->shape();
int length = outShape.elements();
functional::Array<int, functional::Shape::size()> dims;
int axisCPU = (int)(axis + functional::Shape::size() - out->shape().size());
-
+
+ if(axisCPU == 2 && outShape == idxShape) // specialization for axis==2 when there is no broadcasting, @TODO to be removed once we have a faster implementation below
+ return SelectAxis2(out, in, indices);
+
for(int index = 0; index < length; ++index) {
- outShape.dims(index, dims);
- dims[axisCPU] = (int)indices->data<IndexType>()[dims[axisCPU]];
- int inIndex = inShape.index(dims);
- out->data()[index] = in->data()[inIndex];
+ outShape.dims(index, dims); // compute dimension-based indices from global index;
+ int idxIndex = idxShape.bindex(dims); // return global index for indices based on dimension-specific indices from out, take broadcasting into account;
+ dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex]; // substitute index of out-tensor with corresponding axis-local position from in-tensor;
+ int inIndex = inShape.index(dims); // compute global index from dimension-specific indices, no broadcasting as out and in match in all dimensions apart from axis
+ out->data()[index] = in->data()[inIndex]; // assign corresponding values.
}
-
}
void Insert(Tensor out,
@@ -513,7 +712,8 @@ void Insert(Tensor out,
// @TODO: make this efficient
functional::Shape outShape = out->shape();
- functional::Shape inShape = in->shape();
+ functional::Shape inShape = in->shape();
+ functional::Shape idxShape = indices->shape();
int length = inShape.elements();
functional::Array<int, functional::Shape::size()> dims;
@@ -521,7 +721,8 @@ void Insert(Tensor out,
for(int index = 0; index < length; ++index) {
inShape.dims(index, dims);
- dims[axisCPU] = (int)indices->data<IndexType>()[dims[axisCPU]];
+ int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
+ dims[axisCPU] = (int)indices->data<IndexType>()[idxIndex];
int outIndex = outShape.index(dims);
out->data()[outIndex] += in->data()[index];
}
@@ -550,11 +751,11 @@ void GRUFastForward(Tensor out_, std::vector<Tensor> inputs, bool final) {
#pragma omp simd
for(int i = 0; i < cols; ++i) {
- float r = stableSigmoid(xWrow[i] + sUrow[i] + b[i]);
+ float r = functional::Ops<float>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
int k = i + cols;
- float z = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ float z = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
int l = i + 2 * cols;
float h;
@@ -606,8 +807,8 @@ void GRUFastBackward(std::vector<Tensor> outputs,
int k = i + cols;
int l = i + 2 * cols;
- float r = stableSigmoid(rowXW[i] + rowSU[i] + b[i]);
- float z = stableSigmoid(rowXW[k] + rowSU[k] + b[k]);
+ float r = functional::Ops<float>::sigmoid(rowXW[i] + rowSU[i] + b[i]);
+ float z = functional::Ops<float>::sigmoid(rowXW[k] + rowSU[k] + b[k]);
float h;
if(final)
@@ -661,57 +862,52 @@ void GRUFastBackward(std::vector<Tensor> outputs,
}
}
-void CrossEntropyPick(Tensor out_, Tensor in_, Tensor pick_) {
- matchOrAbort<IndexType>(pick_->type());
+void CrossEntropyPick(Tensor out, Tensor in, Tensor labelIndices) {
+ matchOrAbort<IndexType>(labelIndices->type());
- float* out = out_->data();
// Shape& outShape = out_->shape();
- const float* in = in_->data();
- Shape& inShape = in_->shape();
+ Shape& inShape = in->shape();
int rows = inShape.elements() / inShape.back();
int cols = inShape.back();
-#pragma omp parallel for
+ #pragma omp parallel for
for(int j = 0; j < rows; ++j) {
- const float* sp = in + j * cols;
+ const float* sp = in->data() + j * cols;
float max = sp[0];
-#pragma omp simd reduction(max : max)
+ #pragma omp simd reduction(max : max)
for(int i = 1; i < cols; ++i) {
max = std::max(max, sp[i]);
}
float sum = 0.f;
-#pragma omp simd reduction(+ : sum)
+ #pragma omp simd reduction(+ : sum)
for(int i = 0; i < cols; ++i) {
sum += std::exp(sp[i] - max);
}
- // cross-entropy
- int i = (int)pick_->data<IndexType>()[j];
+ // Groundtruth label index
+ IndexType i = labelIndices->data<IndexType>()[j];
// This appears to be safe i.e. that i >= 0 && i < cols is known
- out[j] = std::log(sum) - sp[i] + max;
+ out->data()[j] = std::log(sum) - sp[i] + max; // -log(p_i) = - logsoftmax(x_i - max) = - (x_i - max) - log(sum_j exp(x_j - max))
}
}
-void CrossEntropyPickBackward(Tensor out_,
- Tensor adj_,
- Tensor a,
- Tensor pick_) {
+void CrossEntropyPickBackward(Tensor out,
+ Tensor adj,
+ Tensor in,
+ Tensor labelIndices) {
- matchOrAbort<IndexType>(pick_->type());
- float* out = out_->data();
- Shape& outShape = out_->shape();
- const float* adj = adj_->data();
- const float* in = a->data();
+ matchOrAbort<IndexType>(labelIndices->type());
+ Shape& outShape = out->shape();
int rows = outShape.elements() / outShape.back();
int cols = outShape.back();
#pragma omp parallel for
for(int j = 0; j < rows; ++j) {
- const float* sp = in + j * cols;
- float* so = out + j * cols;
+ const float* sp = in->data() + j * cols;
+ float* so = out->data() + j * cols;
float max = sp[0];
for(int i = 1; i < cols; ++i) {
@@ -725,13 +921,14 @@ void CrossEntropyPickBackward(Tensor out_,
// cross-entropy
for(int i = 0; i < cols; ++i) {
- float sub = (float)(i == (int)pick_->data<IndexType>()[j]);
- so[i] += adj[j] * (std::exp(sp[i] - max) / sum - sub);
+ float sub = (float)(i == (int)labelIndices->data<IndexType>()[j]); // delta, true if label index and column index match
+ auto softmax = std::exp(sp[i] - max) / sum;
+ so[i] += adj->data()[j] * (softmax - sub);
}
}
}
-float L2Norm(Tensor in) {
+float L2Norm(Tensor in, Ptr<Allocator> /*not used*/) {
float sum = 0.f;
size_t size = in->size();
const float* data = in->data();
@@ -851,7 +1048,7 @@ void LayerNormalization(Tensor out_,
sqSum += ex * ex;
}
- float sigma = std::sqrt(eps + sqSum / cols);
+ float sigma = std::sqrt(sqSum / cols + eps);
#pragma omp simd
for(int i = 0; i < cols; ++i) {
@@ -913,7 +1110,7 @@ void LayerNormalizationGrad(Tensor gradX_,
sum_sqr += ex * ex;
}
- float sigma = std::sqrt(eps + sum_sqr / cols);
+ float sigma = std::sqrt(sum_sqr / cols + eps);
#pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float grad_x = 0.f;
@@ -955,7 +1152,7 @@ void LayerNormalizationGrad(Tensor gradX_,
sum_sqr += ex * ex;
}
- float sigma = std::sqrt(eps + sum_sqr / cols);
+ float sigma = std::sqrt(sum_sqr / cols + eps);
#pragma omp simd
for(size_t i = 0; i < cols; ++i) {
float grad_x = 0.f;
@@ -1049,10 +1246,10 @@ void LSTMCellForward(Tensor out_, std::vector<Tensor> inputs) {
const float* sUrow = sU + j * cols * 4;
for(int i = 0; i < cols; ++i) {
- float gf = stableSigmoid(xWrow[i] + sUrow[i] + b[i]);
+ float gf = functional::Ops<float>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
int k = i + cols;
- float gi = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ float gi = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
int l = i + 2 * cols;
float gc = std::tanh(xWrow[l] + sUrow[l] + b[l]);
@@ -1082,7 +1279,7 @@ void LSTMOutputForward(Tensor out_, std::vector<Tensor> inputs) {
for(int i = 0; i < cols; ++i) {
int k = i + 3 * cols;
- float go = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ float go = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
rowOut[i] = go * std::tanh(rowCell[i]);
}
@@ -1122,10 +1319,10 @@ void LSTMCellBackward(std::vector<Tensor> outputs,
const float* rowAdj = adj + j * cols;
for(int i = 0; i < cols; ++i) {
- float gf = stableSigmoid(xWrow[i] + sUrow[i] + b[i]);
+ float gf = functional::Ops<float>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
int k = i + cols;
- float gi = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ float gi = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
int l = i + 2 * cols;
float gc = std::tanh(xWrow[l] + sUrow[l] + b[l]);
@@ -1207,7 +1404,7 @@ void LSTMOutputBackward(std::vector<Tensor> outputs,
for(int i = 0; i < cols; ++i) {
int k = i + 3 * cols;
- float go = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ float go = functional::Ops<float>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
float t = std::tanh(rowCell[i]);
@@ -1233,39 +1430,26 @@ void LSTMOutputBackward(std::vector<Tensor> outputs,
}
}
-// void HighwayForward(Tensor out,
-// const Tensor in1,
-// const Tensor in2,
-// const Tensor t) {
-// size_t length = out->shape().elements();
-// for(size_t i = 0; i < length; ++i) {
-// float sigma = stableSigmoid(t->data()[i]);
-// out->data()[i] = sigma * in1->data()[i] + (1.f - sigma) * in2->data()[i];
-// }
-//}
-
void HighwayForward(Tensor out,
- const Tensor in1,
- const Tensor in2,
- const Tensor t) {
- size_t length = out->shape().elements();
-
- static functional::Approx<10, 0, 100> approxSigmoid(stableSigmoid);
-
- for(size_t i = 0; i < length; ++i) {
- float sigma = approxSigmoid(t->data()[i]);
- out->data()[i] = sigma * in1->data()[i] + (1.f - sigma) * in2->data()[i];
- }
+ const Tensor in1,
+ const Tensor in2,
+ const Tensor t) {
+ using namespace functional;
+ cpu::Element(_1 = sigmoid(_2), out, t);
+ cpu::Element(_1 = _1 * _2 + (1.f - _1) * _3, out, in1, in2);
}
-void HighwayBackward(Tensor /*out1*/,
- Tensor /*out2*/,
- Tensor /*outt*/,
- const Tensor /*in1*/,
- const Tensor /*in2*/,
- const Tensor /*t*/,
- const Tensor /*adj*/) {
- ABORT("Not implemented!");
+void HighwayBackward(Tensor out1,
+ Tensor out2,
+ Tensor outt,
+ const Tensor in1,
+ const Tensor in2,
+ const Tensor t,
+ const Tensor adj) {
+ using namespace functional;
+ cpu::Element(_1 += sigmoid(_2) * _3, out1, t, adj);
+ cpu::Element(_1 += (1.f - sigmoid(_2)) * _3, out2, t, adj);
+ cpu::Element(_1 += sigmoid(_2) * (1.f - sigmoid(_2)) * (_3 - _4) * _5, outt, t, in1, in2, adj);
}
void PoolingWithMaskingForward(Tensor /*out*/,
diff --git a/src/tensors/device.h b/src/tensors/device.h
index 68a9492e..0be6c076 100755..100644
--- a/src/tensors/device.h
+++ b/src/tensors/device.h
@@ -21,7 +21,7 @@ protected:
public:
Device(DeviceId deviceId, size_t alignment = 256)
- : deviceId_(deviceId), data_(0), size_(0), alignment_(alignment) {}
+ : deviceId_(deviceId), alignment_(alignment) {}
virtual ~Device(){};
diff --git a/src/tensors/gpu/add.cu b/src/tensors/gpu/add.cu
index 2431948e..30fc115d 100644..100755
--- a/src/tensors/gpu/add.cu
+++ b/src/tensors/gpu/add.cu
@@ -1,9 +1,5 @@
-/* All or part of this file was contributed by Intel under license:
- * Copyright (C) 2017-2018 Intel Corporation
- * SPDX-License-Identifier: MIT
- */
-
#include "tensors/gpu/add.h"
+#include "tensors/gpu/add_all.h"
#include "tensors/gpu/cuda_helpers.h"
@@ -16,12 +12,14 @@ namespace marian {
namespace gpu {
-template <size_t K, class Functor>
-__global__ void gAddGeneric(Functor functor,
- const functional::Shape full,
- functional::Tensor<float> out,
- functional::Array<functional::Tensor<float>, K> ins,
- float scale = 1.0) {
+template <size_t K, class Functor, class AggFunctor, typename T, typename AccType>
+__global__ void gAggregateGeneric(Functor functor, // functor applied to single corresponding elements in tensors (via broadcasting),
+ AccType aggInit, // aggInit is starting value of accumulation (e.g. 0 for sum),
+ AggFunctor aggFunctor, // aggFunctor is used to accumulate values (e.g. sum),
+ const functional::Shape full, // maximal combined shape of all tensors via broadcasting
+ functional::Tensor<T> out, // output tensor
+ functional::Array<functional::Tensor<T>, K> ins, // input tensors
+ AccType scale = 1.0) { // scale accumulation result by scale. e.g. used for computing mean from sum over N elements with scale 1/N
int outLength = out.shape().elements();
bool same = outLength == full.elements();
for(int i = 0; i < K; ++i)
@@ -37,21 +35,21 @@ __global__ void gAddGeneric(Functor functor,
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < outLength) {
if(same) {
- out[index] += functional::apply(functor, ins, index) * scale;
+ out[index] = (T)aggFunctor((AccType)out[index], functional::applyWithCast<AccType>(functor, ins, index) * scale); // apply functors to with arguments cast to AccType
} else {
out.shape().dims(index, dims);
- out[index] += functional::loops(functor, ins, len, dims) * scale;
+ out[index] = (T)aggFunctor((AccType)out[index], functional::loops(functor, aggInit, aggFunctor, ins, len, dims) * scale); // apply functors to with arguments cast to AccType
}
}
}
}
-template <size_t K, class Functor>
-__global__ void gAddEqual(Functor functor,
- functional::Tensor<float> out,
- functional::Array<functional::Tensor<float>, K> ins,
- float scale,
- bool broadcast) {
+template <size_t K, class Functor, class AggFunctor, typename T, typename AccType>
+__global__ void gAggregateEqual(Functor functor, AggFunctor aggFunctor,
+ functional::Tensor<T> out,
+ functional::Array<functional::Tensor<T>, K> ins,
+ AccType scale,
+ bool broadcast) {
int length = out.shape().elements();
functional::Array<int, functional::Shape::size()> dims;
@@ -67,40 +65,42 @@ __global__ void gAddEqual(Functor functor,
indices[i] = ins[i].shape().bindex(dims);
}
- out[index] += functional::apply(functor, ins, indices) * scale;
+ out[index] = (T)aggFunctor((AccType)out[index], functional::applyWithCast<AccType>(functor, ins, indices) * scale);
}
}
}
-template <size_t K, class Functor>
-__global__ void gAddReduce(Functor functor,
- const functional::Shape full,
- functional::Tensor<float> out,
- functional::Array<functional::Tensor<float>, K> ins,
- float scale = 1.0) {
+template <size_t K, class Functor, class AggFunctor, typename T, typename AccType = float>
+__global__ void gAggregateReduce(Functor functor, AccType aggInit, AggFunctor aggFunctor,
+ const functional::Shape full,
+ functional::Tensor<T> out,
+ functional::Array<functional::Tensor<T>, K> ins,
+ AccType scale = 1.0) {
int rows = full.elements() / full.back();
int cols = full.back();
- bool same = true;
+ bool same = true; // do all inputs have the same number of elements?
for(int i = 0; i < K; ++i)
same = same && ins[i].shape().elements() == full.elements();
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- extern __shared__ float _share[];
- float* _sum = _share + blockDim.x;
+ // make sure shared memory is the same for different types
+ // by using bytes instead of type T
+ extern __shared__ uint8_t _sharedBytes[];
+ AccType* _sum = (AccType*)_sharedBytes;
if(same) {
- _sum[threadIdx.x] = 0;
+ _sum[threadIdx.x] = aggInit;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols)
- _sum[threadIdx.x] += functional::apply(functor, ins, j * cols + id);
+ _sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], functional::applyWithCast<AccType>(functor, ins, j * cols + id)); // casts to AccType before applying functor which then performs operation in AccType
}
} else {
functional::Array<int, functional::Shape::size()> dims;
- _sum[threadIdx.x] = 0;
+ _sum[threadIdx.x] = aggInit;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
@@ -109,7 +109,7 @@ __global__ void gAddReduce(Functor functor,
functional::Array<int, K> indices;
for(int i = 0; i < K; ++i)
indices[i] = ins[i].shape().bindex(dims);
- _sum[threadIdx.x] += functional::apply(functor, ins, indices);
+ _sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], functional::applyWithCast<AccType>(functor, ins, indices));// casts to AccType before applying functor which then performs operation in AccType
}
}
}
@@ -119,18 +119,20 @@ __global__ void gAddReduce(Functor functor,
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1)) {
- _sum[threadIdx.x] += _sum[threadIdx.x + skip];
+ _sum[threadIdx.x] = aggFunctor(_sum[threadIdx.x], _sum[threadIdx.x + skip]);
}
len = (len + 1) >> 1;
}
__syncthreads();
- out[j] += _sum[0] * scale;
+ if(threadIdx.x == 0) // only set value when in thread 0 in block
+ out[j] = aggFunctor(out[j], (T)(_sum[0] * scale));
}
+ __syncthreads();
}
}
-template <class Functor, class... Tensors>
-void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
+template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
+void AggregateTyped(Functor functor, AccType aggInit, AggFunctor aggFunctor, AccType scale, marian::Tensor out, Tensors... tensors) {
cudaSetDevice(out->getDeviceId().no);
auto full = marian::Shape::broadcast({out, tensors...});
@@ -139,37 +141,56 @@ void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
constexpr size_t K = sizeof...(Tensors);
- functional::Tensor<float> gOut = out;
- functional::Array<functional::Tensor<float>, K> gIns = {tensors...};
+ functional::Tensor<T> gOut = out;
+ functional::Array<functional::Tensor<T>, K> gIns = {tensors...};
- if(full.back() != 1 && out->shape().back() == 1) {
- size_t m = full.elements() / length;
- size_t k = full.back();
+ if(out->shape().elements() == 1) { // reduce everything into a single element
+ AggregateAll<T, AccType>(nullptr, functor, aggInit, aggFunctor, scale, out, tensors...); // @TODO: pass allocator in here, currently uses cudaMalloc
+ } else if(full.back() != 1 && out->shape().back() == 1 && full.elements() / full.back() == length) { // element number of out and full shape on axis that are not reduced must match
+ size_t m = full.elements() / full.back(); // how many rows are we iterating over?
+ size_t k = full.back(); // how many columns are being reduced to 1 in each row?
- int blocks = std::min(MAX_BLOCKS, (int)m);
+ int blocks = std::min(MAX_BLOCKS, (int)m);
int threads = std::min(MAX_THREADS, (int)k);
- int shared = sizeof(float) * threads * 2;
-
- gAddReduce<<<blocks, threads, shared>>>(functor, full, gOut, gIns, scale);
-
+ int shared = sizeof(AccType) * threads;
+ gAggregateReduce<K, Functor, AggFunctor, T, AccType><<<blocks, threads, shared>>>(functor, aggInit, aggFunctor, full, gOut, gIns, scale);
} else if(out->shape() == full) {
int threads = std::min(MAX_THREADS, length);
- int blocks
- = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+ int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
bool broadcast = false;
for(int i = 0; i < K; ++i)
broadcast = broadcast || gOut.shape() != gIns[i].shape();
- gAddEqual<<<blocks, threads>>>(functor, gOut, gIns, scale, broadcast);
+ gAggregateEqual<<<blocks, threads>>>(functor, aggFunctor, gOut, gIns, scale, broadcast);
} else {
int threads = std::min(MAX_THREADS, length);
- int blocks
- = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+ int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
- gAddGeneric<<<blocks, threads>>>(functor, full, gOut, gIns, scale);
+ gAggregateGeneric<<<blocks, threads>>>(functor, aggInit, aggFunctor, full, gOut, gIns, scale);
}
}
+template <class Functor, class AggFunctor, class... Tensors>
+void Aggregate(Functor functor, float aggInit, AggFunctor aggFunctor, float scale, marian::Tensor out, Tensors... tensors) {
+ if(out->type() == Type::float32) {
+ AggregateTyped<float, float>(functor, aggInit, aggFunctor, scale, out, tensors...);
+ } else if(out->type() == Type::float16) {
+#if COMPILE_FP16
+ AggregateTyped<half, float>(functor, aggInit, aggFunctor, scale, out, tensors...);
+#else
+ ABORT("FP16 not supported with current hardware or CUDA version");
+#endif
+ } else {
+ ABORT("Type {} not yet supported", out->type());
+ }
+}
+
+template <class Functor, class... Tensors>
+void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
+ auto addFunctor = functional::_1 + functional::_2;
+ Aggregate(functor, 0.f, addFunctor, scale, out, tensors...);
+}
+
#include "tensors/gpu/add.inc"
} // namespace gpu
} // namespace marian
diff --git a/src/tensors/gpu/add.h b/src/tensors/gpu/add.h
index e5e22d88..33ef0966 100644..100755
--- a/src/tensors/gpu/add.h
+++ b/src/tensors/gpu/add.h
@@ -8,5 +8,8 @@ namespace gpu {
template <class Functor, class... Tensors>
void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors);
+
+template <class Functor, class AggFunctor, class... Tensors>
+void Aggregate(Functor functor, float initAgg, AggFunctor aggFunctor, float scale, marian::Tensor out, Tensors... tensors);
}
} // namespace marian
diff --git a/src/tensors/gpu/add.inc b/src/tensors/gpu/add.inc
index 27f35b95..98723b9d 100644..100755
--- a/src/tensors/gpu/add.inc
+++ b/src/tensors/gpu/add.inc
@@ -1,3 +1,4 @@
+// see element.inc for instructions on how to maintain this
using namespace functional;
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
@@ -14,11 +15,21 @@ template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
-template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2> >, float, std::shared_ptr<marian::TensorBase>, marian::Tensor, marian::Tensor);
+template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
-template void Add<BinaryFunctor<elem::Div, Assignee<1>, Assignee<2> >, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase> >(BinaryFunctor<elem::Div, Assignee<1>, Assignee<2> >, float, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>, std::shared_ptr<marian::TensorBase>);
+template void Add<BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Add<BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, marian::Tensor >(BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, float, marian::Tensor, marian::Tensor);
+template void Aggregate<Assignee<1>, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>, marian::Tensor >(Assignee<1>, float, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void Aggregate<Assignee<1>, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>, marian::Tensor >(Assignee<1>, float, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void Aggregate<Assignee<1>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, marian::Tensor >(Assignee<1>, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void Aggregate<Assignee<1>, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>, marian::Tensor >(Assignee<1>, float, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Add<BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Add<BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Add<BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Add<BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor >(BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor); \ No newline at end of file
diff --git a/src/tensors/gpu/add_all.cu b/src/tensors/gpu/add_all.cu
new file mode 100644
index 00000000..ad3ac252
--- /dev/null
+++ b/src/tensors/gpu/add_all.cu
@@ -0,0 +1,116 @@
+#include "tensors/gpu/add_all.h"
+#include "tensors/gpu/cuda_helpers.h"
+#include "functional/functional.h"
+#include "tensors/tensor_operators.h"
+#include "3rd_party/reduce_all.h" // only works with CUDA >9.0, we are dropping CUDA 8.0 support, also changed in CMakeLists.txt
+
+namespace marian {
+
+#if COMPILE_FP16
+// local overload to determine tensor type
+template <> inline Type typeId<half>() { return Type::float16; }
+#endif
+
+// Version with variadic template arguments, called by version with explicit arguments below
+template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
+void AggregateAllVar(Ptr<Allocator> allocator,
+ Functor functor,
+ AccType aggInit,
+ AggFunctor aggFunctor,
+ AccType scale,
+ Tensor out,
+ const Tensors... tensors) {
+ cudaSetDevice(out->getDeviceId().no);
+
+ static_assert(CUDA_VERSION >= 9000, "Marian requires CUDA_VERSION >= 9000 (9.0)");
+
+ constexpr size_t K = sizeof...(Tensors); // obtain arity K of tensors...
+ functional::Array<functional::Tensor<T>, K> gIns = {tensors...}; // convert to array of K objects of type functional::Tensor<T>
+ functional::Shape full = marian::Shape::broadcast({tensors...}); // compute maximal broadcasted shape
+
+ int size = full.elements();
+ int threads = (size < MAX_THREADS * 2) ? nextPow2((size + 1) / 2) : MAX_THREADS; // suggested in NVidia example for the all_reduce kernel
+ int blocks = std::min(MAX_BLOCKS, (size + (threads * 2 - 1)) / (threads * 2)); // suggested in NVidia example for the all_reduce kernel
+
+ // The all_reduce kernel by nivida needs to perform multiple passes if the number of blocks needed to perform the reduction is larger than 1.
+ // Here we allocate the memory for the intermediate reductions for each block.
+ Tensor blockMem;
+ if(blocks > 1 || out->type() != typeId<AccType>()) { // if the out tensor does not have elementType AccType we need to allocate and convert later
+ MemoryPiece::PtrType temporaryMemory;
+ if(allocator) {
+ temporaryMemory = allocator->alloc<AccType>(blocks);
+ } else { // @TODO: get rid of this branch
+ uint8_t* temporaryMemoryPtr = 0;
+ CUDA_CHECK(cudaMalloc(&temporaryMemoryPtr, sizeof(AccType) * blocks));
+ temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType) * blocks); // @TODO: consider implementing MemoryPiece::cudaMalloc<T>(size) for managed memory
+ }
+ blockMem = TensorBase::New(temporaryMemory,
+ Shape({blocks}),
+ typeId<AccType>(),
+ out->getBackend());
+ blockMem->set(aggInit); // set temporary memory to aggInit
+ }
+ else { // we are reducing into a single element now and the type matches, just use out as memory
+ blockMem = out; // do not set final output memory as we might be summing gradients... needs to be handled outside this function
+ }
+
+ functional::Tensor<AccType> gBlockMem = blockMem;
+ reduceSinglePass<T, AccType>(functor, aggInit, aggFunctor, scale, full, /*out=*/gBlockMem, /*in=*/gIns, threads, blocks); // First pass reduction into intermediate memory
+
+ // If we actually needed more than one block to perform the first pass reduction, recursively run a second pass reduction over block memory until block memory has size 1.
+ if(blocks > 1) {
+ using namespace functional;
+ auto identity = _1; // transformation was done in first pass, hence only identity
+ AggregateAll<AccType, AccType>(allocator, identity, aggInit, aggFunctor, scale, out, /*in=*/blockMem); // Reducing AccType in AccType now (meta-reduction)
+ } else if(out->type() != typeId<AccType>()) { // it's only a single block, but we need to convert to different type, as mentioned above
+ CopyCast(out, blockMem);
+ }
+
+ if(blockMem != out) {
+ // Free temporary memory whether allocated in allocator or via cudaMalloc
+ if(allocator)
+ allocator->free(blockMem->memory());
+ else if(blockMem->memory()->data())
+ CUDA_CHECK(cudaFree(blockMem->memory()->data()));
+ }
+}
+
+template <typename T, typename AccType, class Functor, class AggFunctor>
+void AggregateAll(Ptr<Allocator> allocator,
+ Functor functor,
+ AccType aggInit,
+ AggFunctor aggFunctor,
+ AccType scale,
+ Tensor out,
+ const Tensor in1) {
+ AggregateAllVar<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, in1);
+}
+
+template <typename T, typename AccType, class Functor, class AggFunctor>
+void AggregateAll(Ptr<Allocator> allocator,
+ Functor functor,
+ AccType aggInit,
+ AggFunctor aggFunctor,
+ AccType scale,
+ Tensor out,
+ const Tensor in1,
+ const Tensor in2) {
+ AggregateAllVar<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, in1, in2);
+}
+
+template <typename T, typename AccType, class Functor, class AggFunctor>
+void AggregateAll(Ptr<Allocator> allocator,
+ Functor functor,
+ AccType aggInit,
+ AggFunctor aggFunctor,
+ AccType scale,
+ Tensor out,
+ const Tensor in1,
+ const Tensor in2,
+ const Tensor in3) {
+ AggregateAllVar<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, in1, in2, in3);
+}
+
+#include "tensors/gpu/add_all.inc"
+
+} \ No newline at end of file
diff --git a/src/tensors/gpu/add_all.h b/src/tensors/gpu/add_all.h
new file mode 100644
index 00000000..2e37fd49
--- /dev/null
+++ b/src/tensors/gpu/add_all.h
@@ -0,0 +1,87 @@
+#pragma once
+
+// This header file provides wrappers around NVidia's reduce_all kernel with our custom aggregation functionality
+// This kernel reduces a tensor into a single value. We have modified it to allow for different types of aggregations
+// like summing or max etc.
+
+#include "tensors/gpu/cuda_helpers.h"
+#include "tensors/tensor.h"
+#include "tensors/allocator.h"
+#include "functional/tensor.h"
+#include "tensors/tensor_operators.h"
+
+namespace marian {
+
+// These function declarations are repeated as template specialization with variadic template arguments does not seem to work.
+// Here I am just creating version for 1, 2, and 3 arguments. To be extended if required.
+template <typename T, typename AccType, class Functor, class AggFunctor>
+void AggregateAll(Ptr<Allocator> allocator,
+ Functor functor,
+ AccType aggInit,
+ AggFunctor aggFunctor,
+ AccType scale,
+ Tensor out,
+ const Tensor in1);
+
+template <typename T, typename AccType, class Functor, class AggFunctor>
+void AggregateAll(Ptr<Allocator> allocator,
+ Functor functor,
+ AccType aggInit,
+ AggFunctor aggFunctor,
+ AccType scale,
+ Tensor out,
+ const Tensor in1,
+ const Tensor in2);
+
+template <typename T, typename AccType, class Functor, class AggFunctor>
+void AggregateAll(Ptr<Allocator> allocator,
+ Functor functor,
+ AccType aggInit,
+ AggFunctor aggFunctor,
+ AccType scale,
+ Tensor out,
+ const Tensor in1,
+ const Tensor in2,
+ const Tensor in3);
+
+// Aggregates all values into a single tensor and returns the value of that tensor as a float
+// This does a GPU to CPU memory copy via TensorBase::scalar().
+// Used currently only for L2Norm computation
+template <typename T, typename AccType, class Functor, class AggFunctor, class... Tensors>
+AccType AggregateAllAndReturn(Ptr<Allocator> allocator,
+ Functor functor,
+ AccType aggInit,
+ AggFunctor aggFunctor,
+ AccType scale,
+ const Tensors... tensors) {
+ MemoryPiece::PtrType temporaryMemory;
+ if(allocator) {
+ temporaryMemory = allocator->alloc<AccType>(1);
+ } else { // @TODO: get rid of this branch
+ uint8_t* temporaryMemoryPtr = 0;
+ CUDA_CHECK(cudaMalloc(&temporaryMemoryPtr, sizeof(AccType)));
+ temporaryMemory = MemoryPiece::New(temporaryMemoryPtr, sizeof(AccType));
+ }
+
+ std::tuple<Tensors...> in(tensors...);
+
+ // Create a temporary tensor of size 1 to reduce into
+ auto out = TensorBase::New(temporaryMemory,
+ Shape({1}),
+ typeId<AccType>(),
+ std::get<0>(in)->getBackend());
+ out->set(aggInit); // init to aggInit
+
+ AggregateAll<T, AccType>(allocator, functor, aggInit, aggFunctor, scale, out, tensors...);
+
+ AccType outScalar = out->template scalar<AccType>(); // convert to float also if other underlying type
+
+ if(allocator)
+ allocator->free(out->memory());
+ else if(out->memory()->data()) // @TODO: get rid of this branch
+ CUDA_CHECK(cudaFree(out->memory()->data()));
+
+ return outScalar;
+}
+
+} \ No newline at end of file
diff --git a/src/tensors/gpu/add_all.inc b/src/tensors/gpu/add_all.inc
new file mode 100644
index 00000000..73b0bda9
--- /dev/null
+++ b/src/tensors/gpu/add_all.inc
@@ -0,0 +1,71 @@
+// see element.inc for instructions on how to maintain this
+using namespace functional;
+
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Assignee<2>, Assignee<2>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Assignee<2>, Assignee<2>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Neg, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Mult, Assignee<3>, Assignee<3>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Neg, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Mult, Assignee<3>, Assignee<3>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, UnaryFunctor<elem::Neg, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, UnaryFunctor<elem::Neg, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<3>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<3>>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<3>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<3>>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div, Capture, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div, Capture, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, Assignee<1>, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<float, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+
+#if COMPILE_FP16
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Div, Capture, Assignee<1>>>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Assignee<2>, Assignee<2>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Assignee<2>, Assignee<2>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Neg, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Mult, Assignee<3>, Assignee<3>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Neg, Assignee<1>>, Assignee<2>>, BinaryFunctor<elem::Mult, Assignee<3>, Assignee<3>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, UnaryFunctor<elem::Neg, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, UnaryFunctor<elem::Neg, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<3>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<3>>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<3>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<3>>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div, Capture, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Div, Capture, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::sPReLUBack, Assignee<2>, Capture>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::sReLUBack, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Minus, Capture, Assignee<2>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Bump, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Geq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<3>, Assignee<2>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Gt, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Leq, Assignee<2>, Assignee<3>>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, Assignee<1>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Min, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Max, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::LogAddExp, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Eq, Assignee<1>, Assignee<2>>, Assignee<3>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, UnaryFunctor<elem::Exp, BinaryFunctor<elem::Minus, Assignee<2>, Assignee<3>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<2>>, Assignee<3>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Assignee<1>, Capture>, Assignee<2>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<3>>, BinaryFunctor<elem::Mult, UnaryFunctor<elem::Sigmoid, BinaryFunctor<elem::Mult, Capture, Assignee<2>>>, BinaryFunctor<elem::Minus, Capture, BinaryFunctor<elem::Mult, Capture, Assignee<3>>>>>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, Assignee<1>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, Assignee<1>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+template void AggregateAll<__half, float, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>>(std::shared_ptr<Allocator>, BinaryFunctor<elem::Mult, Assignee<1>, Assignee<1>>, float, BinaryFunctor<elem::Plus, Assignee<1>, Assignee<2>>, float, marian::Tensor, marian::Tensor);
+#endif \ No newline at end of file
diff --git a/src/tensors/gpu/algorithm.cu b/src/tensors/gpu/algorithm.cu
index bdf66bac..926e9deb 100755..100644
--- a/src/tensors/gpu/algorithm.cu
+++ b/src/tensors/gpu/algorithm.cu
@@ -20,13 +20,12 @@ template void copy<int8_t>(Ptr<Backend>, const int8_t*, const int8_t*, int8_t*);
template void copy<int16_t>(Ptr<Backend>, const int16_t*, const int16_t*, int16_t*);
template void copy<int32_t>(Ptr<Backend>, const int32_t*, const int32_t*, int32_t*);
template void copy<int64_t>(Ptr<Backend>, const int64_t*, const int64_t*, int64_t*);
-
template void copy<uint8_t>(Ptr<Backend>, const uint8_t*, const uint8_t*, uint8_t*);
template void copy<uint16_t>(Ptr<Backend>, const uint16_t*, const uint16_t*, uint16_t*);
template void copy<uint32_t>(Ptr<Backend>, const uint32_t*, const uint32_t*, uint32_t*);
template void copy<uint64_t>(Ptr<Backend>, const uint64_t*, const uint64_t*, uint64_t*);
-
template void copy<char>(Ptr<Backend>, const char*, const char*, char*);
+template void copy<float16>(Ptr<Backend>, const float16*, const float16*, float16*);
template void copy<float>(Ptr<Backend>, const float*, const float*, float*);
template void copy<double>(Ptr<Backend>, const double*, const double*, double*);
// clang-format on
@@ -55,6 +54,23 @@ void fill(Ptr<Backend> backend, T* begin, T* end, T value) {
CUDA_CHECK(cudaStreamSynchronize(0));
}
+template <>
+void fill<float16>(Ptr<Backend> backend, float16* begin, float16* end, float16 value) {
+ int size = end - begin;
+ if (size == 0)
+ return;
+#if COMPILE_FP16
+ CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
+ int threadsPerBlock = std::min(MAX_THREADS, size);
+ int blocks = (size / threadsPerBlock) + (size % threadsPerBlock != 0); // @TODO: (size+threadsPerBlock-1)/threadsPerBlock or CeilDiv(a,b)
+ gFill<<<blocks, threadsPerBlock>>>((__half*)begin, size, (__half)value);
+ CUDA_CHECK(cudaStreamSynchronize(0));
+#else
+ ABORT("FP16 not supported with current hardware or CUDA version");
+#endif
+}
+
+template void fill<bool>(Ptr<Backend>, bool*, bool*, bool);
template void fill<int8_t>(Ptr<Backend>, int8_t*, int8_t*, int8_t);
template void fill<int16_t>(Ptr<Backend>, int16_t*, int16_t*, int16_t);
template void fill<int32_t>(Ptr<Backend>, int32_t*, int32_t*, int32_t);
@@ -84,7 +100,7 @@ __global__ void gSwap(T* d_v1, T* d_v2, int size) {
if(index < size) {
T temp = d_v1[index];
d_v1[index] = d_v2[index];
- d_v2[index] = temp;
+ d_v2[index] = temp;
}
}
@@ -93,7 +109,7 @@ void swap_ranges(Ptr<Backend> backend, T* begin, T* end, T* dest) {
int size = end - begin;
if (size == 0)
return;
-
+
CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
int threadsPerBlock = std::min(MAX_THREADS, size);
int blocks = (size / threadsPerBlock) + (size % threadsPerBlock != 0); // @TODO: (size+threadsPerBlock-1)/threadsPerBlock or CeilDiv(a,b)
@@ -101,7 +117,25 @@ void swap_ranges(Ptr<Backend> backend, T* begin, T* end, T* dest) {
CUDA_CHECK(cudaStreamSynchronize(0));
}
+template <>
+void swap_ranges<float16>(Ptr<Backend> backend, float16* begin, float16* end, float16* dest) {
+ int size = end - begin;
+ if (size == 0)
+ return;
+
+#if COMPILE_FP16
+ CUDA_CHECK(cudaSetDevice(backend->getDeviceId().no));
+ int threadsPerBlock = std::min(MAX_THREADS, size);
+ int blocks = (size / threadsPerBlock) + (size % threadsPerBlock != 0); // @TODO: (size+threadsPerBlock-1)/threadsPerBlock or CeilDiv(a,b)
+ gSwap<<<blocks, threadsPerBlock>>>((__half*)begin, (__half*)dest, size);
+ CUDA_CHECK(cudaStreamSynchronize(0));
+#else
+ ABORT("FP16 not supported with current hardware or CUDA version");
+#endif
+}
+
// clang-format off
+template void swap_ranges<char>(Ptr<Backend>, char*, char*, char*);
template void swap_ranges<int8_t>(Ptr<Backend>, int8_t*, int8_t*, int8_t*);
template void swap_ranges<int16_t>(Ptr<Backend>, int16_t*, int16_t*, int16_t*);
template void swap_ranges<int32_t>(Ptr<Backend>, int32_t*, int32_t*, int32_t*);
@@ -112,7 +146,6 @@ template void swap_ranges<uint16_t>(Ptr<Backend>, uint16_t*, uint16_t*, uint16_t
template void swap_ranges<uint32_t>(Ptr<Backend>, uint32_t*, uint32_t*, uint32_t*);
template void swap_ranges<uint64_t>(Ptr<Backend>, uint64_t*, uint64_t*, uint64_t*);
-template void swap_ranges<char>(Ptr<Backend>, char*, char*, char*);
template void swap_ranges<float>(Ptr<Backend>, float*, float*, float*);
template void swap_ranges<double>(Ptr<Backend>, double*, double*, double*);
// clang-format on
diff --git a/src/tensors/gpu/backend.h b/src/tensors/gpu/backend.h
index ab1968c1..044c344f 100755..100644
--- a/src/tensors/gpu/backend.h
+++ b/src/tensors/gpu/backend.h
@@ -1,46 +1,68 @@
#pragma once
#include "common/config.h"
-#include "tensors/backend.h" // note: this is one folder up
+#include "tensors/backend.h" // note: this is one folder up
#include "tensors/gpu/cuda_helpers.h"
+#include "common/logging.h"
-#include <cuda.h>
#include <cublas_v2.h>
+#include <cuda.h>
#include <curand.h>
+#include <cusparse.h>
namespace marian {
namespace gpu {
+// @TODO: in the future this should pobably become a fully fledged CudaInfo class with many attributes
+struct CudaCompute {
+ int major;
+ int minor;
+};
+
class Backend : public marian::Backend {
+private:
+ void setCudaComputeCapability() {
+ CUDA_CHECK(cudaDeviceGetAttribute(&compute_.major, cudaDevAttrComputeCapabilityMajor, (int)deviceId_.no));
+ CUDA_CHECK(cudaDeviceGetAttribute(&compute_.minor, cudaDevAttrComputeCapabilityMinor, (int)deviceId_.no));
+ }
+
public:
Backend(DeviceId deviceId, size_t seed) : marian::Backend(deviceId, seed) {
setDevice();
- setHandles();
+ cublasCreate(&cublasHandle_);
+ cusparseCreate(&cusparseHandle_);
+ setCudaComputeCapability();
}
~Backend() {
setDevice();
+ cusparseDestroy(cusparseHandle_);
cublasDestroy(cublasHandle_);
}
- void setDevice() override { cudaSetDevice((int)deviceId_.no); }
+ void setDevice() override { CUDA_CHECK(cudaSetDevice((int)deviceId_.no)); }
- void synchronize() override { cudaStreamSynchronize(0); }
+ void synchronize() override { CUDA_CHECK(cudaStreamSynchronize(0)); }
cublasHandle_t getCublasHandle() { return cublasHandle_; }
+ cusparseHandle_t getCusparseHandle() { return cusparseHandle_; }
-private:
- cublasHandle_t cublasHandle_;
+ CudaCompute getCudaComputeCapability() { return compute_; }
- void setHandles() {
- cublasHandle_ = create_handle();
+ // for CPU, sets to use optimized code for inference.
+ // for GPU, this is invalid. for gpu, isOptimized() function always returns false.
+ void setOptimized(bool optimize) override {
+ LOG_ONCE(info, "setOptimized() not supported for GPU_{}", optimize);
}
-
- cublasHandle_t create_handle() {
- cublasHandle_t cublasHandle;
- cublasCreate(&cublasHandle);
- return cublasHandle;
+
+ bool isOptimized() override {
+ return false;
}
+
+private:
+ cublasHandle_t cublasHandle_;
+ cusparseHandle_t cusparseHandle_;
+ CudaCompute compute_;
};
} // namespace gpu
} // namespace marian
diff --git a/src/tensors/gpu/cuda_helpers.h b/src/tensors/gpu/cuda_helpers.h
index ba890490..d51b257d 100755
--- a/src/tensors/gpu/cuda_helpers.h
+++ b/src/tensors/gpu/cuda_helpers.h
@@ -1,6 +1,22 @@
#pragma once
#include "common/logging.h"
-#include "cuda_runtime.h"
+#include "common/types.h"
+
+#include <cuda_runtime.h>
+
+#if COMPILE_FP16
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4505) // unreferenced local function has been removed
+#endif
+#include <cuda_fp16.h>
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+#endif
+
+// template <> inline bool matchType<__half>(Type type) { return type == Type::float16; }
+// template <> inline std::string request<__half>() { return "float16"; }
// fixes a missing constant in CUDA device code
#define CUDA_FLT_MAX 1.70141e+38; // note: 'static __constant__' causes a warning on gcc; non-static fails CUDA, so #define instead
@@ -13,6 +29,12 @@ const int MAX_BLOCKS = 65535;
"CUDA error {} '{}' - {}:{}: {}", rc, cudaGetErrorString(rc), __FILE__, __LINE__, #expr); \
} while(0)
+#define CUBLAS_CHECK(expr) do { \
+ cublasStatus_t rc = (expr); \
+ ABORT_IF(rc != CUBLAS_STATUS_SUCCESS, \
+ "Cublas Error: {} - {}:{}: {}", rc, __FILE__, __LINE__, #expr); \
+} while(0)
+
#define CUSPARSE_CHECK(expr) do { \
cusparseStatus_t rc = (expr); \
ABORT_IF(rc != CUSPARSE_STATUS_SUCCESS, \
diff --git a/src/tensors/gpu/device.cu b/src/tensors/gpu/device.cu
index a82b387e..25997c75 100755..100644
--- a/src/tensors/gpu/device.cu
+++ b/src/tensors/gpu/device.cu
@@ -8,12 +8,12 @@ namespace marian {
namespace gpu {
Device::~Device() {
- // Note: The CUDA_CHECKs here are not throwing, but will terminate the program.
- CUDA_CHECK(cudaSetDevice(deviceId_.no));
+ // No CUDA error checking as this is a destructor and we cannot do anything about errors anyway.
+ cudaSetDevice(deviceId_.no);
if(data_) {
- CUDA_CHECK(cudaFree(data_));
+ cudaFree(data_);
}
- CUDA_CHECK(cudaDeviceSynchronize());
+ cudaDeviceSynchronize();
}
void Device::reserve(size_t size) {
diff --git a/src/tensors/gpu/element.cu b/src/tensors/gpu/element.cu
index e48abc07..8525b71b 100644..100755
--- a/src/tensors/gpu/element.cu
+++ b/src/tensors/gpu/element.cu
@@ -4,15 +4,16 @@
#include "functional/functional.h"
#include "functional/tensor.h"
#include "functional/tmp.h"
+
#include "tensors/gpu/cuda_helpers.h"
namespace marian {
namespace gpu {
-template <size_t K, bool broadcast, class Functor>
+template <size_t K, bool broadcast, class Functor, typename T>
__global__ void gElement(
Functor functor,
- functional::Array<functional::Tensor<float>, K> tensors) {
+ functional::Array<functional::Tensor<T>, K> tensors) {
int length = tensors[0].shape().elements();
functional::Array<int, functional::Shape::size()> dims;
functional::Array<int, K> indices;
@@ -28,32 +29,53 @@ __global__ void gElement(
indices[i] = tensors[i].shape().bindex(dims);
}
- tensors[0][index] = functional::apply(functor, tensors, indices);
+ tensors[0].data()[index] = functional::apply(functor, tensors, indices);
}
}
}
-template <class Functor, class... Tensors>
-void Element(Functor functor, Tensor out, Tensors... tensors) {
- cudaSetDevice(out->getDeviceId().no);
- constexpr size_t K = sizeof...(tensors) + 1;
- functional::Array<functional::Tensor<float>, K> gTensors = {out, tensors...};
+template <typename T, class Functor, class... Tensors>
+void ElementTyped(Functor functor, Tensor out, Tensors... tensors) {
+ //matchOrAbort<T>(out->type()); // @TODO: figure out undefined reference
+
+ cudaSetDevice(out->getDeviceId().no);
- int length = gTensors[0].shape().elements();
+ int length = out->shape().elements();
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+ constexpr size_t K = sizeof...(tensors) + 1;
+ functional::Array<functional::Tensor<T>, K> gTensors = {out, tensors...};
+
bool broadcast = false;
for(int i = 1; i < K; ++i)
broadcast = broadcast || gTensors[0].shape() != gTensors[i].shape();
-
if(broadcast)
gElement<K, true><<<blocks, threads>>>(functor, gTensors);
else
gElement<K, false><<<blocks, threads>>>(functor, gTensors);
}
+template <class Functor, class... Tensors>
+void Element(Functor functor, Tensor out, Tensors... tensors) {
+ checkCommonType(out, tensors...);
+
+ if(out->type() == Type::float32) {
+ ElementTyped<float>(functor, out, tensors...);
+ } else if(out->type() == Type::float16) {
+#if COMPILE_FP16
+ ElementTyped<__half>(functor, out, tensors...);
+#else
+ ABORT("FP16 not supported with chosen current hardware or CUDA version");
+#endif
+ } else if(out->type() == Type::float64) {
+ ElementTyped<double>(functor, out, tensors...);
+ } else {
+ ABORT("Type {} not yet supported", out->type());
+ }
+}
+
#include "tensors/gpu/element.inc"
} // namespace gpu
} // namespace marian
diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc
index 50ae9793..00aff3d9 100644..100755
--- a/src/tensors/gpu/element.inc
+++ b/src/tensors/gpu/element.inc
@@ -39,12 +39,32 @@ template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunc
template void Element<Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2>>, Capture>, Capture, Capture>>, marian::Tensor>(Assign<Var<1>, TernaryFunctor<elem::IfThenElse, BinaryFunctor<elem::Leq, UnaryFunctor<elem::Abs, Assignee<2>>, Capture>, Capture, Capture>>, marian::Tensor, marian::Tensor);
template void Element<Assign<Var<1>, BinaryFunctor<elem::Clip, Assignee<2>, Capture>>, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Clip, Assignee<2>, Capture>>, marian::Tensor, marian::Tensor);
template void Element<Assign<Var<1>, BinaryFunctor<elem::LogAddExp, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::LogAddExp, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor, marian::Tensor);
-template void Element<Assign<Var<1>, BinaryFunctor<elem::Maximum, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Maximum, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor, marian::Tensor);
-template void Element<Assign<Var<1>, BinaryFunctor<elem::Minimum, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Minimum, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Max, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Max, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Min, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Min, Assignee<2>, Assignee<3>>>, marian::Tensor, marian::Tensor, marian::Tensor);
template void Element<Assign<Var<1>, BinaryFunctor<elem::Div, Assignee<1>, Capture>>>(Assign<Var<1>, BinaryFunctor<elem::Div, Assignee<1>, Capture>>, marian::Tensor);
template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor);
template void Element<Assign<Var<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Capture>>>(Assign<Var<1>, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Mult, Capture, Assignee<1>>, Capture>>, marian::Tensor);
template void Element<Assign<Var<1>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Lt, Assignee<1>, Capture>, Capture>>>(Assign<Var<1>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Lt, Assignee<1>, Capture>, Capture>>, marian::Tensor);
template void Element<Assign<Var<1>, UnaryFunctor<elem::Neg, UnaryFunctor<elem::Log, UnaryFunctor<elem::Neg, UnaryFunctor<elem::Log, Assignee<1>>>>>>>(Assign<Var<1>, UnaryFunctor<elem::Neg, UnaryFunctor<elem::Log, UnaryFunctor<elem::Neg, UnaryFunctor<elem::Log, Assignee<1>>>>>>, marian::Tensor);
template void Element<Assign<Var<1>, UnaryFunctor<elem::Neg, UnaryFunctor<elem::Log, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Neg, UnaryFunctor<elem::Log, BinaryFunctor<elem::Plus, Assignee<1>, Capture>>>, Capture>>>>>(Assign<Var<1>, UnaryFunctor<elem::Neg, UnaryFunctor<elem::Log, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Neg, UnaryFunctor<elem::Log, BinaryFunctor<elem::Plus, Assignee<1>, Capture>>>, Capture>>>>, marian::Tensor);
-template void Element<Assign<Var<1>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<1>, Capture>, Capture>>>(Assign<Var<1>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<1>, Capture>, Capture> >, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<1>, Capture>, Capture>>>(Assign<Var<1>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Lt, Assignee<1>, Capture>, Capture>>, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, UnaryFunctor<elem::sReLU, BinaryFunctor<elem::Minus, Assignee<3>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<4>>, Assignee<4>>>>, Capture>, BinaryFunctor<elem::Mult, Assignee<4>, Assignee<4>>>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, UnaryFunctor<elem::sReLU, BinaryFunctor<elem::Minus, Assignee<3>, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<4>>, Assignee<4>>>>, Capture>, BinaryFunctor<elem::Mult, Assignee<4>, Assignee<4>>>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, UnaryFunctor<elem::sReLU, BinaryFunctor<elem::Minus, Assignee<3>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<4>>, Assignee<4>>, Capture>>>, Capture>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<4>, Assignee<4>>, Capture>>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, UnaryFunctor<elem::sReLU, BinaryFunctor<elem::Minus, Assignee<3>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<4>>, Assignee<4>>, Capture>>>, Capture>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<4>, Assignee<4>>, Capture>>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<4>, Assignee<4>>, Capture>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<4>, Assignee<4>>, Capture>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, UnaryFunctor<elem::sReLU, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Minus, Assignee<3>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<5>>, Assignee<4>>, Capture>>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<4>>, Assignee<4>>, Capture>>>, Capture>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<4>, Assignee<4>>, Capture>>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, UnaryFunctor<elem::sReLU, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Minus, Assignee<3>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<5>>, Assignee<4>>, Capture>>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, BinaryFunctor<elem::Mult, Capture, Assignee<4>>, Assignee<4>>, Capture>>>, Capture>, BinaryFunctor<elem::Div, BinaryFunctor<elem::Mult, Assignee<4>, Assignee<4>>, Capture>>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void Element<Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor>(Assign<Var<1>, BinaryFunctor<elem::Minus, Assignee<1>, BinaryFunctor<elem::Mult, Capture, BinaryFunctor<elem::Plus, BinaryFunctor<elem::Div, BinaryFunctor<elem::Div, Assignee<2>, Capture>, BinaryFunctor<elem::Plus, UnaryFunctor<elem::Sqrt, BinaryFunctor<elem::Div, Assignee<3>, Capture>>, Capture>>, BinaryFunctor<elem::Mult, Capture, Assignee<1>>>>>>, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::NEq, marian::functional::BinaryFunctor<marian::functional::elem::Eq, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::BinaryFunctor<marian::functional::elem::Gt, marian::functional::Assignee<2>, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<2>, marian::functional::Assignee<3> > >, marian::functional::Capture>, marian::functional::Capture> >, marian::Tensor, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::NEq, marian::functional::BinaryFunctor<marian::functional::elem::Eq, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::BinaryFunctor<marian::functional::elem::Gt, marian::functional::Assignee<2>, marian::functional::Assignee<3> >, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<2>, marian::functional::Assignee<3> > >, marian::functional::Capture>, marian::functional::Capture> >, marian::Tensor, marian::Tensor, marian::Tensor);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqrt, marian::functional::Assignee<1> > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Sqrt, marian::functional::Assignee<1> > >, marian::Tensor);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<2>, marian::functional::UnaryFunctor<marian::functional::elem::Sigmoid, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<2> > > > >, marian::Tensor, marian::Tensor);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Capture>, marian::functional::Capture> >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Lt, marian::functional::Assignee<1>, marian::functional::Capture>, marian::functional::Capture>, marian::functional::Capture> >, marian::Tensor);
+template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<1> > > > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Minus, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Plus, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<2>, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Capture, marian::functional::Assignee<1> > > > > >, marian::Tensor, marian::Tensor);
+
+// How to add new specializations:
+// When you use a new specialization, it will cause a link error of this form (example):
+// .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )'
+// To fix this, copy the line with the error message in here and:
+// - replace up to including "undefined reference to `" by "template"
+// - replace final ' by a semicolon
+// - replace 'IntrusivePtr<marian::TensorBase>' with 'marian::Tensor'
+
diff --git a/src/tensors/gpu/prod.cpp b/src/tensors/gpu/prod.cpp
new file mode 100755
index 00000000..b6852a39
--- /dev/null
+++ b/src/tensors/gpu/prod.cpp
@@ -0,0 +1,495 @@
+
+#ifdef _MSC_VER
+#pragma warning(disable: 4505) // warning C4505: '__float2half_rz': unreferenced local function has been removed (missing 'static inline')
+#endif
+
+#include <cublas_v2.h>
+#include <cusparse.h>
+
+// clang-format off
+#include "tensors/gpu/prod.h"
+#include "tensors/gpu/backend.h"
+#include "tensors/gpu/cuda_helpers.h"
+// clang-format on
+
+namespace marian {
+
+namespace gpu {
+
+// The explicit version of matmult like cublasGemmEx choose their math mode based on the algorithm that
+// has been passed into the function call and seem to ignore setMathMode. Here we query the used math mode
+// to choose the algorithm.
+static bool tensorOpsEnabled(cublasHandle_t cublasHandle) {
+#if CUDA_VERSION >= 9000
+ cublasMath_t actual = CUBLAS_DEFAULT_MATH;
+ cublasGetMathMode(cublasHandle, &actual);
+ return actual == CUBLAS_TENSOR_OP_MATH;
+#else
+ return false;
+#endif
+}
+
+static void setTensorMode(cublasHandle_t cublasHandle) {
+ cublasHandle; // fool warnings
+#if CUDA_VERSION >= 9000
+ static int mode = 0; // 1: use TC; -1: do not use TC; 0: not set yet
+ if (mode == 0) { // multi-thread note: this is sort-of thread-safe, since multiple threads would determine the same value
+ const char* var = getenv("ENABLE_CUBLAS_TENSOR_OP_MATH_FP32");
+ if (!var)
+ var = "1";
+ switch(var[0]) {
+ case '0': mode = -1; break;
+ case '1': mode = 1; break;
+ default: ABORT("Invalid ENABLE_CUBLAS_TENSOR_OP_MATH_FP32={}", var);
+ }
+ if (mode > 0) { // try whether it can be set --@TODO: check whether this actually works
+ CUBLAS_CHECK(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));
+ cublasMath_t actual = CUBLAS_DEFAULT_MATH;
+ cublasGetMathMode(cublasHandle, &actual);
+ if (actual != CUBLAS_TENSOR_OP_MATH) {
+ LOG(warn, "[gpu] TensorCores requested but not available");
+ mode = -1;
+ }
+ }
+ if (mode > 0)
+ LOG(info, "[gpu] 16-bit TensorCores enabled for float32 matrix operations");
+ }
+ CUBLAS_CHECK(cublasSetMathMode(cublasHandle, mode > 0 ? CUBLAS_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH));
+#endif
+}
+
+static void unsetTensorMode(cublasHandle_t cublasHandle) {
+ cublasHandle; // fool warnings
+#if CUDA_VERSION >= 9000
+ CUBLAS_CHECK(cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH));
+#endif
+}
+
+// overload for float, contains configuration settings for float32
+static cublasStatus_t cublasGemmTyped(cublasHandle_t handle,
+ CudaCompute computeCapability,
+ cublasOperation_t transa,
+ cublasOperation_t transb,
+ int m, int n, int k,
+ const float* alpha,
+ const float* A, int lda,
+ const float* B, int ldb,
+ const float* beta,
+ float* C, int ldc) {
+// double #if and if unfortunately required to safeguard against compilation error
+// with CUDA 8.0 and runtime error with CUDA >9.0 on GPUs with compute capability under 5
+#if CUDA_VERSION > 9000
+ // query math mode and set algorithm accordingly
+ auto algorithm = tensorOpsEnabled(handle) ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
+ if(computeCapability.major >= 5)
+ return cublasGemmEx(handle, transa, transb,
+ m, n, k, alpha,
+ A, CUDA_R_32F, lda,
+ B, CUDA_R_32F, ldb, beta,
+ C, CUDA_R_32F, ldc,
+ CUDA_R_32F, algorithm); // @TODO: review algorithm
+#endif
+ return cublasSgemm(handle, transa, transb,
+ m, n, k, alpha,
+ A, lda,
+ B, ldb, beta,
+ C, ldc);
+}
+
+#if COMPILE_FP16
+// overload for half, contains configuration settings for float16
+static cublasStatus_t cublasGemmTyped(cublasHandle_t handle,
+ CudaCompute computeCapability,
+ cublasOperation_t transa,
+ cublasOperation_t transb,
+ int m, int n, int k,
+ const half* alpha,
+ const half* A, int lda,
+ const half* B, int ldb,
+ const half* beta,
+ half* C, int ldc) {
+ ABORT_IF(computeCapability.major < 6, "Compute capability {} below 6 should not happen for FP16", computeCapability.major);
+ // query math mode and set algorithm accordingly
+ auto algorithm = tensorOpsEnabled(handle) ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
+ return cublasGemmEx(handle, transa, transb,
+ m, n, k, alpha,
+ A, CUDA_R_16F, lda,
+ B, CUDA_R_16F, ldb, beta,
+ C, CUDA_R_16F, ldc,
+ CUDA_R_16F, algorithm); // @TODO: review algorithm
+}
+#endif
+
+template <typename T>
+void ProdTyped(marian::Tensor C,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ bool transA,
+ bool transB,
+ T beta,
+ T scalar) {
+ CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
+ T alpha = scalar;
+
+ int m = A->shape().elements() / A->shape().back();
+ int k = A->shape().back();
+ if(transA)
+ std::swap(m, k);
+
+ int l = B->shape().elements() / B->shape().back();
+ int n = B->shape().back();
+ if(transB)
+ std::swap(l, n);
+
+ int lda = A->shape().back();
+ int ldb = B->shape().back();
+ int ldc = B->shape().back();
+
+ if(transB)
+ ldc = B->shape().elements() / B->shape().back();
+
+ cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
+ cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
+ auto cublasHandle = backend->getCublasHandle();
+ auto computeCapability = backend->getCudaComputeCapability();
+
+ setTensorMode(cublasHandle);
+ CUBLAS_CHECK(cublasGemmTyped(cublasHandle,
+ computeCapability,
+ opB,
+ opA,
+ n,
+ m,
+ k,
+ &alpha,
+ B->data<T>(),
+ ldb,
+ A->data<T>(),
+ lda,
+ &beta,
+ C->data<T>(),
+ ldc));
+ unsetTensorMode(cublasHandle);
+}
+
+void Prod(marian::Tensor C,
+ const marian::Tensor& A,
+ const marian::Tensor& B,
+ bool transA,
+ bool transB,
+ float beta,
+ float scalar) {
+ if(C->type() == Type::float32) {
+ ProdTyped<float>(C, A, B, transA, transB, beta, scalar);
+#if COMPILE_FP16
+ } else if(C->type() == Type::float16) {
+ ProdTyped<half>(C, A, B, transA, transB, __float2half(beta), __float2half(scalar));
+#endif
+ } else {
+ ABORT("Prod not implemented for type {}", C->type());
+ }
+}
+
+cublasStatus_t cublasGemmBatchedTyped(cublasHandle_t handle,
+ CudaCompute computeCapability,
+ cublasOperation_t transa,
+ cublasOperation_t transb,
+ int m, int n, int k,
+ const float *alpha,
+ const float *Aarray[], int lda,
+ const float *Barray[], int ldb,
+ const float *beta,
+ float *Carray[], int ldc,
+ int batchCount) {
+// double #if and if unfortunately required to safeguard against compilation error
+// with CUDA 8.0 and runtime error with CUDA >9.0 on GPUs with compute capability under 5
+#if CUDA_VERSION > 9000
+ // query math mode and set algorithm accordingly
+ auto algorithm = tensorOpsEnabled(handle) ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
+ if(computeCapability.major >= 5)
+ return cublasGemmBatchedEx(handle, transa, transb,
+ m, n, k, alpha,
+ (void* const*)Aarray, CUDA_R_32F, lda,
+ (void* const*)Barray, CUDA_R_32F, ldb, beta,
+ (void**)Carray, CUDA_R_32F, ldc, batchCount,
+ CUDA_R_32F, algorithm);
+#endif
+ return cublasSgemmBatched(handle, transa, transb,
+ m, n, k, alpha,
+ Aarray, lda,
+ Barray, ldb, beta,
+ Carray, ldc, batchCount);
+}
+
+#if COMPILE_FP16 // should not be visible for CUDA 9.0 and below
+cublasStatus_t cublasGemmBatchedTyped(cublasHandle_t handle,
+ CudaCompute computeCapability,
+ cublasOperation_t transa,
+ cublasOperation_t transb,
+ int m, int n, int k,
+ const half *alpha,
+ const half *Aarray[], int lda,
+ const half *Barray[], int ldb,
+ const half *beta,
+ half *Carray[], int ldc,
+ int batchCount) {
+ ABORT_IF(computeCapability.major < 6, "Compute capability {} below 6 should not happen for FP16", computeCapability.major);
+ // query math mode and set algorithm accordingly
+ auto algorithm = tensorOpsEnabled(handle) ? CUBLAS_GEMM_DEFAULT_TENSOR_OP : CUBLAS_GEMM_DEFAULT;
+ return cublasGemmBatchedEx(handle, transa, transb,
+ m, n, k, alpha,
+ (void* const*)Aarray, CUDA_R_16F, lda,
+ (void* const*)Barray, CUDA_R_16F, ldb, beta,
+ (void**)Carray, CUDA_R_16F, ldc, batchCount,
+ CUDA_R_16F, algorithm); // @TODO: to 16, this is testing
+}
+#endif
+
+template <typename T>
+void ProdBatchedTyped(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ bool transA,
+ bool transB,
+ T beta,
+ T scalar) {
+ CUDA_CHECK(cudaSetDevice((int)C->getDeviceId().no));
+ T alpha = scalar;
+
+ int batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]);
+ int batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]);
+
+ int m = A->shape()[-2];
+ int k = A->shape()[-1];
+ if(transA)
+ std::swap(m, k);
+
+ int l = B->shape()[-2];
+ int n = B->shape()[-1];
+ if(transB)
+ std::swap(l, n);
+
+ int lda = A->shape()[-1];
+ int ldb = B->shape()[-1];
+ int ldc = B->shape()[-1];
+
+ if(transB)
+ ldc = B->shape()[-2];
+
+ cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
+ cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ auto backend = std::static_pointer_cast<gpu::Backend>(C->getBackend());
+ auto cublasHandle = backend->getCublasHandle();
+ auto compute = backend->getCudaComputeCapability();
+
+ auto strideA = batchA == 1 ? 0 : m * k;
+ auto strideB = batchB == 1 ? 0 : n * k;
+ auto strideC = n * m;
+ auto batchC = std::max(batchA, batchB);
+
+ std::vector<const T*> aptr;
+ std::vector<const T*> bptr;
+ std::vector<T*> cptr;
+
+ for(int i = 0; i < batchC; i++) {
+ aptr.push_back(A->data<T>() + (i % batchA) * strideA);
+ bptr.push_back(B->data<T>() + (i % batchB) * strideB);
+ cptr.push_back(C->data<T>() + i * strideC);
+ }
+
+ // auto fails here from weird reason
+ IPtr<MemoryPiece> mp_aptr = allocator->alloc<const T*>(aptr.size());
+ CudaCopy(aptr.data(), aptr.data() + aptr.size(), mp_aptr->data<const T*>());
+
+ IPtr<MemoryPiece> mp_bptr = allocator->alloc<const T*>(bptr.size());
+ CudaCopy(bptr.data(), bptr.data() + bptr.size(), mp_bptr->data<const T*>());
+
+ IPtr<MemoryPiece> mp_cptr = allocator->alloc<T*>(cptr.size());
+ CudaCopy(cptr.data(), cptr.data() + cptr.size(), mp_cptr->data<T*>());
+
+ setTensorMode(cublasHandle);
+ CUBLAS_CHECK(cublasGemmBatchedTyped(cublasHandle,
+ compute,
+ opB,
+ opA,
+ n,
+ m,
+ k,
+ &alpha,
+ mp_bptr->data<const T*>(),
+ ldb,
+ mp_aptr->data<const T*>(),
+ lda,
+ &beta,
+ mp_cptr->data<T*>(),
+ ldc,
+ batchC));
+ unsetTensorMode(cublasHandle);
+
+ allocator->free(mp_aptr);
+ allocator->free(mp_bptr);
+ allocator->free(mp_cptr);
+}
+
+void ProdBatched(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor A,
+ const marian::Tensor B,
+ bool transA,
+ bool transB,
+ float beta,
+ float scalar) {
+ if(C->type() == Type::float32) {
+ ProdBatchedTyped<float>(C, allocator, A, B, transA, transB, beta, scalar);
+#if COMPILE_FP16
+ } else if(C->type() == Type::float16) { // not a *.cu file
+ ProdBatchedTyped<half>(C, allocator, A, B, transA, transB, __float2half(beta), __float2half(scalar));
+#endif
+ } else {
+ ABORT("ProdBatched not implemented for type {}", C->type());
+ }
+}
+
+// bug in cuSparse: sparse matrix is limited to 65535 columns
+// This function is a drop-in replacement that handles it (by slicing).
+cusparseStatus_t
+static cusparseSgemmiEx(cusparseHandle_t handle, int m,
+ int n, // the offending number of columns of matrices B and C
+ int k, int nnz, const float *alpha, const float *A, int lda,
+ const float *cscValB, const int *cscColPtrB, const int *cscRowIndB, const float *beta,
+ float *C, int ldc)
+{
+ const int nMax = 65535; // max. number of columns allowed by cuSparse 10 implementation
+ for (int j0 = 0; j0 < n; j0 += 65535) { // loop over column slices, j0 = index of first column
+ // Call original function on a column slice.
+ // Replace all parameters that relate to the column slice.
+ // nnz does not need to be corrected.
+ auto n1 = std::min(n - j0, nMax); // width of column slice is limited to max
+ auto C1 = C + j0 * ldc; // column slice into result matrix C
+ auto cscColPtrB1 = cscColPtrB + j0; // column slice into sparse factor B
+ auto rc = cusparseSgemmi(handle, m, n1, k, nnz, alpha, A, lda, cscValB, cscColPtrB1, cscRowIndB, beta, C1, ldc);
+ if (rc != CUSPARSE_STATUS_SUCCESS)
+ return rc;
+ }
+ return CUSPARSE_STATUS_SUCCESS;
+}
+
+// @TODO: make this work with fp16
+
+// C = op(S) x D if not swapOperands else C = D x op(S)
+// op(S) = S if not transA else S^T
+void CSRProd(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor& S_values,
+ const marian::Tensor& S_indices,
+ const marian::Tensor& S_offsets,
+ const marian::Tensor& D,
+ bool transS,
+ bool swapOperands,
+ float beta) {
+ cudaSetDevice((int)C->getDeviceId().no);
+ auto cusparseHandle = std::static_pointer_cast<gpu::Backend>(C->getBackend())
+ ->getCusparseHandle();
+ // interpret tensor dimensions as matrix dimensions
+ const auto& shapeC = C->shape();
+ const auto& shapeD = D->shape();
+ // If swapOperands, S and D are swapped (C = D x S instead of C = S x D).
+ // In that case, in the next 6 lines, please read all dimensions as if they were reversed in order.
+ auto rowsC = shapeC[-(int)swapOperands];
+ auto colsC = shapeC.elements() / rowsC;
+ auto rowsD = shapeD[-(int)swapOperands];
+ auto colsD = shapeD.elements() / rowsD;
+ auto rowsS = transS ? rowsD : rowsC;
+ auto colsS = transS ? rowsC : rowsD;
+ ABORT_IF(colsD != colsC, "Inconsistent outer dimensions in CSR product");
+ if (swapOperands) { // make rowsX actual row dimensions again, likewise colsX
+ std::swap(rowsC, colsC);
+ std::swap(rowsD, colsD);
+ std::swap(rowsS, colsS);
+ }
+ // sparse arrays
+ auto numValues = S_values->shape().elements();
+ auto numOffsets = S_offsets->shape().elements() - 1; // -1 since last value is length
+ ABORT_IF(numOffsets != rowsS, "Unexpected number of rows in CSR argument");
+ ABORT_IF(S_values->shape() != S_indices->shape(), "CSR values and indices must have the same size");
+ float alpha = 1;
+ MemoryPiece::PtrType St_values, St_indices, St_offsets;
+ if (transS != swapOperands) {
+ // Cusparse gemmi() does not support this specific version of transpose, and csrmm() is non-deterministic.
+ // Hence, we transpose the matrix explicitly.
+ // Note that gemmi() expects a CSC, while csrmm() a CSR; hence, the strange condition (transS != swapOperands) above.
+ St_values = allocator->alloc<float>(numValues);
+ St_indices = allocator->alloc<int>(numValues);
+ St_offsets = allocator->alloc<int>(colsS + 1);
+ // transpose the second argument
+ CUSPARSE_CHECK(cusparseScsr2csc(cusparseHandle,
+ /*m=*/ rowsS, // number of rows of matrix
+ /*n=*/ colsS, // number of columns of matrix
+ /*nnz=*/ (int)numValues,
+ /*csrcVal=*/ S_values ->data<float>(),
+ /*csrcRowPtr=*/ (int*)S_offsets->data<IndexType>(),
+ /*csrcColInd=*/ (int*)S_indices->data<IndexType>(),
+ /*cscVal=*/ St_values ->data<float>(), // transposed version goes here
+ /*cscRowInd=*/ St_indices->data<int>(),
+ /*cscColPtr=*/ St_offsets->data<int>(),
+ /*copyValues=*/ CUSPARSE_ACTION_NUMERIC,
+ /*idxBase=*/ CUSPARSE_INDEX_BASE_ZERO));
+ std::swap(rowsS, colsS); // these variables now represent the dims of the explicitly transposed object
+ }
+ if (swapOperands) {
+ // C = D x S for row-major matrices
+ // Implemented via cusparse as C' = S' x D' ("csrmm") where C' and D' are column-major,
+ // and S' is CSR (if not transS then we make a transposed copy).
+ cusparseMatDescr_t descrA;
+ CUSPARSE_CHECK(cusparseCreateMatDescr(&descrA));
+ cusparseSetMatType (descrA, CUSPARSE_MATRIX_TYPE_GENERAL);
+ cusparseSetMatIndexBase(descrA, CUSPARSE_INDEX_BASE_ZERO);
+ CUSPARSE_CHECK(cusparseScsrmm(cusparseHandle,
+ CUSPARSE_OPERATION_NON_TRANSPOSE, // (we explicitly transposed above)
+ /*m=*/ rowsS, // #rows of first (CSR) factor (the transpose was done explicitly)
+ /*n=*/ rowsC, // #cols of second (col-major) factor and (col-major) result = #rows of row-major C
+ /*k=*/ colsS, // #cols of first (CSR) factor
+ /*nnz=*/ (int)numValues,
+ &alpha, descrA,
+ /*csrValA=*/ St_values ? St_values ->data<float>() : S_values ->data<float>(),
+ /*csrRowPtrA=*/ St_offsets ? St_offsets->data<int>() : (int*)S_offsets->data<IndexType>(),
+ /*csrColIndA=*/ St_indices ? St_indices->data<int>() : (int*)S_indices->data<IndexType>(),
+ D->data(),
+ /*ldb=*/ colsD, // stride
+ &beta,
+ C->data(),
+ /*ldc=*/ colsC)); // stride
+ cusparseDestroyMatDescr(descrA);
+ }
+ else {
+ // C = S x D for row-major matrices
+ // Implemented via cusparse as C' = D' x S' ("gemmi") where C' and D' are column-major.
+ CUSPARSE_CHECK(cusparseSgemmiEx(cusparseHandle,
+ /*m=*/ colsD, // #rows of first (col-major) factor = #cols of row-major D
+ /*n=*/ rowsC, // #cols of second (CSC) factor and (col-major) result = #rows of row-major C
+ /*k=*/ rowsD, // #cols of first (col-major) factor = #rows of row-major D
+ /*nnz=*/ (int)numValues,
+ &alpha,
+ /*A=*/ D->data(),
+ /*lda=*/ colsD, // stride
+ /*cscValB=*/ St_values ? St_values ->data<float>() : S_values ->data<float>(),
+ /*cscColPtrB=*/ St_offsets ? St_offsets->data<int>() : (int*)S_offsets->data<IndexType>(),
+ /*cscRowIndB=*/ St_indices ? St_indices->data<int>() : (int*)S_indices->data<IndexType>(),
+ &beta,
+ C->data(),
+ /*ldc=*/ colsC)); // stride
+ // Note: cuSparse 10 docs says this about cscColPtrB:
+ // "integer array of k + 1 elements that contains the start of every row and the end of the last row plus one."
+ // This is wrong. It should be col instead of row, and n instead of k.
+ }
+ if(St_values ) allocator->free(St_values );
+ if(St_indices) allocator->free(St_indices);
+ if(St_offsets) allocator->free(St_offsets);
+}
+
+} // namespace gpu
+} // namespace marian
diff --git a/src/tensors/gpu/prod.cu b/src/tensors/gpu/prod.cu
deleted file mode 100644
index 9558a67f..00000000
--- a/src/tensors/gpu/prod.cu
+++ /dev/null
@@ -1,200 +0,0 @@
-
-#include <cublas_v2.h>
-
-// clang-format off
-#include "tensors/gpu/prod.h"
-#include "tensors/gpu/backend.h"
-#include "tensors/gpu/cuda_helpers.h"
-// clang-format on
-
-namespace marian {
-
-namespace gpu {
-
-void Prod(marian::Tensor C,
- const marian::Tensor& A,
- const marian::Tensor& B,
- bool transA,
- bool transB,
- float beta,
- float scalar) {
- cudaSetDevice(C->getDeviceId().no);
- float alpha = scalar;
-
- size_t m = A->shape().elements() / A->shape().back();
- size_t k = A->shape().back();
- if(transA)
- std::swap(m, k);
-
- size_t l = B->shape().elements() / B->shape().back();
- size_t n = B->shape().back();
- if(transB)
- std::swap(l, n);
-
- size_t lda = A->shape().back();
- size_t ldb = B->shape().back();
- size_t ldc = B->shape().back();
-
- if(transB)
- ldc = B->shape().elements() / B->shape().back();
-
- cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
- cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
-
- auto cublasHandle = std::static_pointer_cast<gpu::Backend>(C->getBackend())
- ->getCublasHandle();
-
-#if CUDA_VERSION >= 9000
- cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH);
-#endif
-
- cublasSgemm(cublasHandle,
- opB,
- opA,
- n,
- m,
- k,
- &alpha,
- B->data(),
- ldb,
- A->data(),
- lda,
- &beta,
- C->data(),
- ldc);
-#if CUDA_VERSION >= 9000
- cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH);
-#endif
-}
-
-__global__ void gAddBias(float* out,
- const float* bias,
- size_t length,
- size_t cols) {
- for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
- int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
- if(index < length) {
- size_t index2 = index % cols;
- out[index] += bias[index2];
- }
- }
-}
-
-void AddBias(marian::Tensor C, const marian::Tensor bias) {
- cudaSetDevice(C->getDeviceId().no);
-
- int length = C->shape().elements();
- int cols = bias->shape().elements();
-
- int threads = std::min(MAX_THREADS, length);
- int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
-
- gAddBias<<<blocks, threads>>>(C->data(), bias->data(), length, cols);
-
- cudaStreamSynchronize(0);
-}
-
-void ProdWithBias(marian::Tensor C,
- const marian::Tensor& A,
- const marian::Tensor& B,
- const marian::Tensor& bias,
- bool transA,
- bool transB,
- float beta,
- float scalar) {
- marian::gpu::Prod(C, A, B, transA, transB, beta, scalar);
- marian::gpu::AddBias(C, bias);
-}
-
-void ProdBatched(marian::Tensor C,
- Ptr<Allocator> allocator,
- const marian::Tensor A,
- const marian::Tensor B,
- bool transA,
- bool transB,
- float beta,
- float scalar) {
- cudaSetDevice(C->getDeviceId().no);
- float alpha = scalar;
-
- size_t batchA = A->shape().elements() / (A->shape()[-1] * A->shape()[-2]);
- size_t batchB = B->shape().elements() / (B->shape()[-1] * B->shape()[-2]);
-
- size_t m = A->shape()[-2];
- size_t k = A->shape()[-1];
- if(transA)
- std::swap(m, k);
-
- size_t l = B->shape()[-2];
- size_t n = B->shape()[-1];
- if(transB)
- std::swap(l, n);
-
- size_t lda = A->shape()[-1];
- size_t ldb = B->shape()[-1];
- size_t ldc = B->shape()[-1];
-
- if(transB)
- ldc = B->shape()[-2];
-
- cublasOperation_t opA = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
- cublasOperation_t opB = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
-
- auto cublasHandle = std::static_pointer_cast<gpu::Backend>(C->getBackend())
- ->getCublasHandle();
-
- int strideA = batchA == 1 ? 0 : m * k;
- int strideB = batchB == 1 ? 0 : n * k;
- int strideC = n * m;
- int batchC = std::max(batchA, batchB);
-
- std::vector<const float*> aptr;
- std::vector<const float*> bptr;
- std::vector<float*> cptr;
-
- for(int i = 0; i < batchC; i++) {
- aptr.push_back(A->data() + (i % batchA) * strideA);
- bptr.push_back(B->data() + (i % batchB) * strideB);
- cptr.push_back(C->data() + i * strideC);
- }
-
- auto mp_aptr = allocator->alloc<const float*>(aptr.size());
- CudaCopy(
- aptr.data(), aptr.data() + aptr.size(), mp_aptr->data<const float*>());
-
- auto mp_bptr = allocator->alloc<const float*>(bptr.size());
- CudaCopy(
- bptr.data(), bptr.data() + bptr.size(), mp_bptr->data<const float*>());
-
- auto mp_cptr = allocator->alloc<float*>(cptr.size());
- CudaCopy(cptr.data(), cptr.data() + cptr.size(), mp_cptr->data<float*>());
-
-#if CUDA_VERSION >= 9000
- cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH);
-#endif
- cublasSgemmBatched(cublasHandle,
- opB,
- opA,
- n,
- m,
- k,
- &alpha,
- mp_bptr->data<const float*>(),
- ldb,
- mp_aptr->data<const float*>(),
- lda,
- &beta,
- mp_cptr->data<float*>(),
- ldc,
- batchC);
-#if CUDA_VERSION >= 9000
- cublasSetMathMode(cublasHandle, CUBLAS_DEFAULT_MATH);
-#endif
-
- allocator->free(mp_aptr);
- allocator->free(mp_bptr);
- allocator->free(mp_cptr);
-}
-
-} // namespace gpu
-} // namespace marian
diff --git a/src/tensors/gpu/prod.h b/src/tensors/gpu/prod.h
index ce1f6cdc..9dc1220c 100644
--- a/src/tensors/gpu/prod.h
+++ b/src/tensors/gpu/prod.h
@@ -16,15 +16,6 @@ void Prod(marian::Tensor C,
float beta = 0,
float scalar = 1);
-void ProdWithBias(marian::Tensor C,
- const marian::Tensor& A,
- const marian::Tensor& B,
- const marian::Tensor& bias,
- bool transA,
- bool transB,
- float beta = 0,
- float scalar = 1);
-
void ProdBatched(marian::Tensor C,
Ptr<Allocator> allocator,
const marian::Tensor A,
@@ -33,5 +24,15 @@ void ProdBatched(marian::Tensor C,
bool transB,
float beta = 0,
float scalar = 1);
+
+void CSRProd(marian::Tensor C,
+ Ptr<Allocator> allocator,
+ const marian::Tensor& A_values,
+ const marian::Tensor& A_indices,
+ const marian::Tensor& A_offsets,
+ const marian::Tensor& B,
+ bool transA,
+ bool swapOperands,
+ float beta = 0);
} // namespace gpu
} // namespace marian
diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu
index 34da4c68..eefcc405 100644..100755
--- a/src/tensors/gpu/tensor_operators.cu
+++ b/src/tensors/gpu/tensor_operators.cu
@@ -1,40 +1,147 @@
-//#include <thrust/transform_reduce.h>
-
+#include "common/types.h"
#include "tensors/tensor_operators.h"
#include "functional/functional.h"
#include "functional/tensor.h"
+#include "tensors/allocator.h"
#include "tensors/gpu/backend.h"
#include "tensors/gpu/cuda_helpers.h"
-#include "3rd_party/reduce_all.h"
+#include "tensors/gpu/add_all.h"
namespace marian {
namespace gpu {
-struct isnan_test {
- __host__ __device__ bool operator()(const float a) const { return isnan(a); }
-};
+namespace atomics {
+
+static inline __device__ void atomicAdd(float *address, float val) {
+ //*address += val;
+ ::atomicAdd(address, val);
+}
+
+#if COMPILE_FP16
+// @TODO: copied from CuTorch, adapt this better, give credit.
+static inline __device__ void atomicAdd(half *address, half val) {
+ //*address += val;
+
+#if __CUDA_ARCH__ >= 700 && CUDA_VERSION >= 10000 // compute capability 70 and higher with CUDA 10
+ ::atomicAdd(address, val);
+#else // __CUDA_ARCH__ < 700
+ unsigned int * address_as_ui =
+ (unsigned int *) ((char *)address - ((size_t)address & 2));
+ unsigned int old = *address_as_ui;
+ unsigned int assumed;
+
+ do {
+ assumed = old;
+ #if CUDA_VERSION < 9000
+ half hsum;
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+ hsum = hsum + val;
+ #else
+ __half_raw hsum;
+ hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
+ half tmpres = hsum + val;
+ hsum = __half_raw(tmpres);
+ #endif
+ old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
+ old = atomicCAS(address_as_ui, assumed, old);
+ } while (assumed != old);
+#endif // __CUDA_ARCH__
+}
+#endif
+
+}
+
+template <typename T>
+__global__ void gIsNaN(const T* in, int length, bool* isNaN, bool* isInf) {
+ for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
+ int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
+ if(index < length) {
+ if(isnan((float)in[index])) *isNaN = true;
+ if(isinf((float)in[index])) *isInf = true;
+ }
+ }
+}
+
+void IsNaN(const Tensor in, Ptr<Allocator> allocator, bool& isNaN, bool& isInf) {
+ cudaSetDevice(in->getDeviceId().no);
+
+ int length = in->size();
-__device__ inline float stableSigmoid(float x) {
- if(x >= 0) {
- float z = expf(-x);
- return 1.0 / (1.0 + z);
+ int threads = std::min(MAX_THREADS, length);
+ int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+
+ auto mem = allocator->alloc<bool>(2);
+ bool* dIsNaN = &mem->data<bool>()[0];
+ bool* dIsInf = &mem->data<bool>()[1];
+ fill(in->getBackend(), dIsNaN, dIsNaN + 2, false);
+
+ if(in->type() == Type::float32) {
+ gIsNaN<<<blocks, threads>>>(in->data<float>(), length, dIsNaN, dIsInf);
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ gIsNaN<<<blocks, threads>>>(in->data<half>(), length, dIsNaN, dIsInf);
+#endif
} else {
- float z = expf(x);
- return z / (1.0 + z);
+ ABORT("IsNaN for type {} not implemented", in->type());
}
+
+ CudaCopy(dIsNaN, dIsNaN + 1, &isNaN);
+ CudaCopy(dIsInf, dIsInf + 1, &isInf);
+
+ allocator->free(mem);
+
+ cudaStreamSynchronize(0);
}
-bool IsNan(Tensor in) {
- // cudaSetDevice(in->getDeviceId().no);
- // thrust::device_ptr<float> begin = thrust::device_pointer_cast(in->data());
- // thrust::device_ptr<float> end
- // = thrust::device_pointer_cast(in->data() + in->size());
- // return thrust::transform_reduce(
- // begin, end, isnan_test(), 0, thrust::plus<bool>());
- return false;
+template <typename To, typename From>
+__global__ void gCopyCastTo(To* out, const From* in, int length) {
+ for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
+ int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
+ if(index < length) {
+ out[index] = in[index];
+ }
+ }
+}
+
+template <typename To, typename From>
+void CopyCastTo(To* out, const From* in, int length) {
+ int threads = std::min(MAX_THREADS, length);
+ int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+ gCopyCastTo<<<blocks, threads>>>(out, in, length);
+}
+
+template <typename T>
+void CopyCastFrom(Tensor out, const T* in, int length) {
+ if(out->type() == Type::float32) {
+ CopyCastTo(out->data<float>(), in, length);
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ CopyCastTo(out->data<half>(), in, length);
+#endif
+ } else if(out->type() == Type::float64) {
+ CopyCastTo(out->data<double>(), in, length);
+ } else {
+ ABORT("CopyCastTo to type {} not implemented", out->type());
+ }
+}
+
+void CopyCast(Tensor out, const Tensor in) {
+ cudaSetDevice(out->getDeviceId().no);
+
+ if(in->type() == Type::float32) {
+ CopyCastFrom(out, in->data<float>(), (int)in->size());
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ CopyCastFrom(out, in->data<half>(), (int)in->size());
+#endif
+ } else if(in->type() == Type::float64) {
+ CopyCastFrom(out, in->data<double>(), (int)in->size());
+ } else {
+ ABORT("CopyCastFrom from type {} not implemented", in->type());
+ }
}
void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) {
@@ -46,12 +153,12 @@ void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) {
size_t offset1 = 0;
for(int i = 0; i < step; ++i) {
for(auto in : inputs) {
- size_t size = in->shape().elements() / step;
+ size_t size = (in->shape().elements() / step) * sizeOf(out->type());
size_t offset2 = i * size;
- cudaMemcpy(out->data() + offset1,
- in->data() + offset2,
- size * sizeof(float),
+ cudaMemcpy(out->data<uint8_t>() + offset1,
+ in->data<uint8_t>() + offset2,
+ size,
cudaMemcpyDeviceToDevice);
offset1 += size;
@@ -60,9 +167,9 @@ void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) {
cudaStreamSynchronize(0);
}
-template <bool add>
-__global__ void gInsertCols(float* out,
- const float* in,
+template <bool add, typename T>
+__global__ void gInsertCols(T* out,
+ const T* in,
size_t rows,
size_t cols,
size_t cols_out,
@@ -71,9 +178,9 @@ __global__ void gInsertCols(float* out,
size_t offset_in) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
- if(j < rows) {
- float* rowOut = out + j * cols_out + offset_out;
- const float* rowIn = in + j * cols_in + offset_in;
+ if(j < rows) { // @TODO: change to if j == rows then break, as that's what it means. In 4 functions in here.
+ T* rowOut = out + j * cols_out + offset_out;
+ const T* rowIn = in + j * cols_in + offset_in;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
@@ -87,6 +194,7 @@ __global__ void gInsertCols(float* out,
}
}
+// Special version for axis = -1 @TODO: write common better version
void Concatenate1(Tensor out, const std::vector<Tensor>& inputs) {
cudaSetDevice(out->getDeviceId().no);
@@ -103,19 +211,27 @@ void Concatenate1(Tensor out, const std::vector<Tensor>& inputs) {
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols_in);
- gInsertCols<false><<<blocks, threads>>>(
- out->data(), in->data(), rows, cols_in, cols_out, cols_in, offset, 0);
+ if(out->type() == Type::float32) {
+ gInsertCols<false><<<blocks, threads>>>(out->data<float>(), in->data<float>(), rows, cols_in, cols_out, cols_in, offset, 0);
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gInsertCols<false><<<blocks, threads>>>(out->data<half>(), in->data<half>(), rows, cols_in, cols_out, cols_in, offset, 0);
+#endif
+ } else {
+ ABORT("Concatenate1 not implemented for type {}", out->type());
+ }
offset += cols_in;
}
cudaStreamSynchronize(0);
}
-__global__ void gJoin2(float* out,
+template <typename T>
+__global__ void gJoin2(T* out,
size_t rowBatch,
size_t cols,
- const float* in1,
+ const T* in1,
size_t inStride1,
- const float* in2,
+ const T* in2,
size_t inStride2) {
int outStride = inStride1 + inStride2;
int rows = rowBatch * outStride;
@@ -123,7 +239,7 @@ __global__ void gJoin2(float* out,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float* rowOut = out + j * cols;
+ T* rowOut = out + j * cols;
int curBatch = j / outStride;
int curPos = j % outStride;
@@ -131,8 +247,8 @@ __global__ void gJoin2(float* out,
int jIn1 = (curBatch * inStride1) + curPos;
int jIn2 = (curBatch * inStride2) + curPos - inStride1;
- const float* rowIn1 = in1 + jIn1 * cols;
- const float* rowIn2 = in2 + jIn2 * cols;
+ const T* rowIn1 = in1 + jIn1 * cols;
+ const T* rowIn2 = in2 + jIn2 * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
@@ -147,6 +263,7 @@ __global__ void gJoin2(float* out,
}
}
+// Special version for axis = -2 @TODO: write common better version
void Concatenate2(Tensor out, Tensor in1, Tensor in2) {
cudaSetDevice(out->getDeviceId().no);
@@ -161,13 +278,27 @@ void Concatenate2(Tensor out, Tensor in1, Tensor in2) {
int blocks = std::min(MAX_BLOCKS, (int)rows);
int threads = std::min(MAX_THREADS, (int)cols);
- gJoin2<<<blocks, threads>>>(out->data(),
- rowBatch,
- cols,
- in1->data(),
- rowStride1,
- in2->data(),
- rowStride2);
+ if(out->type() == Type::float32) {
+ gJoin2<<<blocks, threads>>>(out->data<float>(),
+ rowBatch,
+ cols,
+ in1->data<float>(),
+ rowStride1,
+ in2->data<float>(),
+ rowStride2);
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gJoin2<<<blocks, threads>>>(out->data<half>(),
+ rowBatch,
+ cols,
+ in1->data<half>(),
+ rowStride1,
+ in2->data<half>(),
+ rowStride2);
+#endif
+ } else {
+ ABORT("Concatenate2 not implemented for type {}", out->type());
+ }
cudaStreamSynchronize(0);
}
@@ -195,8 +326,18 @@ void Split1(std::vector<Tensor>& outputs, const Tensor in) {
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols_out);
- gInsertCols<true><<<blocks, threads>>>(
- out->data(), in->data(), rows, cols_out, cols_out, cols_in, 0, offset);
+ if(out->type() == Type::float32) {
+ gInsertCols<true><<<blocks, threads>>>(
+ out->data<float>(), in->data<float>(), rows, cols_out, cols_out, cols_in, 0, offset);
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gInsertCols<true><<<blocks, threads>>>(
+ out->data<half>(), in->data<half>(), rows, cols_out, cols_out, cols_in, 0, offset);
+#endif
+ } else {
+ ABORT("Split1 not implemented for type {}", out->type());
+ }
+
offset += cols_out;
}
cudaStreamSynchronize(0);
@@ -204,7 +345,8 @@ void Split1(std::vector<Tensor>& outputs, const Tensor in) {
// @TODO: this function is just a temporary fix until I come up with
// something better for the situation below.
-__global__ void gAddRow(float* out, const float* in, int length) {
+template <typename T>
+__global__ void gAddRow(T* out, const T* in, int length) {
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
@@ -236,8 +378,15 @@ void SplitCont(std::vector<Tensor>& outputs, const Tensor in, int axis) {
int threads = std::min(MAX_THREADS, size);
int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
- gAddRow<<<blocks, threads>>>(
- out->data() + offset2, in->data() + offset1, size);
+ if(out->type() == Type::float32) {
+ gAddRow<<<blocks, threads>>>(out->data<float>() + offset2, in->data<float>() + offset1, size);
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gAddRow<<<blocks, threads>>>(out->data<half>() + offset2, in->data<half>() + offset1, size);
+#endif
+ } else {
+ ABORT("SplitCont not implemented for type {}", out->type());
+ }
offset1 += size;
}
}
@@ -251,10 +400,10 @@ void Deconcatenate(std::vector<Tensor>& outputs, const Tensor in, int ax) {
SplitCont(outputs, in, ax);
}
-template <bool add>
+template <bool add, typename T>
__global__ void gTransposeND(
- functional::Tensor<float> out,
- const functional::Tensor<float> in,
+ functional::Tensor<T> out,
+ const functional::Tensor<T> in,
const functional::Array<int, functional::Shape::size()> permute) {
constexpr size_t N = functional::Shape::size();
functional::Array<int, N> oDims;
@@ -267,17 +416,22 @@ __global__ void gTransposeND(
out.shape().dims(index, oDims);
for(int i = 0; i < N; ++i)
pDims[permute[i]] = oDims[i];
+
+ int inIndex = in.shape().index(pDims);
+
+ // TODO: operates on raw indices, change to
+ // converting Tensor::operator[]
if(add)
- out[index] += in[pDims];
+ out.data()[index] += in.data()[inIndex];
else
- out[index] = in[pDims];
+ out.data()[index] = in.data()[inIndex];
}
}
}
-template <bool add>
-__global__ void gTranspose0213(float* out,
- const float* in,
+template <bool add, typename T>
+__global__ void gTranspose0213(T* out,
+ const T* in,
int rows,
int cols,
int stride1,
@@ -286,14 +440,14 @@ __global__ void gTranspose0213(float* out,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float* rowOut = out + j * cols;
+ T* rowOut = out + j * cols;
int z = j / stride;
int y = (j % stride) / stride1;
int x = (j % stride) % stride1;
int j2 = z * stride + x * stride2 + y;
- const float* rowIn = in + j2 * cols;
+ const T* rowIn = in + j2 * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
@@ -320,8 +474,15 @@ void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis) {
int stride1 = out->shape()[-2];
int stride2 = out->shape()[-3];
- gTranspose0213<false><<<blocks, threads>>>(
- out->data(), in->data(), rows, cols, stride1, stride2);
+ if(in->type() == Type::float32) {
+ gTranspose0213<false><<<blocks, threads>>>(out->data<float>(), in->data<float>(), rows, cols, stride1, stride2);
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ gTranspose0213<false><<<blocks, threads>>>(out->data<half>(), in->data<half>(), rows, cols, stride1, stride2);
+#endif
+ } else {
+ ABORT("Transpose for type {} not implemented", in->type());
+ }
} else {
functional::Array<int, functional::Shape::size()> axes;
int diff = functional::Shape::size() - vAxis.size();
@@ -336,10 +497,19 @@ void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis) {
int blocks
= std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
- gTransposeND<false><<<blocks, threads>>>(out, in, axes);
+ if(in->type() == Type::float32) {
+ gTransposeND<false, float><<<blocks, threads>>>(out, in, axes);
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ gTransposeND<false, half><<<blocks, threads>>>(out, in, axes);
+#endif
+ } else {
+ ABORT("Transpose for type {} not implemented", in->type());
+ }
}
}
+//@TODO: code duplication?
void TransposeNDGrad(Tensor out, Tensor in, const std::vector<int>& vAxis) {
cudaSetDevice(out->getDeviceId().no);
if(vAxis == std::vector<int>({0, 2, 1, 3})) {
@@ -352,8 +522,15 @@ void TransposeNDGrad(Tensor out, Tensor in, const std::vector<int>& vAxis) {
int stride1 = out->shape()[-2];
int stride2 = out->shape()[-3];
- gTranspose0213<true><<<blocks, threads>>>(
- out->data(), in->data(), rows, cols, stride1, stride2);
+ if(in->type() == Type::float32) {
+ gTranspose0213<true><<<blocks, threads>>>(out->data<float>(), in->data<float>(), rows, cols, stride1, stride2);
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ gTranspose0213<true><<<blocks, threads>>>(out->data<half>(), in->data<half>(), rows, cols, stride1, stride2);
+#endif
+ } else {
+ ABORT("Transpose for type {} not implemented", in->type());
+ }
} else {
functional::Array<int, functional::Shape::size()> axes;
int diff = functional::Shape::size() - vAxis.size();
@@ -365,37 +542,64 @@ void TransposeNDGrad(Tensor out, Tensor in, const std::vector<int>& vAxis) {
int length = out->shape().elements();
int threads = std::min(MAX_THREADS, length);
- int blocks
- = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
-
- gTransposeND<true><<<blocks, threads>>>(out, in, axes);
+ int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+
+ if(in->type() == Type::float32) {
+ gTransposeND<true, float><<<blocks, threads>>>(out, in, axes);
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ gTransposeND<true, half><<<blocks, threads>>>(out, in, axes);
+#endif
+ } else {
+ ABORT("Transpose for type {} not implemented", in->type());
+ }
}
}
-__global__ void gSoftmax(float* out,
+// Computes the softmax
+// in - input tensor
+// out - output tensor
+// we compute the softmax over the the cols (last dimension)
+// rows are time, batch or beam dimensions
+// number of threads is number of cols or MAX_THREADS
+// number of blocks is number of rows or MAX_BLOCKS
+// @TODO: handle half2
+template <typename T, typename AccType = float>
+__global__ void gSoftmax(T* out,
functional::Shape outShape,
- const float* in) {
+ const T* in) {
+ using namespace functional;
+
int rows = outShape.elements() / outShape.back();
int cols = outShape.back();
- for(int bid = 0; bid < rows; bid += gridDim.x) {
- int j = bid + blockIdx.x;
- if(j < rows) {
- float* so = out + j * cols;
- const float* sp = in + j * cols;
-
- extern __shared__ float _share[];
-
- float* _max = _share + blockDim.x;
+ for(int bid = 0; bid < rows; bid += gridDim.x) { // loop over blocks of rows
+ int j = bid + blockIdx.x; // blockIdx.x - row index (within block of rows)
+ if(j < rows) { // compute softmax over one row, row elements distributed over threads
+ T* so = out + j * cols; // pointer to row input data
+ const T* sp = in + j * cols;
+
+ // CUDA complains if type or size of shared memory changes, keep size constant.
+ extern __shared__ uint8_t _sharedBytes[];
+ T* _share = (T*)_sharedBytes;
+ AccType* _shareAccType = (AccType*)_sharedBytes;
+
+ // determine max (used below to improve numeric stability)
+ T* _max = _share;
+
+ // @TODO: what's going on here with fp16?
_max[threadIdx.x] = -CUDA_FLT_MAX; // mask
+ // find max over column indices that have the same relative column index (=threadIdx.x) across all blocks of columns
for(int tid = 0; tid < cols; tid += blockDim.x) {
- int id = tid + threadIdx.x;
- if(id < cols) {
- if(sp[id] > _max[threadIdx.x])
- _max[threadIdx.x] = sp[id];
+ // threadIdx.x = column index within block of columns; we reduce over columns within a block, then over blocks
+ int i = tid + threadIdx.x;
+ if(i < cols) {
+ if(sp[i] > _max[threadIdx.x])
+ _max[threadIdx.x] = sp[i];
}
}
__syncthreads();
+ // max over columns within a column block via tree reduction
int len = blockDim.x;
while(len != 1) {
__syncthreads();
@@ -408,21 +612,23 @@ __global__ void gSoftmax(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
- float max = _max[0];
+ T max = _max[0];
__syncthreads();
- float* _sum = _share + blockDim.x;
-
+ // compute denominator
+ AccType* _sum = _shareAccType; // accumulate into AccType
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
- int id = tid + threadIdx.x;
- if(id < cols) {
- float ex = __expf(sp[id] - max);
- so[id] = ex;
- _sum[threadIdx.x] += ex;
+ int i = tid + threadIdx.x;
+ if(i < cols) {
+ // @TODO: is it faster to cache the result of expf() in GPU RAM, or would it be faster to recompute it below?
+ T ex = Ops<T>::exp(sp[i] - max);
+ so[i] = (T)ex;
+ _sum[threadIdx.x] += (AccType)ex; // accumulate into AccType
}
}
__syncthreads();
+ // now reduce over all columns within the block
len = blockDim.x;
while(len != 1) {
__syncthreads();
@@ -432,13 +638,17 @@ __global__ void gSoftmax(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
+
+ // produce final output data
+ AccType sum = _sum[0];
for(int tid = 0; tid < cols; tid += blockDim.x) {
- int id = tid + threadIdx.x;
- if(id < cols) {
- so[id] = so[id] / _sum[0];
+ int i = tid + threadIdx.x;
+ if(i < cols) {
+ so[i] = (T)((AccType)so[i] / sum); // divide as AccType then convert
}
}
}
+ __syncthreads();
}
}
@@ -450,26 +660,42 @@ void Softmax(Tensor out, Tensor in) {
int blocks = std::min(MAX_BLOCKS, (int)m);
int threads = std::min(MAX_THREADS, (int)k);
- int shared = sizeof(float) * threads * 2;
-
- gSoftmax<<<blocks, threads, shared>>>(out->data(), out->shape(), in->data());
+ int shared = sizeof(float) * threads; // accumulate into float
+
+ if(in->type() == Type::float32) {
+ gSoftmax<float, float><<<blocks, threads, shared>>>(out->data<float>(), out->shape(), in->data<float>());
+#if COMPILE_FP16
+ } else if (in->type() == Type::float16) {
+ gSoftmax<half, float><<<blocks, threads, shared>>>(out->data<half>(), out->shape(), in->data<half>());
+#endif
+ } else {
+ ABORT("Softmax not implemented for type {}", in->type());
+ }
}
-__global__ void gLogSoftmax(float* out,
+// @TODO: refactor to reuse code from softmax, add comments
+template <typename T, typename AccType = float>
+__global__ void gLogSoftmax(T* out,
const functional::Shape outShape,
- const float* in) {
+ const T* in) {
+
+ using namespace functional;
+
int rows = outShape.elements() / outShape.back();
int cols = outShape.back();
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float* so = out + j * cols;
- const float* sp = in + j * cols;
+ T* so = out + j * cols;
+ const T* sp = in + j * cols;
- extern __shared__ float _share[];
+ // CUDA complains if type or size of shared memory changes, keep size constant.
+ extern __shared__ uint8_t _sharedBytes[];
+ T* _share = (T*)_sharedBytes;
+ AccType* _shareAccType = (AccType*)_sharedBytes;
- float* _max = _share + blockDim.x;
+ T* _max = _share; // 16-bit is ok for max if applicable
_max[threadIdx.x] = sp[threadIdx.x];
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
@@ -491,19 +717,19 @@ __global__ void gLogSoftmax(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
- float max = _max[0];
+ T max = _max[0];
__syncthreads();
- float* _sum = _share + blockDim.x;
+ AccType* _sum = _shareAccType; // keep AccType for accumulation
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float sm = sp[id] - max;
- float ex = __expf(sm);
+ T sm = sp[id] - max;
+ AccType ex = Ops<AccType>::exp(sm); // sum with AccType
so[id] = sm;
- _sum[threadIdx.x] += ex;
+ _sum[threadIdx.x] += ex; // sum with AccType
}
}
__syncthreads();
@@ -516,12 +742,14 @@ __global__ void gLogSoftmax(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
+ AccType sum = _sum[0];
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols)
- so[id] -= __logf(_sum[0]);
+ so[id] -= (T)Ops<AccType>::log(sum); // take log at the end and convert
}
}
+ __syncthreads();
}
}
@@ -533,33 +761,42 @@ void LogSoftmax(Tensor out, Tensor in) {
int blocks = std::min(MAX_BLOCKS, (int)m);
int threads = std::min(MAX_THREADS, (int)k);
- int shared = sizeof(float) * threads * 2;
-
- gLogSoftmax<<<blocks, threads, shared>>>(
- out->data(), out->shape(), in->data());
+ int shared = sizeof(float) * threads; // use float32 as accumulation type
+
+ if(in->type() == Type::float32) {
+ gLogSoftmax<float, float><<<blocks, threads, shared>>>(out->data<float>(), out->shape(), in->data<float>());
+#if COMPILE_FP16
+ } else if (in->type() == Type::float16) {
+ gLogSoftmax<half, float><<<blocks, threads, shared>>>(out->data<half>(), out->shape(), in->data<half>());
+#endif
+ } else {
+ ABORT("LogSoftmax not implemented for type {}", in->type());
+ }
}
///////////////////////////////////////////////////////
-__global__ void gSoftmaxGrad(float* grad,
- const float* adj,
- const float* val,
+template <typename T, typename AccType = float>
+__global__ void gSoftmaxGrad(T* grad,
+ const T* adj,
+ const T* val,
const int rows,
const int cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- extern __shared__ float _share[];
- float* _sum = _share + blockDim.x;
- float* gradRow = grad + j * cols;
- const float* adjRow = adj + j * cols;
- const float* valRow = val + j * cols;
- _sum[threadIdx.x] = 0.0;
+ extern __shared__ uint8_t _sharedBytes[];
+ AccType* _sum = (AccType*)_sharedBytes;
+
+ T* gradRow = grad + j * cols;
+ const T* adjRow = adj + j * cols;
+ const T* valRow = val + j * cols;
+ _sum[threadIdx.x] = (AccType)0.0f;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- _sum[threadIdx.x] += valRow[id] * adjRow[id];
+ _sum[threadIdx.x] += (AccType)valRow[id] * (AccType)adjRow[id];
}
}
__syncthreads();
@@ -568,22 +805,24 @@ __global__ void gSoftmaxGrad(float* grad,
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
- _sum[threadIdx.x] += _sum[threadIdx.x + skip];
+ _sum[threadIdx.x] += _sum[threadIdx.x + skip]; // accumulates in AccType
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float val = valRow[id] * (adjRow[id] - _sum[0]);
+ AccType val = (AccType)valRow[id] * ((AccType)adjRow[id] - _sum[0]);
if(val)
- gradRow[id] += val;
+ gradRow[id] += (T)val;
}
}
}
+ __syncthreads();
}
}
+// @TODO: refactor with logsoftmax, add math
void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
cudaSetDevice(adj->getDeviceId().no);
// grad and val are both m-by-k matrices, passed as input.
@@ -595,30 +834,42 @@ void SoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, k);
- int shared = sizeof(float) * threads * 2;
- gSoftmaxGrad<<<blocks, threads, shared>>>(
- grad->data(), adj->data(), val->data(), m, k);
+ int shared = sizeof(float) * threads;
+
+ if(grad->type() == Type::float32) {
+ gSoftmaxGrad<float, float><<<blocks, threads, shared>>>(
+ grad->data<float>(), adj->data<float>(), val->data<float>(), m, k);
+#if COMPILE_FP16
+ } else if (grad->type() == Type::float16) {
+ // Accumulate into float
+ gSoftmaxGrad<half, float><<<blocks, threads, shared>>>(
+ grad->data<half>(), adj->data<half>(), val->data<half>(), m, k);
+#endif
+ } else {
+ ABORT("SoftmaxGrad not implemented for type {}", grad->type());
+ }
}
-__global__ void gLogSoftmaxGrad(float* grad,
- const float* adj,
- const float* val,
+template <typename T, typename AccType = float>
+__global__ void gLogSoftmaxGrad(T* grad,
+ const T* adj,
+ const T* val,
const int rows,
const int cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- extern __shared__ float _share[];
- float* _sum = _share + blockDim.x;
+ extern __shared__ uint8_t _sharedBytes[];
+ AccType* _sum = (AccType*)_sharedBytes;
- float* gradRow = grad + j * cols;
- const float* adjRow = adj + j * cols;
- const float* valRow = val + j * cols;
+ T* gradRow = grad + j * cols;
+ const T* adjRow = adj + j * cols;
+ const T* valRow = val + j * cols;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- _sum[threadIdx.x] += adjRow[id];
+ _sum[threadIdx.x] += (AccType)adjRow[id];
}
}
__syncthreads();
@@ -627,16 +878,17 @@ __global__ void gLogSoftmaxGrad(float* grad,
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
- _sum[threadIdx.x] += _sum[threadIdx.x + skip];
+ _sum[threadIdx.x] += _sum[threadIdx.x + skip]; // AccType
len = (len + 1) >> 1;
}
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols)
- gradRow[id] += adjRow[id] - (expf(valRow[id]) * _sum[0]);
+ gradRow[id] += (T)((AccType)adjRow[id] - (functional::Ops<AccType>::exp((AccType)valRow[id]) * _sum[0]));
}
}
+ __syncthreads();
}
}
@@ -652,35 +904,27 @@ void LogSoftmaxGrad(Tensor grad, Tensor adj, Tensor val) {
int blocks = std::min(MAX_BLOCKS, m);
int threads = std::min(MAX_THREADS, k);
- int shared = sizeof(float) * threads * 2;
- gLogSoftmaxGrad<<<blocks, threads, shared>>>(
- grad->data(), adj->data(), val->data(), m, k);
-}
-
-///////////////////////////////////////////////////////
-__global__ void gArgmax(float* out,
- const float* data,
- size_t rows,
- size_t cols) {
- size_t row = blockIdx.x;
- size_t startInd = row * cols;
- float maxScore = -99999;
- size_t maxInd;
- for(size_t col = 0; col < cols; ++col) {
- size_t ind = startInd + col;
- float score = data[ind];
- if(score > maxScore) {
- maxScore = score;
- maxInd = col;
- }
+ int shared = sizeof(float) * threads; // Use float32 as accumulation type
+
+ if(grad->type() == Type::float32) {
+ gLogSoftmaxGrad<float, float><<<blocks, threads, shared>>>(
+ grad->data<float>(), adj->data<float>(), val->data<float>(), m, k);
+#if COMPILE_FP16
+ } else if (grad->type() == Type::float16) {
+ // accumulate into float
+ gLogSoftmaxGrad<half, float><<<blocks, threads, shared>>>(
+ grad->data<half>(), adj->data<half>(), val->data<half>(), m, k);
+#endif
+ } else {
+ ABORT("LogSoftmaxGrad not implemented for type {}", grad->type());
}
- out[row] = maxInd;
}
///////////////////////////////////////////////////////
-__global__ void gCopyRows(float* out,
- const float* in,
+template <typename T>
+__global__ void gCopyRows(T* out,
+ const T* in,
size_t cols,
const IndexType* sourceRowIdx,
size_t rows) {
@@ -690,8 +934,8 @@ __global__ void gCopyRows(float* out,
size_t dstId = j;
size_t srcId = sourceRowIdx[j];
- float* rowOut = out + dstId * cols;
- const float* rowIn = in + srcId * cols;
+ T* rowOut = out + dstId * cols;
+ const T* rowIn = in + srcId * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
@@ -716,28 +960,44 @@ void CopyRows(Tensor out,
int threads = std::min(MAX_THREADS, (int)cols);
int blocks = std::min(MAX_BLOCKS, (int)rowsToCopy);
- gCopyRows<<<blocks, threads>>>(
- out->data(), in->data(), cols, indices->data<IndexType>(), rowsToCopy);
+ if(out->type() == Type::float32) {
+ gCopyRows<<<blocks, threads>>>(
+ out->data<float>(), in->data<float>(), cols, indices->data<IndexType>(), rowsToCopy);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gCopyRows<<<blocks, threads>>>(
+ out->data<half>(), in->data<half>(), cols, indices->data<IndexType>(), rowsToCopy);
+#endif
+ } else {
+ ABORT("CopyRows not implemented for type {}", out->type());
+ }
}
-__global__ void gPasteRows(float* out,
- const float* in,
+template <typename T>
+__global__ void gPasteRows(T* out,
+ const T* in,
size_t cols,
const IndexType* targetRowIdx,
size_t rows) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
- int j = bid + blockIdx.x;
+ int j = bid + blockIdx.x; // index into 'indices' vector
if(j < rows) {
size_t dstId = targetRowIdx[j];
size_t srcId = j;
- float* rowOut = out + dstId * cols;
- const float* rowIn = in + srcId * cols;
+ T* rowOut = out + dstId * cols;
+ const T* rowIn = in + srcId * cols;
+ // aggregate the entire row
for(int tid = 0; tid < cols; tid += blockDim.x) {
- int i = tid + threadIdx.x;
- if(i < cols)
- atomicAdd(rowOut + i, rowIn[i]);
+ int i = tid + threadIdx.x; // column index --@TODO: column index should be called 'j'
+ if(i < cols) {
+ // Note: atomicAdd() not needed if number of blocks is 1. Avoid it because it is slow for fp16.
+ if (gridDim.x == 1)
+ rowOut[i] += rowIn[i];
+ else
+ atomics::atomicAdd(rowOut + i, rowIn[i]);
+ }
}
}
}
@@ -755,16 +1015,34 @@ void PasteRows(Tensor out,
size_t rowsToCopy = indices->size();
int threads = std::min(MAX_THREADS, (int)cols);
+#if 1 // @TODO: make this configurable with a 'deterministic' flag
+ // If we only use one block, then each core operates on a different column,
+ // hence the summation becomes deterministic.
+ // However, we only use e.g. 512 cores out of possibly 3000+, so this will be
+ // 6 x slower in this example.
+ int blocks = 1;
+#else
int blocks = std::min(MAX_BLOCKS, (int)rowsToCopy);
-
- gPasteRows<<<blocks, threads>>>(
- out->data(), in->data(), cols, indices->data<IndexType>(), rowsToCopy);
+#endif
+
+ if(out->type() == Type::float32) {
+ gPasteRows<<<blocks, threads>>>(
+ out->data<float>(), in->data<float>(), cols, indices->data<IndexType>(), rowsToCopy);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gPasteRows<<<blocks, threads>>>(
+ out->data<half>(), in->data<half>(), cols, indices->data<IndexType>(), rowsToCopy);
+#endif
+ } else {
+ ABORT("CopyRows not implemented for type {}", out->type());
+ }
}
/////////////
-__global__ void gCopyCols(float* out,
- const float* in,
+template <typename T>
+__global__ void gCopyCols(T* out,
+ const T* in,
size_t rows,
size_t colsIn,
const IndexType* sourceColIdx,
@@ -772,8 +1050,8 @@ __global__ void gCopyCols(float* out,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- const float* rowIn = in + j * colsIn;
- float* rowOut = out + j * colsOut;
+ const T* rowIn = in + j * colsIn;
+ T* rowOut = out + j * colsOut;
for(int tid = 0; tid < colsOut; tid += blockDim.x) {
int i = tid + threadIdx.x;
@@ -797,12 +1075,22 @@ void CopyCols(Tensor out, const Tensor in, const Tensor indices) {
int threads = std::min(MAX_THREADS, (int)colsToCopy);
int blocks = std::min(MAX_BLOCKS, (int)rows);
- gCopyCols<<<blocks, threads>>>(
- out->data(), in->data(), rows, cols, indices->data<IndexType>(), colsToCopy);
+ if(out->type() == Type::float32) {
+ gCopyCols<<<blocks, threads>>>(
+ out->data<float>(), in->data<float>(), rows, cols, indices->data<IndexType>(), colsToCopy);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gCopyCols<<<blocks, threads>>>(
+ out->data<half>(), in->data<half>(), rows, cols, indices->data<IndexType>(), colsToCopy);
+#endif
+ } else {
+ ABORT("CopyCols not implemented for type {}", out->type());
+ }
}
-__global__ void gPasteCols(float* out,
- const float* in,
+template <typename T>
+__global__ void gPasteCols(T* out,
+ const T* in,
size_t rows,
size_t colsOut,
const IndexType* targetColIdx,
@@ -810,13 +1098,13 @@ __global__ void gPasteCols(float* out,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- const float* rowIn = in + j * colsIn;
- float* rowOut = out + j * colsOut;
+ const T* rowIn = in + j * colsIn;
+ T* rowOut = out + j * colsOut;
for(int tid = 0; tid < colsIn; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < colsIn)
- rowOut[targetColIdx[i]] += rowIn[i];
+ rowOut[targetColIdx[i]] += rowIn[i]; // @TODO: atomicAdd?
}
}
}
@@ -837,16 +1125,27 @@ void PasteCols(Tensor out,
int threads = std::min(MAX_THREADS, (int)colsToCopy);
int blocks = std::min(MAX_BLOCKS, (int)rows);
- gPasteCols<<<blocks, threads>>>(
- out->data(), in->data(), rows, cols, indices->data<IndexType>(), colsToCopy);
+ if(out->type() == Type::float32) {
+ gPasteCols<<<blocks, threads>>>(
+ out->data<float>(), in->data<float>(), rows, cols, indices->data<IndexType>(), colsToCopy);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gPasteCols<<<blocks, threads>>>(
+ out->data<half>(), in->data<half>(), rows, cols, indices->data<IndexType>(), colsToCopy);
+#endif
+ } else {
+ ABORT("PasteCols not implemented for type {}", out->type());
+ }
}
-__global__ void gSelect(float* out,
+template <typename T>
+__global__ void gSelect(T* out,
functional::Shape outShape,
- const float* in,
+ const T* in,
const functional::Shape inShape,
int axis,
- IndexType* d_indices) {
+ const IndexType* d_indices,
+ const functional::Shape idxShape) {
int length = outShape.elements();
functional::Array<int, functional::Shape::size()> dims;
@@ -854,19 +1153,22 @@ __global__ void gSelect(float* out,
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
outShape.dims(index, dims);
- dims[axis] = d_indices[dims[axis]];
+ int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
+ dims[axis] = (int)d_indices[idxIndex];
int inIndex = inShape.index(dims);
out[index] = in[inIndex];
}
}
}
-__global__ void gInsert(float* out,
+template <typename T>
+__global__ void gInsert(T* out,
functional::Shape outShape,
- const float* in,
+ const T* in,
const functional::Shape inShape,
int axis,
- IndexType* d_indices) {
+ const IndexType* d_indices,
+ const functional::Shape idxShape) {
int length = inShape.elements();
functional::Array<int, functional::Shape::size()> dims;
@@ -874,9 +1176,10 @@ __global__ void gInsert(float* out,
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
inShape.dims(index, dims);
- dims[axis] = d_indices[dims[axis]];
+ int idxIndex = idxShape.bindex(dims); // broadcast index into indices tensor
+ dims[axis] = (int)d_indices[idxIndex];
int outIndex = outShape.index(dims);
- out[outIndex] += in[index];
+ out[outIndex] += in[index]; // this is probably wrong, atomicAdd?
}
}
}
@@ -895,12 +1198,28 @@ void Select(Tensor out,
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
int axisGPU = axis + functional::Shape::size() - out->shape().size();
- gSelect<<<blocks, threads>>>(out->data(),
- out->shape(),
- in->data(),
- in->shape(),
- axisGPU,
- indices->data<IndexType>());
+
+ if(out->type() == Type::float32) {
+ gSelect<<<blocks, threads>>>(out->data<float>(),
+ out->shape(),
+ in->data<float>(),
+ in->shape(),
+ axisGPU,
+ indices->data<IndexType>(),
+ indices->shape());
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gSelect<<<blocks, threads>>>(out->data<half>(),
+ out->shape(),
+ in->data<half>(),
+ in->shape(),
+ axisGPU,
+ indices->data<IndexType>(),
+ indices->shape());
+#endif
+ } else {
+ ABORT("Select not implemented for type {}", out->type());
+ }
}
void Insert(Tensor out,
@@ -916,51 +1235,68 @@ void Insert(Tensor out,
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
int axisGPU = axis + functional::Shape::size() - out->shape().size();
- gInsert<<<blocks, threads>>>(out->data(),
- out->shape(),
- in->data(),
- in->shape(),
- axisGPU,
- indices->data<IndexType>());
-}
-
-__global__ void gGRUFastForward(float* out,
- const float* state,
- const float* xW,
- const float* sU,
- const float* b,
- const float* mask,
+
+ if(out->type() == Type::float32) {
+ gInsert<<<blocks, threads>>>(out->data<float>(),
+ out->shape(),
+ in->data<float>(),
+ in->shape(),
+ axisGPU,
+ indices->data<IndexType>(),
+ indices->shape());
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gInsert<<<blocks, threads>>>(out->data<half>(),
+ out->shape(),
+ in->data<half>(),
+ in->shape(),
+ axisGPU,
+ indices->data<IndexType>(),
+ indices->shape());
+#endif
+ } else {
+ ABORT("Insert not implemented for type {}", out->type());
+ }
+}
+
+template <typename T>
+__global__ void gGRUFastForward(T* out,
+ const T* state,
+ const T* xW,
+ const T* sU,
+ const T* b,
+ const T* mask,
size_t rows,
size_t cols,
bool final) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float m = !mask || mask[j];
- float* rowOut = out + j * cols;
- const float* rowState = state + j * cols;
+ T m = !mask || mask[j];
+ T* rowOut = out + j * cols;
+ const T* rowState = state + j * cols;
- const float* xWrow = xW + j * cols * 3;
- const float* sUrow = sU + j * cols * 3;
+ const T* xWrow = xW + j * cols * 3;
+ const T* sUrow = sU + j * cols * 3;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) {
- float r = stableSigmoid(xWrow[i] + sUrow[i] + b[i]);
+ T r = functional::Ops<T>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
int k = i + cols;
- float z = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ T z = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
int l = i + 2 * cols;
- float h;
+ T h;
if(final)
- h = tanhf(xWrow[l] + (sUrow[l] + b[l]) * r);
+ h = functional::Ops<T>::tanh(xWrow[l] + (sUrow[l] + b[l]) * r);
else
- h = tanhf(xWrow[l] + sUrow[l] * r + b[l]);
+ h = functional::Ops<T>::tanh(xWrow[l] + sUrow[l] * r + b[l]);
- float out = (1.0f - z) * h + z * rowState[i];
- rowOut[i] = m * out + (1 - m) * rowState[i];
+ T out = ((T)1.f - z) * h + z * rowState[i];
+ rowOut[i] = m * out + ((T)1.f - m) * rowState[i];
}
}
}
@@ -976,44 +1312,62 @@ void GRUFastForward(Tensor out, std::vector<Tensor> inputs, bool final) {
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols);
- gGRUFastForward<<<blocks, threads>>>(
- out->data(), // output
- inputs[0]->data(), // state
- inputs[1]->data(), // xW
- inputs[2]->data(), // sU
- inputs[3]->data(), // b
- inputs.size() > 4 ? inputs[4]->data() : 0, // mask
- rows,
- cols,
- final);
-}
-
-__global__ void gGRUFastBackward(float* outState,
- float* outXW,
- float* outSU,
- float* outB,
- const float* state,
- const float* xW,
- const float* sU,
- const float* b,
- const float* mask,
- const float* adj,
+ if(out->type() == Type::float32) {
+ gGRUFastForward<<<blocks, threads>>>(
+ out->data<float>(), // output
+ inputs[0]->data<float>(), // state
+ inputs[1]->data<float>(), // xW
+ inputs[2]->data<float>(), // sU
+ inputs[3]->data<float>(), // b
+ inputs.size() > 4 ? inputs[4]->data<float>() : 0, // mask
+ rows,
+ cols,
+ final);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gGRUFastForward<<<blocks, threads>>>(
+ out->data<half>(), // output
+ inputs[0]->data<half>(), // state
+ inputs[1]->data<half>(), // xW
+ inputs[2]->data<half>(), // sU
+ inputs[3]->data<half>(), // b
+ inputs.size() > 4 ? inputs[4]->data<half>() : 0, // mask
+ rows,
+ cols,
+ final);
+#endif
+ } else {
+ ABORT("GRUFastForward not implemented for type {}", out->type());
+ }
+}
+
+template <typename T>
+__global__ void gGRUFastBackward(T* outState,
+ T* outXW,
+ T* outSU,
+ T* outB,
+ const T* state,
+ const T* xW,
+ const T* sU,
+ const T* b,
+ const T* mask,
+ const T* adj,
size_t rows,
size_t cols,
bool final) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float m = !mask || mask[j];
+ T m = !mask || mask[j];
- float* rowOutState = outState + j * cols;
- float* rowOutXW = outXW + j * cols * 3;
- float* rowOutSU = outSU + j * cols * 3;
+ T* rowOutState = outState + j * cols;
+ T* rowOutXW = outXW + j * cols * 3;
+ T* rowOutSU = outSU + j * cols * 3;
- const float* rowState = state + j * cols;
- const float* rowXW = xW + j * cols * 3;
- const float* rowSU = sU + j * cols * 3;
- const float* rowAdj = adj + j * cols;
+ const T* rowState = state + j * cols;
+ const T* rowXW = xW + j * cols * 3;
+ const T* rowSU = sU + j * cols * 3;
+ const T* rowAdj = adj + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
@@ -1021,25 +1375,25 @@ __global__ void gGRUFastBackward(float* outState,
int k = i + cols;
int l = i + 2 * cols;
- float r = stableSigmoid(rowXW[i] + rowSU[i] + b[i]);
- float z = stableSigmoid(rowXW[k] + rowSU[k] + b[k]);
+ T r = functional::Ops<T>::sigmoid(rowXW[i] + rowSU[i] + b[i]);
+ T z = functional::Ops<T>::sigmoid(rowXW[k] + rowSU[k] + b[k]);
- float h;
+ T h;
if(final)
- h = tanhf(rowXW[l] + (rowSU[l] + b[l]) * r);
+ h = functional::Ops<T>::tanh(rowXW[l] + (rowSU[l] + b[l]) * r);
else
- h = tanhf(rowXW[l] + rowSU[l] * r + b[l]);
+ h = functional::Ops<T>::tanh(rowXW[l] + rowSU[l] * r + b[l]);
- float adj = rowAdj[i];
+ T adj = rowAdj[i];
- float t = (1 - z) * (1 - h * h);
+ T t = ((T)1.f - z) * ((T)1.f - h * h);
// df/ds
if(outState)
- rowOutState[i] += (m * z - m + 1) * adj;
+ rowOutState[i] += (m * z - m + (T)1.f) * adj;
// df/d(xW_r) ...
- float dfdxW_r = m * r * (1 - r) * t * adj;
+ T dfdxW_r = m * r * ((T)1.f - r) * t * adj;
if(final)
dfdxW_r *= rowSU[l] + b[l];
else
@@ -1049,28 +1403,28 @@ __global__ void gGRUFastBackward(float* outState,
if(outSU)
rowOutSU[i] += dfdxW_r;
if(outB)
- atomicAdd(outB + i, dfdxW_r);
+ atomics::atomicAdd(outB + i, dfdxW_r); // @TODO: get rid of atomicAdd everywhere
// df/d(xW_z) ...
- float dfdxW_z = m * (1 - z) * z * (rowState[i] - h) * adj;
+ T dfdxW_z = m * ((T)1.f - z) * z * (rowState[i] - h) * adj;
if(outXW)
rowOutXW[k] += dfdxW_z;
if(outSU)
rowOutSU[k] += dfdxW_z;
if(outB)
- atomicAdd(outB + k, dfdxW_z);
+ atomics::atomicAdd(outB + k, dfdxW_z);
// df/d(xW_x) ...
- float dfdxW_x = m * t * adj;
+ T dfdxW_x = m * t * adj;
if(outXW)
rowOutXW[l] += dfdxW_x;
if(outSU)
rowOutSU[l] += dfdxW_x * r;
if(outB)
if(final)
- atomicAdd(outB + l, dfdxW_x * r);
+ atomics::atomicAdd(outB + l, dfdxW_x * r);
else
- atomicAdd(outB + l, dfdxW_x);
+ atomics::atomicAdd(outB + l, dfdxW_x);
}
}
}
@@ -1089,38 +1443,60 @@ void GRUFastBackward(std::vector<Tensor> outputs,
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols);
- gGRUFastBackward<<<blocks, threads>>>(
- outputs[0] ? outputs[0]->data() : 0, // state - adj
- outputs[1] ? outputs[1]->data() : 0, // xW - adj
- outputs[2] ? outputs[2]->data() : 0, // sU - adj
- outputs[3] ? outputs[3]->data() : 0, // b - adj
- inputs[0]->data(), // state
- inputs[1]->data(), // xW
- inputs[2]->data(), // sU
- inputs[3]->data(), // b
- inputs.size() > 4 ? inputs[4]->data() : 0, // mask
- adj->data(),
- rows,
- cols,
- final);
+ if(adj->type() == Type::float32) {
+ gGRUFastBackward<<<blocks, threads>>>(
+ outputs[0] ? outputs[0]->data<float>() : 0, // state - adj
+ outputs[1] ? outputs[1]->data<float>() : 0, // xW - adj
+ outputs[2] ? outputs[2]->data<float>() : 0, // sU - adj
+ outputs[3] ? outputs[3]->data<float>() : 0, // b - adj
+ inputs[0]->data<float>(), // state
+ inputs[1]->data<float>(), // xW
+ inputs[2]->data<float>(), // sU
+ inputs[3]->data<float>(), // b
+ inputs.size() > 4 ? inputs[4]->data<float>() : 0, // mask
+ adj->data<float>(),
+ rows,
+ cols,
+ final);
+#if COMPILE_FP16
+ } else if (adj->type() == Type::float16) {
+ gGRUFastBackward<<<blocks, threads>>>(
+ outputs[0] ? outputs[0]->data<half>() : 0, // state - adj
+ outputs[1] ? outputs[1]->data<half>() : 0, // xW - adj
+ outputs[2] ? outputs[2]->data<half>() : 0, // sU - adj
+ outputs[3] ? outputs[3]->data<half>() : 0, // b - adj
+ inputs[0]->data<half>(), // state
+ inputs[1]->data<half>(), // xW
+ inputs[2]->data<half>(), // sU
+ inputs[3]->data<half>(), // b
+ inputs.size() > 4 ? inputs[4]->data<half>() : 0, // mask
+ adj->data<half>(),
+ rows,
+ cols,
+ final);
+#endif
+ } else {
+ ABORT("gGRUFastBackward not implemented for type {}", adj->type());
+ }
}
-__global__ void gCrossEntropyPick(float* out,
+template <typename T, typename AccType = float>
+__global__ void gCrossEntropyPick(T* out,
const functional::Shape outShape,
- const float* in,
+ const T* in,
const functional::Shape inShape,
const IndexType* pick) {
int rows = inShape.elements() / inShape.back();
int cols = inShape.back();
+ extern __shared__ uint8_t _sharedBytes[];
+
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- const float* sp = in + j * cols;
-
- extern __shared__ float _share[];
- float* _max = _share + blockDim.x;
+ const T* sp = in + j * cols;
+ T* _max = (T*)_sharedBytes;
_max[threadIdx.x] = sp[threadIdx.x];
for(int tid = 1; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
@@ -1142,15 +1518,16 @@ __global__ void gCrossEntropyPick(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
- float max = _max[0];
+ T max = _max[0];
__syncthreads();
- float* _sum = _share + blockDim.x;
- _sum[threadIdx.x] = 0.0;
+ AccType* _sum = (AccType*)_sharedBytes;
+ _sum[threadIdx.x] = (AccType)0.0f;
+
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- _sum[threadIdx.x] += __expf(sp[id] - max);
+ _sum[threadIdx.x] += functional::Ops<AccType>::exp(sp[id] - max);
}
}
__syncthreads();
@@ -1165,20 +1542,21 @@ __global__ void gCrossEntropyPick(float* out,
__syncthreads();
// cross-entropy
+ auto sum = _sum[0];
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
- if(id == (int)pick[j]) {
- out[j] = __logf(_sum[0]) - sp[id] + max;
- }
+ if(id == (int)pick[j])
+ out[j] = (T)functional::Ops<AccType>::log(sum) - sp[id] + max;
}
}
+ __syncthreads();
}
}
// In each j-th row, take the corresponding j-th label index i from indices and compute:
-// For each vocabulary item v, the only non-zero element in a row in the sum is the item
-// that matches the label indexed by i (the picked element).
-// C = sum_{v in V}(-logsoftmax(A) * delta(v, i) = -logsoftmax(A)[i]
+// For each vocabulary item v, the only non-zero element in a row in the sum is the item
+// that matches the label indexed by i (the picked element).
+// C = sum_{v in V}(-logsoftmax(A) * delta(v, i) = -logsoftmax(A)[i]
void CrossEntropyPick(Tensor out, Tensor in, Tensor indices) {
matchOrAbort<IndexType>(indices->type());
@@ -1189,28 +1567,38 @@ void CrossEntropyPick(Tensor out, Tensor in, Tensor indices) {
int blocks = std::min(MAX_BLOCKS, (int)rows);
int threads = std::min(MAX_THREADS, (int)cols);
- int shared = sizeof(float) * threads * 2;
-
- gCrossEntropyPick<<<blocks, threads, shared>>>(
- out->data(), out->shape(), in->data(), in->shape(), indices->data<IndexType>());
+ int shared = sizeof(float) * threads; // Use float32 as accumulation type
+
+ if(out->type() == Type::float32) {
+ gCrossEntropyPick<float, float><<<blocks, threads, shared>>>(
+ out->data<float>(), out->shape(), in->data<float>(), in->shape(), indices->data<IndexType>());
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gCrossEntropyPick<half, float><<<blocks, threads, shared>>>(
+ out->data<half>(), out->shape(), in->data<half>(), in->shape(), indices->data<IndexType>());
+#endif
+ } else {
+ ABORT("CrossEntropyPick not implemented for type {}", out->type());
+ }
}
-__global__ void gCrossEntropyPickBackward(float* out,
+template <typename T, typename AccType = float>
+__global__ void gCrossEntropyPickBackward(T* out,
const functional::Shape outShape,
- const float* adj,
- const float* in,
+ const T* adj,
+ const T* in,
const IndexType* pick) {
int rows = outShape.elements() / outShape.back();
int cols = outShape.back();
+
+ extern __shared__ uint8_t _sharedBytes[];
+
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- const float* sp = in + j * cols;
- float* so = out + j * cols;
-
- extern __shared__ float _share[];
- float* _max = _share + blockDim.x;
-
+ const T* sp = in + j * cols;
+ T* so = out + j * cols;
+ T* _max = (T*)_sharedBytes;
_max[threadIdx.x] = sp[threadIdx.x];
for(int tid = 1; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
@@ -1232,15 +1620,15 @@ __global__ void gCrossEntropyPickBackward(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
- float max = _max[0];
+ T max = _max[0];
__syncthreads();
- float* _sum = _share + blockDim.x;
+ AccType* _sum = (AccType*)_sharedBytes;
_sum[threadIdx.x] = 0.0;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float ex = __expf(sp[id] - max);
+ AccType ex = functional::Ops<AccType>::exp(sp[id] - max);
_sum[threadIdx.x] += ex;
}
}
@@ -1259,11 +1647,13 @@ __global__ void gCrossEntropyPickBackward(float* out,
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float sub = (float)(id == (int)pick[j]);
- so[id] += adj[j] * (__expf(sp[id] - max) / _sum[0] - sub);
+ AccType sub = (AccType)(id == (int)pick[j]);
+ auto softmax = functional::Ops<AccType>::exp(sp[id] - max) / _sum[0];
+ so[id] += (AccType)adj[j] * (softmax - sub);
}
}
}
+ __syncthreads();
}
}
@@ -1277,37 +1667,49 @@ void CrossEntropyPickBackward(Tensor out, Tensor adj, Tensor a, Tensor indices)
int blocks = std::min(MAX_BLOCKS, (int)rows);
int threads = std::min(MAX_THREADS, (int)cols);
- int shared = sizeof(float) * threads * 2;
-
- gCrossEntropyPickBackward<<<blocks, threads, shared>>>(
- out->data(), out->shape(), adj->data(), a->data(), indices->data<IndexType>());
+ int shared = sizeof(float) * threads; // use float as accumulation type
+
+ if(out->type() == Type::float32) {
+ gCrossEntropyPickBackward<float, float><<<blocks, threads, shared>>>(
+ out->data<float>(), out->shape(), adj->data<float>(), a->data<float>(), indices->data<IndexType>());
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gCrossEntropyPickBackward<half, float><<<blocks, threads, shared>>>(
+ out->data<half>(), out->shape(), adj->data<half>(), a->data<half>(), indices->data<IndexType>());
+#endif
+ } else {
+ ABORT("CrossEntropyPick not implemented for type {}", out->type());
+ }
}
-float L2Norm(Tensor in) {
+// computes the L2Norm of tensor and returns value as flaot on the CPU,
+// this is mostly used for diagnostic purposes and gradient clipping
+float L2Norm(Tensor in, Ptr<Allocator> allocator) { // @TODO: reverse order of arguments
cudaSetDevice(in->getDeviceId().no);
int size = in->shape().elements();
int threads = std::min(MAX_THREADS, size);
- int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
-
- uint8_t* data;
- cudaMalloc(&data, blocks * sizeof(float));
- Tensor out(new TensorBase(New<MemoryPiece>(data, blocks * sizeof(float)),
- {1, blocks},
- in->getBackend()));
+ int blocks = std::min(MAX_BLOCKS, size / threads + (size % threads != 0));
using namespace functional;
- ReduceAll(_1 * _1, out, in);
- float dataCpu = sqrtf(out->get(0));
- out.reset();
- cudaFree(data);
- return dataCpu;
+ float l2Norm;
+ if(in->type() == Type::float32) {
+ l2Norm = std::sqrt(AggregateAllAndReturn</*ElementType=*/float, /*AccType=*/float>(allocator, /*functor=*/_1 * _1, /*aggInit=*/0.f, /*aggFunctor=*/_1 + _2, /*scale=*/1.f, in));
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ l2Norm = std::sqrt(AggregateAllAndReturn</*ElementType=*/half, /*AccType=*/float>(allocator, /*functor=*/_1 * _1, /*aggInit=*/0.f, /*aggFunctor=*/_1 + _2, /*scale=*/1.f, in));
+#endif
+ } else {
+ ABORT("L2Norm not implemented for type {}", in->type());
+ }
+ return l2Norm;
}
-__global__ void gAtt(float* out,
- const float* va,
- const float* ctx,
- const float* state,
+template <typename T, typename AccType = float>
+__global__ void gAtt(T* out,
+ const T* va,
+ const T* ctx,
+ const T* state,
int m, // total rows (batch x time x beam)
int k, // depth
int b, // batch size
@@ -1319,19 +1721,19 @@ __global__ void gAtt(float* out,
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- const float* vaRow = va;
- const float* ctxRow = ctx + (j % (b * t)) * cols;
- const float* stateRow = state + ((j / (b * t)) * b + j % b) * cols;
+ const T* vaRow = va;
+ const T* ctxRow = ctx + (j % (b * t)) * cols;
+ const T* stateRow = state + ((j / (b * t)) * b + j % b) * cols;
- extern __shared__ float _share[];
- float* _sum = _share + blockDim.x;
+ extern __shared__ AccType _share[];
+ AccType* _sum = _share;
- _sum[threadIdx.x] = 0.0;
+ _sum[threadIdx.x] = 0.f;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float z = ctxRow[id] + stateRow[id];
- float ex = tanhf(z) * vaRow[id];
+ AccType z = (AccType)ctxRow[id] + (AccType)stateRow[id];
+ AccType ex = functional::Ops<AccType>::tanh(z) * (AccType)vaRow[id];
_sum[threadIdx.x] += ex;
}
}
@@ -1345,35 +1747,45 @@ __global__ void gAtt(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
- out[j] = _sum[0];
- __syncthreads();
+ out[j] = (T)_sum[0];
}
+ __syncthreads();
}
}
void Att(Tensor out, Tensor va, Tensor context, Tensor state) {
cudaSetDevice(out->getDeviceId().no);
- size_t m = out->shape().elements() / out->shape().back();
- size_t k = context->shape()[-1];
- size_t b = context->shape()[-2];
- size_t t = context->shape()[-3];
-
- int blocks = std::min(MAX_BLOCKS, (int)m);
- int threads = std::min(MAX_THREADS, (int)k);
- int shared = sizeof(float) * threads * 2;
-
- gAtt<<<blocks, threads, shared>>>(
- out->data(), va->data(), context->data(), state->data(), m, k, b, t);
+ size_t totalRows = out->shape().elements() / out->shape().back(); // number of rows
+ size_t modelDim = context->shape()[-1]; // number of cols
+ size_t batchDim = context->shape()[-2];
+ size_t contextWordsDim = context->shape()[-3];
+
+ int blocks = std::min(MAX_BLOCKS, (int)totalRows);
+ int threads = std::min(MAX_THREADS, (int)modelDim);
+ int shared = sizeof(float) * threads;
+
+ if(out->type() == Type::float32) {
+ gAtt<float, float><<<blocks, threads, shared>>>(
+ out->data<float>(), va->data<float>(), context->data<float>(), state->data<float>(), totalRows, modelDim, batchDim, contextWordsDim);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gAtt<half, float><<<blocks, threads, shared>>>(
+ out->data<half>(), va->data<half>(), context->data<half>(), state->data<half>(), totalRows, modelDim, batchDim, contextWordsDim);
+#endif
+ } else {
+ ABORT("gAtt not implemented for type {}", out->type());
+ }
}
-__global__ void gAttBack(float* gVa,
- float* gContext,
- float* gState,
- const float* va,
- const float* context,
- const float* state,
- const float* adj,
+template <typename T>
+__global__ void gAttBack(T* gVa,
+ T* gContext,
+ T* gState,
+ const T* va,
+ const T* context,
+ const T* state,
+ const T* adj,
int m, // rows
int k, // cols
int n // batch size
@@ -1383,23 +1795,23 @@ __global__ void gAttBack(float* gVa,
for(int bid = 0; bid < m; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float* gcRow = gContext + j * cols;
- float* gsRow = gState + (j % n) * cols;
+ T* gcRow = gContext + j * cols;
+ T* gsRow = gState + (j % n) * cols;
- const float* cRow = context + j * cols;
- const float* sRow = state + (j % n) * cols;
+ const T* cRow = context + j * cols;
+ const T* sRow = state + (j % n) * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float z = cRow[id] + sRow[id];
+ T z = cRow[id] + sRow[id];
- float t = tanhf(z);
- float r = va[id] * (1.f - t * t);
+ T t = functional::Ops<T>::tanh(z);
+ T r = va[id] * ((T)1.f - t * t);
- gcRow[id] += r * adj[j];
+ gcRow[id] += r * adj[j]; // atomicAdd? reasons for instabilities?
gsRow[id] += r * adj[j];
- atomicAdd(gVa + id, t * adj[j]);
+ atomics::atomicAdd(gVa + id, t * adj[j]); // @TODO: get rid of atomicAdd via Matmul
}
}
}
@@ -1422,41 +1834,60 @@ void AttBack(Tensor gVa,
int blocks = std::min(MAX_BLOCKS, (int)n);
int threads = std::min(MAX_THREADS, (int)k);
- gAttBack<<<blocks, threads>>>(gVa->data(),
- gContext->data(),
- gState->data(),
-
- va->data(),
- context->data(),
- state->data(),
-
- adj->data(),
- m,
- k,
- n);
+ if(gVa->type() == Type::float32) {
+ gAttBack<<<blocks, threads>>>(gVa->data<float>(),
+ gContext->data<float>(),
+ gState->data<float>(),
+ va->data<float>(),
+ context->data<float>(),
+ state->data<float>(),
+ adj->data<float>(),
+ m,
+ k,
+ n);
+#if COMPILE_FP16
+ } else if (gVa->type() == Type::float16) {
+ gAttBack<<<blocks, threads>>>(gVa->data<half>(),
+ gContext->data<half>(),
+ gState->data<half>(),
+ va->data<half>(),
+ context->data<half>(),
+ state->data<half>(),
+ adj->data<half>(),
+ m,
+ k,
+ n);
+#endif
+ } else {
+ ABORT("gAttBack not implemented for type {}", gVa->type());
+ }
}
-__global__ void gLNormalization(float* out,
- const float* in,
- const float* alpha,
- const float* beta,
+template <typename T, typename AccType = float>
+__global__ void gLNormalization(T* out,
+ const T* in,
+ const T* gamma,
+ const T* beta,
int rows,
int cols,
- float eps = 1e-9) {
- extern __shared__ float _share[];
+ AccType eps = 1e-9) {
+ extern __shared__ uint8_t _sharedBytes[];
+ AccType* _shareAccType = (AccType*)_sharedBytes;
+
+ AccType N = cols;
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float* so = out + j * cols;
- const float* sp = in + j * cols;
+ T* yRow = out + j * cols;
+ const T* xRow = in + j * cols;
- float* _sum = _share + blockDim.x;
- _sum[threadIdx.x] = 0.0f;
+ AccType* _sum = _shareAccType; // accumulate into floats
+ _sum[threadIdx.x] = (AccType)0.0f;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- _sum[threadIdx.x] += sp[id];
+ _sum[threadIdx.x] += (AccType)xRow[id];
}
}
__syncthreads();
@@ -1470,16 +1901,17 @@ __global__ void gLNormalization(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
- float mean = _sum[0] / cols;
+ AccType mean = _sum[0] / N;
__syncthreads();
- float* _sqSum = _share + blockDim.x;
+ AccType* _sqSum = _shareAccType;
- _sqSum[threadIdx.x] = 0.0;
+ _sqSum[threadIdx.x] = (AccType)0.0f;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float ex = sp[id] - mean;
+ AccType xv = (AccType)xRow[id];
+ AccType ex = xv - mean;
_sqSum[threadIdx.x] += ex * ex;
}
}
@@ -1493,19 +1925,22 @@ __global__ void gLNormalization(float* out,
len = (len + 1) >> 1;
}
__syncthreads();
- float sigma = sqrtf(eps + (_sqSum[0] / cols));
+ AccType sigma = functional::Ops<AccType>::sqrt(_sqSum[0] / N + eps); // all AccType
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float t = alpha[id] * ((sp[id] - mean) / sigma);
- if(beta != nullptr)
- t += beta[id];
- so[id] = t;
+ AccType gammav = (AccType)gamma[id];
+ AccType xv = (AccType)xRow[id];
+ AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
+ AccType lv = (xv - mean) / sigma;
+ AccType y = gammav * lv + betav;
+ yRow[id] = (T)y;
}
}
}
+ __syncthreads();
}
}
@@ -1521,55 +1956,77 @@ void LayerNormalization(Tensor out,
int blocks = std::min(MAX_BLOCKS, (int)rows);
int threads = std::min(MAX_THREADS, (int)cols);
- int shared = 2 * threads * sizeof(float);
-
- gLNormalization<<<blocks, threads, shared>>>(out->data(),
- in->data(),
- gamma->data(),
- beta ? beta->data() : nullptr,
- rows,
- cols,
- eps);
-}
-
-__global__ void gLayerNormalizationGrad(float* gradX,
- float* gradGamma,
- float* gradBeta,
- float* adj,
- float* y,
- float* x,
- float* gamma,
- float* beta,
+ int shared = threads * sizeof(float);
+
+ if(out->type() == Type::float32) {
+ gLNormalization<float, float><<<blocks, threads, shared>>>(out->data<float>(),
+ in->data<float>(),
+ gamma->data<float>(),
+ beta ? beta->data<float>() : nullptr,
+ rows,
+ cols,
+ eps);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gLNormalization<half, float><<<blocks, threads, shared>>>(out->data<half>(),
+ in->data<half>(),
+ gamma->data<half>(),
+ beta ? beta->data<half>() : nullptr,
+ rows,
+ cols,
+ eps);
+#endif
+ } else {
+ ABORT("LayerNormalization not implemented for type {}", out->type());
+ }
+}
+
+template <typename T, typename AccType = float>
+__global__ void gLayerNormalizationGrad(T* gradX,
+ T* gradGamma,
+ T* adj,
+ T* y,
+ T* x,
+ T* gamma,
+ T* beta,
int rows,
int cols,
- float eps = 1e-9) {
- extern __shared__ float shared[];
+ AccType eps = 1e-9) {
+ extern __shared__ uint8_t sharedBytes[];
+ AccType* shared = (AccType*)sharedBytes;
+
+ AccType N = cols;
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float* sum_adj = shared;
- float* sum_adj_x = shared + blockDim.x;
- float* sum_x = shared + 2 * blockDim.x;
- float* sum_sqr = shared + 3 * blockDim.x;
+ AccType* sum_adj = shared; // sum of gradient coming in
+ AccType* sum_adj_l = shared + blockDim.x; // sum of gradient coming in times layerNorm from value
+ AccType* sum_x = shared + 2 * blockDim.x; // sum of input value x
+ AccType* sum_sqr = shared + 3 * blockDim.x; // sum of (x - mean)^2
- const float* xRow = x + j * cols;
- const float* yRow = y + j * cols;
- const float* adjRow = adj + j * cols;
- float* gradXRow = gradX + j * cols;
+ const T* xRow = x + j * cols;
+ const T* yRow = y + j * cols;
+ const T* adjRow = adj + j * cols;
- sum_x[threadIdx.x] = 0.0f;
- sum_adj[threadIdx.x] = 0.0f;
- sum_adj_x[threadIdx.x] = 0.0f;
- sum_sqr[threadIdx.x] = 0.0f;
+ sum_x[threadIdx.x] = (AccType)0.0f;
+ sum_adj[threadIdx.x] = (AccType)0.0f;
+ sum_adj_l[threadIdx.x] = (AccType)0.0f;
+ sum_sqr[threadIdx.x] = (AccType)0.0f;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- sum_x[threadIdx.x] += xRow[id];
- sum_adj_x[threadIdx.x]
- += adjRow[id] * (yRow[id] - ((beta) ? beta[id] : 0)) / gamma[id];
- sum_adj[threadIdx.x] += adjRow[id];
+ AccType xv = xRow[id];
+ AccType yv = yRow[id];
+ AccType betav = beta ? (AccType)beta[id] : (AccType)0.f;
+ AccType gammav = (AccType)gamma[id];
+ AccType adjv = adjRow[id];
+ AccType lv = (yv - betav) / gammav; // go back to LN(x) from scaled and shifted version for accumulation
+
+ sum_x[threadIdx.x] += xv;
+ sum_adj_l[threadIdx.x] += adjv * lv;
+ sum_adj[threadIdx.x] += adjv;
}
}
__syncthreads();
@@ -1578,20 +2035,21 @@ __global__ void gLayerNormalizationGrad(float* gradX,
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1)) {
- sum_x[threadIdx.x] += sum_x[threadIdx.x + skip];
- sum_adj[threadIdx.x] += sum_adj[threadIdx.x + skip];
- sum_adj_x[threadIdx.x] += sum_adj_x[threadIdx.x + skip];
+ sum_x[threadIdx.x] += sum_x[threadIdx.x + skip]; // Accumulates in AccType
+ sum_adj[threadIdx.x] += sum_adj[threadIdx.x + skip]; // Accumulates in AccType
+ sum_adj_l[threadIdx.x] += sum_adj_l[threadIdx.x + skip]; // Accumulates in AccType
}
len = (len + 1) >> 1;
}
__syncthreads();
- float mean = sum_x[0] / cols;
+ AccType mean = sum_x[0] / N;
__syncthreads();
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float ex = xRow[id] - mean;
+ AccType xv = xRow[id];
+ AccType ex = xv - mean;
sum_sqr[threadIdx.x] += ex * ex;
}
}
@@ -1602,39 +2060,53 @@ __global__ void gLayerNormalizationGrad(float* gradX,
__syncthreads();
int skip = (len + 1) >> 1;
if(threadIdx.x < (len >> 1))
- sum_sqr[threadIdx.x] += sum_sqr[threadIdx.x + skip];
+ sum_sqr[threadIdx.x] += sum_sqr[threadIdx.x + skip]; // Accumulates in AccType
len = (len + 1) >> 1;
}
__syncthreads();
- float sigma = sqrtf(eps + (sum_sqr[0] / cols));
+ AccType sigma = functional::Ops<AccType>::sqrt(sum_sqr[0] / N + eps);
__syncthreads();
+ // Jacobian of layer norm
+ // J = [ \frac{1}{N\sigma} (N\delta_{ij} - l_i l_j - 1) ]_{ij}
+ // J * a = dC/dx_i = ( N a_i - l_i \sum_j l_j a_j - \sum_j a_j ) / (N \sigma)
+
for(int tid = 0; tid < cols; tid += blockDim.x) {
int id = tid + threadIdx.x;
if(id < cols) {
- float grad_x = 0.0f;
- float x_hat = (yRow[id] - ((beta) ? beta[id] : 0)) / gamma[id];
- grad_x += cols * adjRow[id];
- grad_x -= sum_adj[0];
- grad_x -= sum_adj_x[0] * x_hat;
- grad_x /= (cols * sigma);
-
- float valX = gamma[id] * grad_x;
- float sign = (0.f < valX) - (valX < 0.f);
- valX = fabs(valX) > 1000 ? sign * 1000 : valX;
-
- gradXRow[id] += valX;
- atomicAdd(gradGamma + id, adjRow[id] * x_hat);
- if(beta) {
- atomicAdd(gradBeta + id, adjRow[id]);
- }
+
+ AccType xv = xRow[id];
+ AccType gammav = (AccType)gamma[id];
+ AccType adjv = adjRow[id];
+ AccType lv = (xv - mean) / sigma;
+
+ AccType gradLv = N * adjv - lv * sum_adj_l[0] - sum_adj[0];
+ gradLv /= N * sigma;
+
+ AccType gradXv = gammav * gradLv;
+
+ // Keep LN gradient between [-1000, 1000] for TensorOps, this currently used for making values fit into fp16. @TODO: to be fixed and removed.
+ AccType sign = functional::Ops<AccType>::sgn(gradXv);
+ AccType cutoff = (AccType)1000.f; // @TODO: expose this somehow as an option?
+ // or better: make obsolete.
+ gradXv = functional::Ops<AccType>::abs(gradXv) > cutoff ? sign * cutoff : gradXv;
+
+ T* gradXRow = gradX + j * cols;
+ gradXRow[id] += (T)(gradXv);
+
+ T* gradGammaRow = gradGamma + j * cols;
+ // assignment is correct here as this gets summed up
+ // in the next kernel via matrix product
+ gradGammaRow[id] = (T)(adjv * lv);
}
}
}
+ __syncthreads();
}
}
-void LayerNormalizationGrad(Tensor gradX,
+void LayerNormalizationGrad(Ptr<Allocator> allocator,
+ Tensor gradX,
Tensor gradGamma,
Tensor gradBeta,
Tensor adj,
@@ -1649,25 +2121,62 @@ void LayerNormalizationGrad(Tensor gradX,
int threads = std::min(MAX_THREADS, cols);
int blocks = std::min(MAX_BLOCKS, rows);
- int shared = sizeof(float) * threads * 4;
-
- gLayerNormalizationGrad<<<blocks, threads, shared>>>(
- gradX->data(),
- gradGamma->data(),
- (gradBeta) ? gradBeta->data() : nullptr,
- adj->data(),
- y->data(),
- x->data(),
- gamma->data(),
- (beta) ? beta->data() : nullptr,
+
+ auto tempGradGammaMemory = allocator->alloc(adj->memory()->size());
+ Tensor tempGradGamma = TensorBase::New(tempGradGammaMemory, adj->shape(), adj->type(), adj->getBackend());
+ tempGradGamma->set(0.f);
+
+ auto tempOnesMemory = allocator->alloc(rows * sizeOf(adj->type()));
+ Tensor tempOnes = TensorBase::New(tempOnesMemory, Shape({1, rows}), adj->type(), adj->getBackend());
+ tempOnes->set(1.f);
+
+ if(gradX->type() == Type::float32) {
+ int shared = sizeof(float) * threads * 4;
+ gLayerNormalizationGrad<float, float><<<blocks, threads, shared>>>(
+ gradX->data<float>(),
+ tempGradGamma->data<float>(),
+ adj->data<float>(),
+ y->data<float>(),
+ x->data<float>(),
+ gamma->data<float>(),
+ (beta) ? beta->data<float>() : nullptr,
rows,
cols,
eps);
+#if COMPILE_FP16
+ } else if (gradX->type() == Type::float16) {
+ // accumulate in float
+ int shared = sizeof(float) * threads * 4;
+ gLayerNormalizationGrad<half, float><<<blocks, threads, shared>>>(
+ gradX->data<half>(),
+ tempGradGamma->data<half>(),
+ adj->data<half>(),
+ y->data<half>(),
+ x->data<half>(),
+ gamma->data<half>(),
+ (beta) ? beta->data<half>() : nullptr,
+ rows,
+ cols,
+ eps);
+#endif
+ } else {
+ ABORT("LayerNormalizationGrad not implemented for type {}", gradX->type());
+ }
+
+ // We use this go get rid of the atomicAdd and perform a reduce of the gradients afterwards.
+ // This is much faster for fp16 which seems to have a broken atomicAdd implementation
+ gpu::Prod(gradGamma, tempOnes, tempGradGamma, false, false, 1, 1); // beta set to one to add
+
+ if(gradBeta) // dC/dbeta = adj - inverse broadcasting (reduction)
+ gpu::Prod(gradBeta, tempOnes, adj, false, false, 1, 1); // beta set to one to add
+
+ allocator->free(tempGradGammaMemory);
+ allocator->free(tempOnesMemory);
}
-template <bool add>
-__global__ void gShift(float* out,
- const float* in,
+template <bool add, typename T>
+__global__ void gShift(T* out,
+ const T* in,
int length,
int offset,
float padValue) {
@@ -1679,7 +2188,7 @@ __global__ void gShift(float* out,
out[index] += in[index - offset];
} else {
if(index - offset < 0 || index - offset >= length)
- out[index] = padValue;
+ out[index] = (T)padValue;
else
out[index] = in[index - offset];
}
@@ -1710,8 +2219,17 @@ void Shift(Tensor out,
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
- gShift<false>
- <<<blocks, threads>>>(out->data(), in->data(), length, offset, padValue);
+ if(out->type() == Type::float32) {
+ gShift<false>
+ <<<blocks, threads>>>(out->data<float>(), in->data<float>(), length, offset, padValue);
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gShift<false>
+ <<<blocks, threads>>>(out->data<half>(), in->data<half>(), length, offset, padValue);
+#endif
+ } else {
+ ABORT("Shift not implemented for type {}", out->type());
+ }
}
void ShiftGrad(Tensor out, Tensor in, marian::Shape shift, bool invert) {
@@ -1733,8 +2251,17 @@ void ShiftGrad(Tensor out, Tensor in, marian::Shape shift, bool invert) {
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
- gShift<true>
- <<<blocks, threads>>>(out->data(), in->data(), length, offset, 0.f);
+ if(out->type() == Type::float32) {
+ gShift<true>
+ <<<blocks, threads>>>(out->data<float>(), in->data<float>(), length, offset, 0.f); // @TODO: What about padValue?
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gShift<true>
+ <<<blocks, threads>>>(out->data<half>(), in->data<half>(), length, offset, 0.f);
+#endif
+ } else {
+ ABORT("Shift not implemented for type {}", out->type());
+ }
}
__global__ void gSetSparse(float* out,
@@ -1777,38 +2304,39 @@ void SetSparse(float* out,
/******************************************************************************/
-__global__ void gLSTMCellForward(float* out,
- const float* cell,
- const float* xW,
- const float* sU,
- const float* b,
- const float* mask,
+template <typename T>
+__global__ void gLSTMCellForward(T* out,
+ const T* cell,
+ const T* xW,
+ const T* sU,
+ const T* b,
+ const T* mask,
size_t rows,
size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float m = !mask || mask[j];
+ T m = !mask || mask[j];
- float* rowOut = out + j * cols;
- const float* rowCell = cell + j * cols;
+ T* rowOut = out + j * cols;
+ const T* rowCell = cell + j * cols;
- const float* xWrow = xW + j * cols * 4;
- const float* sUrow = sU + j * cols * 4;
+ const T* xWrow = xW + j * cols * 4;
+ const T* sUrow = sU + j * cols * 4;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) {
- float gf = stableSigmoid(xWrow[i] + sUrow[i] + b[i]);
+ T gf = functional::Ops<T>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
int k = i + cols;
- float gi = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ T gi = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
int l = i + 2 * cols;
- float gc = tanhf(xWrow[l] + sUrow[l] + b[l]);
+ T gc = functional::Ops<T>::tanh(xWrow[l] + sUrow[l] + b[l]);
- float cout = gf * rowCell[i] + gi * gc;
- rowOut[i] = m * cout + (1 - m) * rowCell[i];
+ T cout = gf * rowCell[i] + gi * gc;
+ rowOut[i] = m * cout + ((T)1.f - m) * rowCell[i];
}
}
}
@@ -1824,40 +2352,56 @@ void LSTMCellForward(Tensor out, std::vector<Tensor> inputs) {
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols);
- gLSTMCellForward<<<blocks, threads>>>(
- out->data(), // output
- inputs[0]->data(), // cell state
- inputs[1]->data(), // xW
- inputs[2]->data(), // sU
- inputs[3]->data(), // b
- inputs.size() > 4 ? inputs[4]->data() : 0, // mask
+ if(out->type() == Type::float32) {
+ gLSTMCellForward<<<blocks, threads>>>(
+ out->data<float>(), // output
+ inputs[0]->data<float>(), // cell state
+ inputs[1]->data<float>(), // xW
+ inputs[2]->data<float>(), // sU
+ inputs[3]->data<float>(), // b
+ inputs.size() > 4 ? inputs[4]->data<float>() : 0, // mask
rows,
cols);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gLSTMCellForward<<<blocks, threads>>>(
+ out->data<half>(), // output
+ inputs[0]->data<half>(), // cell state
+ inputs[1]->data<half>(), // xW
+ inputs[2]->data<half>(), // sU
+ inputs[3]->data<half>(), // b
+ inputs.size() > 4 ? inputs[4]->data<half>() : 0, // mask
+ rows,
+ cols);
+#endif
+ } else {
+ ABORT("LSTMCellForward not implemented for type {}", out->type());
+ }
}
-__global__ void gLSTMOutputForward(float* out,
- const float* cell,
- const float* xW,
- const float* sU,
- const float* b,
+template <typename T>
+__global__ void gLSTMOutputForward(T* out,
+ const T* cell,
+ const T* xW,
+ const T* sU,
+ const T* b,
size_t rows,
size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float* rowOut = out + j * cols;
- const float* rowCell = cell + j * cols;
+ T* rowOut = out + j * cols;
+ const T* rowCell = cell + j * cols;
- const float* xWrow = xW + j * cols * 4;
- const float* sUrow = sU + j * cols * 4;
+ const T* xWrow = xW + j * cols * 4;
+ const T* sUrow = sU + j * cols * 4;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) {
int k = i + 3 * cols;
- float go = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
-
- rowOut[i] = go * tanhf(rowCell[i]);
+ T go = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
+ rowOut[i] = go * functional::Ops<T>::tanh(rowCell[i]);
}
}
}
@@ -1873,85 +2417,100 @@ void LSTMOutputForward(Tensor out, std::vector<Tensor> inputs) {
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols);
- gLSTMOutputForward<<<blocks, threads>>>(out->data(), // output
- inputs[0]->data(), // cell state
- inputs[1]->data(), // xW
- inputs[2]->data(), // sU
- inputs[3]->data(), // b
- rows,
- cols);
-}
-
-__global__ void gLSTMCellBackward(float* outCell,
- float* outXW,
- float* outSU,
- float* outB,
- const float* cell,
- const float* xW,
- const float* sU,
- const float* b,
- const float* mask,
- const float* adj,
+ if(out->type() == Type::float32) {
+ gLSTMOutputForward<<<blocks, threads>>>(out->data<float>(), // output
+ inputs[0]->data<float>(), // cell state
+ inputs[1]->data<float>(), // xW
+ inputs[2]->data<float>(), // sU
+ inputs[3]->data<float>(), // b
+ rows,
+ cols);
+#if COMPILE_FP16
+ } else if (out->type() == Type::float16) {
+ gLSTMOutputForward<<<blocks, threads>>>(out->data<half>(), // output
+ inputs[0]->data<half>(), // cell state
+ inputs[1]->data<half>(), // xW
+ inputs[2]->data<half>(), // sU
+ inputs[3]->data<half>(), // b
+ rows,
+ cols);
+#endif
+ } else {
+ ABORT("gLSTMOutputForward not implemented for type {}", out->type());
+ }
+}
+
+template <typename T>
+__global__ void gLSTMCellBackward(T* outCell,
+ T* outXW,
+ T* outSU,
+ T* outB,
+ const T* cell,
+ const T* xW,
+ const T* sU,
+ const T* b,
+ const T* mask,
+ const T* adj,
size_t rows,
size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float m = !mask || mask[j];
+ T m = !mask || mask[j];
- float* rowOutCell = outCell + j * cols;
- float* rowOutXW = outXW + j * cols * 4;
- float* rowOutSU = outSU + j * cols * 4;
+ T* rowOutCell = outCell + j * cols;
+ T* rowOutXW = outXW + j * cols * 4;
+ T* rowOutSU = outSU + j * cols * 4;
- const float* rowCell = cell + j * cols;
- const float* xWrow = xW + j * cols * 4;
- const float* sUrow = sU + j * cols * 4;
+ const T* rowCell = cell + j * cols;
+ const T* xWrow = xW + j * cols * 4;
+ const T* sUrow = sU + j * cols * 4;
- const float* rowAdj = adj + j * cols;
+ const T* rowAdj = adj + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) {
- float gf = stableSigmoid(xWrow[i] + sUrow[i] + b[i]);
+ T gf = functional::Ops<T>::sigmoid(xWrow[i] + sUrow[i] + b[i]);
int k = i + cols;
- float gi = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ T gi = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
int l = i + 2 * cols;
- float gc = tanhf(xWrow[l] + sUrow[l] + b[l]);
+ T gc = functional::Ops<T>::tanh(xWrow[l] + sUrow[l] + b[l]);
- float adj = rowAdj[i];
+ T adj = rowAdj[i];
// dc/dc_{t-1}
if(outCell)
- rowOutCell[i] += (m * gf - m + 1) * adj;
+ rowOutCell[i] += (m * gf - m + (T)1.f) * adj;
// dc/d(b_f) = dc/d(xW_f) ...
- float dcdxf = m * rowCell[i] * gf * (1 - gf) * adj;
+ T dcdxf = m * rowCell[i] * gf * ((T)1.f - gf) * adj;
if(outXW)
rowOutXW[i] += dcdxf;
if(outSU)
rowOutSU[i] += dcdxf;
if(outB)
- atomicAdd(outB + i, dcdxf);
+ atomics::atomicAdd(outB + i, dcdxf); // @TODO: get rid of atomicAdd everywhere!
// dc/d(b_i) ...
- float dcdb_i = m * gc * gi * (1 - gi) * adj;
+ T dcdb_i = m * gc * gi * ((T)1.f - gi) * adj;
if(outXW)
rowOutXW[k] += dcdb_i;
if(outSU)
rowOutSU[k] += dcdb_i;
if(outB)
- atomicAdd(outB + k, dcdb_i);
+ atomics::atomicAdd(outB + k, dcdb_i);
// dc/d(b_c) ...
- float dcdxc = m * gi * (1 - gc * gc) * adj;
+ T dcdxc = m * gi * ((T)1.f - gc * gc) * adj;
if(outXW)
rowOutXW[l] += dcdxc;
if(outSU)
rowOutSU[l] += dcdxc;
if(outB)
- atomicAdd(outB + l, dcdxc);
+ atomics::atomicAdd(outB + l, dcdxc);
}
}
}
@@ -1969,67 +2528,89 @@ void LSTMCellBackward(std::vector<Tensor> outputs,
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols);
- gLSTMCellBackward<<<blocks, threads>>>(
- outputs[0] ? outputs[0]->data() : 0, // state - adj
- outputs[1] ? outputs[1]->data() : 0, // xW - adj
- outputs[2] ? outputs[2]->data() : 0, // sU - adj
- outputs[3] ? outputs[3]->data() : 0, // b - adj
- inputs[0]->data(), // state
- inputs[1]->data(), // xW
- inputs[2]->data(), // sU
- inputs[3]->data(), // b
- inputs.size() > 4 ? inputs[4]->data() : 0, // mask
- adj->data(),
+ if(adj->type() == Type::float32) {
+ gLSTMCellBackward<<<blocks, threads>>>(
+ outputs[0] ? outputs[0]->data<float>() : 0, // state - adj
+ outputs[1] ? outputs[1]->data<float>() : 0, // xW - adj
+ outputs[2] ? outputs[2]->data<float>() : 0, // sU - adj
+ outputs[3] ? outputs[3]->data<float>() : 0, // b - adj
+ inputs[0]->data<float>(), // state
+ inputs[1]->data<float>(), // xW
+ inputs[2]->data<float>(), // sU
+ inputs[3]->data<float>(), // b
+ inputs.size() > 4 ? inputs[4]->data<float>() : 0, // mask
+ adj->data<float>(),
rows,
cols);
+#if COMPILE_FP16
+ } else if (adj->type() == Type::float16) {
+ gLSTMCellBackward<<<blocks, threads>>>(
+ outputs[0] ? outputs[0]->data<half>() : 0, // state - adj
+ outputs[1] ? outputs[1]->data<half>() : 0, // xW - adj
+ outputs[2] ? outputs[2]->data<half>() : 0, // sU - adj
+ outputs[3] ? outputs[3]->data<half>() : 0, // b - adj
+ inputs[0]->data<half>(), // state
+ inputs[1]->data<half>(), // xW
+ inputs[2]->data<half>(), // sU
+ inputs[3]->data<half>(), // b
+ inputs.size() > 4 ? inputs[4]->data<half>() : 0, // mask
+ adj->data<half>(),
+ rows,
+ cols);
+#endif
+ } else {
+ ABORT("gLSTMCellBackward not implemented for type {}", adj->type());
+ }
+
}
-__global__ void gLSTMOutputBackward(float* outCell,
- float* outXW,
- float* outSU,
- float* outB,
- const float* cell,
- const float* xW,
- const float* sU,
- const float* b,
- const float* adj,
+template <typename T>
+__global__ void gLSTMOutputBackward(T* outCell,
+ T* outXW,
+ T* outSU,
+ T* outB,
+ const T* cell,
+ const T* xW,
+ const T* sU,
+ const T* b,
+ const T* adj,
size_t rows,
size_t cols) {
for(int bid = 0; bid < rows; bid += gridDim.x) {
int j = bid + blockIdx.x;
if(j < rows) {
- float* rowOutCell = outCell + j * cols;
- float* rowOutXW = outXW + j * cols * 4;
- float* rowOutSU = outSU + j * cols * 4;
+ T* rowOutCell = outCell + j * cols;
+ T* rowOutXW = outXW + j * cols * 4;
+ T* rowOutSU = outSU + j * cols * 4;
- const float* rowCell = cell + j * cols;
- const float* xWrow = xW + j * cols * 4;
- const float* sUrow = sU + j * cols * 4;
+ const T* rowCell = cell + j * cols;
+ const T* xWrow = xW + j * cols * 4;
+ const T* sUrow = sU + j * cols * 4;
- const float* rowAdj = adj + j * cols;
+ const T* rowAdj = adj + j * cols;
for(int tid = 0; tid < cols; tid += blockDim.x) {
int i = tid + threadIdx.x;
if(i < cols) {
int k = i + 3 * cols;
- float go = stableSigmoid(xWrow[k] + sUrow[k] + b[k]);
+ T go = functional::Ops<T>::sigmoid(xWrow[k] + sUrow[k] + b[k]);
- float t = tanhf(rowCell[i]);
+ T t = functional::Ops<T>::tanh(rowCell[i]);
- float adj = rowAdj[i];
+ T adj = rowAdj[i];
// dc/dc_{t-1}
if(outCell)
- rowOutCell[i] += go * (1 - t * t) * adj;
+ rowOutCell[i] += go * ((T)1.f - t * t) * adj;
// dc/d(b_o) = dc/d(xW_f) ...
- float dcdxo = t * go * (1 - go) * adj;
+ float dcdxo = t * go * ((T)1.f - go) * adj;
if(outXW)
rowOutXW[k] += dcdxo;
if(outSU)
rowOutSU[k] += dcdxo;
if(outB)
- atomicAdd(outB + k, dcdxo);
+ atomics::atomicAdd(outB + k, dcdxo); // @TODO: get rid of atomicAdd
}
}
}
@@ -2047,30 +2628,50 @@ void LSTMOutputBackward(std::vector<Tensor> outputs,
int blocks = std::min(MAX_BLOCKS, rows);
int threads = std::min(MAX_THREADS, cols);
- gLSTMOutputBackward<<<blocks, threads>>>(
- outputs[0] ? outputs[0]->data() : 0, // state - adj
- outputs[1] ? outputs[1]->data() : 0, // xW - adj
- outputs[2] ? outputs[2]->data() : 0, // sU - adj
- outputs[3] ? outputs[3]->data() : 0, // b - adj
- inputs[0]->data(), // state
- inputs[1]->data(), // xW
- inputs[2]->data(), // sU
- inputs[3]->data(), // b
- adj->data(),
- rows,
- cols);
+ if(adj->type() == Type::float32) {
+ gLSTMOutputBackward<<<blocks, threads>>>(
+ outputs[0] ? outputs[0]->data<float>() : 0, // state - adj
+ outputs[1] ? outputs[1]->data<float>() : 0, // xW - adj
+ outputs[2] ? outputs[2]->data<float>() : 0, // sU - adj
+ outputs[3] ? outputs[3]->data<float>() : 0, // b - adj
+ inputs[0]->data<float>(), // state
+ inputs[1]->data<float>(), // xW
+ inputs[2]->data<float>(), // sU
+ inputs[3]->data<float>(), // b
+ adj->data<float>(),
+ rows,
+ cols);
+#if COMPILE_FP16
+ } else if (adj->type() == Type::float16) {
+ gLSTMOutputBackward<<<blocks, threads>>>(
+ outputs[0] ? outputs[0]->data<half>() : 0, // state - adj
+ outputs[1] ? outputs[1]->data<half>() : 0, // xW - adj
+ outputs[2] ? outputs[2]->data<half>() : 0, // sU - adj
+ outputs[3] ? outputs[3]->data<half>() : 0, // b - adj
+ inputs[0]->data<half>(), // state
+ inputs[1]->data<half>(), // xW
+ inputs[2]->data<half>(), // sU
+ inputs[3]->data<half>(), // b
+ adj->data<half>(),
+ rows,
+ cols);
+#endif
+ } else {
+ ABORT("gLSTMOutputBackward not implemented for type {}", adj->type());
+ }
}
-__global__ void gHighwayForward(float* out,
- const float* in1,
- const float* in2,
- const float* t,
+template <typename T>
+__global__ void gHighwayForward(T* out,
+ const T* in1,
+ const T* in2,
+ const T* t,
size_t length) {
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
- float sigma = stableSigmoid(t[index]);
- out[index] = in1[index] * sigma + in2[index] * (1.f - sigma);
+ T sigma = functional::Ops<T>::sigmoid(t[index]);
+ out[index] = in1[index] * sigma + in2[index] * ((T)1.f - sigma);
}
}
}
@@ -2086,26 +2687,36 @@ void HighwayForward(Tensor out,
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
- gHighwayForward<<<blocks, threads>>>(
- out->data(), in1->data(), in2->data(), t->data(), length);
+ if(out->type() == Type::float32) {
+ gHighwayForward<<<blocks, threads>>>(
+ out->data<float>(), in1->data<float>(), in2->data<float>(), t->data<float>(), length);
+#if COMPILE_FP16
+ } else if(out->type() == Type::float16) {
+ gHighwayForward<<<blocks, threads>>>(
+ out->data<half>(), in1->data<half>(), in2->data<half>(), t->data<half>(), length);
+#endif
+ } else {
+ ABORT("HighwayForward not implemented for type {}", out->type());
+ }
}
-__global__ void gHighwayBackward(float* out1,
- float* out2,
- float* outt,
- const float* in1,
- const float* in2,
- const float* t,
- const float* adj,
+template <typename T>
+__global__ void gHighwayBackward(T* out1,
+ T* out2,
+ T* outt,
+ const T* in1,
+ const T* in2,
+ const T* t,
+ const T* adj,
size_t length) {
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
- float sigma = stableSigmoid(t[index]);
+ T sigma = functional::Ops<T>::sigmoid(t[index]);
out1[index] = sigma * adj[index];
- out2[index] = (1.f - sigma) * adj[index];
+ out2[index] = ((T)1.f - sigma) * adj[index];
outt[index]
- = sigma * (1.f - sigma) * (in1[index] - in2[index]) * adj[index];
+ = sigma * ((T)1.f - sigma) * (in1[index] - in2[index]) * adj[index];
}
}
}
@@ -2124,14 +2735,29 @@ void HighwayBackward(Tensor out1,
int threads = std::min(MAX_THREADS, length);
int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
- gHighwayBackward<<<blocks, threads>>>(out1->data(),
- out2->data(),
- outt->data(),
- in1->data(),
- in2->data(),
- t->data(),
- adj->data(),
- length);
+ if(out1->type() == Type::float32) {
+ gHighwayBackward<<<blocks, threads>>>(out1->data<float>(),
+ out2->data<float>(),
+ outt->data<float>(),
+ in1->data<float>(),
+ in2->data<float>(),
+ t->data<float>(),
+ adj->data<float>(),
+ length);
+#if COMPILE_FP16
+ } else if(out1->type() == Type::float16) {
+ gHighwayBackward<<<blocks, threads>>>(out1->data<half>(),
+ out2->data<half>(),
+ outt->data<half>(),
+ in1->data<half>(),
+ in2->data<half>(),
+ t->data<half>(),
+ adj->data<half>(),
+ length);
+#endif
+ } else {
+ ABORT("HighwayForward not implemented for type {}", out1->type());
+ }
}
__global__ void gMaxPoolingForward(float* out,
@@ -2175,6 +2801,7 @@ void PoolingWithMaskingForward(Tensor out,
Tensor mask,
int width,
bool isEven) {
+ matchOrAbort<float>(out->type());
int n = out->shape().elements();
int threads = std::min(n, MAX_THREADS);
int blocks = n / threads + (n % threads != 0);
@@ -2247,6 +2874,7 @@ void PoolingWithMaskingBackward(Tensor adj,
Tensor mask,
int width,
bool isEven) {
+ matchOrAbort<float>(adj->type());
int n = adj->shape().elements();
int threads = std::min(n, 512);
int blocks = n / threads + (n % threads != 0);
diff --git a/src/tensors/memory_piece.h b/src/tensors/memory_piece.h
index c3435263..cfe82949 100644
--- a/src/tensors/memory_piece.h
+++ b/src/tensors/memory_piece.h
@@ -1,5 +1,7 @@
#pragma once
+#include "common/definitions.h"
+
#include <iostream>
namespace marian {
@@ -9,9 +11,21 @@ private:
uint8_t* data_;
size_t size_;
-public:
+ ENABLE_INTRUSIVE_PTR(MemoryPiece)
+
+ // Contructor is private, use MemoryPiece::New(...)
MemoryPiece(uint8_t* data, size_t size) : data_(data), size_(size) {}
+public:
+ // Use this whenever pointing to MemoryPiece
+ typedef IPtr<MemoryPiece> PtrType;
+
+ // Use this whenever creating a pointer to MemoryPiece
+ template <class ...Args>
+ static PtrType New(Args&& ...args) {
+ return PtrType(new MemoryPiece(std::forward<Args>(args)...));
+ }
+
uint8_t* data() const { return data_; }
uint8_t* data() { return data_; }
@@ -39,5 +53,6 @@ public:
<< " size: " << mp.size();
return out;
}
+
};
} // namespace marian
diff --git a/src/tensors/rand.cpp b/src/tensors/rand.cpp
index 3c7a519b..e6dbc46e 100755
--- a/src/tensors/rand.cpp
+++ b/src/tensors/rand.cpp
@@ -4,7 +4,6 @@
#ifdef CUDA_FOUND
#include "gpu/cuda_helpers.h"
-
#include <curand.h>
#endif
@@ -81,9 +80,10 @@ CurandRandomGenerator::CurandRandomGenerator(size_t seed, DeviceId deviceId)
}
CurandRandomGenerator::~CurandRandomGenerator() {
- if(deviceId_.type == DeviceType::gpu)
- cudaSetDevice((int)deviceId_.no);
- CURAND_CHECK(curandDestroyGenerator(generator_));
+ // No CUDA error checking as this is a destructor and we cannot do anything about errors anyway.
+ if(deviceId_.type == DeviceType::gpu)
+ cudaSetDevice((int)deviceId_.no);
+ curandDestroyGenerator(generator_);
}
void CurandRandomGenerator::uniform(Tensor tensor, float a, float b) {
diff --git a/src/tensors/rand.h b/src/tensors/rand.h
index 568ac932..94b44a97 100644
--- a/src/tensors/rand.h
+++ b/src/tensors/rand.h
@@ -7,7 +7,7 @@
namespace marian {
class TensorBase;
-typedef Ptr<TensorBase> Tensor;
+typedef IPtr<TensorBase> Tensor;
class RandomGenerator {
protected:
@@ -15,11 +15,11 @@ protected:
public:
RandomGenerator(size_t seed) : seed_(seed) { }
-
+ virtual ~RandomGenerator() {}
virtual void uniform(Tensor, float a, float b) = 0;
virtual void normal(Tensor, float mean, float stddev) = 0;
};
Ptr<RandomGenerator> createRandomGenerator(size_t /*seed*/, DeviceId);
-} \ No newline at end of file
+}
diff --git a/src/tensors/tensor.cpp b/src/tensors/tensor.cpp
new file mode 100755
index 00000000..6133732d
--- /dev/null
+++ b/src/tensors/tensor.cpp
@@ -0,0 +1,142 @@
+#include "tensors/tensor.h"
+#include "tensors/tensor_operators.h"
+#include "common/io.h"
+
+namespace marian {
+
+template <typename T>
+std::string TensorBase::debug(int precision, int dispCols) {
+ // values
+ size_t totSize = shape_.elements();
+ std::vector<T> values(totSize);
+
+ get(values);
+
+ std::stringstream strm;
+ assert(shape_.size());
+ strm << shape_;
+ strm << " type=" << type_;
+ strm << " device=" << backend_->getDeviceId();
+ strm << " ptr=" << (size_t)memory_->data();
+ strm << " bytes=" << memory_->size();
+ strm << std::endl;
+
+ int colWidth = precision + 4;
+
+ if(isFloat(type_))
+ strm << std::fixed << std::setprecision(precision) << std::setfill(' ');
+ else
+ strm << std::fixed << std::setprecision(0) << std::setfill(' ');
+
+ double maxv = std::numeric_limits<double>::lowest();
+ double minv = std::numeric_limits<double>::max();
+ double l2Sum = 0.0;
+ for(int i = 0; i < values.size(); ++i) {
+ if((double)values[i] > maxv) maxv = (double)values[i];
+ if((double)values[i] < minv) minv = (double)values[i];
+ l2Sum += (double)values[i] * (double)values[i];
+ }
+ strm << "min: " << minv << " max: " << maxv << " l2-norm: " << sqrt(l2Sum) << std::endl;
+
+ for(int i = 0; i < values.size(); ++i) {
+ std::vector<int> dims;
+ shape().dims(i, dims);
+
+ bool disp = true;
+ for(int j = 0; j < dims.size(); ++j)
+ disp = disp && (dims[j] < dispCols || dims[j] >= shape()[j] - dispCols);
+
+ if(disp) {
+ if(dims.back() == 0) {
+ bool par = true;
+ std::vector<std::string> p;
+ for(int j = (int)dims.size() - 1; j >= 0; --j) {
+ if(dims[j] != 0)
+ par = false;
+
+ p.push_back(par ? "[" : " ");
+ }
+ for(auto it = p.rbegin(); it != p.rend(); ++it)
+ strm << *it;
+ strm << " ";
+ }
+
+ strm << std::setw(colWidth);
+ if(isFloat(type_)) {
+ strm << (double)values[i];
+ } else if(isSignedInt(type_)) {
+ strm << (int64_t)values[i];
+ } else {
+ strm << (uint64_t)values[i];
+ }
+ strm << " ";
+
+ if(dims.back() + 1 == shape().back()) {
+ for(int j = (int)dims.size() - 1; j >= 0; --j) {
+ if(dims[j] + 1 != shape()[j])
+ break;
+ strm << "]";
+ }
+ strm << std::endl;
+ }
+
+ bool prev = true;
+ for(int j = (int)dims.size() - 1; j >= 0; --j) {
+ if(j < (int)dims.size() - 1)
+ prev = prev && dims[j + 1] + 1 == shape()[j + 1];
+ if(prev && dims[j] + 1 == dispCols && shape()[j] > 2 * dispCols) {
+ if(j < (int)dims.size() - 1)
+ for(int k = 0; k <= j; ++k)
+ strm << " ";
+ strm << "... ";
+ if(j < (int)dims.size() - 1)
+ strm << std::endl;
+ break;
+ }
+ }
+ }
+ }
+ strm << std::endl;
+ return strm.str();
+}
+
+template std::string TensorBase::debug<float16>(int, int);
+template std::string TensorBase::debug<float >(int, int);
+template std::string TensorBase::debug<double >(int, int);
+
+template std::string TensorBase::debug<uint8_t >(int, int);
+template std::string TensorBase::debug<uint16_t>(int, int);
+template std::string TensorBase::debug<uint32_t>(int, int);
+template std::string TensorBase::debug<uint64_t>(int, int);
+
+template std::string TensorBase::debug<int8_t >(int, int);
+template std::string TensorBase::debug<int16_t>(int, int);
+template std::string TensorBase::debug<int32_t>(int, int);
+template std::string TensorBase::debug<int64_t>(int, int);
+
+// fill an io::item with data from a Tensor, used for saving
+// and other IO operations.
+void TensorBase::get(io::Item& item, const std::string& name) {
+ item.name = name;
+ item.shape = shape_;
+ item.type = type_;
+
+ item.bytes.resize(memory_->size());
+ copy(backend_,
+ memory_->data<char>(),
+ memory_->data<char>() + memory_->size(),
+ item.bytes.data());
+}
+
+void TensorBase::set(const io::Item& item) {
+ ABORT_IF(item.type != type_, "Tensor type {} and item type {} do not match", type_, item.type);
+ ABORT_IF(item.shape != shape_, "Tensor shape {} and item shape {} do not match", shape_, item.shape);
+ ABORT_IF(item.bytes.size() > memory_->size(), "Item data size {} too large for memory {}", item.bytes.size(), memory_->size());
+ copy(backend_,
+ item.bytes.data(),
+ item.bytes.data() + item.bytes.size(),
+ memory_->data<char>());
+}
+
+} // namespace marian
+
diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h
index acc7e54c..6f4c202a 100755
--- a/src/tensors/tensor.h
+++ b/src/tensors/tensor.h
@@ -17,31 +17,46 @@
namespace marian {
-class TensorBase : public std::enable_shared_from_this<TensorBase> {
-private:
- Ptr<MemoryPiece> memory_;
+namespace io {
+ struct Item;
+}
+
+class TensorBase {
+ MemoryPiece::PtrType memory_;
Shape shape_;
Type type_{Type::float32};
Ptr<Backend> backend_;
-public:
- TensorBase(Ptr<MemoryPiece> memory,
+ ENABLE_INTRUSIVE_PTR(TensorBase)
+
+ // Constructors are private, use TensorBase::New(...)
+ TensorBase(MemoryPiece::PtrType memory,
Shape shape,
Type type,
Ptr<Backend> backend)
: memory_(memory), shape_(shape), type_(type), backend_(backend) {}
- TensorBase(Ptr<MemoryPiece> memory, Shape shape, Ptr<Backend> backend)
+ TensorBase(MemoryPiece::PtrType memory, Shape shape, Ptr<Backend> backend)
: memory_(memory),
shape_(shape),
type_(Type::float32),
backend_(backend) {}
- ~TensorBase() {}
+public:
+ // Use this whenever pointing to MemoryPiece
+ typedef IPtr<TensorBase> PtrType;
+
+ // Use this whenever creating a pointer to MemoryPiece
+ template <class ...Args>
+ static PtrType New(Args&& ...args) {
+ return PtrType(new TensorBase(std::forward<Args>(args)...));
+ }
- virtual void reset(Ptr<MemoryPiece> memory) { memory_ = memory; }
+ virtual ~TensorBase() {}
- virtual Ptr<MemoryPiece> memory() { return memory_; }
+ virtual void reset(MemoryPiece::PtrType memory) { memory_ = memory; }
+
+ virtual MemoryPiece::PtrType memory() { return memory_; }
virtual Type type() { return type_; }
@@ -56,47 +71,44 @@ public:
virtual size_t size() { return shape_.elements(); }
+ // this version of scalar will abort if numeric types do not match
template <typename T>
T scalar() {
- ABORT_IF(!matchType<T>(type_),
- "Requested type ({}) and underlying type ({}) do not match",
- request<T>(),
- type_);
-
ABORT_IF(size() != 1, "Tensor is not a scalar");
return get<T>(0);
}
+ // this non-template version converts all numeric types to float
virtual float scalar() {
- return scalar<float>();
+ DISPATCH_BY_TYPE0(type_, (float)scalar);
}
Ptr<Backend> getBackend() { return backend_; }
DeviceId getDeviceId() { return backend_->getDeviceId(); }
Tensor subtensor(size_t offset, size_t size) {
- auto mem = New<MemoryPiece>(memory_->data() + sizeOf(type_) * offset,
- sizeOf(type_) * size);
- return New<TensorBase>(mem, Shape{1, (int)size}, backend_);
+ auto mem = MemoryPiece::New(memory_->data() + sizeOf(type_) * offset, sizeOf(type_) * size);
+ return TensorBase::New(mem, Shape{1, (int)size}, type(), backend_);
}
+ // @TODO: review if we can eliminate GPU-specific code here,
+ // potentially by moving this to non-class members.
template <typename T>
T get(size_t i) {
- ABORT_IF(!matchType<T>(type_),
- "Requested type ({}) and underlying type ({}) do not match",
- request<T>(),
- type_);
-
- T temp = 0;
- if(backend_->getDeviceId().type == DeviceType::cpu) {
- std::copy(data<T>() + i, data<T>() + i + 1, &temp);
- }
-#ifdef CUDA_FOUND
- else {
- gpu::copy(backend_, data<T>() + i, data<T>() + i + 1, &temp);
+ if(!matchType<T>(type_)) {
+ DISPATCH_BY_TYPE1(type_, (T)get, i);
+ } else {
+ T temp = 0;
+ if(backend_->getDeviceId().type == DeviceType::cpu) {
+ std::copy(data<T>() + i, data<T>() + i + 1, &temp);
+ }
+ #ifdef CUDA_FOUND
+ else {
+ gpu::copy(backend_, data<T>() + i, data<T>() + i + 1, &temp);
+ }
+ #endif
+ return temp;
}
-#endif
- return temp;
}
float get(size_t i) {
@@ -104,46 +116,49 @@ public:
}
template <typename T>
- void set(size_t i, T value) {
+ void get(std::vector<T>& v) {
ABORT_IF(!matchType<T>(type_),
"Requested type ({}) and underlying type ({}) do not match",
request<T>(),
type_);
+ v.resize(size());
if(backend_->getDeviceId().type == DeviceType::cpu) {
- std::copy(&value, &value + 1, data<T>() + i);
+ std::copy(data<T>(), data<T>() + size(), v.data());
}
#ifdef CUDA_FOUND
else {
- gpu::copy(backend_, &value, &value + 1, data<T>() + i);
+ gpu::copy(backend_, data<T>(), data<T>() + size(), v.data());
}
#endif
}
- template <typename T>
- void get(std::vector<T>& v) {
- ABORT_IF(!matchType<T>(type_),
- "Requested type ({}) and underlying type ({}) do not match",
- request<T>(),
- type_);
+ void get(io::Item& item, const std::string& name);
- v.resize(size());
- if(backend_->getDeviceId().type == DeviceType::cpu) {
- std::copy(data<T>(), data<T>() + size(), v.data());
- }
+ template <typename T>
+ void set(size_t i, T value) {
+ if(!matchType<T>(type_)) {
+ DISPATCH_BY_TYPE2(type_, set, i, value);
+ } else {
+ if(backend_->getDeviceId().type == DeviceType::cpu) {
+ std::copy(&value, &value + 1, data<T>() + i);
+ }
#ifdef CUDA_FOUND
- else {
- gpu::copy(backend_, data<T>(), data<T>() + size(), v.data());
- }
+ else {
+ gpu::copy(backend_, &value, &value + 1, data<T>() + i);
+ }
#endif
+ }
}
template <typename T>
void set(const T* begin, const T* end) {
- ABORT_IF(!matchType<T>(type_),
- "Requested type ({}) and underlying type ({}) do not match",
- request<T>(),
- type_);
+ ABORT_IF(end - begin != shape_.elements(),
+ "Vector size ({}) and underlying shape ({}, {}) do not match",
+ end - begin,
+ std::string(shape_),
+ memory_->size());
+ matchOrAbort<T>(type_);
if(backend_->getDeviceId().type == DeviceType::cpu) {
std::copy(begin, end, data<T>());
@@ -160,37 +175,28 @@ public:
set(v.data(), v.data() + v.size());
}
+ void set(const io::Item& item);
+
+ // For single values enable conversion to other numeric formats if possible
template <typename T>
void set(T value) {
if(!matchType<T>(type_)) {
- switch(type_) {
- case Type::float32: set<float >((float )value); break;
- case Type::float64: set<double >((double )value); break;
- case Type::int8: set<int8_t >((int8_t )value); break;
- case Type::int16: set<int16_t >((int16_t )value); break;
- case Type::int32: set<int32_t >((int32_t )value); break;
- case Type::int64: set<int64_t >((int64_t )value); break;
- case Type::uint8: set<uint8_t >((uint8_t )value); break;
- case Type::uint16: set<uint16_t>((uint16_t)value); break;
- case Type::uint32: set<uint32_t>((uint32_t)value); break;
- case Type::uint64: set<uint64_t>((uint64_t)value); break;
- default:
- ABORT(
- "Requested type ({}) cannot be converted to underlying type ({})",
- request<float>(),
- type_);
+ DISPATCH_BY_TYPE1(type_, setAs, value);
+ } else {
+ if(backend_->getDeviceId().type == DeviceType::cpu) {
+ std::fill(data<T>(), data<T>() + size(), value);
}
+ #ifdef CUDA_FOUND
+ else {
+ gpu::fill(backend_, data<T>(), data<T>() + size(), value);
+ }
+ #endif
}
-
- if(backend_->getDeviceId().type == DeviceType::cpu) {
- std::fill(data<T>(), data<T>() + size(), value);
- }
-#ifdef CUDA_FOUND
- else {
- gpu::fill(backend_, data<T>(), data<T>() + size(), value);
- }
-#endif
}
+private: // subroutine for above: helper that accepts any type and casts it to <T>
+ template <typename Tas, typename Tval>
+ void setAs(Tval value) { set((Tas)value); }
+public:
void setSparse(const std::vector<size_t>& k, const std::vector<float>& v) {
ABORT_IF(!matchType<float>(type_),
@@ -230,22 +236,7 @@ public:
}
void copyFrom(Tensor in) {
- switch(type_) {
- case Type::int8: copyFrom<int8_t>(in); break;
- case Type::int16: copyFrom<int16_t>(in); break;
- case Type::int32: copyFrom<int32_t>(in); break;
- case Type::int64: copyFrom<int64_t>(in); break;
-
- case Type::uint8: copyFrom<uint8_t>(in); break;
- case Type::uint16: copyFrom<uint16_t>(in); break;
- case Type::uint32: copyFrom<uint32_t>(in); break;
- case Type::uint64: copyFrom<uint64_t>(in); break;
-
- case Type::float32: copyFrom<float>(in); break;
- case Type::float64: copyFrom<double>(in); break;
-
- default: ABORT("Unknown type {}", type_);
- }
+ DISPATCH_BY_TYPE1(type_, copyFrom, in);
}
// Swaps the contents of the current tensor with the argument tensor
@@ -280,132 +271,30 @@ public:
}
void swap(Tensor swapee) {
- switch(type_) {
- case Type::int8: swap<int8_t>(swapee); break;
- case Type::int16: swap<int16_t>(swapee); break;
- case Type::int32: swap<int32_t>(swapee); break;
- case Type::int64: swap<int64_t>(swapee); break;
-
- case Type::uint8: swap<uint8_t>(swapee); break;
- case Type::uint16: swap<uint16_t>(swapee); break;
- case Type::uint32: swap<uint32_t>(swapee); break;
- case Type::uint64: swap<uint64_t>(swapee); break;
-
- case Type::float32: swap<float>(swapee); break;
- case Type::float64: swap<double>(swapee); break;
-
- default: ABORT("Unknown type {}", type_);
- }
+ DISPATCH_BY_TYPE1(type_, swap, swapee);
}
-
+
template <typename T>
- std::string debug() {
- ABORT_IF(!matchType<T>(type_),
- "Requested type ({}) and underlying type ({}) do not match",
- request<T>(),
- type_);
+ std::string debug(int precision = 8, int dispCols = 5);
- std::stringstream strm;
- assert(shape_.size());
- strm << shape_;
- strm << " type=" << type_;
- strm << " device=" << backend_->getDeviceId();
- strm << " ptr=" << (size_t)memory_->data();
- strm << " bytes=" << memory_->size();
- strm << std::endl;
-
- // values
- size_t totSize = shape_.elements();
- std::vector<T> values(totSize);
- get(values);
-
- int dispCols = 5;
- if(isFloat(type_))
- strm << std::fixed << std::setprecision(8) << std::setfill(' ');
- else
- strm << std::fixed << std::setprecision(0) << std::setfill(' ');
-
- for(int i = 0; i < values.size(); ++i) {
- std::vector<int> dims;
- shape().dims(i, dims);
-
- bool disp = true;
- for(int j = 0; j < dims.size(); ++j)
- disp = disp && (dims[j] < dispCols || dims[j] >= shape()[j] - dispCols);
-
- if(disp) {
- if(dims.back() == 0) {
- bool par = true;
- std::vector<std::string> p;
- for(int j = (int)dims.size() - 1; j >= 0; --j) {
- if(dims[j] != 0)
- par = false;
-
- p.push_back(par ? "[" : " ");
- }
- for(auto it = p.rbegin(); it != p.rend(); ++it)
- strm << *it;
- strm << " ";
- }
-
- strm << std::setw(12);
- if(isFloat(type_)) {
- strm << (double)values[i];
- } else if(isSignedInt(type_)) {
- strm << (int64_t)values[i];
- } else {
- strm << (uint64_t)values[i];
- }
- strm << " ";
-
- if(dims.back() + 1 == shape().back()) {
- for(int j = (int)dims.size() - 1; j >= 0; --j) {
- if(dims[j] + 1 != shape()[j])
- break;
- strm << "]";
- }
- strm << std::endl;
- }
-
- bool prev = true;
- for(int j = (int)dims.size() - 1; j >= 0; --j) {
- if(j < (int)dims.size() - 1)
- prev = prev && dims[j + 1] + 1 == shape()[j + 1];
- if(prev && dims[j] + 1 == dispCols && shape()[j] > 2 * dispCols) {
- if(j < (int)dims.size() - 1)
- for(int k = 0; k <= j; ++k)
- strm << " ";
- strm << "... ";
- if(j < (int)dims.size() - 1)
- strm << std::endl;
- break;
- }
- }
- }
- }
- strm << std::endl;
- return strm.str();
+ std::string debug(int precision = 8, int dispCols = 5) {
+ DISPATCH_BY_TYPE2(type_, debug, precision, dispCols);
}
- std::string debug() {
- switch(type_) {
- case Type::int8: return debug<int8_t>();
- case Type::int16: return debug<int16_t>();
- case Type::int32: return debug<int32_t>();
- case Type::int64: return debug<int64_t>();
-
- case Type::uint8: return debug<uint8_t>();
- case Type::uint16: return debug<uint16_t>();
- case Type::uint32: return debug<uint32_t>();
- case Type::uint64: return debug<uint64_t>();
+};
- case Type::float32: return debug<float>();
- case Type::float64: return debug<double>();
+typedef TensorBase::PtrType Tensor;
- default: ABORT("Unknown type {}", type_);
- }
+template <class TensorType0, class ...TensorTypeRest>
+static inline void checkCommonType(TensorType0 first, TensorTypeRest ...rest) {
+ std::vector<Tensor> vTensors({first, rest...});
+ Type firstType = first->type();
+ for(int i = 1; i < vTensors.size(); ++i) {
+ ABORT_IF(vTensors[i]->type() != firstType,
+ "Type of tensor {} is different from type of tensor 0 ({} != {})",
+ i, vTensors[i]->type(), firstType);
}
-};
+}
-typedef std::shared_ptr<TensorBase> Tensor;
} // namespace marian
+
diff --git a/src/tensors/tensor_allocator.h b/src/tensors/tensor_allocator.h
index 3a2d99e6..e3bc79f9 100755..100644
--- a/src/tensors/tensor_allocator.h
+++ b/src/tensors/tensor_allocator.h
@@ -44,6 +44,13 @@ public:
allocator_->reserve(mult * GROW);
}
+ void reserveExact(const std::vector<size_t>& bytes) {
+ size_t total = 0;
+ for(auto part : bytes)
+ total += allocator_->alignedSize(part);
+ reserveExact(total);
+ }
+
void reserveExact(size_t bytes = 0) {
size_t mbytes = bytes / MBYTE;
if(mbytes == 0) {
@@ -63,26 +70,25 @@ public:
void clear() { allocator_->clear(); }
size_t capacity(Shape shape, Type type = Type::float32) {
- return allocator_->capacity(shape.elements(), type);
+ return allocator_->capacity<char>(requiredBytes(shape, type));
}
- void allocate(Tensor& t, Shape shape, Type type = Type::float32) {
+ void allocate(/*out*/ Tensor& t, Shape shape, Type type = Type::float32) {
if(!t || t->shape() != shape) {
- int size = shape.elements();
- auto mem = allocator_->alloc(size, type);
- t = Tensor(new TensorBase(mem, shape, type, backend_));
+ auto mem = allocator_->alloc(requiredBytes(shape, type));
+ t = Tensor(TensorBase::New(mem, shape, type, backend_));
}
}
- void free(Tensor& t) { allocator_->free(t->memory()); }
+ void free(const Tensor& t) { allocator_->free(t->memory()); }
- Tensor asTensor() {
+ Tensor asTensor(Type type = Type::float32) {
auto mem = allocator_->memory();
- auto size = mem->size() / sizeof(float);
- return Tensor(new TensorBase(mem, {1, (int)size}, backend_));
+ auto size = mem->size() / sizeOf(type);
+ return TensorBase::New(mem, Shape({1, (int)size}), type, backend_);
}
- size_t size() { return allocator_->size() / sizeof(float); }
+ size_t size(Type type = Type::float32) { return allocator_->size() / sizeOf(type); }
Ptr<Allocator> allocator() { return allocator_; }
};
diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h
index 2cac284a..227bc953 100755..100644
--- a/src/tensors/tensor_operators.h
+++ b/src/tensors/tensor_operators.h
@@ -25,7 +25,7 @@
namespace marian {
template <typename InIt, typename OutIt>
-void copy(Ptr<Backend> backend, const InIt beg, const InIt end, OutIt it) {
+void copy(Ptr<Backend>& MAYBE_UNUSED backend, const InIt beg, const InIt end, OutIt it) {
#ifdef CUDA_FOUND
if(backend->getDeviceId().type == DeviceType::gpu)
gpu::copy(backend, beg, end, it);
@@ -34,6 +34,9 @@ void copy(Ptr<Backend> backend, const InIt beg, const InIt end, OutIt it) {
std::copy(beg, end, it);
}
+DISPATCH2(CopyCast, marian::Tensor, const marian::Tensor);
+DISPATCH4(IsNaN, const Tensor, Ptr<Allocator>, bool&, bool&);
+
template <class Functor, class... Tensors>
void Element(Functor functor, marian::Tensor out, Tensors... tensors) {
#ifdef CUDA_FOUND
@@ -51,12 +54,22 @@ void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
gpu::Add(functor, scale, out, tensors...);
else
#endif
- cpu::Add(functor, scale, out, tensors...);
+ cpu::Aggregate(functor, /*aggInit=*/0.0f, functional::_1 + functional::_2, scale, out, tensors...);
}
template <class Functor, class... Tensors>
void Add(Functor functor, marian::Tensor out, Tensors... tensors) {
- Add(functor, 1, out, tensors...);
+ Add(functor, /*scale=*/1.f, out, tensors...);
+}
+
+template <class Functor, class AggFunctor, class... Tensors>
+void Aggregate(Functor functor, float aggInit, AggFunctor aggFunctor, marian::Tensor out, Tensors... tensors) {
+#ifdef CUDA_FOUND
+ if(out->getBackend()->getDeviceId().type == DeviceType::gpu)
+ gpu::Aggregate(functor, aggInit, aggFunctor, 1.0f, out, tensors...);
+ else
+#endif
+ cpu::Aggregate(functor, aggInit, aggFunctor, 1.0f, out, tensors...);
}
template <class Functor, class... Tensors>
@@ -74,9 +87,18 @@ void Reduce(Functor functor, marian::Tensor out, Tensors... tensors) {
Add(functor, out, tensors...);
}
+template <class Functor, class AggFunctor, class... Tensors>
+void Reduce(Functor functor, AggFunctor aggFunctor, float aggInit,
+ marian::Tensor out,
+ Tensors... tensors) {
+ out->set(aggInit);
+ Aggregate(functor, aggInit, aggFunctor, out, tensors...);
+}
+
// clang-format off
DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float)
DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
+DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)
DISPATCH2(Softmax, marian::Tensor, marian::Tensor)
DISPATCH3(SoftmaxGrad, marian::Tensor, marian::Tensor, marian::Tensor)
@@ -97,19 +119,19 @@ DISPATCH3(Concatenate, marian::Tensor, const std::vector<marian::Tensor>&, int)
// clang-format on
-static inline void Bernoulli(Tensor resultTensor, float keepProb, float scale = 1.f) {
+// Bernoulli(tensor, 0.5f, 2.f, -1.f) generates a tensor composed of 50% of 1 and 50% of -1.
+static inline void Bernoulli(Tensor resultTensor, float keepProb, float scale = 1.f, float shift = 0.f) {
// in-place uniform distribution
auto rnd = resultTensor->getBackend()->getRandomGenerator();
rnd->uniform(resultTensor, 0.f, 1.f); // temporarily mis-use this to hold the random numbers
using namespace functional;
- Element(_1 = (_1 < keepProb) * scale, resultTensor);
+ Element(_1 = (_1 < keepProb) * scale + shift, resultTensor);
}
-
static inline void Dropout(Tensor tensor, float dropProb) {
float keepProb = 1.f - dropProb;
float scale = 1.f / keepProb;
- Bernoulli(tensor, keepProb, scale);
+ Bernoulli(tensor, keepProb, scale, /*shift=*/0.f);
}
#ifdef CUDA_FOUND
@@ -139,7 +161,52 @@ static inline void Deconcatenate(std::vector<marian::Tensor>& outputs,
// clang-format off
DISPATCH5(LayerNormalization, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float)
-DISPATCH9(LayerNormalizationGrad, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float)
+
+#ifdef CUDA_FOUND
+namespace gpu {
+void LayerNormalizationGrad(Ptr<Allocator> allocator,
+ Tensor gradX,
+ Tensor gradGamma,
+ Tensor gradBeta,
+ Tensor adj,
+ Tensor y,
+ Tensor x,
+ Tensor gamma,
+ Tensor beta,
+ float eps);
+}
+#endif
+
+namespace cpu {
+void LayerNormalizationGrad(Tensor gradX,
+ Tensor gradGamma,
+ Tensor gradBeta,
+ Tensor adj,
+ Tensor y,
+ Tensor x,
+ Tensor gamma,
+ Tensor beta,
+ float eps);
+}
+
+static inline void LayerNormalizationGrad(
+ Ptr<Allocator> MAYBE_UNUSED allocator,
+ Tensor gradX,
+ Tensor gradGamma,
+ Tensor gradBeta,
+ Tensor adj,
+ Tensor y,
+ Tensor x,
+ Tensor gamma,
+ Tensor beta,
+ float eps) {
+#ifdef CUDA_FOUND
+ if(gradX->getBackend()->getDeviceId().type == DeviceType::gpu)
+ gpu::LayerNormalizationGrad(allocator, gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
+ else
+#endif
+ cpu::LayerNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
+}
DISPATCH4(HighwayForward, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH7(HighwayBackward, marian::Tensor, marian::Tensor, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
@@ -244,21 +311,21 @@ DISPATCH7(AttBack, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tenso
#ifdef CUDA_FOUND
namespace gpu {
-float L2Norm(marian::Tensor in);
+float L2Norm(marian::Tensor in, Ptr<Allocator> allocator);
}
#endif
namespace cpu {
-float L2Norm(marian::Tensor in);
+float L2Norm(marian::Tensor in, Ptr<Allocator> allocator);
}
-static inline float L2Norm(marian::Tensor in) {
+static inline float L2Norm(marian::Tensor in, Ptr<Allocator> allocator) {
#ifdef CUDA_FOUND
if(in->getBackend()->getDeviceId().type == DeviceType::gpu)
- return gpu::L2Norm(in);
+ return gpu::L2Norm(in, allocator);
else
#endif
- return cpu::L2Norm(in);
+ return cpu::L2Norm(in, allocator);
}
// clang-format off
diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt
index 347a18b7..e31ed1d4 100644
--- a/src/tests/CMakeLists.txt
+++ b/src/tests/CMakeLists.txt
@@ -1,49 +1,25 @@
# Unit tests
-set(UNIT_TESTS
- graph_tests
- operator_tests
- rnn_tests
- attention_tests
-)
-
-foreach(test ${UNIT_TESTS})
- add_executable("run_${test}" run_tests.cpp "${test}.cpp")
- target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch)
-
- if(CUDA_FOUND)
- target_link_libraries("run_${test}" marian marian_cuda ${EXT_LIBS} Catch)
- endif(CUDA_FOUND)
-
- add_test(NAME ${test} COMMAND "run_${test}")
-endforeach(test)
+add_subdirectory(units)
# Testing apps
-add_executable(logger_test logger_test.cpp)
-add_executable(dropout_test dropout_test.cpp)
-add_executable(prod_test prod.cpp)
-add_executable(cli_test cli_test.cpp)
-
-if(CUDA_FOUND)
-add_executable(pooling_test pooling_test.cpp)
-target_link_libraries(pooling_test marian ${EXT_LIBS} Catch)
-target_link_libraries(pooling_test marian marian_cuda ${EXT_LIBS} Catch)
-endif(CUDA_FOUND)
-
-add_executable(sqlite_test sqlite_test.cpp)
+set(APP_TESTS
+ logger
+ dropout
+ sqlite
+ prod
+ cli
+ pooling
+)
-foreach(exec
- logger_test
- dropout_test
- sqlite_test
- prod_test
- cli_test
- )
- target_link_libraries(${exec} marian ${EXT_LIBS} Catch)
- if(CUDA_FOUND)
- target_link_libraries(${exec} marian marian_cuda ${EXT_LIBS} Catch)
- endif(CUDA_FOUND)
+foreach(test ${APP_TESTS})
+ add_executable("test_${test}" "${test}.cpp")
- set_target_properties(${exec} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
-endforeach(exec)
+ if(CUDA_FOUND)
+ target_link_libraries("test_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS})
+ else(CUDA_FOUND)
+ target_link_libraries("test_${test}" marian ${EXT_LIBS})
+ endif(CUDA_FOUND)
+ set_target_properties("test_${test}" PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}")
+endforeach(test)
diff --git a/src/tests/cli_test.cpp b/src/tests/cli.cpp
index 2f40bb13..67c9bdd7 100644
--- a/src/tests/cli_test.cpp
+++ b/src/tests/cli.cpp
@@ -39,9 +39,10 @@ int main(int argc, char** argv) {
auto options = New<Options>();
{
- auto w = New<CLIWrapper>(options);
+ YAML::Node config;
+ auto w = New<CLIWrapper>(config);
w->add<int>("-i,--int", "help message")->implicit_val("555")->default_val("123");
- w->add<std::string>("-s,--str", "help message")->default_val("foo");
+ w->add<std::string>("-s,--str", "help message");
w->add<std::vector<float>>("-v,--vec", "help message")->expected(-2);
w->switchGroup("My group");
w->add<std::vector<std::string>>("--defvec,-d", "help message")->default_val("foo");
@@ -49,10 +50,11 @@ int main(int argc, char** argv) {
w->add<bool>("-x,--xbool", "false boolean option", true);
w->add<std::string>("--a-very-long-option-name-for-testing-purposes", "A very long text a very long text a very long text a very long text a very long text a very long text");
w->switchGroup();
- w->add<std::string>("-f,--file", "help message")->check(validators::file_exists);
+ //w->add<std::string>("-f,--file", "help message")->check(validators::file_exists);
//w.add<color>("-e,--enum", "help message for enum");
w->parse(argc, argv);
+ options->merge(config);
}
options->get<int>("int");
@@ -65,7 +67,12 @@ int main(int argc, char** argv) {
//w.get<color>("enum");
YAML::Emitter emit;
- OutputYaml(options->getYaml(), emit);
+ OutputYaml(options->cloneToYamlNode(), emit);
std::cout << emit.c_str() << std::endl;
+
+ std::cout << "===" << std::endl;
+ std::cout << "vec/str.hasAndNotEmpty? " << options->hasAndNotEmpty("vec") << " " << options->hasAndNotEmpty("str") << std::endl;
+ std::cout << "vec/str.has? " << options->has("vec") << " " << options->has("str") << std::endl;
+
return 0;
}
diff --git a/src/tests/conv_test.cu b/src/tests/conv.cu
index 963ad7c9..963ad7c9 100644
--- a/src/tests/conv_test.cu
+++ b/src/tests/conv.cu
diff --git a/src/tests/conv_char_test.cu b/src/tests/conv_char.cu
index 66d2986f..66d2986f 100644
--- a/src/tests/conv_char_test.cu
+++ b/src/tests/conv_char.cu
diff --git a/src/tests/dropout_test.cpp b/src/tests/dropout.cpp
index c31ee621..367029fe 100644
--- a/src/tests/dropout_test.cpp
+++ b/src/tests/dropout.cpp
@@ -20,8 +20,8 @@ int main(int argc, char** argv) {
for(int i = 0; i < 10; ++i) {
g->clear();
- auto mask1 = g->dropout(0.2, {10, 3072});
- auto mask2 = g->dropout(0.3, {1, 3072});
+ auto mask1 = g->dropoutMask(0.2, {10, 3072});
+ auto mask2 = g->dropoutMask(0.3, {1, 3072});
auto mask = mask1 + mask2;
debug(mask1, "mask1");
debug(mask2, "mask2");
diff --git a/src/tests/graph_tests.cpp b/src/tests/graph_tests.cpp
deleted file mode 100644
index 822cb2e1..00000000
--- a/src/tests/graph_tests.cpp
+++ /dev/null
@@ -1,107 +0,0 @@
-#include "catch.hpp"
-#include "graph/expression_graph.h"
-#include "graph/expression_operators.h"
-
-using namespace marian;
-
-#ifdef CUDA_FOUND
-TEST_CASE("Graph device is set", "[graph]") {
- auto graph = New<ExpressionGraph>();
-
- graph->setDevice({0, DeviceType::gpu});
-
- DeviceId testId{0, DeviceType::gpu};
- REQUIRE(graph->getDeviceId() == testId);
-}
-
-TEST_CASE("Expression graph can be initialized with constant values",
- "[graph]") {
- auto graph = New<ExpressionGraph>();
- graph->setDevice({0, DeviceType::gpu});
- graph->reserveWorkspaceMB(4);
-
- std::vector<float> values;
-
- SECTION("initializing with zeros") {
- graph->clear();
- values.clear();
- auto zeros = graph->param("0s", {2, 5}, inits::zeros);
- graph->forward();
-
- zeros->val()->get(values);
- REQUIRE(values == std::vector<float>(10, 0.0f));
- }
-
- SECTION("initializing with ones") {
- graph->clear();
- values.clear();
- auto ones = graph->param("1s", {2, 5}, inits::ones);
- graph->forward();
-
- ones->val()->get(values);
- REQUIRE(values == std::vector<float>(10, 1.0f));
- }
-
- SECTION("initializing from vector") {
- graph->clear();
- values.clear();
- std::vector<float> v({1, 2, 3, 4, 5, 6});
- auto vals = graph->param("vs", {2, 3}, inits::from_vector(v));
- graph->forward();
-
- REQUIRE(values.empty());
- vals->val()->get(values);
- REQUIRE(values == v);
- }
-}
-#endif
-
-TEST_CASE("Graph device is set (cpu)", "[graph]") {
- auto graph = New<ExpressionGraph>();
-
- graph->setDevice({0, DeviceType::cpu});
-
- DeviceId testId{0, DeviceType::cpu};
- REQUIRE(graph->getDeviceId() == testId);
-}
-
-TEST_CASE("Expression graph can be initialized with constant values (cpu)",
- "[graph]") {
- auto graph = New<ExpressionGraph>();
- graph->setDevice({0, DeviceType::cpu});
- graph->reserveWorkspaceMB(4);
-
- std::vector<float> values;
-
- SECTION("initializing with zero (cpu)") {
- graph->clear();
- values.clear();
- auto zeros = graph->param("0s", {2, 5}, inits::zeros);
- graph->forward();
-
- zeros->val()->get(values);
- REQUIRE(values == std::vector<float>(10, 0.0f));
- }
-
- SECTION("initializing with ones (cpu)") {
- graph->clear();
- values.clear();
- auto ones = graph->param("1s", {2, 5}, inits::ones);
- graph->forward();
-
- ones->val()->get(values);
- REQUIRE(values == std::vector<float>(10, 1.0f));
- }
-
- SECTION("initializing from vector (cpu)") {
- graph->clear();
- values.clear();
- std::vector<float> v({1, 2, 3, 4, 5, 6});
- auto vals = graph->param("vs", {2, 3}, inits::from_vector(v));
- graph->forward();
-
- REQUIRE(values.empty());
- vals->val()->get(values);
- REQUIRE(values == v);
- }
-}
diff --git a/src/tests/logger_test.cpp b/src/tests/logger.cpp
index ff5727f7..ff5727f7 100644
--- a/src/tests/logger_test.cpp
+++ b/src/tests/logger.cpp
diff --git a/src/tests/operator_tests.cpp b/src/tests/operator_tests.cpp
deleted file mode 100644
index 7c17ec8f..00000000
--- a/src/tests/operator_tests.cpp
+++ /dev/null
@@ -1,619 +0,0 @@
-#include "catch.hpp"
-#include "graph/expression_graph.h"
-#include "graph/expression_operators.h"
-
-using namespace marian;
-
-void tests(DeviceType device) {
- auto floatApprox = [](float x, float y) { return x == Approx(y); };
-
- Config::seed = 1234;
-
- auto graph = New<ExpressionGraph>();
- graph->setDevice({0, device});
- graph->reserveWorkspaceMB(16);
-
- std::vector<float> values;
-
- SECTION("scalar multiplication") {
- graph->clear();
- values.clear();
- std::vector<float> vB({1, 2, 3, 4, 5, 6});
-
- auto B = graph->param("B", {3, 2}, inits::from_vector(vB));
- auto B2 = B * 2.0f;
- graph->forward();
-
- CHECK(B2->shape() == Shape({3, 2}));
- B2->val()->get(values);
-
- std::vector<float> vB2({2, 4, 6, 8, 10, 12});
- CHECK(values == vB2);
- }
-
- SECTION("elementwise binary operators with broadcasting") {
- graph->clear();
- values.clear();
-
- std::vector<float> vA({1, -2, 3, -4});
- std::vector<float> vB({0.5, 1.5});
-
- std::vector<float> vAdd({1.5, -0.5, 3.5, -2.5});
- std::vector<float> vMinus({-0.5, 3.5, -2.5, 5.5});
- std::vector<float> vMult({0.5, -3.0, 1.5, -6.0});
- std::vector<float> vDiv({2.0f, -1.33333f, 6.0f, -2.66667f});
-
- auto a = graph->constant({2, 2, 1}, inits::from_vector(vA));
- auto b = graph->constant({2, 1}, inits::from_vector(vB));
-
- auto add = a + b;
- auto minus = b - a;
- auto mult = a * b;
- auto div = a / b;
-
- graph->forward();
-
- CHECK(add->shape() == Shape({2, 2, 1}));
- CHECK(minus->shape() == Shape({2, 2, 1}));
- CHECK(mult->shape() == Shape({2, 2, 1}));
- CHECK(div->shape() == Shape({2, 2, 1}));
-
- add->val()->get(values);
- CHECK( values == vAdd );
-
- minus->val()->get(values);
- CHECK( values == vMinus );
-
- mult->val()->get(values);
- CHECK( values == vMult );
-
- div->val()->get(values);
- CHECK( std::equal(values.begin(), values.end(),
- vDiv.begin(), floatApprox) );
- }
-
- SECTION("transposing and reshaping") {
- graph->clear();
- values.clear();
-
- std::vector<float> vA({1, 2, 3, 4, 5, 6, 7, 8});
-
- std::vector<float> vT1({1, 5, 2, 6, 3, 7, 4, 8});
- std::vector<float> vT3({1, 2, 5, 6, 3, 4, 7, 8});
- std::vector<float> vT4({1, 5, 3, 7, 2, 6, 4, 8});
- std::vector<float> vT5({1, 2, 5, 6, 3, 4, 7, 8});
-
- auto a = graph->constant({2, 4}, inits::from_vector(vA));
-
- auto t1 = transpose(a);
- auto t2 = transpose(t1);
- auto t3 = transpose(reshape(t1, {2, 2, 2}));
-
- auto t4 = transpose(reshape(a, {2, 1, 2, 2}), {1, 3, 2, 0});
- auto t5 = transpose(reshape(a, {2, 1, 2, 2}), {2, 0, 1, 3});
-
- graph->forward();
-
- CHECK(t1->shape() == Shape({4, 2}));
- CHECK(t2->shape() == Shape({2, 4}));
- CHECK(t3->shape() == Shape({2, 2, 2}));
- CHECK(t4->shape() == Shape({1, 2, 2, 2}));
- CHECK(t5->shape() == Shape({2, 2, 1, 2}));
-
- t1->val()->get(values);
- CHECK( values == vT1 );
-
- t2->val()->get(values);
- CHECK( values == vA );
-
- t3->val()->get(values);
- CHECK( values == vT3 );
-
- t4->val()->get(values);
- CHECK( values == vT4 );
-
- t5->val()->get(values);
- CHECK( values == vT5 );
- }
-
- SECTION("softmax and logsoftmax") {
- graph->clear();
- values.clear();
- std::vector<float> in({-.2, -.3, 4.5, 5.2, -10, 101.45, -100.05, 1.05e-5});
-
- std::vector<float> smOut({ 0.52498f, 0.47502f, 0.33181f, 0.66819f,
- 0.0f, 1.0f, 0.0f, 1.0f });
-
- std::vector<float> lsmOut({ -0.6444f, -0.7444f, -1.10319f, -0.40319f,
- -111.45f, 0.0f, -100.05001f, 0.0f });
-
- auto input = graph->constant({2, 2, 2}, inits::from_vector(in));
-
- auto sm = softmax(input);
- auto lsm = logsoftmax(input);
-
- graph->forward();
-
- CHECK(sm->shape() == Shape({2, 2, 2}));
- CHECK(lsm->shape() == Shape({2, 2, 2}));
-
- sm->val()->get(values);
-
- CHECK( std::equal(values.begin(), values.end(),
- smOut.begin(), floatApprox) );
-
- lsm->val()->get(values);
-
- CHECK( std::equal(values.begin(), values.end(),
- lsmOut.begin(), floatApprox) );
- }
-
- SECTION("layer normalization") {
- graph->clear();
- values.clear();
-
-#ifdef CUDA_FOUND
- std::vector<float> vLn({
- -1.1962, 1.43061, 0.380288, -0.614697, 0.816638, 0.622649,
- -1.69679, 0.257504, -1.12563, -0.151387, 1.61181, -0.334796,
- 1.07207, -0.622614, 0.862014, -1.31147
- });
-#else
- std::vector<float> vLn({
- -1.49821, -0.152206, 0.394932, 1.25548, -1.51701, -0.28032,
- 0.9483, 0.849025, 0.855183, 1.11657, -0.788354, -1.1834,
- -0.85939, -1.13109, 0.972076, 1.01841
- });
-#endif
-
- auto a = graph->constant({2, 2, 4}, inits::glorot_uniform);
-
- auto gamma = graph->param("gamma", {1, 4}, inits::ones);
- auto beta = graph->param("beta", {1, 4}, inits::zeros);
-
- auto ln = layerNorm(a, gamma, beta);
-
- graph->forward();
-
- CHECK(ln->shape() == Shape({2, 2, 4}));
-
- ln->val()->get(values);
- CHECK( std::equal(values.begin(), values.end(),
- vLn.begin(), floatApprox) );
-
- }
-
- SECTION("reductions") {
- graph->clear();
- values.clear();
-
- std::vector<float> vA({1, 2, 3, 4, 5, 6, 7, 8});
- std::vector<float> vS1({6, 8, 10, 12});
- std::vector<float> vS2({10, 26});
-
- std::vector<float> vW({2.77778f, 6.77778f});
-
-
- auto a = graph->constant({2, 4}, inits::from_vector(vA));
-
- auto s1 = sum(a, /*axis=*/ 0);
- auto s2 = sum(a, /*axis=*/ 1);
-
- auto m3 = mean(s1, /*axis=*/ 1);
-
- auto sp = scalar_product(s2, s2, /*axis=*/ 0);
-
- auto wa = weighted_average(a, s1, /*axis=*/ -1);
-
- graph->forward();
-
- CHECK(s1->shape() == Shape({1, 4}));
- CHECK(s2->shape() == Shape({2, 1}));
- CHECK(m3->shape() == Shape({1, 1}));
- CHECK(sp->shape() == Shape({1, 1}));
- CHECK(wa->shape() == Shape({2, 1}));
-
- s1->val()->get(values);
- CHECK( values == vS1 );
-
- s2->val()->get(values);
- CHECK( values == vS2 );
-
- CHECK( m3->val()->scalar() == 9 );
- CHECK( sp->val()->scalar() == 776 );
-
- wa->val()->get(values);
- CHECK( std::equal(values.begin(), values.end(),
- vW.begin(), floatApprox) );
- }
-
- SECTION("concatenation") {
- graph->clear();
- values.clear();
-
- std::vector<float> vO1({ 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
- 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4,
- 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
- 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4});
-
- std::vector<float> vO2({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
- 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
- 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
- 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4});
-
- std::vector<float> vO3({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
- 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4});
-
- std::vector<float> vO4({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
- 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4});
-
- auto in1 = graph->constant({1, 2, 2, 3}, inits::from_value(1));
- auto in2 = graph->constant({1, 2, 2, 3}, inits::from_value(2));
- auto in3 = graph->constant({1, 2, 2, 3}, inits::from_value(3));
- auto in4 = graph->constant({1, 2, 2, 3}, inits::from_value(4));
-
- auto c1out1 = concatenate({in1, in2, in3, in4}, /*axis=*/ 2);
- auto c1out2 = concatenate({in1, in2, in3, in4}, /*axis=*/ -1);
- auto c1out3 = concatenate({in1, in2, in3, in4}, /*axis=*/ -3);
- auto c1out4 = concatenate({in1, in2, in3, in4}, /*axis=*/ 0);
-
- graph->forward();
-
- CHECK(c1out1->shape() == Shape({1, 2, 8, 3}));
- CHECK(c1out2->shape() == Shape({1, 2, 2, 12}));
- CHECK(c1out3->shape() == Shape({1, 8, 2, 3}));
- CHECK(c1out4->shape() == Shape({4, 2, 2, 3}));
-
- c1out1->val()->get(values);
- CHECK( values == vO1 );
-
- c1out2->val()->get(values);
- CHECK( values == vO2 );
-
- c1out3->val()->get(values);
- CHECK( values == vO3 );
-
- c1out4->val()->get(values);
- CHECK( values == vO4 );
- }
-
- SECTION("dot product") {
- graph->clear();
- values.clear();
-
- std::vector<float> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
- std::vector<float> vB({1, 2, 3, 4, 5, 6});
- std::vector<float> vC({22, 28, 49, 64, 76, 100, 103, 136});
-
- auto A = graph->param("A", {2, 2, 3}, inits::from_vector(vA));
- auto B = graph->param("B", {3, 2}, inits::from_vector(vB));
- auto C = dot(A, B);
- graph->forward();
-
- CHECK(C->shape() == Shape({2, 2, 2}));
- C->val()->get(values);
- CHECK(values == vC);
- }
-
- SECTION("affine transformation") {
- graph->clear();
- values.clear();
-
- std::vector<float> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
- std::vector<float> vB({1, 2, 3, 4, 5, 6});
- std::vector<float> vAff({24, 30, 51, 66, 78, 102, 105, 138});
-
- auto A = graph->param("A", {4, 3}, inits::from_vector(vA));
- auto B = graph->param("B", {3, 2}, inits::from_vector(vB));
- auto C = graph->param("C", {4, 2}, inits::from_value(2));
- auto aff1 = affine(A, B, C);
- auto aff2 = dot(A, B) + C;
- graph->forward();
-
- CHECK(aff1->shape() == Shape({4, 2}));
- aff1->val()->get(values);
- CHECK(values == vAff);
-
- std::vector<float> values2;
- CHECK(aff2->shape() == aff1->shape());
- aff2->val()->get(values2);
- CHECK(values2 == values);
- }
-
- SECTION("repeat") {
- graph->clear();
- values.clear();
-
- std::vector<float> vA({1, 2, 3, 4, 5, 6});
- std::vector<float> vB({1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6});
- std::vector<float> vC({1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
-
- auto A = graph->param("A", {2,3}, inits::from_vector(vA));
- auto I = repeat(A, 1, 0);
- auto B = repeat(A, 2, 0);
- auto C = repeat(A, 2, 1);
- graph->forward();
-
- CHECK(I->shape() == Shape({2, 3}));
- I->val()->get(values);
- CHECK(values == vA);
-
- CHECK(B->shape() == Shape({4, 3}));
- B->val()->get(values);
- CHECK(values == vB);
-
- CHECK(C->shape() == Shape({2, 6}));
- C->val()->get(values);
- CHECK(values == vC);
- }
-
- SECTION("flatten") {
- graph->clear();
- values.clear();
-
- std::vector<float> vIn({1, 2, 3, 4, 5, 6, 7, 8});
-
- auto A = graph->param("A", {2, 4}, inits::from_vector(vIn));
- auto Af = flatten(A);
- auto B = graph->param("B", {2, 2, 1, 2}, inits::from_vector(vIn));
- auto Bf = flatten(B);
- graph->forward();
-
- CHECK(Af->shape() == Shape({8}));
- Af->val()->get(values);
- CHECK(values == vIn);
-
- CHECK(Bf->shape() == Shape({8}));
- Bf->val()->get(values);
- CHECK(values == vIn);
- }
-
- SECTION("rows selection from 2d matrix") {
- graph->clear();
- values.clear();
-
- std::vector<float> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
-
- std::vector<IndexType> iB0({0}); // first row
- std::vector<IndexType> iB1({0, 1, 2}); // several consecutive rows
- std::vector<IndexType> iB2({0, 2}); // two nonconsecutive rows
- std::vector<IndexType> iB3({2, 1}); // reversed order
- std::vector<IndexType> iB4({1, 1}); // repeated rows
- std::vector<IndexType> iB5({0, 1, 2, 3}); // identity
- std::vector<IndexType> iB6({}); // empty
- std::vector<float> vB0({1, 2, 3});
- std::vector<float> vB1({1, 2, 3, 4, 5, 6, 7, 8, 9});
- std::vector<float> vB2({1, 2, 3, 7, 8, 9});
- std::vector<float> vB3({7, 8, 9, 4, 5, 6});
- std::vector<float> vB4({4, 5, 6, 4, 5, 6});
- std::vector<float> vB6;
-
- auto A = graph->param("A", {4, 3}, inits::from_vector(vA));
- auto B0 = rows(A, iB0);
- auto B1 = rows(A, iB1);
- auto B2 = rows(A, iB2);
- auto B3 = rows(A, iB3);
- auto B4 = rows(A, iB4);
- auto B5 = rows(A, iB5);
- auto B6 = rows(A, iB6);
- graph->forward();
-
- CHECK(B0->shape() == Shape({1, 3}));
- B0->val()->get(values);
- CHECK( values == vB0 );
-
- CHECK(B1->shape() == Shape({3, 3}));
- B1->val()->get(values);
- CHECK( values == vB1 );
-
- CHECK(B2->shape() == Shape({2, 3}));
- B2->val()->get(values);
- CHECK( values == vB2 );
-
- CHECK(B3->shape() == Shape({2, 3}));
- B3->val()->get(values);
- CHECK( values == vB3 );
-
- CHECK(B4->shape() == Shape({2, 3}));
- B4->val()->get(values);
- CHECK( values == vB4 );
-
- CHECK(B5->shape() == Shape({4, 3}));
- B5->val()->get(values);
- CHECK( values == vA );
-
- CHECK(B6->shape() == Shape({0, 3}));
- B6->val()->get(values);
- CHECK( values == vB6 );
- }
-
- SECTION("columns selection from 2d matrix") {
- graph->clear();
- values.clear();
-
- std::vector<float> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
-
- std::vector<IndexType> iB0({0}); // first column
- std::vector<IndexType> iB1({0, 1, 2}); // several consecutive columns
- std::vector<IndexType> iB2({0, 2}); // two nonconsecutive columns
- std::vector<IndexType> iB3({2, 1}); // reversed order
- std::vector<IndexType> iB4({1, 1}); // repeated columns
- std::vector<IndexType> iB5({0, 1, 2, 3}); // identity
- std::vector<IndexType> iB6({}); // empty
-
- std::vector<float> vB0({1, 5, 9});
- std::vector<float> vB1({1, 2, 3, 5, 6, 7, 9, 10, 11});
- std::vector<float> vB2({1, 3, 5, 7, 9, 11});
- std::vector<float> vB3({3, 2, 7, 6, 11, 10});
- std::vector<float> vB4({2, 2, 6, 6, 10, 10});
- std::vector<float> vB6;
-
- auto A = graph->param("A", {3, 4}, inits::from_vector(vA));
- auto B0 = cols(A, iB0);
- auto B1 = cols(A, iB1);
- auto B2 = cols(A, iB2);
- auto B3 = cols(A, iB3);
- auto B4 = cols(A, iB4);
- auto B5 = cols(A, iB5);
- auto B6 = cols(A, iB6);
- graph->forward();
-
- CHECK(B0->shape() == Shape({3, 1}));
- B0->val()->get(values);
- CHECK( values == vB0 );
-
- CHECK(B1->shape() == Shape({3, 3}));
- B1->val()->get(values);
- CHECK( values == vB1 );
-
- CHECK(B2->shape() == Shape({3, 2}));
- B2->val()->get(values);
- CHECK( values == vB2 );
-
- CHECK(B3->shape() == Shape({3, 2}));
- B3->val()->get(values);
- CHECK( values == vB3 );
-
- CHECK(B4->shape() == Shape({3, 2}));
- B4->val()->get(values);
- CHECK( values == vB4 );
-
- CHECK(B5->shape() == Shape({3, 4}));
- B5->val()->get(values);
- CHECK( values == vA );
-
- CHECK(B6->shape() == Shape({3, 0}));
- B6->val()->get(values);
- CHECK( values == vB6 );
- }
-
- SECTION("relation of rows and columns selection using transpose") {
- graph->clear();
- values.clear();
- std::vector<float> values2;
-
- std::vector<float> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
- std::vector<IndexType> idx({0, 1});
-
- auto A1 = graph->param("4x3", {4,3}, inits::from_vector(vA));
- auto B1 = rows(transpose(A1), idx);
- auto C1 = transpose(cols(A1, idx));
- auto A2 = graph->param("6x2", {6,2}, inits::from_vector(vA));
- auto B2 = cols(transpose(A2), idx);
- auto C2 = transpose(rows(A2, idx));
- graph->forward();
-
- CHECK(B1->shape() == C1->shape());
- B1->val()->get(values);
- C1->val()->get(values2);
- CHECK( values == values2 );
-
- values.clear();
- values2.clear();
-
- CHECK(B2->shape() == C2->shape());
- B2->val()->get(values);
- C2->val()->get(values2);
- CHECK( values == values2 );
- }
-
- SECTION("select operator") {
- using Indices = std::vector<IndexType>;
-
- graph->clear();
- values.clear();
-
- std::vector<float> in({1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12});
- std::vector<float> vB1({1, -2, 3});
- std::vector<float> vB2({1, -4, 7, -10});
- std::vector<float> vB3({-2, 5, -8, 11});
- std::vector<float> vB4({1, -2, 3, -4, 5, -6});
- std::vector<float> vD1(vB4);
- std::vector<float> vD2({5, -6, 11, -12});
- std::vector<float> vD3({1, -2, 5, -6, 7, -8, 11, -12});
-
- auto A = graph->param("4x3", {4,3}, inits::from_vector(in));
- auto B1 = select(A, Indices({0}), 0);
- auto B2 = select(A, Indices({0}), 1);
- auto B3 = select(A, Indices({1}), -1);
- auto B4 = select(A, Indices({0, 1}), 0);
-
- auto C = graph->param("2x3x2", {2, 3, 2}, inits::from_vector(in));
- auto D1 = select(C, Indices({0}), 0);
- auto D2 = select(C, Indices({2}), -2);
- auto D3 = select(C, Indices({0,2}), 1);
- graph->forward();
-
- CHECK(B1->shape() == Shape({1, 3}));
- B1->val()->get(values);
- CHECK( values == vB1 );
-
- CHECK(B2->shape() == Shape({4, 1}));
- B2->val()->get(values);
- CHECK( values == vB2 );
-
- CHECK(B3->shape() == Shape({4, 1}));
- B3->val()->get(values);
- CHECK( values == vB3 );
-
- CHECK(B4->shape() == Shape({2, 3}));
- B4->val()->get(values);
- CHECK( values == vB4 );
-
- values.clear();
-
- CHECK(D1->shape() == Shape({1, 3, 2}));
- D1->val()->get(values);
- CHECK( values == vD1 );
-
- CHECK(D2->shape() == Shape({2, 1, 2}));
- D2->val()->get(values);
- CHECK( values == vD2 );
-
- CHECK(D3->shape() == Shape({2, 2, 2}));
- D3->val()->get(values);
- CHECK( values == vD3 );
- }
-
- SECTION("rows/cols as select operations") {
- graph->clear();
- values.clear();
- std::vector<float> values2;
-
- std::vector<float> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
- std::vector<IndexType> idx({0, 2});
-
- auto A = graph->param("4x3", {4, 3}, inits::from_vector(vA));
- auto B1 = rows(A, idx);
- auto B2 = select(A, idx, 0);
- auto C1 = cols(A, idx);
- auto C2 = select(A, idx, 1);
- graph->forward();
-
- CHECK(B1->shape() == B2->shape());
- B1->val()->get(values);
- B2->val()->get(values2);
- CHECK( values == values2 );
-
- CHECK(C1->shape() == C2->shape());
- C1->val()->get(values);
- C2->val()->get(values2);
- CHECK( values == values2 );
- }
-}
-
-#ifdef CUDA_FOUND
-TEST_CASE("Expression graph supports basic math operations (gpu)", "[operator]") {
- tests(DeviceType::gpu);
-}
-#endif
-
-#ifdef BLAS_FOUND
-TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]") {
- tests(DeviceType::cpu);
-}
-#endif
diff --git a/src/tests/pooling_test.cpp b/src/tests/pooling.cpp
index 8f77cc7b..27f3b5da 100644
--- a/src/tests/pooling_test.cpp
+++ b/src/tests/pooling.cpp
@@ -42,10 +42,10 @@ int main(int argc, char** argv) {
}
auto x = graph->param("x", {dimBatch, dimWord, batchLength},
- inits::from_vector(embData));
+ inits::fromVector(embData));
auto xMask = graph->constant({dimBatch, 1, batchLength},
- inits::from_vector(embMask));
+ inits::fromVector(embMask));
// auto pooling = MaxPooling("pooling")(x, xMask);
// auto idx = graph->constant({elemNum, 1}, inits::zeros);
diff --git a/src/tests/prod.cpp b/src/tests/prod.cpp
index 5ade6716..698712b9 100644
--- a/src/tests/prod.cpp
+++ b/src/tests/prod.cpp
@@ -1,34 +1,35 @@
#include "marian.h"
#include "common/timer.h"
-int main(int argc, char** argv) {
+int main(int /*argc*/, char** /*argv*/) {
using namespace marian;
{
- auto g = New<ExpressionGraph>(true, false);
+ auto g = New<ExpressionGraph>(true);
g->setDevice({0, DeviceType::cpu});
+ g->getBackend()->setOptimized(false);
g->reserveWorkspaceMB(2512);
timer::AutoTimer timer;
for(int i = 0; i < 100; ++i) {
g->clear();
- auto x = g->constant({1, 4, 8, 256}, inits::glorot_uniform);
+ auto x = g->constant({1, 4, 8, 256}, inits::glorotUniform());
- auto W1 = g->param("W1", {256, 2048}, inits::glorot_uniform);
- auto b1 = g->param("b1", {1, 2048}, inits::glorot_uniform);
+ auto W1 = g->param("W1", {256, 2048}, inits::glorotUniform());
+ auto b1 = g->param("b1", {1, 2048}, inits::glorotUniform());
auto out = affine(x, W1, b1);
for(int i = 2; i < 20; ++i) {
- auto Wi = g->param("W" + std::to_string(i), {2048, 2048}, inits::glorot_uniform);
- auto bi = g->param("b" + std::to_string(i), {1, 2048}, inits::glorot_uniform);
+ auto Wi = g->param("W" + std::to_string(i), {2048, 2048}, inits::glorotUniform());
+ auto bi = g->param("b" + std::to_string(i), {1, 2048}, inits::glorotUniform());
out = relu(affine(out, Wi, bi));
}
- auto Wn = g->param("Wn", {2048, 256}, inits::glorot_uniform);
- auto bn = g->param("bn", {1, 256}, inits::glorot_uniform);
+ auto Wn = g->param("Wn", {2048, 256}, inits::glorotUniform());
+ auto bn = g->param("bn", {1, 256}, inits::glorotUniform());
auto y = affine(out, Wn, bn);
@@ -37,30 +38,31 @@ int main(int argc, char** argv) {
}
{
- auto g = New<ExpressionGraph>(true, true);
+ auto g = New<ExpressionGraph>(true);
g->setDevice({0, DeviceType::cpu});
+ g->getBackend()->setOptimized(true);
g->reserveWorkspaceMB(2512);
timer::AutoTimer timer;
for(int i = 0; i < 100; ++i) {
g->clear();
- auto x = g->constant({1, 4, 8, 256}, inits::glorot_uniform);
+ auto x = g->constant({1, 4, 8, 256}, inits::glorotUniform());
- auto W1 = g->param("W1", {256, 2048}, inits::glorot_uniform);
- auto b1 = g->param("b1", {1, 2048}, inits::glorot_uniform);
+ auto W1 = g->param("W1", {256, 2048}, inits::glorotUniform());
+ auto b1 = g->param("b1", {1, 2048}, inits::glorotUniform());
auto out = affine(x, W1, b1);
for(int i = 2; i < 20; ++i) {
- auto Wi = g->param("W" + std::to_string(i), {2048, 2048}, inits::glorot_uniform);
- auto bi = g->param("b" + std::to_string(i), {1, 2048}, inits::glorot_uniform);
+ auto Wi = g->param("W" + std::to_string(i), {2048, 2048}, inits::glorotUniform());
+ auto bi = g->param("b" + std::to_string(i), {1, 2048}, inits::glorotUniform());
out = relu(affine(out, Wi, bi));
}
- auto Wn = g->param("Wn", {2048, 256}, inits::glorot_uniform);
- auto bn = g->param("bn", {1, 256}, inits::glorot_uniform);
+ auto Wn = g->param("Wn", {2048, 256}, inits::glorotUniform());
+ auto bn = g->param("bn", {1, 256}, inits::glorotUniform());
auto y = affine(out, Wn, bn);
@@ -68,6 +70,5 @@ int main(int argc, char** argv) {
}
}
-
return 0;
}
diff --git a/src/tests/sqlite_test.cpp b/src/tests/sqlite.cpp
index 21748514..f822dbc5 100644
--- a/src/tests/sqlite_test.cpp
+++ b/src/tests/sqlite.cpp
@@ -8,6 +8,8 @@
#include <fstream>
int main(int argc, char** argv) {
+ ABORT_IF(argc != 3, "FATAL ERROR: Incorrect number of command line arguments "
+ "(expected: 2) for command {}.",argv[0]);
SQLite::Database db("corpus.db", SQLite::OPEN_READWRITE|SQLite::OPEN_CREATE);
db.exec("PRAGMA temp_store_directory = '/data1/marcinjd';");
diff --git a/src/tests/tensor_test.cu b/src/tests/tensor.cu
index 72cdc276..72cdc276 100644
--- a/src/tests/tensor_test.cu
+++ b/src/tests/tensor.cu
diff --git a/src/tests/units/CMakeLists.txt b/src/tests/units/CMakeLists.txt
new file mode 100644
index 00000000..3814b481
--- /dev/null
+++ b/src/tests/units/CMakeLists.txt
@@ -0,0 +1,20 @@
+# Unit tests
+set(UNIT_TESTS
+ graph_tests
+ operator_tests
+ rnn_tests
+ attention_tests
+ fastopt_tests
+)
+
+foreach(test ${UNIT_TESTS})
+ add_executable("run_${test}" run_tests.cpp "${test}.cpp")
+
+ if(CUDA_FOUND)
+ target_link_libraries("run_${test}" ${EXT_LIBS} marian ${EXT_LIBS} marian_cuda ${EXT_LIBS} Catch)
+ else(CUDA_FOUND)
+ target_link_libraries("run_${test}" marian ${EXT_LIBS} Catch)
+ endif(CUDA_FOUND)
+
+ add_test(NAME ${test} COMMAND "run_${test}")
+endforeach(test)
diff --git a/src/tests/attention_tests.cpp b/src/tests/units/attention_tests.cpp
index eff55572..e13e7943 100644
--- a/src/tests/attention_tests.cpp
+++ b/src/tests/units/attention_tests.cpp
@@ -1,18 +1,33 @@
#include "catch.hpp"
#include "marian.h"
+#ifdef CUDA_FOUND
+#include "tensors/gpu/backend.h"
+#endif
+
#include "rnn/rnn.h"
#include "rnn/constructors.h"
#include "rnn/attention.h"
using namespace marian;
-void tests(DeviceType type) {
- auto floatApprox = [](float x, float y) { return x == Approx(y).epsilon(0.01); };
+template <typename T>
+void tests(DeviceType type, Type floatType = Type::float32) {
+
+// Checking for FP16 support and skipping if not supported.
+#ifdef CUDA_FOUND
+ if(type == DeviceType::gpu && floatType == Type::float16) {
+ auto gpuBackend = New<gpu::Backend>(DeviceId({0, type}), /*seed=*/1234);
+ auto cudaCompute = gpuBackend->getCudaComputeCapability();
+ if(cudaCompute.major < 6) return;
+ }
+#endif
+
+ auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01); };
Config::seed = 1234;
- Words vWords = {
+ std::vector<IndexType> vWords = {
43, 2, 83, 78,
6, 38, 80, 40,
40, 70, 26, 60,
@@ -23,7 +38,7 @@ void tests(DeviceType type) {
0, 0, 0, 0
};
- std::vector<float> vMask = {
+ std::vector<T> vMask = {
1, 1, 1, 1,
1, 1, 1, 1,
1, 1, 1, 1,
@@ -36,10 +51,11 @@ void tests(DeviceType type) {
SECTION("Attention over encoder context") {
auto graph = New<ExpressionGraph>();
+ graph->setDefaultElementType(floatType);
graph->setDevice({0, type});
graph->reserveWorkspaceMB(16);
- std::vector<float> values;
+ std::vector<T> values;
int dimEmb = 16;
int dimBatch = 4;
@@ -47,19 +63,19 @@ void tests(DeviceType type) {
auto emb = graph->param("Embeddings",
{128, dimEmb},
- inits::glorot_uniform);
+ inits::glorotUniform());
auto input = reshape(rows(emb, vWords), {dimTime, dimBatch, dimEmb});
auto mask = graph->constant({dimTime, dimBatch, 1},
- inits::from_vector(vMask));
+ inits::fromVector(vMask));
- auto rnn = rnn::rnn(graph) //
+ auto rnn = rnn::rnn() //
("prefix", "rnntest") //
("type", "gru") //
("dimInput", 16) //
("dimState", 8) //
- .push_back(rnn::cell(graph)) //
- .construct();
+ .push_back(rnn::cell()) //
+ .construct(graph);
auto context = rnn->transduce(input, mask);
@@ -71,12 +87,12 @@ void tests(DeviceType type) {
auto att = New<rnn::Attention>(graph, options, encState);
- std::vector<float> vState(64);
+ std::vector<T> vState(64);
std::generate(vState.begin(), vState.end(),
[](){ static int n = -32; return n++ / 64.f; });
rnn::State state({graph->constant({1, 1, 4, 16},
- inits::from_vector(vState)),
+ inits::fromVector(vState)),
nullptr});
auto aligned = att->apply(state);
@@ -86,7 +102,7 @@ void tests(DeviceType type) {
CHECK(aligned->shape() == Shape({1, 1, 4, 8}));
#ifdef CUDA_FOUND
- std::vector<float> vAligned({
+ std::vector<T> vAligned({
0.0396688, -0.0124071, -0.0159668, -0.00080064,
-0.0132853, 0.0240206, 0.0744701, -0.0248388,
0.0258906, -0.00868394, -0.0374499, 0.0357639,
@@ -97,7 +113,7 @@ void tests(DeviceType type) {
-0.0330807, 0.018745, 0.0341848, -0.0111661
});
#else
- std::vector<float> vAligned({
+ std::vector<T> vAligned({
-0.061056, 0.0262615, -0.0393096, 0.115902,
0.0941305, 0.00475613, -0.0159573, 0.00293181,
-0.0919751, -0.018913, 0.00927365, -0.000343846,
@@ -116,12 +132,18 @@ void tests(DeviceType type) {
#ifdef CUDA_FOUND
TEST_CASE("Model components, Attention (gpu)", "[attention]") {
- tests(DeviceType::gpu);
+ tests<float>(DeviceType::gpu);
}
+
+#if COMPILE_FP16
+TEST_CASE("Model components, Attention (gpu, fp16)", "[attention]") {
+ tests<float16>(DeviceType::gpu, Type::float16);
+}
+#endif
#endif
#ifdef BLAS_FOUND
TEST_CASE("Model components, Attention (cpu)", "[attention]") {
- tests(DeviceType::cpu);
+ tests<float>(DeviceType::cpu);
}
#endif
diff --git a/src/tests/units/fastopt_tests.cpp b/src/tests/units/fastopt_tests.cpp
new file mode 100644
index 00000000..cec6ab90
--- /dev/null
+++ b/src/tests/units/fastopt_tests.cpp
@@ -0,0 +1,82 @@
+#include "catch.hpp"
+#include "common/fastopt.h"
+#include "3rd_party/yaml-cpp/yaml.h"
+
+using namespace marian;
+
+TEST_CASE("FastOpt can be constructed from a YAML node", "[fastopt]") {
+ YAML::Node node;
+
+ SECTION("from a simple node") {
+ YAML::Node node = YAML::Load("{foo: bar}");
+ const FastOpt o(node);
+
+ CHECK( o.has("foo") );
+ CHECK_FALSE( o.has("bar") );
+ CHECK_FALSE( o.has("baz") );
+ }
+
+ SECTION("from a sequence node") {
+ YAML::Node node = YAML::Load("{foo: [bar, baz]}");
+ const FastOpt o(node);
+ CHECK( o.has("foo") );
+ }
+
+ SECTION("from nested nodes") {
+ YAML::Node node = YAML::Load("{foo: {bar: 123, baz}}");
+ const FastOpt o(node);
+ CHECK( o.has("foo") );
+ CHECK( o["foo"].has("bar") );
+ CHECK( o["foo"].has("baz") );
+ CHECK( o["foo"]["bar"].as<int>() == 123 );
+ CHECK( o["foo"]["baz"].isNull() );
+ }
+}
+
+TEST_CASE("Options can be accessed", "[fastopt]") {
+ YAML::Node node = YAML::Load("{"
+ "foo: bar,"
+ "seq: [1, 2, 3],"
+ "subnode: {"
+ " baz: [ 111.5, False ],"
+ " qux: 222,"
+ " preprocess1: n,"
+ " preprocess2: d,"
+ " preprocess3: y,"
+ " }"
+ "}");
+
+ const FastOpt o(node);
+
+ SECTION("using operator[]") {
+ auto& oo = o["subnode"];
+ CHECK( oo.has("baz") );
+ CHECK( oo.has("qux") );
+ CHECK_NOTHROW( o["subnode"]["baz"] );
+ }
+
+ SECTION("using as<T>()") {
+ CHECK( o["foo"].as<std::string>() == "bar" );
+ CHECK( o["subnode"]["baz"][0].as<float>() == 111.5f );
+ CHECK( o["subnode"]["baz"][1].as<bool>() == false );
+ CHECK( o["subnode"]["baz"][0].as<int>() == 111 );
+ CHECK( o["subnode"]["preprocess1"].as<std::string>() == "n" ); // don't allow "n" to be cast to boolean false while converting from YAML
+ CHECK( o["subnode"]["preprocess2"].as<std::string>() == "d" );
+ CHECK( o["subnode"]["preprocess3"].as<std::string>() == "y" ); // don't allow "y" to be cast to boolean true while converting from YAML
+ }
+
+ node["foo"] = "baz";
+ if(o.has("foo")) {
+ FastOpt temp(node["foo"]);
+ const_cast<FastOpt&>(o["foo"]).swap(temp);
+ }
+
+ CHECK( o["foo"].as<std::string>() == "baz" );
+
+ // for(auto k : o[subnode].keys())
+ // o[subnode][k].type()
+
+ SECTION("using as<std::vector<T>>()") {
+ CHECK( o["seq"].as<std::vector<double>>() == std::vector<double>({1, 2, 3}) );
+ }
+}
diff --git a/src/tests/units/graph_tests.cpp b/src/tests/units/graph_tests.cpp
new file mode 100644
index 00000000..403b06e9
--- /dev/null
+++ b/src/tests/units/graph_tests.cpp
@@ -0,0 +1,123 @@
+#include "catch.hpp"
+#include "graph/expression_graph.h"
+#include "graph/expression_operators.h"
+
+#ifdef CUDA_FOUND
+#include "tensors/gpu/backend.h"
+#endif
+
+using namespace marian;
+
+#ifdef CUDA_FOUND
+TEST_CASE("Graph device is set", "[graph]") {
+ auto graph = New<ExpressionGraph>();
+ graph->setDevice({0, DeviceType::gpu});
+
+ DeviceId testId{0, DeviceType::gpu};
+ REQUIRE(graph->getDeviceId() == testId);
+}
+
+TEST_CASE("Expression graph can be initialized with constant values",
+ "[graph]") {
+
+ for(auto type : std::vector<Type>({Type::float32, Type::float16})) {
+
+ auto graph = New<ExpressionGraph>();
+ graph->setDefaultElementType(type);
+ graph->setDevice({0, DeviceType::gpu});
+ graph->reserveWorkspaceMB(4);
+
+ if(type == Type::float16) {
+ auto gpuBackend = std::dynamic_pointer_cast<gpu::Backend>(graph->getBackend());
+ if(gpuBackend) {
+ auto cudaCompute = gpuBackend->getCudaComputeCapability();
+ if(cudaCompute.major < 6) continue;
+ }
+ }
+
+ std::vector<float> values;
+
+ SECTION("initializing with zeros") {
+ graph->clear();
+ values.clear();
+ auto zeros = graph->param("0s", {2, 5}, inits::zeros());
+ graph->forward();
+
+ zeros->val()->get(values);
+ REQUIRE(values == std::vector<float>(10, 0.0f));
+ }
+
+ SECTION("initializing with ones") {
+ graph->clear();
+ values.clear();
+ auto ones = graph->param("1s", {2, 5}, inits::ones());
+ graph->forward();
+
+ ones->val()->get(values);
+ REQUIRE(values == std::vector<float>(10, 1.0f));
+ }
+
+ SECTION("initializing from vector") {
+ graph->clear();
+ values.clear();
+ std::vector<float> v({1, 2, 3, 4, 5, 6});
+ auto vals = graph->param("vs", {2, 3}, inits::fromVector(v));
+ graph->forward();
+
+ REQUIRE(values.empty());
+ vals->val()->get(values);
+ REQUIRE(values == v);
+ }
+ }
+}
+#endif
+
+TEST_CASE("Graph device is set (cpu)", "[graph]") {
+ auto graph = New<ExpressionGraph>();
+
+ graph->setDevice({0, DeviceType::cpu});
+
+ DeviceId testId{0, DeviceType::cpu};
+ REQUIRE(graph->getDeviceId() == testId);
+}
+
+TEST_CASE("Expression graph can be initialized with constant values (cpu)",
+ "[graph]") {
+ auto graph = New<ExpressionGraph>();
+ graph->setDevice({0, DeviceType::cpu});
+ graph->reserveWorkspaceMB(4);
+
+ std::vector<float> values;
+
+ SECTION("initializing with zero (cpu)") {
+ graph->clear();
+ values.clear();
+ auto zeros = graph->param("0s", {2, 5}, inits::zeros());
+ graph->forward();
+
+ zeros->val()->get(values);
+ REQUIRE(values == std::vector<float>(10, 0.0f));
+ }
+
+ SECTION("initializing with ones (cpu)") {
+ graph->clear();
+ values.clear();
+ auto ones = graph->param("1s", {2, 5}, inits::ones());
+ graph->forward();
+
+ ones->val()->get(values);
+ REQUIRE(values == std::vector<float>(10, 1.0f));
+ }
+
+ SECTION("initializing from vector (cpu)") {
+ graph->clear();
+ values.clear();
+ std::vector<float> v({1, 2, 3, 4, 5, 6});
+ auto vals = graph->param("vs", {2, 3}, inits::fromVector(v));
+ graph->forward();
+
+ REQUIRE(values.empty());
+ vals->val()->get(values);
+ REQUIRE(values == v);
+ }
+}
diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp
new file mode 100644
index 00000000..682ef480
--- /dev/null
+++ b/src/tests/units/operator_tests.cpp
@@ -0,0 +1,847 @@
+#include "catch.hpp"
+#include "graph/expression_graph.h"
+#include "graph/expression_operators.h"
+
+#ifdef CUDA_FOUND
+#include "tensors/gpu/backend.h"
+#endif
+
+#include <cmath>
+
+using namespace marian;
+
+template <typename T>
+void tests(DeviceType device, Type floatType = Type::float32) {
+
+// Checking for FP16 support and skipping if not supported.
+#ifdef CUDA_FOUND
+ if(device == DeviceType::gpu && floatType == Type::float16) {
+ auto gpuBackend = New<gpu::Backend>(DeviceId({0, device}), /*seed=*/1234);
+ auto cudaCompute = gpuBackend->getCudaComputeCapability();
+ if(cudaCompute.major < 6) return;
+ }
+#endif
+
+ auto floatApprox = [](T x, T y) -> bool { return x == Approx(y).epsilon(0.01); };
+ auto floatEqual = [](T x, T y) -> bool { return x == y; };
+
+ Config::seed = 1234;
+ auto graph = New<ExpressionGraph>();
+ graph->setDefaultElementType(floatType);
+ graph->setDevice({0, device});
+ graph->reserveWorkspaceMB(16);
+
+ std::vector<T> values, values2;
+
+ SECTION("scalar multiplication") {
+ graph->clear();
+ values.clear();
+ std::vector<T> vB({1, 2, 3, 4, 5, 6});
+
+ auto B = graph->param("B", {3, 2}, inits::fromVector(vB));
+ auto B2 = B * 2.0f;
+ graph->forward();
+
+ CHECK(B2->shape() == Shape({3, 2}));
+ B2->val()->get(values);
+
+ std::vector<T> vB2({2, 4, 6, 8, 10, 12});
+ CHECK(values == vB2);
+ }
+
+ SECTION("elementwise binary operators with broadcasting") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({1, -2, 3, -4});
+ std::vector<T> vB({0.5, 1.5});
+
+ auto a = graph->constant({2, 2, 1}, inits::fromVector(vA));
+ auto b = graph->constant({2, 1}, inits::fromVector(vB));
+
+ auto compare = [&](Expr res, std::function<float(float,float)> f, bool exactMatch) -> bool {
+ if (res->shape() != Shape({ 2, 2, 1 }))
+ return false;
+ res->val()->get(values);
+ std::vector<float> ref{f(vA[0], vB[0]), f(vA[1], vB[1]), f(vA[2], vB[0]), f(vA[3], vB[1])};
+ return std::equal(values.begin(), values.end(), ref.begin(), exactMatch ? floatEqual : floatApprox);
+ };
+
+ auto rplus = a + b;
+ auto rminus = a - b;
+ auto rmult = a * b;
+ auto rdiv = a / b;
+ auto rlae = logaddexp(a, b);
+ auto rmax = maximum(a, b);
+ auto rmin = minimum(a, b);
+ auto rlt = lt(a, b);
+ auto req = eq(a, b);
+ auto rgt = gt(a, b);
+ auto rge = ge(a, b);
+ auto rne = ne(a, b);
+ auto rle = le(a, b);
+
+ graph->forward();
+
+ CHECK(compare(rplus, [](float a, float b) {return a + b;}, true));
+ CHECK(compare(rminus, [](float a, float b) {return a - b;}, true));
+ CHECK(compare(rmult, [](float a, float b) {return a * b;}, true));
+ CHECK(compare(rdiv, [](float a, float b) {return a / b;}, false));
+ CHECK(compare(rlae, [](float a, float b) {return logf(expf(a) + expf(b));}, false));
+ CHECK(compare(rmax, [](float a, float b) {return std::max(a, b);}, true));
+ CHECK(compare(rmin, [](float a, float b) {return std::min(a, b);}, true));
+ CHECK(compare(rlt, [](float a, float b) {return a < b;}, true));
+ CHECK(compare(req, [](float a, float b) {return a == b;}, true));
+ CHECK(compare(rgt, [](float a, float b) {return a > b;}, true));
+ CHECK(compare(rge, [](float a, float b) {return a >= b;}, true));
+ CHECK(compare(rne, [](float a, float b) {return a != b;}, true));
+ CHECK(compare(rle, [](float a, float b) {return a <= b;}, true));
+ }
+
+ SECTION("transposing and reshaping") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({1, 2, 3, 4, 5, 6, 7, 8});
+
+ std::vector<T> vT1({1, 5, 2, 6, 3, 7, 4, 8});
+ std::vector<T> vT3({1, 2, 5, 6, 3, 4, 7, 8});
+ std::vector<T> vT4({1, 5, 3, 7, 2, 6, 4, 8});
+ std::vector<T> vT5({1, 2, 5, 6, 3, 4, 7, 8});
+
+ auto a = graph->constant({2, 4}, inits::fromVector(vA));
+
+ auto t1 = transpose(a);
+ auto t2 = transpose(t1);
+ auto t3 = transpose(reshape(t1, {2, 2, 2}));
+
+ auto t4 = transpose(reshape(a, {2, 1, 2, 2}), {1, 3, 2, 0});
+ auto t5 = transpose(reshape(a, {2, 1, 2, 2}), {2, 0, 1, 3});
+
+ auto t6 = stopGradient(a);
+
+ graph->forward();
+
+ CHECK(t1->shape() == Shape({4, 2}));
+ CHECK(t2->shape() == Shape({2, 4}));
+ CHECK(t3->shape() == Shape({2, 2, 2}));
+ CHECK(t4->shape() == Shape({1, 2, 2, 2}));
+ CHECK(t5->shape() == Shape({2, 2, 1, 2}));
+ CHECK(t6->shape() == a->shape());
+
+ t1->val()->get(values);
+ CHECK( values == vT1 );
+
+ t2->val()->get(values);
+ CHECK( values == vA );
+
+ t3->val()->get(values);
+ CHECK( values == vT3 );
+
+ t4->val()->get(values);
+ CHECK( values == vT4 );
+
+ t5->val()->get(values);
+ CHECK( values == vT5 );
+
+ t6->val()->get(values);
+ CHECK(values == vA);
+ CHECK(!t6->trainable());
+ }
+
+ SECTION("softmax and logsoftmax") {
+ graph->clear();
+ values.clear();
+ std::vector<T> in({-.2, -.3, 4.5, 5.2, -10, 101.45, -100.05, 1.05e-5});
+
+ std::vector<T> smOut({ 0.52498f, 0.47502f, 0.33181f, 0.66819f,
+ 0.0f, 1.0f, 0.0f, 1.0f });
+
+ std::vector<T> lsmOut({ -0.6444f, -0.7444f, -1.10319f, -0.40319f,
+ -111.45f, 0.0f, -100.05001f, 0.0f });
+
+ auto input = graph->constant({2, 2, 2}, inits::fromVector(in));
+
+ auto sm = softmax(input);
+ auto lsm = logsoftmax(input);
+
+ graph->forward();
+
+ CHECK(sm->shape() == Shape({2, 2, 2}));
+ CHECK(lsm->shape() == Shape({2, 2, 2}));
+
+ sm->val()->get(values);
+
+ CHECK( std::equal(values.begin(), values.end(),
+ smOut.begin(), floatApprox) );
+
+ lsm->val()->get(values);
+
+ CHECK( std::equal(values.begin(), values.end(),
+ lsmOut.begin(), floatApprox) );
+ }
+
+ SECTION("layer normalization") {
+ graph->clear();
+ values.clear();
+
+#ifdef CUDA_FOUND
+ std::vector<T> vLn({
+ -1.1962, 1.43061, 0.380288, -0.614697, 0.816638, 0.622649,
+ -1.69679, 0.257504, -1.12563, -0.151387, 1.61181, -0.334796,
+ 1.07207, -0.622614, 0.862014, -1.31147
+ });
+#else
+ std::vector<T> vLn({
+ -1.49821, -0.152206, 0.394932, 1.25548, -1.51701, -0.28032,
+ 0.9483, 0.849025, 0.855183, 1.11657, -0.788354, -1.1834,
+ -0.85939, -1.13109, 0.972076, 1.01841
+ });
+#endif
+
+ auto a = graph->constant({2, 2, 4}, inits::glorotUniform());
+ auto gamma = graph->param("gamma", {1, 4}, inits::ones());
+ auto beta = graph->param("beta", {1, 4}, inits::zeros());
+ auto ln = layerNorm(a, gamma, beta);
+
+ graph->forward();
+
+ CHECK(ln->shape() == Shape({2, 2, 4}));
+
+ ln->val()->get(values);
+ CHECK( std::equal(values.begin(), values.end(),
+ vLn.begin(), floatApprox) );
+
+ }
+
+ SECTION("reductions") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({1, 6, 3, 8,
+ 5, 2, 7, 4});
+ // import numpy as np
+ // a = np.array([[1, 6, 3, 8], [5, 2, 7, 4]])
+ std::vector<T> vS1({6, 8, 10, 12}); // s1 = np.sum(a, axis=0)
+ std::vector<T> vS2({18, 18}); // np.sum(a, axis = 1)
+ std::vector<T> vS4({2.6925824f, 1.80277564f}); // np.std(a, axis = 1)
+ std::vector<T> vV5({7.25, 3.25}); // np.var(a, axis = 1)
+ std::vector<T> vM6({8, 7}); // np.max(a, axis = 1)
+ std::vector<T> vM7({1, 2}); // np.min(a, axis = 1)
+ std::vector<T> vP8({144, 280}); // np.prod(a, axis = 1)
+ std::vector<T> vL9({8.13364336f, 7.17551536f}); // np.log(np.sum(np.exp(a), axis=1))
+ std::vector<T> vW({5.0f, 4.55555556f}); // np.mean(a*s1,axis=-1) / np.mean(s1,axis=-1)
+
+ auto a = graph->constant({2, 4}, inits::fromVector(vA));
+
+ auto s1 = sum(a, /*axis=*/ 0);
+ auto s2 = sum(a, /*axis=*/ 1);
+
+ auto m3 = mean(s1, /*axis=*/ 1);
+
+ auto s4 = marian::std(a, /*axis=*/ 1);
+ auto v5 = var(a, /*axis=*/ 1);
+
+ auto m6 = max(a, /*axis=*/ 1);
+ auto m7 = min(a, /*axis=*/ 1);
+ auto p8 = prod(a, /*axis=*/ 1);
+ auto l9 = logsumexp(a, /*axis=*/ 1);
+
+ auto sp = scalar_product(s2, s2, /*axis=*/ 0);
+
+ auto wa = weighted_average(a, s1, /*axis=*/ -1);
+
+ graph->forward();
+
+ CHECK(s1->shape() == Shape({1, 4}));
+ CHECK(s2->shape() == Shape({2, 1}));
+ CHECK(m3->shape() == Shape({1, 1}));
+ CHECK(s4->shape() == Shape({2, 1}));
+ CHECK(v5->shape() == Shape({2, 1}));
+ CHECK(m6->shape() == Shape({2, 1}));
+ CHECK(m7->shape() == Shape({2, 1}));
+ CHECK(p8->shape() == Shape({2, 1}));
+ CHECK(l9->shape() == Shape({2, 1}));
+ CHECK(sp->shape() == Shape({1, 1}));
+ CHECK(wa->shape() == Shape({2, 1}));
+
+ s1->val()->get(values); CHECK(values == vS1);
+ s2->val()->get(values); CHECK(values == vS2);
+
+ CHECK(m3->val()->scalar() == 9);
+
+ s4->val()->get(values); CHECK(std::equal(values.begin(), values.end(), vS4.begin(), floatApprox));
+ v5->val()->get(values); CHECK(values == vV5);
+ m6->val()->get(values); CHECK(values == vM6);
+ m7->val()->get(values); CHECK(values == vM7);
+ p8->val()->get(values); CHECK(values == vP8);
+ l9->val()->get(values); CHECK(std::equal(values.begin(), values.end(), vL9.begin(), floatApprox));
+
+ CHECK(sp->val()->scalar() == 648);
+
+ wa->val()->get(values); CHECK(std::equal(values.begin(), values.end(), vW.begin(), floatApprox));
+ }
+
+ SECTION("concatenation") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vO1({ 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4,
+ 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4});
+
+ std::vector<T> vO2({1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
+ 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
+ 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4,
+ 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4});
+
+ std::vector<T> vO3({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4});
+
+ std::vector<T> vO4({1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4});
+
+ auto in1 = graph->constant({1, 2, 2, 3}, inits::fromValue(1));
+ auto in2 = graph->constant({1, 2, 2, 3}, inits::fromValue(2));
+ auto in3 = graph->constant({1, 2, 2, 3}, inits::fromValue(3));
+ auto in4 = graph->constant({1, 2, 2, 3}, inits::fromValue(4));
+
+ auto c1out1 = concatenate({in1, in2, in3, in4}, /*axis=*/ 2);
+ auto c1out2 = concatenate({in1, in2, in3, in4}, /*axis=*/ -1);
+ auto c1out3 = concatenate({in1, in2, in3, in4}, /*axis=*/ -3);
+ auto c1out4 = concatenate({in1, in2, in3, in4}, /*axis=*/ 0);
+
+ graph->forward();
+
+ CHECK(c1out1->shape() == Shape({1, 2, 8, 3}));
+ CHECK(c1out2->shape() == Shape({1, 2, 2, 12}));
+ CHECK(c1out3->shape() == Shape({1, 8, 2, 3}));
+ CHECK(c1out4->shape() == Shape({4, 2, 2, 3}));
+
+ c1out1->val()->get(values);
+ CHECK( values == vO1 );
+
+ c1out2->val()->get(values);
+ CHECK( values == vO2 );
+
+ c1out3->val()->get(values);
+ CHECK( values == vO3 );
+
+ c1out4->val()->get(values);
+ CHECK( values == vO4 );
+ }
+
+ SECTION("dot product") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({1, 2, 3,
+ 4, 5, 6,
+ 7, 8, 9,
+ 10, 11, 12});
+ std::vector<T> vB({1, 2,
+ 3, 4,
+ 5, 6});
+ std::vector<T> vC({22, 28,
+ 49, 64,
+ 76, 100,
+ 103, 136});
+
+ auto A = graph->param("A", {2, 2, 3}, inits::fromVector(vA));
+ auto B = graph->param("B", {3, 2}, inits::fromVector(vB));
+ auto C = dot(A, B);
+
+ CHECK(C->shape() == Shape({2, 2, 2}));
+
+ graph->forward();
+
+ C->val()->get(values);
+ CHECK(values == vC);
+ }
+
+ // Currently no support for fp16 or CPU - TODO use MKL for CPU, convert to float32 on the fly for fp16 via cast(x, Type::float16) or internally
+ if(device == DeviceType::gpu && floatType == Type::float32) {
+ SECTION("csr-dot product") {
+ graph->clear();
+ values.clear();
+ // CSR dot product, tested against dense product on the same values
+ std::vector<float> vS({1, 0, 0, 1, // sparse
+ 0, 0, 1, 1.5});
+ std::vector<float> vD({1, 2, 3, 1.2, 5.6, // dense
+ 4, 5, 6, 2.3, 6.7,
+ 7, 8, 9, 3.4, 7.8,
+ 1, 1, 2, 4.5, 8.9});
+ auto S = graph->param("S", { 2, 4 }, inits::fromVector(vS));
+ auto D = graph->param("D", { 4, 5 }, inits::fromVector(vD));
+ auto DT = graph->param("DT", { 5, 4 }, inits::fromVector(vD)); // example matrix with transposed dimensions
+ std::vector<float> SV; // create CSR version of S
+ std::vector<IndexType> SI, SO;
+ SO.push_back((IndexType)SI.size());
+ for (IndexType i = 0; i < S->shape()[0]; i++) {
+ for (IndexType j = 0; j < S->shape()[1]; j++) {
+ auto k = 4 * i + j;
+ if (vS[k] != 0) {
+ SV.push_back(vS[k]);
+ SI.push_back(j);
+ }
+ }
+ SO.push_back((IndexType)SI.size());
+ }
+
+ auto SxDd = dot(S, D);
+ auto STxSxDd = dot(S, SxDd, /*transA=*/true);
+ auto SxDs = csr_dot( // sparse x dense
+ S->shape(),
+ graph->constant({(int)SV.size()}, inits::fromVector(SV), floatType),
+ graph->constant({(int)SI.size()}, inits::fromVector(SI), Type::uint32),
+ graph->constant({(int)SO.size()}, inits::fromVector(SO), Type::uint32),
+ D);
+ auto STxSxDs = csr_dot( // transpose(sparse) x dense; we use result of previous since dimensions match
+ S->shape(),
+ graph->constant({(int)SV.size()}, inits::fromVector(SV), floatType),
+ graph->constant({(int)SI.size()}, inits::fromVector(SI), Type::uint32),
+ graph->constant({(int)SO.size()}, inits::fromVector(SO), Type::uint32),
+ SxDd, /*transS=*/true);
+
+ auto DTxSTd = dot(DT, S, /*transA=*/false, /*transB=*/true);
+ auto DTxSTxSd = dot(DTxSTd, S);
+ auto DTxSTs = dot_csr( // dense x sparse
+ DT,
+ S->shape(),
+ graph->constant({(int)SV.size()}, inits::fromVector(SV), floatType),
+ graph->constant({(int)SI.size()}, inits::fromVector(SI), Type::uint32),
+ graph->constant({(int)SO.size()}, inits::fromVector(SO), Type::uint32),
+ /*transS=*/true);
+ auto DTxSTxSs = dot_csr( // dense x transpose(sparse)
+ DTxSTd,
+ S->shape(),
+ graph->constant({(int)SV.size()}, inits::fromVector(SV), floatType),
+ graph->constant({(int)SI.size()}, inits::fromVector(SI), Type::uint32),
+ graph->constant({(int)SO.size()}, inits::fromVector(SO), Type::uint32));
+
+ CHECK(SxDs->shape() == SxDd->shape());
+ CHECK(STxSxDs->shape() == STxSxDd->shape());
+ CHECK(DTxSTs->shape() == DTxSTd->shape());
+ CHECK(DTxSTxSs->shape() == DTxSTxSd->shape());
+
+ graph->forward();
+
+ // dense and sparse operation results must be the same
+ SxDd ->val()->get(values2); SxDs ->val()->get(values); CHECK(values == values2);
+ STxSxDd ->val()->get(values2); STxSxDs ->val()->get(values); CHECK(values == values2);
+ DTxSTd ->val()->get(values2); DTxSTs ->val()->get(values); CHECK(values == values2);
+ DTxSTxSd->val()->get(values2); DTxSTxSs->val()->get(values); CHECK(values == values2);
+ }
+ }
+
+ SECTION("affine transformation") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ std::vector<T> vB({1, 2, 3, 4, 5, 6});
+ std::vector<T> vAff({24, 30, 51, 66, 78, 102, 105, 138});
+
+ auto A = graph->param("A", {4, 3}, inits::fromVector(vA));
+ auto B = graph->param("B", {3, 2}, inits::fromVector(vB));
+ auto C = graph->param("C", {4, 2}, inits::fromValue(2));
+
+ auto aff1 = affine(A, B, C);
+ auto aff2 = dot(A, B) + C;
+
+ graph->forward();
+
+ CHECK(aff1->shape() == Shape({4, 2}));
+ aff1->val()->get(values);
+ CHECK(values == vAff);
+
+ std::vector<T> values2;
+ CHECK(aff2->shape() == aff1->shape());
+ aff2->val()->get(values2);
+ CHECK(values2 == values);
+ }
+
+ SECTION("repeat") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({1, 2, 3, 4, 5, 6});
+ std::vector<T> vB({1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6});
+ std::vector<T> vC({1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
+
+ auto A = graph->param("A", {2,3}, inits::fromVector(vA));
+ auto I = repeat(A, 1, 0);
+ auto B = repeat(A, 2, 0);
+ auto C = repeat(A, 2, 1);
+ graph->forward();
+
+ CHECK(I->shape() == Shape({2, 3}));
+ I->val()->get(values);
+ CHECK(values == vA);
+
+ CHECK(B->shape() == Shape({4, 3}));
+ B->val()->get(values);
+ CHECK(values == vB);
+
+ CHECK(C->shape() == Shape({2, 6}));
+ C->val()->get(values);
+ CHECK(values == vC);
+ }
+
+ SECTION("flatten") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vIn({1, 2, 3, 4, 5, 6, 7, 8});
+
+ auto A = graph->param("A", {2, 4}, inits::fromVector(vIn));
+ auto Af = flatten(A);
+ auto B = graph->param("B", {2, 2, 1, 2}, inits::fromVector(vIn));
+ auto Bf = flatten(B);
+ graph->forward();
+
+ CHECK(Af->shape() == Shape({8}));
+ Af->val()->get(values);
+ CHECK(values == vIn);
+
+ CHECK(Bf->shape() == Shape({8}));
+ Bf->val()->get(values);
+ CHECK(values == vIn);
+ }
+
+ SECTION("rows selection from 2d matrix") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+
+ std::vector<IndexType> iB0({0}); // first row
+ std::vector<IndexType> iB1({0, 1, 2}); // several consecutive rows
+ std::vector<IndexType> iB2({0, 2}); // two nonconsecutive rows
+ std::vector<IndexType> iB3({2, 1}); // reversed order
+ std::vector<IndexType> iB4({1, 1}); // repeated rows
+ std::vector<IndexType> iB5({0, 1, 2, 3}); // identity
+ std::vector<IndexType> iB6({}); // empty
+ std::vector<T> vB0({1, 2, 3});
+ std::vector<T> vB1({1, 2, 3, 4, 5, 6, 7, 8, 9});
+ std::vector<T> vB2({1, 2, 3, 7, 8, 9});
+ std::vector<T> vB3({7, 8, 9, 4, 5, 6});
+ std::vector<T> vB4({4, 5, 6, 4, 5, 6});
+ std::vector<T> vB6;
+
+ auto A = graph->param("A", {4, 3}, inits::fromVector(vA));
+ auto B0 = rows(A, iB0);
+ auto B1 = rows(A, iB1);
+ auto B2 = rows(A, iB2);
+ auto B3 = rows(A, iB3);
+ auto B4 = rows(A, iB4);
+ auto B5 = rows(A, iB5);
+ auto B6 = rows(A, iB6);
+ graph->forward();
+
+ CHECK(B0->shape() == Shape({1, 3}));
+ B0->val()->get(values);
+ CHECK( values == vB0 );
+
+ CHECK(B1->shape() == Shape({3, 3}));
+ B1->val()->get(values);
+ CHECK( values == vB1 );
+
+ CHECK(B2->shape() == Shape({2, 3}));
+ B2->val()->get(values);
+ CHECK( values == vB2 );
+
+ CHECK(B3->shape() == Shape({2, 3}));
+ B3->val()->get(values);
+ CHECK( values == vB3 );
+
+ CHECK(B4->shape() == Shape({2, 3}));
+ B4->val()->get(values);
+ CHECK( values == vB4 );
+
+ CHECK(B5->shape() == Shape({4, 3}));
+ B5->val()->get(values);
+ CHECK( values == vA );
+
+ CHECK(B6->shape() == Shape({0, 3}));
+ B6->val()->get(values);
+ CHECK( values == vB6 );
+ }
+
+ SECTION("columns selection from 2d matrix") {
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+
+ std::vector<IndexType> iB0({0}); // first column
+ std::vector<IndexType> iB1({0, 1, 2}); // several consecutive columns
+ std::vector<IndexType> iB2({0, 2}); // two nonconsecutive columns
+ std::vector<IndexType> iB3({2, 1}); // reversed order
+ std::vector<IndexType> iB4({1, 1}); // repeated columns
+ std::vector<IndexType> iB5({0, 1, 2, 3}); // identity
+ std::vector<IndexType> iB6({}); // empty
+
+ std::vector<T> vB0({1, 5, 9});
+ std::vector<T> vB1({1, 2, 3, 5, 6, 7, 9, 10, 11});
+ std::vector<T> vB2({1, 3, 5, 7, 9, 11});
+ std::vector<T> vB3({3, 2, 7, 6, 11, 10});
+ std::vector<T> vB4({2, 2, 6, 6, 10, 10});
+ std::vector<T> vB6;
+
+ auto A = graph->param("A", {3, 4}, inits::fromVector(vA));
+ auto B0 = cols(A, iB0);
+ auto B1 = cols(A, iB1);
+ auto B2 = cols(A, iB2);
+ auto B3 = cols(A, iB3);
+ auto B4 = cols(A, iB4);
+ auto B5 = cols(A, iB5);
+ auto B6 = cols(A, iB6);
+ graph->forward();
+
+ CHECK(B0->shape() == Shape({3, 1}));
+ B0->val()->get(values);
+ CHECK( values == vB0 );
+
+ CHECK(B1->shape() == Shape({3, 3}));
+ B1->val()->get(values);
+ CHECK( values == vB1 );
+
+ CHECK(B2->shape() == Shape({3, 2}));
+ B2->val()->get(values);
+ CHECK( values == vB2 );
+
+ CHECK(B3->shape() == Shape({3, 2}));
+ B3->val()->get(values);
+ CHECK( values == vB3 );
+
+ CHECK(B4->shape() == Shape({3, 2}));
+ B4->val()->get(values);
+ CHECK( values == vB4 );
+
+ CHECK(B5->shape() == Shape({3, 4}));
+ B5->val()->get(values);
+ CHECK( values == vA );
+
+ CHECK(B6->shape() == Shape({3, 0}));
+ B6->val()->get(values);
+ CHECK( values == vB6 );
+ }
+
+ SECTION("relation of rows and columns selection using transpose") {
+ graph->clear();
+ values.clear();
+ std::vector<T> values2;
+
+ std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
+ std::vector<IndexType> idx({0, 1});
+
+ auto A1 = graph->param("4x3", {4,3}, inits::fromVector(vA));
+ auto B1 = rows(transpose(A1), idx);
+ auto C1 = transpose(cols(A1, idx));
+ auto A2 = graph->param("6x2", {6,2}, inits::fromVector(vA));
+ auto B2 = cols(transpose(A2), idx);
+ auto C2 = transpose(rows(A2, idx));
+ graph->forward();
+
+ CHECK(B1->shape() == C1->shape());
+ B1->val()->get(values);
+ C1->val()->get(values2);
+ CHECK( values == values2 );
+
+ values.clear();
+ values2.clear();
+
+ CHECK(B2->shape() == C2->shape());
+ B2->val()->get(values);
+ C2->val()->get(values2);
+ CHECK( values == values2 );
+ }
+
+ SECTION("select, step, slice operators") {
+ using IndexVector = std::vector<IndexType>;
+
+ graph->clear();
+ values.clear();
+
+ std::vector<T> vA({ 1, -2, 3,
+ -4, 5, -6,
+ 7, -8, 9,
+ -10, 11, -12});
+ std::vector<T> vC({ 1, -2, // C = np.array([1, -2, 3, -4, 5, -6, 7, -8, 9, -10, 11, -12]).reshape((2, 3, 2))
+ 3, -4,
+ 5, -6,
+
+ 7, -8,
+ 9, -10,
+ 11, -12 });
+ std::vector<T> vB1({1, -2, 3});
+ std::vector<T> vB2({1, -4, 7, -10});
+ std::vector<T> vB3({-2, 5, -8, 11});
+ std::vector<T> vB4({1, -2, 3, -4, 5, -6});
+ std::vector<T> vD1(vB4);
+ std::vector<T> vD2({5, -6, 11, -12});
+ std::vector<T> vD3({1, -2, 5, -6, 7, -8, 11, -12}); // C[:,(0,2),:]
+ std::vector<T> vD4({5, -6, 3, -4, 7, -8, 11, -12}); // [C[0,(2,1),:],C[1,(0,2),:]]
+ std::vector<T> vS1({7, -8, 9});
+ std::vector<T> vS2({-4, 5, -6, 7, -8, 9});
+ std::vector<T> vS3({7, -8, 9, -10, 11, -12});
+
+ auto A = graph->param("4x3", {4,3}, inits::fromVector(vA));
+ auto B1a = index_select(A, 0, IndexVector({0})); // always uses gather()
+ auto B1b = slice(A, 0, 0); // memory-consecutive view
+ auto B2 = slice(A, 1, 0); // not memory-consecutive
+ auto B3 = slice(A, -1, 1);
+ auto B4a = index_select(A, 0, IndexVector({0, 1}));
+ auto B4b = slice(A, 0, Slice(0, 2)); // this is memory-consecutive
+ auto B5 = slice(A, 0, Slice(0, 4)); // this is a no-op
+ CHECK(B1a->type() == "rows"); // actually optimized to rows()
+ CHECK(B1b->type() == "sliceView"); // must use view
+ CHECK(B2->type() == "gather"); // cannot use view
+ CHECK(B4a->type() == "rows");
+ CHECK(B4b->type() == "sliceView"); // must use view
+ CHECK(B5.get() == A.get()); // must be no-op
+
+ auto C = graph->param("2x3x2", {2, 3, 2}, inits::fromVector(vC));
+ auto D1 = slice(C, 0, 0);
+ auto D2 = slice(C, -2, 2);
+ auto D3 = index_select(C, 1, IndexVector({0, 2})); // C[:,(0,2),:]
+ CHECK(D1->type() == "sliceView");
+ CHECK(D2->type() == "gather");
+ // enable this once gather() supports batched indices:
+ auto D4 = gather(C, 1, graph->constant({2, 2, 1}, // [C[0,(2,1),:],C[1,(0,2),:]]
+ inits::fromVector(std::vector<IndexType>{
+ 2, 1,
+ 0, 2 }),
+ Type::uint32));
+
+ auto S1 = slice(A, 0, 2);
+ auto S2 = narrow(A, 0, 1, 2);
+ auto S3 = slice(A, 0, Slice(-2, Slice::END));
+
+ graph->forward();
+
+ CHECK(B1a->shape() == Shape({1, 3})); B1a->val()->get(values); CHECK( values == vB1 );
+ CHECK(B1b->shape() == Shape({1, 3})); B1b->val()->get(values); CHECK( values == vB1 );
+ CHECK(B2->shape() == Shape({4, 1})); B2->val()->get(values); CHECK( values == vB2 );
+ CHECK(B3->shape() == Shape({4, 1})); B3->val()->get(values); CHECK( values == vB3 );
+ CHECK(B4a->shape() == Shape({2, 3})); B4a->val()->get(values); CHECK( values == vB4 );
+ CHECK(B4b->shape() == Shape({2, 3})); B4b->val()->get(values); CHECK( values == vB4 );
+
+ CHECK(D1->shape() == Shape({1, 3, 2})); D1->val()->get(values); CHECK( values == vD1 );
+ CHECK(D2->shape() == Shape({2, 1, 2})); D2->val()->get(values); CHECK( values == vD2 );
+ CHECK(D3->shape() == Shape({2, 2, 2})); D3->val()->get(values); CHECK( values == vD3 );
+ CHECK(D4->shape() == Shape({2, 2, 2})); D4->val()->get(values); CHECK( values == vD4 );
+
+ CHECK(S1->shape() == Shape({1,3})); S1->val()->get(values); CHECK(values == vS1);
+ CHECK(S2->shape() == Shape({2,3})); S2->val()->get(values); CHECK(values == vS2);
+ CHECK(S3->shape() == Shape({2,3})); S3->val()->get(values); CHECK(values == vS3);
+ }
+
+ SECTION("rows/cols as gather operations") {
+ graph->clear();
+ values.clear();
+ std::vector<T> values2;
+
+
+ std::vector<T> vA({0, .3333, -.2, -.3, 0, 4.5, 5.2, -10, 101.45, -100.05, 0, 1.05e-5});
+ std::vector<IndexType> indices({0, 2});
+
+ auto A = graph->param("4x3", {4, 3}, inits::fromVector(vA));
+ auto B1 = rows(A, indices);
+ auto B2 = gather(A, 0, graph->indices(indices, A, 0));
+ auto C1 = cols(A, indices);
+ auto C2 = gather(A, 1, graph->indices(indices, A, 1));
+
+ graph->forward();
+
+ CHECK(B1->shape() == B2->shape());
+ B1->val()->get(values);
+ B2->val()->get(values2);
+ CHECK( values == values2 );
+
+ CHECK(C1->shape() == C2->shape());
+ C1->val()->get(values);
+ C2->val()->get(values2);
+ CHECK( values == values2 );
+ }
+}
+
+#ifdef CUDA_FOUND
+TEST_CASE("Expression graph supports basic math operations (gpu)", "[operator]") {
+ tests<float>(DeviceType::gpu);
+}
+
+#if COMPILE_FP16
+TEST_CASE("Expression graph supports basic math operations (gpu fp16)", "[operator]") {
+ tests<float16>(DeviceType::gpu, Type::float16);
+}
+#endif
+#endif
+
+#ifdef BLAS_FOUND
+TEST_CASE("Expression graph supports basic math operations (cpu)", "[operator]") {
+ tests<float>(DeviceType::cpu);
+}
+#endif
+
+#ifdef BLAS_FOUND
+#ifdef CUDA_FOUND
+
+TEST_CASE("Compare aggregate operator", "[graph]") {
+ auto floatApprox = [](float x, float y) -> bool { return x == Approx(y).epsilon(0.01); };
+
+ Config::seed = 1234;
+
+ std::vector<float> initc;
+ std::vector<float> inita;
+
+ {
+ auto graph = New<ExpressionGraph>();
+ graph->setDevice({0, DeviceType::cpu});
+ graph->reserveWorkspaceMB(40);
+
+ auto chl = graph->param("1x10x512x2048", {1, 10, 512, 2048}, inits::normal());
+ auto adj = graph->param("1x1x512x2048", {1, 1, 512, 2048}, inits::normal());
+ graph->forward();
+
+ chl->val()->get(initc);
+ adj->val()->get(inita);
+ }
+
+ SECTION("initializing with zero (cpu)") {
+ std::vector<float> values1;
+ std::vector<float> values2;
+
+ auto graph1 = New<ExpressionGraph>();
+ graph1->setDevice({0, DeviceType::cpu});
+ graph1->reserveWorkspaceMB(40);
+
+ auto graph2 = New<ExpressionGraph>();
+ graph2->setDevice({0, DeviceType::gpu});
+ graph2->reserveWorkspaceMB(40);
+
+ auto chl1 = graph1->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
+ auto adj1 = graph1->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
+ auto prod1 = scalar_product(chl1, adj1, -1);
+ graph1->forward();
+
+ auto chl2 = graph2->param("1x10x512x2048", {1, 10, 512, 2048}, inits::fromVector(initc));
+ auto adj2 = graph2->param("1x1x512x2048", {1, 1, 512, 2048}, inits::fromVector(inita));
+ auto prod2 = scalar_product(chl2, adj2, -1);
+ graph2->forward();
+
+ prod1->val()->get(values1);
+ prod2->val()->get(values2);
+
+ CHECK( std::equal(values1.begin(), values1.end(), values2.begin(), floatApprox) );
+ }
+}
+
+ #endif
+ #endif \ No newline at end of file
diff --git a/src/tests/rnn_tests.cpp b/src/tests/units/rnn_tests.cpp
index 87639c86..56a2d1fd 100644
--- a/src/tests/rnn_tests.cpp
+++ b/src/tests/units/rnn_tests.cpp
@@ -1,15 +1,30 @@
#include "catch.hpp"
#include "marian.h"
+#ifdef CUDA_FOUND
+#include "tensors/gpu/backend.h"
+#endif
+
#include "rnn/rnn.h"
#include "rnn/constructors.h"
using namespace marian;
-void tests(DeviceType type) {
- auto floatApprox = [](float x, float y) { return x == Approx(y).epsilon(0.01); };
+template <typename T>
+void tests(DeviceType type, Type floatType = Type::float32) {
+
+// Checking for FP16 support and skipping if not supported.
+#ifdef CUDA_FOUND
+ if(type == DeviceType::gpu && floatType == Type::float16) {
+ auto gpuBackend = New<gpu::Backend>(DeviceId({0, type}), /*seed=*/1234);
+ auto cudaCompute = gpuBackend->getCudaComputeCapability();
+ if(cudaCompute.major < 6) return;
+ }
+#endif
+
+ auto floatApprox = [](T x, T y) { return x == Approx(y).epsilon(0.01); };
- Words vWords = {
+ std::vector<IndexType> vWords = {
43, 2, 83, 78,
6, 38, 80, 40,
40, 70, 26, 60,
@@ -20,7 +35,7 @@ void tests(DeviceType type) {
0, 0, 0, 0
};
- std::vector<float> vMask = {
+ std::vector<T> vMask = {
1, 1, 1, 1,
1, 1, 1, 1,
1, 1, 1, 1,
@@ -35,21 +50,22 @@ void tests(DeviceType type) {
Config::seed = 1234;
auto graph = New<ExpressionGraph>();
+ graph->setDefaultElementType(floatType);
graph->setDevice({0, type});
graph->reserveWorkspaceMB(16);
- std::vector<float> values;
+ std::vector<T> values;
auto input = graph->constant({4, 1, 4},
- inits::glorot_uniform);
+ inits::glorotUniform());
- auto rnn = rnn::rnn(graph) //
- ("prefix", "rnntest") //
- ("type", "tanh") //
- ("dimInput", 4) //
- ("dimState", 4) //
- .push_back(rnn::cell(graph)) //
- .construct();
+ auto rnn = rnn::rnn() //
+ ("prefix", "rnntest") //
+ ("type", "tanh") //
+ ("dimInput", 4) //
+ ("dimState", 4) //
+ .push_back(rnn::cell()) //
+ .construct(graph);
auto output = rnn->transduce(input);
@@ -58,14 +74,14 @@ void tests(DeviceType type) {
CHECK(output->shape() == Shape({4, 1, 4}));
#ifdef CUDA_FOUND
- std::vector<float> vOutput({
+ std::vector<T> vOutput({
0.637288, 0.906478, 0.603604, 0.152291,
-0.5333, -0.854558, 0.458454, -0.179582,
0.736857, 0.964425, 0.43848, 0.0261131,
-0.533659, -0.733491, -0.953666, -0.965717
});
#else
- std::vector<float> vOutput({
+ std::vector<T> vOutput({
-0.523228, 0.645143, 0.430939, 0.273439,
-0.747293, 0.131912, 0.115222, 0.363874,
0.367535, -0.819531, -0.313036, -0.387701,
@@ -82,10 +98,11 @@ void tests(DeviceType type) {
Config::seed = 1234;
auto graph = New<ExpressionGraph>();
+ graph->setDefaultElementType(floatType);
graph->setDevice({0, type});
graph->reserveWorkspaceMB(16);
- std::vector<float> values;
+ std::vector<T> values;
auto buildRnn = [&graph] (std::string prefix,
Expr input, Expr mask,
@@ -117,7 +134,7 @@ void tests(DeviceType type) {
auto backward = type == "alternating" ? rnn::dir::alternating_backward
: rnn::dir::backward;
- auto rnnFw = rnn::rnn(graph) //
+ auto rnnFw = rnn::rnn() //
("type", cellType) //
("direction", forward) //
("dimInput", dimEmb) //
@@ -126,7 +143,7 @@ void tests(DeviceType type) {
("skip", skip);
for(int i = 1; i <= first; ++i) {
- auto stacked = rnn::stacked_cell(graph);
+ auto stacked = rnn::stacked_cell();
for(int j = 1; j <= cellDepth; ++j) {
std::string paramPrefix = prefix + "_bi";
if(i > 1)
@@ -134,21 +151,22 @@ void tests(DeviceType type) {
if(i > 1 || j > 1)
paramPrefix += "_cell" + std::to_string(j);
- stacked.push_back(rnn::cell(graph)("prefix", paramPrefix));
+ stacked.push_back(rnn::cell()("prefix", paramPrefix));
}
rnnFw.push_back(stacked);
}
- auto rnnBw = rnn::rnn(graph) //
- ("type", cellType) //
- ("direction", backward) //
- ("dimInput", dimEmb) //
- ("dimState", dimRnn) //
- ("layer-normalization", layerNorm) //
+
+ auto rnnBw = rnn::rnn() //
+ ("type", cellType) //
+ ("direction", backward) //
+ ("dimInput", dimEmb) //
+ ("dimState", dimRnn) //
+ ("layer-normalization", layerNorm) //
("skip", skip);
for(int i = 1; i <= first; ++i) {
- auto stacked = rnn::stacked_cell(graph);
+ auto stacked = rnn::stacked_cell();
for(int j = 1; j <= cellDepth; ++j) {
std::string paramPrefix = prefix + "_bi_r";
if(i > 1)
@@ -156,13 +174,13 @@ void tests(DeviceType type) {
if(i > 1 || j > 1)
paramPrefix += "_cell" + std::to_string(j);
- stacked.push_back(rnn::cell(graph)("prefix", paramPrefix));
+ stacked.push_back(rnn::cell()("prefix", paramPrefix));
}
rnnBw.push_back(stacked);
}
- auto context = concatenate({rnnFw->transduce(input, mask),
- rnnBw->transduce(input, mask)},
+ auto context = concatenate({rnnFw.construct(graph)->transduce(input, mask),
+ rnnBw.construct(graph)->transduce(input, mask)},
/*axis =*/ input->shape().size() - 1);
if(second > 0) {
@@ -170,25 +188,25 @@ void tests(DeviceType type) {
// previous bidirectional RNN through multiple layers
// construct RNN first
- auto rnnUni = rnn::rnn(graph) //
- ("type", cellType) //
- ("dimInput", 2 * dimRnn) //
- ("dimState", dimRnn) //
- ("layer-normalization", layerNorm) //
+ auto rnnUni = rnn::rnn() //
+ ("type", cellType) //
+ ("dimInput", 2 * dimRnn) //
+ ("dimState", dimRnn) //
+ ("layer-normalization", layerNorm) //
("skip", skip);
for(int i = first + 1; i <= second + first; ++i) {
- auto stacked = rnn::stacked_cell(graph);
+ auto stacked = rnn::stacked_cell();
for(int j = 1; j <= cellDepth; ++j) {
std::string paramPrefix = prefix + "_l" + std::to_string(i) + "_cell"
+ std::to_string(j);
- stacked.push_back(rnn::cell(graph)("prefix", paramPrefix));
+ stacked.push_back(rnn::cell()("prefix", paramPrefix));
}
rnnUni.push_back(stacked);
}
// transduce context to new context
- context = rnnUni->transduce(context);
+ context = rnnUni.construct(graph)->transduce(context);
}
return context;
};
@@ -199,11 +217,11 @@ void tests(DeviceType type) {
auto emb = graph->param("Embeddings",
{128, dimEmb},
- inits::glorot_uniform);
+ inits::glorotUniform());
auto input = reshape(rows(emb, vWords), {dimTime, dimBatch, dimEmb});
auto mask = graph->constant({dimTime, dimBatch, 1},
- inits::from_vector(vMask));
+ inits::fromVector(vMask));
int dimRnn = 32;
auto context1 = buildRnn("enc1", input, mask, dimRnn);
@@ -225,7 +243,7 @@ void tests(DeviceType type) {
CHECK(contextSum1->shape() == Shape({dimTime, dimBatch, 1}));
#ifdef CUDA_FOUND
- std::vector<float> vContextSum1({
+ std::vector<T> vContextSum1({
-0.110829, -0.510232, 0.265193, 0.194025,
-0.242112, 0.185029, 0.0530527, 0.359336,
0.60218, 0.46511, -0.240092, 0.100453,
@@ -236,7 +254,7 @@ void tests(DeviceType type) {
0.360119, 0.422752, 0.55825, 0.0469481
});
#else
- std::vector<float> vContextSum1({
+ std::vector<T> vContextSum1({
-0.0674548, 0.383986, -0.613574, 0.226154,
-0.819571, 0.47317, -1.39324, -0.401005,
-0.24099, 0.64791, -0.120434, -0.818529,
@@ -255,7 +273,7 @@ void tests(DeviceType type) {
CHECK(contextSum2->shape() == Shape({dimTime, dimBatch, 1}));
#ifdef CUDA_FOUND
- std::vector<float> vContextSum2({
+ std::vector<T> vContextSum2({
-0.0282316, 0.0219561, -0.012136, 0.0206684,
-0.0755229, 0.00091961, 0.0206883, 0.0176061,
-0.0272491, 0.0833994, 0.0279131, 0.0170246,
@@ -266,7 +284,7 @@ void tests(DeviceType type) {
0.123207, 0.0774718, 0.0741554, 0.0548368
});
#else
- std::vector<float> vContextSum2({
+ std::vector<T> vContextSum2({
0.0193405, -0.0580973, -0.0213983, 0.0381918,
-0.0135365, -0.0934286, -0.0171637, 0.0198686,
-0.0102693, -0.0865369, -0.0160779, 0.0393178,
@@ -284,7 +302,7 @@ void tests(DeviceType type) {
//CHECK(context3->shape() == Shape({dimBatch, 2 * dimRnn, dimTime}));
//CHECK(contextSum3->shape() == Shape({dimBatch, 1, dimTime}));
//
- //std::vector<float> vContextSum3({
+ //std::vector<T> vContextSum3({
// 1.135, 2.40939, 2.37631, 2.03765,
// 0.0583942, -4.89241, 5.31731, -1.52973,
// 3.52754, 1.02098, -4.05162, -1.11594,
@@ -311,12 +329,18 @@ void tests(DeviceType type) {
#ifdef CUDA_FOUND
TEST_CASE("Model components, RNN etc. (gpu)", "[model]") {
- tests(DeviceType::gpu);
+ tests<float>(DeviceType::gpu);
+}
+
+#if COMPILE_FP16
+TEST_CASE("Model components, RNN etc. (gpu, fp16)", "[model]") {
+ tests<float16>(DeviceType::gpu, Type::float16);
}
#endif
+#endif
#ifdef BLAS_FOUND
TEST_CASE("Model components, RNN etc. (cpu)", "[model]") {
- tests(DeviceType::cpu);
+ tests<float>(DeviceType::cpu);
}
#endif
diff --git a/src/tests/run_tests.cpp b/src/tests/units/run_tests.cpp
index 0c7c351f..0c7c351f 100644
--- a/src/tests/run_tests.cpp
+++ b/src/tests/units/run_tests.cpp
diff --git a/src/training/communicator.cpp b/src/training/communicator.cpp
index e158fcb2..98b7be7f 100755..100644
--- a/src/training/communicator.cpp
+++ b/src/training/communicator.cpp
@@ -38,7 +38,7 @@ Ptr<ICommunicator> createCommunicator(
}
// the actual implementation is inside communicator.cu
- return New<NCCLCommunicator>(graphs, mpi);
+ return New<NCCLCommunicator>(graphs, mpi);
#else // no CUDA or no NCCL
noNccl; // (unused)
return New<DefaultCommunicator>(graphs, mpi);
@@ -72,20 +72,19 @@ class MPIWrapper : public IMPIWrapper
public:
MPIWrapper(bool multiThreaded) {
- int requiredThreadingMode = multiThreaded ? MPI_THREAD_MULTIPLE : MPI_THREAD_SINGLE;
+ int requiredThreadingMode = multiThreaded ? MPI_THREAD_MULTIPLE : MPI_THREAD_FUNNELED; // FUNNELED means only one thread ever calls MPI
int argc = 1; char* argv[] = { const_cast<char*>("this.exe") }; char** argvp = argv; // dummy argc/argv since MPI_Init needs something here
int providedThreadingMode;
HANDLE_MPI_ERROR(MPI_Init_thread(&argc, &argvp, MPI_THREAD_MULTIPLE, &providedThreadingMode));
MPI_Comm_set_errhandler(MPI_COMM_WORLD, MPI_ERRORS_RETURN); // have errors reported as return codes
- ABORT_IF(
- providedThreadingMode < requiredThreadingMode,
- "Your version of MPI does not support multi-threaded communication.");
-
MPI_Comm_size(MPI_COMM_WORLD, &comm_world_size_);
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank_);
+ ABORT_IF(comm_world_size_ > 1 && providedThreadingMode < requiredThreadingMode,
+ "Your version of MPI does not support multi-threaded communication.");
+
// patch logging pattern to include the MPI rank, so that we can associate error messages with nodes
if (numMPIProcesses() > 1) {
std::string rankStr = std::to_string(MPIWrapper::myMPIRank());
@@ -124,6 +123,8 @@ public:
HANDLE_MPI_ERROR(MPI_Recv(buf, (int)count, datatype, (int)sourceRank, tag, comm, status));
}
virtual void allReduce(const void* sendbuf, void* recvbuf, size_t count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) const override {
+ if (sendbuf == recvbuf)
+ sendbuf = MPI_IN_PLACE; // MSMPI requires this
HANDLE_MPI_ERROR(MPI_Allreduce(sendbuf, recvbuf, (int)count, datatype, op, comm));
}
virtual void finalize() override {
@@ -140,7 +141,7 @@ public:
FakeMPIWrapper(bool) {
LOG(warn, "Compiled without MPI support. Falling back to FakeMPIWrapper");
}
-
+ virtual ~FakeMPIWrapper() {}
virtual size_t myMPIRank() const override { return 0; };
virtual size_t numMPIProcesses() const override { return 1; };
@@ -168,7 +169,7 @@ public:
// to only accept one parameter, and remove this error check can be removed.
ABORT_IF(sendbuf != recvbuf, "FakeMPIWrapper::allReduce() only implemented for in-place operation"); // otherwise it's not a no-op, we must copy data
}
-#pragma warning(push)
+#pragma warning(pop)
virtual void finalize() override { }
};
diff --git a/src/training/communicator.h b/src/training/communicator.h
index eef136f2..47274491 100755..100644
--- a/src/training/communicator.h
+++ b/src/training/communicator.h
@@ -37,8 +37,8 @@ public:
virtual void foreach(const ForeachFunc& func, bool parallel = true) const = 0;
// @TODO: We probably can still share foreach() between the two implementations. Just need to move some helper functions from the .cu file.
- virtual void scatterReduce() const = 0; // reduce param gradients and scatter into gradient shards
- virtual void allGather() const = 0; // redistribute value shards into param values
+ virtual void scatterReduceAndResetGrads() const = 0; // reduce param gradients and scatter into gradient shards
+ virtual void allGatherParams() const = 0; // redistribute value shards into param values
virtual void swapParams(const std::vector<Tensor>& paramShards) const = 0;
@@ -153,14 +153,11 @@ public:
t.join();
}
- void scatterReduce() const override {
+ void scatterReduceAndResetGrads() const override {
const_cast<DefaultCommunicator*>(this)->lazyInit();
- int totalSize = (int)graphs_[0]->params()->vals()->size();
- int shardSize = (int)ceil(totalSize / (float)graphs_.size());
-
// Gather gradients from different devices into current gradient shards
- auto scatter = [this, shardSize](size_t idx, size_t begin, size_t end) {
+ auto scatter = [this](size_t idx, size_t begin, size_t end) {
auto curGrad = graphs_[idx]->params()->grads()->subtensor(begin, end-begin);
// collect and sum gradients
@@ -175,15 +172,23 @@ public:
}
};
+ // reset gradients outside current shard
+ auto reset = [this](size_t idx, size_t begin, size_t end) {
+ auto grad = graphs_[idx]->params()->grads();
+ if (begin > 0)
+ grad->subtensor(0, begin)->set(0);
+ if (end < grad->size())
+ grad->subtensor(end, grad->size()-end)->set(0);
+ };
+
foreach(scatter);
+ foreach(reset);
}
- void allGather() const override {
- int totalSize = (int)graphs_[0]->params()->vals()->size();
- int shardSize = (int)ceil(totalSize / (float)graphs_.size());
+ void allGatherParams() const override {
// Update all graphs with parameter shard
- auto gather = [this, shardSize](size_t idx, size_t begin, size_t end) {
+ auto gather = [this](size_t idx, size_t begin, size_t end) {
auto getShard = [&](Ptr<ExpressionGraph> graph) {
return graph->params()->vals()->subtensor(begin, end-begin);
};
@@ -203,7 +208,6 @@ public:
void swapParams(const std::vector<Tensor>& paramShards) const override {
// Update all graphs with parameter shard
-
auto gather = [this, paramShards](size_t idx, size_t begin, size_t end) {
ABORT_IF(end - begin != paramShards[idx]->size(), "inconsistent shard size (swapParams, [{}], {} vs {})??", idx, end-begin, paramShards[idx]->size());
// Copy parameter shard to each graph, apart from last graph
diff --git a/src/training/communicator_nccl.h b/src/training/communicator_nccl.h
index bf1f07df..55b48528 100755..100644
--- a/src/training/communicator_nccl.h
+++ b/src/training/communicator_nccl.h
@@ -4,10 +4,11 @@
#include "3rd_party/threadpool.h"
#include "tensors/gpu/cuda_helpers.h"
+#include "common/timer.h"
+
// Generated by NCCL make files in build/nccl/include;
// include dir has been set in CMake files. NCCL add version number etc.
#include "nccl.h"
-
#include <cuda_runtime.h>
#if (NCCL_MAJOR<3 || NCCL_MINOR<2)
@@ -44,6 +45,14 @@ private:
}
}
+ void synchronizeAllOnNullStream() const {
+ for (int i = 0; i < graphs_.size(); ++i) {
+ auto backend = graphs_[i]->params()->vals()->getBackend();
+ backend->setDevice();
+ backend->synchronize(); // note: synchronize() does not set the device by itself
+ }
+ }
+
std::string mpiIdStr() const { // (for logging)
return mpi_ ? mpi_->idStr() : "";
}
@@ -150,6 +159,16 @@ public:
CUDA_CHECK(cudaStreamCreate(&streams_[i]));
}
+ // Note: due to a bug in NCCL 2.3.5, NCCL's allocation of shared memory intermittently fails with
+ // Failed, NCCL error 2 'unhandled system error' - ncclGroupEnd()
+ // include/shm.h:26 NCCL WARN Unable to allocate shared memory (4263936 bytes) : Interrupted system call
+ // This is caused by SIGPROF signals being raised, causing EINTR, which NCCL does not handle.
+ // Reported as Issue #137 on the NCCL Github, and supposedly fixed for 2.3.7 (to be verified).
+ // To work around, we disable the SIGPROF signal during NCCL initialization.
+#define SIG_BAD 27 // SIGPROF
+ BlockSignal blockThread(SIG_BAD, pthread_sigmask); // Note: I don't know yet which of these two makes the difference.
+ BlockSignal blockProc(SIG_BAD, sigprocmask); // So for now just block both.
+
// set up NCCL
// Since we want to use MPI, we cannot use NCCL's handy convenience function. Instead, we must go the laborious route.
// cf. https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/index.html#multidevprothrd
@@ -160,35 +179,19 @@ public:
NCCL_CHECK(ncclGetUniqueId(&uniqueId));
if (mpi_) {
- //LOG(info, "[{}] before bcast", mpiIdStr());
static_assert(sizeof(uniqueId) == NCCL_UNIQUE_ID_BYTES, "wrong NCCL_UNIQUE_ID_BYTES??"); // (this value is used in NVidia examples)
mpi_->bCast(&uniqueId, sizeof(uniqueId), MPI_BYTE, 0);
- //LOG(info, "[{}] after bcast", mpiIdStr());
}
- //mpiBarrier(); // should not be needed since bCast is a barrier
-
- // Note: due to a bug in NCCL 2.3.5, NCCL's allocation of shared memory intermittently fails with
- // Failed, NCCL error 2 'unhandled system error' - ncclGroupEnd()
- // include/shm.h:26 NCCL WARN Unable to allocate shared memory (4263936 bytes) : Interrupted system call
- // This is caused by SIGPROF signals being raised, causing EINTR, which NCCL does not handle.
- // Reported as Issue #137 on the NCCL Github.
- // To work around, we disable the SIGPROF signal during NCCL initialization.
-#define SIG_BAD 27 // SIGPROF
- BlockSignal blockThread(SIG_BAD, pthread_sigmask); // Note: I don't know yet which of these two makes the difference.
- BlockSignal blockProc(SIG_BAD, sigprocmask); // So for now just block both.
-
groupStart();
for (int localDeviceIndex = 0; localDeviceIndex < devices_.size(); localDeviceIndex++) {
CUDA_CHECK(cudaSetDevice(devices_[localDeviceIndex]));
- //LOG(info, "[{}] ncclCommInitRank {} out of {}: GPU[{}]", mpiIdStr(), myNcclRank(localDeviceIndex), numNcclRanks(), localDeviceIndex);
NCCL_CHECK(ncclCommInitRank(&comms_[localDeviceIndex], numNcclRanks(), uniqueId, myNcclRank(localDeviceIndex)));
- //LOG(info, "[{}] done ncclCommInitRank {} out of {}, GPU[{}]", mpiIdStr(), myNcclRank(localDeviceIndex), numNcclRanks(), localDeviceIndex);
}
groupEnd();
mpiBarrier(); // (synchronize the log messages)
- LOG(debug, "NCCLCommunicator constructed successfully for {}", mpiIdStr());
+ LOG(info, "[comm] NCCLCommunicator constructed successfully.");
mpiBarrier(); // (synchronize the log messages)
}
@@ -206,61 +209,59 @@ public:
for(size_t i = 0; i < graphs_.size(); ++i) {
size_t begin, end; std::tie
(begin, end) = localShardRange(i);
- //std::cerr << "[" << mpiIdStr() << "] foreach " << begin << " " << end << std::endl;
-try{
if (parallel)
threadResults_[i] = threadPool_.enqueue(func, i, begin, end);
- //group.emplace_back(func, i, begin, end);
- //threadPool_.enqueue([&](size_t i){
- // func(i, begin, end);
- //}, i);
else
func(i, begin, end);
-}
-catch (const std::exception& e) // something leaks thread handles
-{
- // keeping this around, in case the error still happens --@TODO: remove once this has not been observed anymore
- LOG(info, "caught exception in foreach {}", i);
- system("ps -T -A");
- throw;
-}
}
if (parallel)
for(size_t i = 0; i < graphs_.size(); ++i)
threadResults_[i].wait();
- //for(auto& t : group) // (note: group is empty is not parallel)
- // t.join();
}
- void scatterReduce() const override {
+ void scatterReduceAndResetGrads() const override {
+ synchronizeAllOnNullStream();
+
groupStart();
for(int i = 0; i < graphs_.size(); ++i) {
size_t begin, end; std::tie
(begin, end) = localShardRange(i);
- //std::cerr << "[" << mpiIdStr() << "] scatterReduce " << begin << " " << end << std::endl;
auto grads = graphs_[i]->params()->grads();
const auto* sendbuf = grads->data();
auto* recvbuf = grads->subtensor(begin, end-begin)->data();
size_t bufsize = shardSize();
+ ABORT_IF(grads->subtensor(begin, end-begin)->size() != bufsize, "unexpected subtensor size??");
NCCL_CHECK(ncclReduceScatter(sendbuf, recvbuf, bufsize, ncclFloat, ncclSum, comms_[i], streams_[i]));
}
groupEnd();
- //std::cerr << "scatterReduce submitted" << std::endl;
synchronizeAll();
- //std::cerr << "scatterReduce completed" << std::endl;
+
+ // reset gradients
+ // In the future, we can keep quantization residuals here straight in the grads themselves.
+ auto resetGrads = [&](size_t i, size_t begin, size_t end) {
+ auto grads = graphs_[i]->params()->grads();
+ auto size = grads->size();
+ // reset everything outside the shard that we reduce in
+ if (begin > 0)
+ grads->subtensor(0, begin)->set(0.f);
+ if (end < size)
+ grads->subtensor(end, size - end)->set(0.f);
+ };
+ foreach(resetGrads);
}
// This distributes all 64 model shards to all 64 GPUs.
- // @TODO: For unknown reasons, this takes longer than any other operation incl. scatterReduce().
+ // @TODO: For unknown reasons, this takes longer than any other operation incl. scatterReduceAndResetGrads().
// But both should have the same number of data transfers of the same size.
- void allGather() const override {
+ void allGatherParams() const override {
+ synchronizeAllOnNullStream();
+
groupStart();
for(int i = 0; i < graphs_.size(); ++i) {
size_t begin, end; std::tie
(begin, end) = localShardRange(i);
- //std::cerr << "[" << mpiIdStr() << "] allGather " << begin << " " << end << std::endl;
auto vals = graphs_[i]->params()->vals();
const auto* sendbuf = vals->subtensor(begin, end-begin)->data();
@@ -281,14 +282,12 @@ catch (const std::exception& e) // something leaks thread handles
auto distributedParams = gatherState([&](size_t localDeviceIndex) {
std::vector<float> tmp;
distributedParamShards[localDeviceIndex]->get(tmp);
- //LOG(info, "[{}] swapParams.getFn({}) -> size {}, ({}, {}, {}, ...)", mpiIdStr(), localDeviceIndex, tmp.size(), tmp[0], tmp[1], tmp[2]);
return tmp;
});
// Now all MPI processes hold an identical copy of a concatenation of all distributedParamShards[] across local and remote devices.
std::vector<float> localParams;
graphs_[0]->params()->vals()->get(localParams);
// Now all MPI processes hold an identical copy of params() (remember, we assumed all devices hold the same params()).
- //LOG(info, "[{}] swapParams: distributedParams.size = {}, localParams.size = {}", mpiIdStr(), distributedParams.size(), localParams.size());
ABORT_IF(distributedParams.size() != localParams.size(), "distributed sharded and local params have different size??");
// swap
@@ -331,7 +330,6 @@ catch (const std::exception& e) // something leaks thread handles
tmp = getFn(localDeviceIndex);
localData.insert(localData.end(), tmp.begin(), tmp.end());
}
- //LOG(info, "[{}] gatherState: localData.size = {}", mpiIdStr(), localData.size());
// second, concatenate across MPI processes
// Note that all local devices occupy consecutive ncclRanks in order.
std::vector<float> data;
diff --git a/src/training/exponential_smoothing.h b/src/training/exponential_smoothing.h
index bc2f2761..30e64430 100755..100644
--- a/src/training/exponential_smoothing.h
+++ b/src/training/exponential_smoothing.h
@@ -3,6 +3,7 @@
#include "common/definitions.h"
#include "functional/functional.h"
#include "tensors/tensor_operators.h"
+#include "optimizers/optimizers.h"
namespace marian {
@@ -12,18 +13,35 @@ namespace marian {
*/
class ExponentialSmoothing {
public:
- ExponentialSmoothing(float decay = 0.0f)
- : mvAvg_{decay > 0}, mvDecay_{decay} {}
+ ExponentialSmoothing(Ptr<Options> options) {
+ mvDecayBy_ = options->get<float>("exponential-smoothing");
+ refBatchTrgWords_ = options->get<size_t>("mini-batch-words-ref"); // adjust as if our MB size (in target labels) was this value
+ mvAvg_ = (mvDecayBy_ > 0);
+ }
protected:
- void updateAvgParams(Tensor paramsAvg, Tensor params, size_t batches) {
+ void updateAvgParams(Tensor paramsAvg, Tensor params, size_t batches, size_t actualBatchTrgWords = OptimizerBase::mbSizeNotProvided) {
+ double beta = 1. - mvDecayBy_;
+ // correction term if batch size is different from what mvDecayBy_ was specified for
+ if (refBatchTrgWords_) {
+ LOG_ONCE(info, "Exponential smoothing gets automatically adjusted as if update size was {} target words", refBatchTrgWords_);
+ ABORT_IF(actualBatchTrgWords == OptimizerBase::mbSizeNotProvided,
+ "This graph-group type does not support reference batch size specification for exponential-smoothing");
+ beta = pow(beta, (double)actualBatchTrgWords / (double)refBatchTrgWords_);
+ // If actual size differs from reference, then try to estimate the equivalent number of batches.
+ // E.g. if MB size is growing over time, then this is an overestimate, which would diminish the
+ // effect overly quickly, but in a range where that should be OK.
+ batches = std::max(batches, batches * actualBatchTrgWords / refBatchTrgWords_); // @BUGBUG: Does not consider that batch size is changing
+ }
+ // reduce effect of decay parameter in early training stages
+ float decayBy = std::max(1.f - (float)beta,
+ 1.f - (float)(batches + 1) / (float)(batches + 10));
using namespace functional;
- float decay = std::max(mvDecay_,
- 1.f - (float)(batches + 1) / (float)(batches + 10));
- Element(_1 = ((1.f - decay) * _1) + (decay * _2), paramsAvg, params);
+ Element(_1 = ((1.f - decayBy) * _1) + (decayBy * _2), paramsAvg, params);
}
bool mvAvg_{false};
- float mvDecay_{1e-4f};
+ float mvDecayBy_{1e-4f}; // decay prior model by this factor
+ size_t refBatchTrgWords_{0}; // mvDecayBy_ is specified for this batch size (in target words) (0 means not specified)
};
} // namespace marian
diff --git a/src/training/gradient_dropping/sparse_tensor.h b/src/training/gradient_dropping/sparse_tensor.h
index 5effa3d7..652a77ec 100755..100644
--- a/src/training/gradient_dropping/sparse_tensor.h
+++ b/src/training/gradient_dropping/sparse_tensor.h
@@ -118,7 +118,7 @@ public:
}
// Convert a tensor into a sparse tensor format
- void fromDense(Tensor t) {
+ void fromDense(Tensor MAYBE_UNUSED t) {
if(backend_->getDeviceId().type == DeviceType::cpu) {
ABORT("Gradient Dropping for CPU is not yet supported");
}
diff --git a/src/training/graph_group.h b/src/training/graph_group.h
index fc372adc..56b8afe3 100755..100644
--- a/src/training/graph_group.h
+++ b/src/training/graph_group.h
@@ -21,6 +21,7 @@ protected:
Ptr<OptimizerBase> opt_; // the optimizer
Ptr<Scheduler> scheduler_; // scheduler that keeps track of how much has been processed
bool finalized_{false}; // 'true' if training has completed (further updates are no longer allowed)
+ size_t typicalTrgBatchWords_{ 0 }; // for dynamic batch sizing: typical batch size in words
public:
GraphGroup(Ptr<Options> options) : options_(options), opt_(Optimizer(options)) {}
@@ -33,6 +34,10 @@ public:
virtual void save(bool isFinal = false) = 0;
+ void validate() {
+ ABORT_IF(finalized_, "Training has already finished.");
+ }
+
virtual void finalize() {
finalized_ = true;
}
@@ -48,13 +53,14 @@ public:
* The actual allowed size is then determined by multiplying it with the
* number of devices, which is passed in as the 'multiplier'.
*/
- virtual Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> graph,
- Ptr<models::ModelBase> model,
- size_t multiplier = 1) {
+ // @TODO: Can this be made const? It seems wrong to have a stateful method that still returns a result.
+ Ptr<data::BatchStats> collectStats(Ptr<ExpressionGraph> graph,
+ Ptr<models::ICriterionFunction> model,
+ const std::vector<Ptr<Vocab>>& vocabs,
+ double multiplier = 1.) {
auto stats = New<data::BatchStats>();
- size_t numFiles
- = options_->get<std::vector<std::string>>("train-sets").size();
+ size_t numFiles = options_->get<std::vector<std::string>>("train-sets").size();
// Initialize first batch to step size
size_t first = options_->get<size_t>("mini-batch-fit-step");
@@ -65,34 +71,46 @@ public:
size_t maxLength = options_->get<size_t>("max-length");
maxLength = (size_t)(std::ceil(maxLength / (float)step) * step);
- // @TODO: ugly
- auto toptions = New<Options>();
- toptions->merge(options_);
+ // this should be only one class label per line on input, hence restricting length to 1
+ std::vector<size_t> localMaxes(numFiles, maxLength);
+ auto inputTypes = options_->get<std::vector<std::string>>("input-types", {});
+ for(int i = 0; i < inputTypes.size(); ++i)
+ if(inputTypes[i] == "class")
+ localMaxes[i] = 1;
size_t maxBatch = 512;
bool fits = true;
while(fits) {
std::vector<size_t> lengths(numFiles, first);
- auto batch = data::CorpusBatch::fakeBatch(lengths, maxBatch, toptions);
+ for(int j = 0; j < lengths.size(); ++j) // apply length restrictions
+ lengths[j] = std::min(lengths[j], localMaxes[j]);
+
+ auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, maxBatch, options_);
auto cost = model->build(graph, batch);
fits = graph->fits();
if(fits)
maxBatch *= 2;
}
+ // Do a binary search for maxmimum batch size that fits into given workspace memory
+ // for a tested sentence length.
for(size_t i = step; i <= maxLength; i += step) {
size_t start = 1;
size_t end = maxBatch;
std::vector<size_t> lengths(numFiles, i);
+ for(int j = 0; j < lengths.size(); ++j) // apply length restrictions
+ lengths[j] = std::min(lengths[j], localMaxes[j]);
fits = true;
do {
size_t current = (start + end) / 2;
- auto batch = data::CorpusBatch::fakeBatch(lengths, current, toptions);
+ auto batch = data::CorpusBatch::fakeBatch(lengths, vocabs, current, options_);
auto cost = model->build(graph, batch);
fits = graph->fits();
+ LOG(debug, "[batching] length: {} - size: {} - fits: {}", lengths[0], current, fits);
+
if(fits) {
stats->add(batch, multiplier);
start = current + 1;
@@ -105,6 +123,10 @@ public:
}
return stats;
}
+
+ void setTypicalTrgBatchWords(size_t typicalTrgBatchWords) { // needed for dynamic MB scaling
+ typicalTrgBatchWords_ = typicalTrgBatchWords;
+ }
};
/**
@@ -120,17 +142,14 @@ protected:
std::vector<size_t> devices_; // [num local GPUs]
/** Graph builders for clients (which run forward and backward passes). */
- std::vector<Ptr<models::ModelBase>> clientBuilders_;
+ std::vector<Ptr<models::ICriterionFunction>> clientBuilders_;
/** Graphs of clients. One entry per GPU on this node. */
std::vector<Ptr<ExpressionGraph>> clientGraphs_; // [num local GPUs]
public:
- MultiNodeGraphGroupBase(Ptr<Options> options)
- : Base(options) {
-
- // Setup MPI
- setupMPI();
+ MultiNodeGraphGroupBase(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
+ : Base(options), mpi_(mpi) {
// Set up devices for this node
std::vector<size_t> devices; // set of GPU device ids for this MPI process
@@ -143,18 +162,11 @@ public:
clientGraphs_.push_back(New<ExpressionGraph>());
clientGraphs_[i]->setDevice({ devices_[i], DeviceType::gpu });
clientGraphs_[i]->reserveWorkspaceMB(options_->get<size_t>("workspace"));
- clientBuilders_.push_back(models::from_options(options_, models::usage::training));
+ clientBuilders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
}
/**
- * Setup MPI world size and rank of this node.
- */
- void setupMPI() {
- mpi_ = initMPI(/*multiThreaded=*/!options_->get<bool>("sync-sgd"));
- }
-
- /**
* Load the GPU configuration of this node (i.e. which GPUs to use) and the
* number of GPUs on the other nodes.
*/
@@ -194,10 +206,8 @@ public:
}
virtual void finalize() override {
- if (mpi_) {
+ if (mpi_)
finalizeMPI(std::move(mpi_));
- ABORT_IF(mpi_, "MPI not finalized??");
- }
Base::finalize();
}
};
diff --git a/src/training/graph_group_async.cpp b/src/training/graph_group_async.cpp
index 1d004201..4636f4b0 100755..100644
--- a/src/training/graph_group_async.cpp
+++ b/src/training/graph_group_async.cpp
@@ -5,23 +5,26 @@
namespace marian {
-AsyncGraphGroup::AsyncGraphGroup(Ptr<Options> config)
+AsyncGraphGroup::AsyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi)
: GraphGroup(config),
- ExponentialSmoothing{options_->get<float>("exponential-smoothing")},
+ ExponentialSmoothing(options_),
devices_{Config::getDevices(options_)},
shardSync_(devices_.size()),
- optimizerDelay_{options_->get<size_t>("optimizer-delay")} {
+ optimizerDelay_((size_t)options_->get<double>("optimizer-delay")) {
+ ABORT_IF(mpi->numMPIProcesses() != 1, "AsyncGraphGroup presently does not support multiple MPI processes");
+ ABORT_IF((double)optimizerDelay_ != options_->get<double>("optimizer-delay"), "AsyncGraphGroup presently does not implement fractional values for --optimizer-delay");
pool_.reset(new ThreadPool(devices_.size(), devices_.size()));
for(auto device : devices_) {
auto graph = New<ExpressionGraph>();
graph->setDevice(device);
+ graph->setCheckpointing(options_->get<bool>("gradient-checkpointing"));
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
shardOpt_.push_back(Optimizer(options_));
- builders_.push_back(models::from_options(options_, models::usage::training));
+ builders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
}
@@ -187,12 +190,12 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
auto task = [this](Ptr<data::Batch> batch) {
static size_t i = 0;
thread_local Ptr<ExpressionGraph> graph;
- thread_local Ptr<models::ModelBase> builder;
+ thread_local Ptr<models::ICriterionFunction> builder;
thread_local size_t t = 0;
thread_local size_t num_seen_words = 0;
thread_local size_t num_seen_sentences = 0;
thread_local int t_id = 0;
- thread_local float cost = 0;
+ thread_local StaticLoss loss;
thread_local Tensor accGradients;
thread_local Ptr<TensorAllocator> accAlloc;
@@ -204,14 +207,14 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
builder = builders_[i++];
}
- auto costNode = builder->build(graph, batch);
+ Ptr<RationalLoss> dynamicLoss = builder->build(graph, batch);
if(t % optimizerDelay_ == 0) {
fetchParams(graph->params()->vals(), params_, t_id);
}
graph->forward();
- cost += costNode->scalar();
+ loss += *dynamicLoss;
graph->backward();
Tensor gradients;
@@ -249,23 +252,21 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) {
// Wait until the thread that wants to do validation is finished.
pool_->wait_for_one(lock);
- if(options_->get<std::string>("cost-type") != "ce-sum")
- cost /= optimizerDelay_;
-
if(optimizerDelay_ > 1) {
std::vector<size_t> fakeLength = {1, 1};
- auto fb = data::CorpusBatch::fakeBatch(
- fakeLength, num_seen_sentences, NULL);
+ std::vector<Ptr<Vocab>> vocabs;
+ auto fb = data::CorpusBatch::fakeBatch(fakeLength, vocabs, num_seen_sentences, NULL);
fb->front()->setWords(num_seen_words);
- scheduler_->update(cost, fb);
+
+ scheduler_->update(loss, fb);
num_seen_words = 0;
num_seen_sentences = 0;
} else {
- scheduler_->update(cost, batch);
+ scheduler_->update(loss, batch);
}
- cost = 0;
+ loss.reset();
if(scheduler_->saving() || scheduler_->validating()) {
// Wait with validation or saving until all other threads are done with
@@ -325,7 +326,7 @@ void AsyncGraphGroup::load() {
setFn(i, data.begin() + begin, data.begin() + end);
}
});
- } else if(options_->has("pretrained-model")) {
+ } else if(options_->hasAndNotEmpty("pretrained-model")) {
std::string nameInit = options_->get<std::string>("pretrained-model");
LOG(info,
"Initialize model weights with the pre-trained model {}",
diff --git a/src/training/graph_group_async.h b/src/training/graph_group_async.h
index d3af0f22..dd5254bb 100755..100644
--- a/src/training/graph_group_async.h
+++ b/src/training/graph_group_async.h
@@ -16,7 +16,7 @@ public:
protected:
bool first_{true};
- std::vector<Ptr<models::ModelBase>> builders_;
+ std::vector<Ptr<models::ICriterionFunction>> builders_;
std::vector<Ptr<ExpressionGraph>> graphs_;
std::vector<DeviceId> devices_;
@@ -52,10 +52,10 @@ protected:
void execute(Ptr<data::Batch> batch);
public:
- AsyncGraphGroup(Ptr<Options> config);
+ AsyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi);
void update(Ptr<data::Batch> batch) override {
- ABORT_IF(finalized_, "Training has already finished");
+ validate();
execute(batch);
}
@@ -63,8 +63,9 @@ public:
void save(bool final = false) override;
void save(Ptr<ExpressionGraph>, bool final = false);
- Ptr<data::BatchStats> collectStats() {
- return GraphGroup::collectStats(graphs_[0], builders_[0]);
+ // @TODO: give it a fake batch generator which own vocabs instead of passing vocabs
+ virtual Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>& vocabs) {
+ return GraphGroup::collectStats(graphs_[0], builders_[0], vocabs);
}
virtual void finalize() override;
diff --git a/src/training/graph_group_async_drop.cpp b/src/training/graph_group_async_drop.cpp
index 35d6f05d..35d6f05d 100755..100644
--- a/src/training/graph_group_async_drop.cpp
+++ b/src/training/graph_group_async_drop.cpp
diff --git a/src/training/graph_group_async_drop.h b/src/training/graph_group_async_drop.h
index 7d22208b..3e313440 100755..100644
--- a/src/training/graph_group_async_drop.h
+++ b/src/training/graph_group_async_drop.h
@@ -30,8 +30,8 @@ protected:
int device_id) override;
public:
- AsyncGraphGroupDrop(Ptr<Options> options)
- : AsyncGraphGroup(options),
+ AsyncGraphGroupDrop(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
+ : AsyncGraphGroup(options, mpi),
dropping_warmup{options->get<size_t>("grad-dropping-warmup")},
droping_rate{options->get<float>("grad-dropping-rate")},
dropping_momentum{options->get<float>("grad-dropping-momentum")} {}
diff --git a/src/training/graph_group_multinode.cpp b/src/training/graph_group_multinode.cpp
index f5cbafd3..28c25641 100755..100644
--- a/src/training/graph_group_multinode.cpp
+++ b/src/training/graph_group_multinode.cpp
@@ -512,11 +512,11 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
auto task = [this](Ptr<data::Batch> batch) {
static size_t i = 0;
thread_local Ptr<ExpressionGraph> graph;
- thread_local Ptr<models::ModelBase> builder;
+ thread_local Ptr<models::ICriterionFunction> builder;
thread_local size_t my_id = 0;
thread_local size_t t = 0;
// only for scheduler statistic
- thread_local float cost = 0;
+ thread_local StaticLoss loss;
thread_local size_t num_seen_words = 0;
thread_local size_t num_seen_sentences = 0;
@@ -527,7 +527,7 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
builder = clientBuilders_[i++];
}
- auto costNode = builder->build(graph, batch);
+ auto lossNode = builder->build(graph, batch);
if(t == 0) {
mpi_->barrier();
@@ -537,7 +537,7 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
}
graph->forward();
- cost += costNode->scalar();
+ loss += *lossNode;
num_seen_words += batch->words();
num_seen_sentences += batch->size();
graph->backward();
@@ -592,22 +592,19 @@ void MultiNodeGraphGroup::execute(Ptr<data::Batch> batch) {
// Wait until the thread that wants to do validation is finished.
clientThreadPool_->wait_for_one(lock);
- if(options_->get<std::string>("cost-type") != "ce-sum")
- cost /= tau_;
-
if(tau_ > 1) {
std::vector<size_t> fakeLength = {1, 1};
- auto fb = data::CorpusBatch::fakeBatch(
- fakeLength, num_seen_sentences, NULL);
+ std::vector<Ptr<Vocab>> vocabs;
+ auto fb = data::CorpusBatch::fakeBatch(fakeLength, vocabs, num_seen_sentences, NULL);
fb->front()->setWords(num_seen_words);
- scheduler_->update(cost, fb);
+ scheduler_->update(loss, fb);
} else {
- scheduler_->update(cost, batch);
+ scheduler_->update(loss, batch);
}
num_seen_words = 0;
num_seen_sentences = 0;
- cost = 0;
+ loss.reset();
if((scheduler_->saving() || scheduler_->validating())) {
// Wait with validation or saving until all other threads are done with
diff --git a/src/training/graph_group_multinode.h b/src/training/graph_group_multinode.h
index c86225eb..340fdcbb 100755..100644
--- a/src/training/graph_group_multinode.h
+++ b/src/training/graph_group_multinode.h
@@ -351,10 +351,10 @@ public:
/**
* (Constructor) Call super class and initialize client graphs and builders.
*/
- MultiNodeGraphGroup(Ptr<Options> options)
- : Base(options),
+ MultiNodeGraphGroup(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
+ : Base(options, mpi),
clientCommOverlap{options_->get<bool>("multi-node-overlap")},
- tau_{options_->get<size_t>("optimizer-delay")} { }
+ tau_{(size_t)options_->get<double>("optimizer-delay")} { }
/**
* (Destructor) Shut down server shard thread and (if comm. overlap enabled)
@@ -376,7 +376,7 @@ public:
* Update any client model with given batch if batch is assigned to this node.
*/
void update(Ptr<data::Batch> batch) override {
- ABORT_IF(finalized_, "Training has already finished");
+ validate();
// Only take batch assigned to this node
if(batchIter_ % mpi_->numMPIProcesses() == (size_t)mpi_->myMPIRank()) {
execute(batch);
@@ -397,7 +397,7 @@ public:
size_t i = 0;
for(auto graph : clientGraphs_)
clientBuilders_[i++]->load(graph, name);
- } else if(options_->has("pretrained-model")) {
+ } else if(options_->hasAndNotEmpty("pretrained-model")) {
std::string init = options_->get<std::string>("pretrained-model");
LOG(info,
"Initialize model weights with the pre-trained model {}",
@@ -454,8 +454,8 @@ public:
/**
* Collect statistics from first client's graph.
*/
- Ptr<data::BatchStats> collectStats() {
- return GraphGroup::collectStats(clientGraphs_[0], clientBuilders_[0]);
+ Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>& vocabs) {
+ return GraphGroup::collectStats(clientGraphs_[0], clientBuilders_[0], vocabs);
}
};
} // namespace marian
diff --git a/src/training/graph_group_multinode_sync.cpp b/src/training/graph_group_multinode_sync.cpp
index 32865705..904d614f 100755..100644
--- a/src/training/graph_group_multinode_sync.cpp
+++ b/src/training/graph_group_multinode_sync.cpp
@@ -175,7 +175,7 @@ void MultiNodeGraphGroupSync::execute(Ptr<data::Batch> fullBatch) {
static int t = 0;
- static float cost = 0;
+ static StaticLoss loss;
static size_t num_seen_words = 0;
static size_t num_seen_sentences = 0;
@@ -185,7 +185,7 @@ void MultiNodeGraphGroupSync::execute(Ptr<data::Batch> fullBatch) {
auto graph = clientGraphs_[my_id];
auto builder = clientBuilders_[my_id];
- auto costNode = builder->build(graph, batch);
+ auto lossNode = builder->build(graph, batch);
if(t == 0) {
if(my_id != 0)
@@ -195,7 +195,7 @@ void MultiNodeGraphGroupSync::execute(Ptr<data::Batch> fullBatch) {
graph->forward();
{
std::lock_guard<std::mutex> guard(sumCostMutex_);
- cost += costNode->scalar();
+ loss += *lossNode;
num_seen_words += batch->words();
num_seen_sentences += batch->size();
}
@@ -219,22 +219,18 @@ void MultiNodeGraphGroupSync::execute(Ptr<data::Batch> fullBatch) {
// Run scheduler (if enabled)
if(t % tau_ == 0 && scheduler_) {
- if(options_->get<std::string>("cost-type") != "ce-sum")
- cost /= (tau_ * devices_.size());
-
if(tau_ > 1) {
std::vector<size_t> fakeLength = {1, 1};
- auto fb
- = data::CorpusBatch::fakeBatch(fakeLength, num_seen_sentences, NULL);
+ auto fb = data::CorpusBatch::fakeBatch(fakeLength, std::vector<Ptr<Vocab>>(), num_seen_sentences, NULL);
fb->front()->setWords(num_seen_words);
- scheduler_->update(cost, fb);
+ scheduler_->update(loss, fb);
} else {
- scheduler_->update(cost, fullBatch);
+ scheduler_->update(loss, fullBatch);
}
num_seen_words = 0;
num_seen_sentences = 0;
- cost = 0;
+ loss.reset();
if((scheduler_->saving() || scheduler_->validating())) {
// wait until other nodes are ready
diff --git a/src/training/graph_group_multinode_sync.h b/src/training/graph_group_multinode_sync.h
index bfef2050..0a089d7b 100755..100644
--- a/src/training/graph_group_multinode_sync.h
+++ b/src/training/graph_group_multinode_sync.h
@@ -63,7 +63,6 @@ private:
Tensor paramsAvg_;
std::vector<float> accGradientsSync_cpu;
std::vector<float> receiveBuffer_cpu;
- bool synchronization_happened{false};
Ptr<OptimizerBase> syncOptimizer_;
@@ -131,9 +130,9 @@ public:
/**
* (Constructor) Call super class and initialize client graphs and builders.
*/
- MultiNodeGraphGroupSync(Ptr<Options> options)
- : Base(options),
- tau_{options_->get<size_t>("optimizer-delay")},
+ MultiNodeGraphGroupSync(Ptr<Options> options, Ptr<IMPIWrapper> mpi)
+ : Base(options, mpi),
+ tau_{(size_t)options_->get<double>("optimizer-delay")},
syncOptimizer_{Optimizer(options_)},
movingAvg_{options_->get<float>("exponential-smoothing") > 0},
mvDecay_{options_->get<float>("exponential-smoothing")} {
@@ -143,7 +142,7 @@ public:
* Update any client model with given batch if batch is assigned to this node.
*/
void update(Ptr<data::Batch> batch) override {
- ABORT_IF(finalized_, "Training has already finished");
+ validate();
if(batchIter_ % mpi_->numMPIProcesses() == mpi_->myMPIRank()) { // Only take batch assigned to this node
execute(batch);
}
@@ -163,7 +162,7 @@ public:
size_t i = 0;
for(auto graph : clientGraphs_)
clientBuilders_[i++]->load(graph, name);
- } else if(options_->has("pretrained-model")) {
+ } else if(options_->hasAndNotEmpty("pretrained-model")) {
std::string init = options_->get<std::string>("pretrained-model");
LOG(info,
"Initialize model weights with the pre-trained model {}",
@@ -220,9 +219,9 @@ public:
/**
* Collect statistics from first client's graph.
*/
- Ptr<data::BatchStats> collectStats() {
+ Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>& vocabs) {
return GraphGroup::collectStats(
- clientGraphs_[0], clientBuilders_[0], devices_.size());
+ clientGraphs_[0], clientBuilders_[0], vocabs, (double)devices_.size());
}
};
} // namespace marian
diff --git a/src/training/graph_group_singleton.cpp b/src/training/graph_group_singleton.cpp
index ac6ef5d0..21083408 100755..100644
--- a/src/training/graph_group_singleton.cpp
+++ b/src/training/graph_group_singleton.cpp
@@ -10,10 +10,8 @@ void SingletonGraph::setScheduler(Ptr<Scheduler> scheduler) {
}
void SingletonGraph::execute(Ptr<data::Batch> batch) {
- auto costNode = builder_->build(graph_, batch);
-
+ auto lossNode = builder_->build(graph_, batch);
graph_->forward();
- float cost = costNode->scalar();
graph_->backward();
// Get batch stats
@@ -34,7 +32,7 @@ void SingletonGraph::execute(Ptr<data::Batch> batch) {
}
if(scheduler_) {
- scheduler_->update(cost, batch);
+ scheduler_->update(*lossNode, batch);
if(scheduler_->validating()) {
if(mvAvg_) {
diff --git a/src/training/graph_group_singleton.h b/src/training/graph_group_singleton.h
index 1eb589d9..8a93af71 100755..100644
--- a/src/training/graph_group_singleton.h
+++ b/src/training/graph_group_singleton.h
@@ -16,16 +16,17 @@ public:
virtual void setScheduler(Ptr<Scheduler> scheduler) override;
private:
- Ptr<models::ModelBase> builder_;
+ Ptr<models::ICriterionFunction> builder_;
Ptr<ExpressionGraph> graph_;
Ptr<ExpressionGraph> graphAvg_;
void execute(Ptr<data::Batch> batch);
public:
- SingletonGraph(Ptr<Options> config)
+ SingletonGraph(Ptr<Options> config, Ptr<IMPIWrapper> mpi)
: GraphGroup(config),
- ExponentialSmoothing(options_->get<float>("exponential-smoothing")) {
+ ExponentialSmoothing(config) {
+ ABORT_IF(mpi->numMPIProcesses() != 1, "SingletonGraph does not support multiple MPI processes");
// Get device ID
auto devices = Config::getDevices(options_);
ABORT_IF(devices.size() != 1, "Only one device ID should be provided for singleton training");
@@ -33,14 +34,15 @@ public:
// Initialize graph
graph_ = New<ExpressionGraph>();
graph_->setDevice(deviceId);
+ graph_->setCheckpointing(options_->get<bool>("gradient-checkpointing"));
graph_->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph_->reserveWorkspaceMB(options_->get<size_t>("workspace"));
opt_ = Optimizer(options_);
- builder_ = models::from_options(options_, models::usage::training);
+ builder_ = models::createCriterionFunctionFromOptions(options_, models::usage::training);
}
void update(Ptr<data::Batch> batch) override {
- ABORT_IF(finalized_, "Training has already finished");
+ validate();
execute(batch);
}
@@ -69,7 +71,7 @@ public:
/*scatterStateFn=*/[&](const std::vector<float>& data, const OptimizerBase::ScatterStateSetFunc& setFn) {
setFn(/*localDeviceIndex=*/0, data.begin(), data.end());
});
- } else if(options_->has("pretrained-model")) {
+ } else if(options_->hasAndNotEmpty("pretrained-model")) {
std::string init = options_->get<std::string>("pretrained-model");
LOG(info,
"Initialize model weights with the pre-trained model {}",
@@ -125,8 +127,8 @@ public:
});
}
- Ptr<data::BatchStats> collectStats() {
- return GraphGroup::collectStats(graph_, builder_);
+ Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>& vocabs) {
+ return GraphGroup::collectStats(graph_, builder_, vocabs);
}
virtual void finalize() override { finalized_ = true; }
diff --git a/src/training/graph_group_sync.cpp b/src/training/graph_group_sync.cpp
index 6699484c..1b3c16de 100755
--- a/src/training/graph_group_sync.cpp
+++ b/src/training/graph_group_sync.cpp
@@ -2,32 +2,37 @@
namespace marian {
-SyncGraphGroup::SyncGraphGroup(Ptr<Options> config)
- : GraphGroup(config),
- ExponentialSmoothing{options_->get<float>("exponential-smoothing")},
- delay_{options_->get<size_t>("optimizer-delay")} { // @TODO: rename to something else; delay means delayed updated, not accumulation
-
- mpi_ = initMPI(/*multiThreaded=*/false); // when not running under MPI, this will be a fake object that represents a one-MPI-process setup
+SyncGraphGroup::SyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi)
+ : GraphGroup(config), ExponentialSmoothing(config),
+ delay_{options_->get<double>("optimizer-delay")}, mpi_(mpi) { // @TODO: rename delay_ to something else; delay means delayed updated, not accumulation
devices_ = Config::getDevices(options_, mpi_->myMPIRank(), mpi_->numMPIProcesses());
for(auto device : devices_) {
auto graph = New<ExpressionGraph>();
graph->setDevice(device);
+ graph->setCheckpointing(options_->get<bool>("gradient-checkpointing"));
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
graphs_.push_back(graph);
shardOpt_.push_back(Optimizer(options_));
- builders_.push_back(models::from_options(options_, models::usage::training));
+ builders_.push_back(models::createCriterionFunctionFromOptions(options_, models::usage::training));
}
// Note: We may well end up with only one MPI process or only one graph per worker.
// This part of the code will not special-case any of this here.
// Rather, it is assumed that the communicator knows to reduce unnecessary transfers to no-ops.
comm_ = createCommunicator(graphs_, /*noNccl=*/options_->get<bool>("no-nccl", false), /*mpi=*/mpi_);
+
+ auto formattedDeviceType = utils::utf8ToUpper(devices_.front().typeAsString()) + "s";
+ if (mpi_->numMPIProcesses() > 1)
+ LOG(info, "[training] Using {} {}, distributed over {} MPI processes", mpi_->numMPIProcesses() * devices_.size(), formattedDeviceType, mpi_->numMPIProcesses());
+ else
+ LOG(info, "[training] Using {} {}", devices_.size(), formattedDeviceType);
}
void SyncGraphGroup::setScheduler(Ptr<Scheduler> scheduler) /*override*/ {
+ validate();
scheduler_ = scheduler;
// optimizer has to be registered last to see changes of learning rate
// @TODO: ^^Fix this comment. Either it refers to the scheduler, or it should be moved. Which one?
@@ -38,38 +43,39 @@ void SyncGraphGroup::setScheduler(Ptr<Scheduler> scheduler) /*override*/ {
}
void SyncGraphGroup::initialize(const Ptr<data::Batch>& exampleBatch) {
- // Initialize 0th graph with random weights in one forward step
- // @TODO: Why do we need the THREAD_GUARD here? Why not run this on the main thread?
- THREAD_GUARD({
- builders_[0]->build(graphs_[0], exampleBatch);
- graphs_[0]->forward();
+ // Initialize graphs with random weights in one forward step
+ // Also allocate and clear the gradients
+ comm_->foreach([&](size_t i, size_t /*begin*/, size_t /*end*/) {
+ builders_[i]->build(graphs_[i], exampleBatch);
+ graphs_[i]->forward();
+ graphs_[i]->params()->allocateBackward();
+ graphs_[i]->params()->set_zero_adjoint();
});
- // Copy weights from 0th graph to all other graphs
+ // Copy weights from 0-th graph to all other graphs
// to have equal weights across devices
- ThreadPool pool(graphs_.size() - 1, graphs_.size() - 1);
- for(size_t i = 1; i < graphs_.size(); ++i) {
- auto init = [&](size_t i) {
- // initialize t-th graph and weights
- builders_[i]->build(graphs_[i], exampleBatch);
- graphs_[i]->forward();
- // overwrite weights of t-th graph with weights from 0th graph
+ comm_->foreach([&](size_t i, size_t /*begin*/, size_t /*end*/) {
+ if (i > 0)
graphs_[i]->params()->vals()->copyFrom(graphs_[0]->params()->vals());
- };
- pool.enqueue(init, i);
- }
- // ThreadPool destructor waits until completion of all tasks.
- // @TODO: can we use comm_->foreach()?
+ });
}
void SyncGraphGroup::initializeAvg() {
Ptr<ExpressionGraph> graphAvg; // CPU-side temp
std::string name = options_->get<std::string>("model");
- if(filesystem::exists(name + ".orig.npz")) {
+ std::string suffix = name.substr(name.size() - 4);
+ ABORT_IF(suffix != ".npz" && suffix != ".bin", "Unknown model suffix {}", suffix);
+
+ if(filesystem::exists(name + ".orig" + suffix)) {
// Load the averaged parameters into a temporary graph
graphAvg = New<ExpressionGraph>();
graphAvg->setDevice({0, DeviceType::cpu});
- graphAvg->load(name, false);
+
+ // load model through builder to activate model specific loading functions.
+ // This is important if a model is overloading Model::load(...) and e.g.
+ // mapping matrix names as in Amun.h
+ auto builder = models::createCriterionFunctionFromOptions(options_, models::usage::training);
+ builder->load(graphAvg, name, false);
graphAvg->forward(); // initialize parameters if needed
}
@@ -100,32 +106,201 @@ void SyncGraphGroup::initializeAvg() {
comm_->foreach(init, /*parallel=*/false); // @TODO: is sequential operation necessary here? (is the allocation stuff sufficiently reentrant or thread-separated?)
}
-Ptr<data::BatchStats> SyncGraphGroup::collectStats() {
- // @TODO: This should only run on MPI process 0. Also we can share vv this vv expression with update().
- size_t multiplier = devices_.size() * mpi_->numMPIProcesses() * delay_;
- return GraphGroup::collectStats(graphs_[0], builders_[0], multiplier);
+Ptr<data::BatchStats> SyncGraphGroup::collectStats(const std::vector<Ptr<Vocab>>& vocabs) {
+ // This function determines the granularity in which the reader provides data.
+ // If no mini-batch-fit, then user provides a constant number. It reads that much. We won't get into this function.
+ // If mini-batch-fit, then we get here and set miniBatchFitMultiplier_. Then...
+ // If dynamic MB scaling, then we want fine-grained minibatches of the size of one GPU.
+ // If not, we prefer a single large batch that can be split into equal-size parts over GPUs,
+ // so that we have perfect load balancing and read precisely as much as we need (no waste).
+ double multiplier = devices_.size() * mpi_->numMPIProcesses() * delay_;
+ bool isDynamic = scheduler_->isDynamicMBSizeScaling();
+ double readerMultiplier = isDynamic ? 1. : multiplier; // multiplier applied already by reader
+ updateMultiplier_ = isDynamic ? multiplier : 1.; // multiplier applied later in update()
+ return GraphGroup::collectStats(graphs_[0], builders_[0], vocabs, readerMultiplier);
+}
+
+// helper for MB scaling: quantize the ratio with a given error margin
+static double roundUpRatio(double ratio) {
+ if (ratio == 0)
+ return ratio;
+ // find largest power of two that fits into ratio
+ double p = 1;
+ while (p*2 < ratio)
+ p *= 2;
+ // round up to nearest multiple of a largest power of 2 where relative error is within margin
+ // 25% error margin seems acceptable:
+ // - using a 25% larger MB size should not break convergence
+ // - @TODO: not using the first 25% of the next block is OK since those are dominated by data exchange
+ double maxError = 0.25;
+ while (p >= 1) {
+ double proposedRatio = ceil(ratio / p) * p;
+ double error = (proposedRatio - ratio) / ratio;
+ if (fabs(error) <= maxError)
+ return proposedRatio;
+ p /= 2;
+ }
+ return ratio;
+}
+
+// helper routine that handles accumulation and load-balancing of sub-batches to fill all devices
+// It adds 'newBatch' to 'pendingBatches_', and if sufficient batches have been queued, then
+// returns 'pendingBatches_' in 'subBatches' and resets it. If not, it returns false.
+bool SyncGraphGroup::tryGetSubBatches(Ptr<data::Batch> newBatch,
+ std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches) {
+ // The reader delivers in chunks of these sizes, according to case:
+ // - no dynamic MB-size scaling:
+ // - reader batch size = update batch size, with...
+ // - mini-batch-fit:
+ // - update batch size = what fits into all GPUs, times decay_ to allow experimenting with fractional sizes
+ // - no mini-batch-fit:
+ // - update batch size = user-specified size (user guarantees that it fits if distributed over delay_ GPUs)
+ // - dynamic MB-size scaling:
+ // - update batch size = aggregate reader batch size * (dynamic progress-based ratio * reference adjustment), with...
+ // - mini-batch-fit:
+ // - aggregate reader batch size = equal to what fits into one GPU * warpSize * delay_
+ // - no mini-batch-fit:
+ // - aggregate reader batch size = user-specified size (user guarantees that it fits if distributed over delay_ GPUs)
+ // - reference adjustment =
+ // - reference batch size specified: (reference batch size / typical aggregate reader batch size)
+ // - no ref size specified: 1
+
+ size_t warpSize = devices_.size() * mpi_->numMPIProcesses(); // warp := set of batches processed concurrently across GPus and workers
+
+ // if not dynamic then return the big batch, but first split it over GPUs as it may be too large
+ if (!scheduler_->isDynamicMBSizeScaling()) {
+ // If mini-batch-fit, then the read batch is (devices_.size() * mpi_->numMPIProcesses() * delay_)
+ // times what fits one GPU. If not mini-batch-fit, it is whatever the user has specified, which
+ // is the user's responsibility to guarantee that it fits into 'delay_' warps.
+ // Distribute evenly over all GPUs we have, using multiple warps if needed.
+ size_t numWarps = (size_t)ceil(delay_);
+ subBatches = newBatch->split(numWarps * warpSize);
+ numReadBatches = 1;
+ return true;
+ }
+ LOG_ONCE(info, "[training] Dynamic mini-batch scaling enabled");
+
+ // if dynamic and mini-batch-fit, then we get batches in the size of what fits into one GPU
+ pendingBatches_.push_back(newBatch);
+
+ // what ratio (how many batches in reader's batch size) do we want, based on current training progress schedule?
+ double ratio = scheduler_->getDynamicMBSizeMultiplier();
+
+ // relative to what base? (what does ratio == 1 mean)
+ ratio *= updateMultiplier_; // if mini-batch-fit, this is = warpSize * delay_, otherwise 1
+
+ // If a reference is given, then at progress == mbWarmup.n (ratio=1), we would like to have refBatchLabels instead of whichever
+ // the actual batch size is. Since we cannot know the future actual batch sizes that will be delivered
+ // by the reader, we approximate them with (typicalTrgBatchWords * updateMultiplier), and scale ratio accordingly.
+ auto refBatchLabels = options_->get<size_t>("mini-batch-words-ref");
+ if (refBatchLabels != 0) {
+ LOG_ONCE(info, "[scheduler] Scaling to {} reference labels, using actual-batch-word estimate of {}", refBatchLabels, typicalTrgBatchWords_);
+ ABORT_IF(typicalTrgBatchWords_ == 0, "Dynamic scaling with words target requires MB size to be known in words"); // happens if MB size is specified in sentences
+ ratio *= (double)refBatchLabels / (double)(typicalTrgBatchWords_ * updateMultiplier_);
+ }
+
+ // round up to full batches if within a certain error margin --@BUGBUG: Not invariant w.r.t. GPU size, as ratio is relative to what fits into 1 GPU
+ ratio = roundUpRatio(ratio);
+
+ if (pendingBatches_.size() < ratio)
+ return false; // not enough data yet
+
+ // now we have enough to fill at least 'ratio' batches
+ // @BUGBUG: We do not handle the case that fixed MB size * ratio exceeds GPU memory (we'd need to split that).
+
+ numReadBatches = pendingBatches_.size(); // remember original batch-counter increment from reader (which is not always the same as subBatches.size() in the end)
+
+ // in fact, we got too much, so make up for it by shortening all batches to accurately reflect desired ratio
+ // e.g. ratio = 3.3 for 4 batches -> Reduce each by 3.3/4
+ // Alternatively, we could just shorten the last 'warp', but that would not be invariant to warp size.
+ for (auto& batch : pendingBatches_) {
+ auto reducedBatchSize = (size_t)ceil((double)batch->size() * ratio / (double)pendingBatches_.size());
+ size_t minSize = 1;
+ if (pendingBatches_.size() == 1) { // enforce a minimum (only needed/correct if still in first batch)
+ size_t minTrgWords = 256; // don't go below this number of target words, as it seems excessive --@TODO: parameterize?
+ minSize = 1 + (minTrgWords * batch->size() - 1) / batch->wordsTrg(); // approximately convert minTrgWords into a #sentences
+ }
+ reducedBatchSize = std::max(reducedBatchSize, minSize);
+ if (reducedBatchSize < batch->size())
+ batch = batch->split(/*numSubBatches=*/1, reducedBatchSize).front();
+ }
+
+ // load-balance: distribute the last numWarps-group's batches over GPUs
+ // This is tricky since batches do not have the same length, therefore we can only split, but not merge.
+ auto numWarps = (pendingBatches_.size() - 1) / warpSize + 1; // = ceil(#buffers / (#GPUs * #workers))
+ auto availableDevices = numWarps * warpSize; // we will run this many GPUs: better use them all
+ if (pendingBatches_.size() < availableDevices) {
+ // last warp does not use all available GPUs: try to re-balance
+ auto fullWarpsBatches = (numWarps - 1) * warpSize; // number of batches in all but the last warp. Those warps that are fully used.
+ auto lastWarpSize = pendingBatches_.size() - fullWarpsBatches; // the last warp is possibly not fully used
+ //LOG(info, "attempting to redistribute last {} batches over {} devices", lastWarpSize, warpSize);
+ auto splitInto = warpSize / lastWarpSize;
+ if (splitInto > 1) { // unfortunately we can only split in integer ratios
+ // split each of last numWarps's batches into 'splitInto' batches
+ // pop them first
+ std::vector<Ptr<data::Batch>> batchesToSplit;
+ while (pendingBatches_.size() > fullWarpsBatches) {
+ batchesToSplit.push_back(pendingBatches_.back());
+ pendingBatches_.pop_back();
+ }
+ // now split them and push them back
+ for (auto& batchToSplit : batchesToSplit) {
+ //LOG(info, "{}-way splitting batchToSplit with size {}", splitInto, batchToSplit->size());
+ auto splitBatches = batchToSplit->split(splitInto);
+ for (auto& splitBatch : splitBatches) {
+ //LOG(info, " -> getting batchToSplit with size {}", splitBatch->size());
+ pendingBatches_.push_back(splitBatch);
+ }
+ }
+ }
+ ABORT_IF(pendingBatches_.size() > availableDevices, "somehow split into too many batches??");
+ }
+ subBatches = std::move(pendingBatches_);
+
+ // @TODO: sort by width, so that in case of delay > 1, each GPU gets about the same size
+ return true;
}
-void SyncGraphGroup::update(Ptr<data::Batch> batch) /*override*/ {
- ABORT_IF(finalized_, "Training has already finished");
+void SyncGraphGroup::update(Ptr<data::Batch> newBatch) /*override*/ {
+ validate();
+
+ std::vector<Ptr<data::Batch>> subBatches;
+ size_t numReadBatches; // actual #batches delivered by reader, for restoring from checkpoint --@TODO: reader should checkpoint itself; should not go via the scheduler
+ bool gotSubBatches = tryGetSubBatches(newBatch, subBatches, numReadBatches);
- // distribute the batch over (delay, local device, MPI rank)
- size_t numSubBatches = delay_ * devices_.size() * mpi_->numMPIProcesses();
- auto subBatches = batch->split(numSubBatches);
- subBatches.resize(numSubBatches); // pad with nullptrs if out of data
+ // not enough data yet: return right away
+ if (!gotSubBatches)
+ return;
+
+ update(subBatches, numReadBatches);
+}
+
+void SyncGraphGroup::update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches) {
+ // determine num words for dynamic hyper-parameter adjustment
+ // @TODO: We can return these directly from tryGetSubBatches()
+ size_t batchSize = 0;
+ size_t batchTrgWords = 0;
+ for (const auto& batch : subBatches) {
+ batchSize += batch->size();
+ batchTrgWords += batch->wordsTrg();
+ }
// Helper to access the subBatches array
- auto getSubBatch = [&](size_t t, size_t localDeviceIndex, size_t rank) {
- // 't' (the delay) should be slowest changing dimension. If subBatches are sorted by
+ auto getSubBatch = [&](size_t warp, size_t localDeviceIndex, size_t rank) -> Ptr<data::Batch> {
+ // Warp should be slowest changing dimension. If subBatches are sorted by
// length, then grouping sentences of similar length into the same delay step can
// reduce unnecessary time spent in padding.
- return subBatches[(t * mpi_->numMPIProcesses() + rank) * devices_.size() + localDeviceIndex];
+ auto index = (warp * mpi_->numMPIProcesses() + rank) * devices_.size() + localDeviceIndex;
+ if (index < subBatches.size())
+ return subBatches[index];
+ else
+ return nullptr; // null if we reached beyond the end
};
// Upon very first execution, reset everything
if(first_) {
- LOG(debug, "[{}] Processing first minibatch. Batches are processed as {} processes x {} GPUs/process x {} delay steps",
- mpi_->idStr(), mpi_->numMPIProcesses(), devices_.size(), delay_);
+ LOG(info, "[training] Batches are processed as {} process(es) x {} devices/process",
+ mpi_->numMPIProcesses(), devices_.size());
initialize(subBatches.front());
if(mvAvg_ && paramsAvg_.empty())
initializeAvg();
@@ -133,34 +308,26 @@ void SyncGraphGroup::update(Ptr<data::Batch> batch) /*override*/ {
}
// Compute gradients
- // This happens in multiple steps in case of delay_ > 1.
- std::vector<float> localDeviceCosts(devices_.size(), 0.f); // [local device index] aggregate cost for each local device
- for (size_t t = 0; t < delay_; t++) {
- // Execute single forward/backward step
- auto forwardBackward = [&](size_t localDeviceIndex, size_t /*begin*/, size_t /*end*/) {
- auto graph = graphs_[localDeviceIndex];
- auto subBatch = getSubBatch(t, localDeviceIndex, mpi_->myMPIRank());
-
- if(subBatch) {
- timer::Timer timer;
- auto costNode = builders_[localDeviceIndex]->build(graph, subBatch);
- //LOG(info, timer.format(2, "after build: %ws"));
- graph->forward();
- //LOG(info, timer.format(2, "after forward (no sync): %ws"));
- localDeviceCosts[localDeviceIndex] += costNode->scalar();
- graph->backward(/*zero=*/t == 0); // only reset gradients to 0 if t = 0
- //LOG(info, timer.format(2, "after backward (no sync): %ws"));
- //localDeviceCosts[localDeviceIndex] += costNode->scalar(); // moved here for time measurements; @TODO: move this back
- //LOG(info, timer.format(2, "after scalar() (that's a sync): %ws"));
- }
- else { // empty batch: execute do-nothing fw-bw step for proper inits and resets
- graph->forward();
- graph->backward(/*zero=*/t == 0);
- }
- };
-
- comm_->foreach(forwardBackward); // compute gradients in parallel on each device. Aggregate if delay_ > 1.
- }
+ std::vector<StaticLoss> localDeviceLosses(devices_.size()); // [local device index] aggregate cost for each local device
+ comm_->foreach([&](size_t localDeviceIndex, size_t /*begin*/, size_t /*end*/) { // parallel across devices. Aggregate for warp > 1.
+ auto graph = graphs_[localDeviceIndex];
+ // reset gradient --presently done outside
+ //graph->params()->allocateBackward();
+ //graph->params()->set_zero_adjoint();
+ // This happens in multiple steps if there are more subbatches than devices.
+ for (size_t warp = 0; ; warp++) {
+ // Execute single forward/backward step
+ auto subBatch = getSubBatch(warp, localDeviceIndex, mpi_->myMPIRank());
+ if (!subBatch)
+ break;
+
+ auto rationalLoss = builders_[localDeviceIndex]->build(graph, subBatch);
+ graph->forward();
+
+ localDeviceLosses[localDeviceIndex] += *rationalLoss;
+ graph->backward(/*zero=*/false); // (gradients are reset before we get here)
+ }
+ });
// At this point, each device on each MPI process has a gradient aggregated over a subset of the sub-batches.
// Update parameter shard with gradient shard
@@ -168,45 +335,37 @@ void SyncGraphGroup::update(Ptr<data::Batch> batch) /*override*/ {
auto curGrad = graphs_[idx]->params()->grads()->subtensor(begin, end-begin);
auto curParam = graphs_[idx]->params()->vals()->subtensor(begin, end-begin);
- // if individual gradients were averages, then need to average again over all subBatches
- auto div = subBatches.size();
- if (options_->get<std::string>("cost-type") == "ce-sum")
- div = 1;
- if(div != 1) {
- using namespace functional;
- Element(_1 = _1 / (float)div, curGrad);
- }
-
// actual model update
- shardOpt_[idx]->update(curParam, curGrad);
+ auto updateTrgWords =
+ /*if*/(options_->get<std::string>("cost-type") == "ce-sum") ?
+ batchTrgWords
+ /*else*/:
+ OptimizerBase::mbSizeNotProvided;
+ shardOpt_[idx]->update(curParam, curGrad, updateTrgWords);
+ curGrad->set(0.f);
if(mvAvg_)
updateAvgParams(
- paramsAvg_[idx], curParam, scheduler_->numberOfBatches());
+ paramsAvg_[idx], curParam, scheduler_->numberOfBatches(), updateTrgWords);
};
- timer::Timer timer;
- comm_->scatterReduce(); // reduce gradients across all devices (globally) into shards
- //LOG(info, timer.format(2, "after scatterReduce (has sync): %ws"));
- comm_->foreach(update); // per-shard model-update
- //LOG(info, timer.format(2, "after model update (no sync): %ws"));
- //graphs_.front()->getBackend()->synchronize(); // @TODO: This is strictly for time measurement. Make sure it doesn't accidentally stay in here!!
- //LOG(info, timer.format(2, "after model update sync (which is unnecessary except for time measurements): %ws"));
- comm_->allGather(); // distribute param value shards back
- //LOG(info, timer.format(2, "after allGather (has sync): %ws"));
-
// cost across all local devices (scheduler will aggregate cross-process)
- float localCost = 0;
- for(auto& c : localDeviceCosts) // localDeviceCosts is already summed up over delay steps
- localCost += c;
-
- // if localCost is average-based, we need to turn the sum over devices into an average as well
- if(options_->get<std::string>("cost-type") != "ce-sum")
- localCost /= numSubBatches;
+ StaticLoss localLoss;
+ for(auto& l : localDeviceLosses) // localDeviceLosses is already summed up over delay steps
+ localLoss += l;
+
+ // model update
+ if (std::isfinite(localLoss.loss) || mpi_->numMPIProcesses() > 1) { // guard against NaN (except with MPI, as this simple way could hang it)
+ comm_->scatterReduceAndResetGrads(); // reduce gradients across all devices and MPI nodes into shards
+ comm_->foreach(update); // per-shard model-update
+ comm_->allGatherParams(); // distribute param value shards back
+ }
+ else
+ LOG(info, "[training] skipping {}-th update due to loss being {}", scheduler_->numberOfBatches(), localLoss.loss);
if(scheduler_) {
- // track and log localCost
- scheduler_->update(localCost, subBatches, mpi_);
+ // track and log localLoss
+ scheduler_->update(localLoss, numReadBatches, batchSize, batchTrgWords, mpi_);
// save intermediate model (and optimizer state) to file
if(scheduler_->saving())
@@ -224,6 +383,7 @@ void SyncGraphGroup::update(Ptr<data::Batch> batch) /*override*/ {
}
void SyncGraphGroup::load() /*override*/ {
+ validate();
// This function loads the main parameters in the graphs.
// In case of exponential smoothing, we also need to restore paramsAvg_.
@@ -237,9 +397,12 @@ void SyncGraphGroup::load() /*override*/ {
scheduler_->load(name);
std::string nameGraph = name;
- if(mvAvg_ && filesystem::exists(name + ".orig.npz"))
+ std::string suffix = name.substr(name.size() - 4);
+ ABORT_IF(suffix != ".npz" && suffix != ".bin", "Unknown model suffix {}", suffix);
+
+ if(mvAvg_ && filesystem::exists(name + ".orig" + suffix))
// Load the original parameters from model.npz.orig.npz
- nameGraph += ".orig.npz";
+ nameGraph += ".orig" + suffix;
size_t i = 0;
for(auto graph : graphs_)
@@ -249,14 +412,15 @@ void SyncGraphGroup::load() /*override*/ {
std::vector<Ptr<Backend>> backends;
for(auto graph : graphs_)
backends.push_back(graph->getBackend());
- shardOpt_[0]->load(name + ".optimizer.npz", shardOpt_, backends,
+ shardOpt_[0]->load(name + ".optimizer.npz", shardOpt_, backends, // keep npz suffix for optimize checkpoint
[&](const std::vector<float>& optimizerStateVector, const OptimizerBase::ScatterStateSetFunc& setShardFn) {
comm_->scatterState(optimizerStateVector, setShardFn);
});
- } else if(options_->has("pretrained-model")) {
+ LOG(info, "[training] Model reloaded from {}", name);
+ } else if(options_->hasAndNotEmpty("pretrained-model")) {
std::string nameInit = options_->get<std::string>("pretrained-model");
LOG(info,
- "Initialize model weights with the pre-trained model {}",
+ "[training] Initializing model weights with the pre-trained model {}",
nameInit);
size_t i = 0;
@@ -267,86 +431,72 @@ void SyncGraphGroup::load() /*override*/ {
}
void SyncGraphGroup::save(bool final) /*override*/ {
+ // validate(); @TODO: get rid of this everywhere (SyncGraphGroup)
barrier(); // (for better grouping of log messages)
- //LOG(info, "[{}] save() line {}!", this->mpi_->idStr(), __LINE__);
// do final validation
if(final && scheduler_) {
// bring the smoothed model in
// Note that it is sharded. For multi-node, it is sharded over multiple machines, so this is a network access.
// Also note that the swap must run on all MPI processes concurrently, although only one actually validates.
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
swapParamsAvg();
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
if (isMainProcess()) // in multi-node, only first MPI process saves the model (they are all identical)
scheduler_->validate(graphs_, true);
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
swapParamsAvg();
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
}
+ // @TODO: put all this in one place, in new branch this is already localized in one place and class, this is a quick hack which will be
+ // done better after the next merge. Not doing this in other graph_groups as this would only make the merge harder.
+ // Determine model suffix *.npz or *.bin, then use the same suffix for all following models saved.
std::string name = options_->get<std::string>("model");
+ std::string suffix = name.substr(name.size() - 4);
+ ABORT_IF(suffix != ".npz" && suffix != ".bin", "Unknown model suffix {}", suffix);
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
barrier(); // (for better grouping of log messages)
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
// if smoothing then save original (unsmoothed) parameters as well
- // @TODO: Check whether we are reloading the correct file (the unsmoothed one).
if(mvAvg_ && paramsAvg_.size() > 0 && isMainProcess()) // only save from one MPI process
// Save the original parameters in model.npz.orig.npz
- builders_[0]->save(graphs_[0], name + ".orig.npz", true);
+ builders_[0]->save(graphs_[0], name + ".orig" + suffix, true);
// Temporarily switch to the averaged parameters
// Note: the smoothed model is sharded across GPUs, and across MPI processes if applicable. This brings it into MPI process[*].device[*]
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
swapParamsAvg();
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
// save main model file
if (isMainProcess()) { // only save from one MPI process
// if not overwrite then save a copy with number of updates in the model pathname
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
if(!options_->get<bool>("overwrite") && !final) {
std::string numberOfBatches
= scheduler_ ? std::to_string(scheduler_->numberOfBatches())
: "unknown";
std::string nameOverwrite = name;
- nameOverwrite.replace(name.size() - 4, 4, ".iter" + numberOfBatches + ".npz"); // @TODO: use insert?
+ nameOverwrite.replace(name.size() - 4, 4, ".iter" + numberOfBatches + suffix); // @TODO: use insert?
builders_[0]->save(graphs_[0], nameOverwrite);
}
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
// save main model file
builders_[0]->save(graphs_[0], name, true);
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
// save scheduler-related state
if (scheduler_)
scheduler_->save(name);
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
}
// Switch back to the original parameters
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
swapParamsAvg();
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
-#if 0 // temporary, for testing of saving distributed models; must be identical to .orig.npz
- if(mvAvg_ && paramsAvg_.size() > 0 && isMainProcess())
- builders_[0]->save(graphs_[0], name + ".orig_after_swapping.npz", true);
-#endif
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
barrier(); // (for better grouping of log messages)
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
// persist optimizer state
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
shardOpt_[0]->save(name + ".optimizer.npz", shardOpt_,
[&](const OptimizerBase::GatherStateGetFunc& getShardFn) {
return comm_->gatherState(getShardFn);
},
isMainProcess());
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
barrier(); // (for better grouping of log messages)
- //LOG(info, "[{}] save() line {}", this->mpi_->idStr(), __LINE__);
+}
+
+void SyncGraphGroup::finalize() /*override*/ {
+ validate();
+ Base::finalize();
}
} // namespace marian
diff --git a/src/training/graph_group_sync.h b/src/training/graph_group_sync.h
index 47816699..147b172c 100755
--- a/src/training/graph_group_sync.h
+++ b/src/training/graph_group_sync.h
@@ -7,23 +7,26 @@
namespace marian {
class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing {
- const size_t delay_{ 1 }; // optimizer-delay parameter
+ using Base = GraphGroup;
+ const double delay_{1.}; // optimizer-delay parameter. Fractional means to use a fraction of whatever the MB size is
Ptr<ICommunicator> comm_; // [not null] communicator, e.g. NCCLCommunicator
Ptr<IMPIWrapper> mpi_; // [not null] all MPI-like communication goes through this (this is a dummy implementation if no MPI run)
- std::vector<DeviceId> devices_; // [deviceIndex]
- std::vector<Ptr<models::ModelBase>> builders_; // [deviceIndex]
- std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
+ std::vector<DeviceId> devices_; // [deviceIndex]
+ std::vector<Ptr<models::ICriterionFunction>> builders_; // [deviceIndex]
+ std::vector<Ptr<ExpressionGraph>> graphs_; // [deviceIndex]
std::vector<Ptr<OptimizerBase>> shardOpt_; // [deviceIndex]
-
std::vector<Tensor> paramsAvg_; // [deviceIndex] exponentially smoothed parameters, sharded
// @TODO: instead, create an array of ExponentialSmoothing objects, and don't use ExponentialSmoothing as a base class
std::vector<Ptr<TensorAllocator>> paramsAllocs_; // [deviceIndex] we must hold a reference to the memory until this class dies
// @TODO: move this nto ExponentialSmoothing, together with paramsAvg_?
- bool first_{ true }; // gets interpreted and cleared by update()
+ // state for update()
+ bool first_{ true }; // gets interpreted and cleared by update()
+ std::vector<Ptr<data::Batch>> pendingBatches_; // in case of dynamic MB-size scaling, we temporarly buffer up batches across update() calls until enough
+ double updateMultiplier_{1}; // multiplier not applied in collectStats() (no multiplier if not mini-batch-fit)
void initialize(const Ptr<data::Batch>& exampleBatch);
void initializeAvg();
@@ -32,8 +35,11 @@ class SyncGraphGroup : public GraphGroup, public ExponentialSmoothing {
void barrier() const { mpi_->barrier(); } // (we need this several times)
void swapParamsAvg() { if (mvAvg_ && paramsAvg_.size() > 0) comm_->swapParams(paramsAvg_); } // note: must call this on all MPI ranks in parallel
+ bool tryGetSubBatches(Ptr<data::Batch> newBatch, std::vector<Ptr<data::Batch>>& subBatches, size_t& numReadBatches);
+ void update(std::vector<Ptr<data::Batch>> subBatches, size_t numReadBatches);
+
public:
- SyncGraphGroup(Ptr<Options> config);
+ SyncGraphGroup(Ptr<Options> config, Ptr<IMPIWrapper> mpi);
void setScheduler(Ptr<Scheduler> scheduler) override;
@@ -42,7 +48,9 @@ public:
void load() override;
void save(bool final = false) override;
- Ptr<data::BatchStats> collectStats();
+ Ptr<data::BatchStats> collectStats(const std::vector<Ptr<Vocab>>&);
+ void finalize() override;
+
// @TODO: consider to make this a virtual as well? Currently it is a template dispatch
};
} // namespace marian
diff --git a/src/training/scheduler.cpp b/src/training/scheduler.cpp
new file mode 100644
index 00000000..4c30cb04
--- /dev/null
+++ b/src/training/scheduler.cpp
@@ -0,0 +1,43 @@
+#include "scheduler.h"
+#include <signal.h>
+#include <cassert>
+
+namespace marian {
+
+// SIGNAL HANDLING, see scheduler.cpp for definitions
+// Currently, only the following is handled by a custom signal handler:
+// SIGTERM: When SIGTERM is received, the global (static member) flag sigterm_ (false by default) is set to true
+// by signalHandler(). When sigterm_ is true, keepGoing() returns false, and the current state of training models
+// is saved prior to exiting.
+// This functionality is helpful when training on clusters with time limits on compute slots, e.g., on s
+// clusters managed by slurm. Slurm can be asked to sending a (custom) warning signal to a process at a given
+// point in time prior to the hard "time's up".
+
+bool sigterm_{false}; // flag signalling that SIGTERM has been received false by default, set to true by signalHandler(SIGTERM)
+
+void signalHandler(int sig) {
+ // Note: sys_siglist[sig] or stdsignal() describe the effect (e.g.,
+ // 'Terminated' rather than provide the signal name (which are #define(s)
+ // in signal.h), so we have to do custom log messages here.
+ switch (sig) {
+ case SIGTERM: // save models and exit
+ LOG(info, "[training] Scheduler received signal SIGTERM"); // @TODO: figure out if this is safe. The logs are global and thread-safe, so should be OK?
+ sigterm_ = true;
+ break;
+ default:
+ ABORT("No action defined for signal {}", sig);
+ }
+}
+
+// installs signalHandler() for select signals (currently only SIGTERM)
+void installSignalHandlers() {
+ // TODO: use sigaction instead of signal,
+ // cf. https://stackoverflow.com/questions/231912/what-is-the-difference-between-sigaction-and-signal
+ signal(SIGTERM, signalHandler);
+}
+
+bool getSigtermFlag() {
+ return sigterm_;
+}
+
+}
diff --git a/src/training/scheduler.h b/src/training/scheduler.h
index dee62496..7e601632 100755
--- a/src/training/scheduler.h
+++ b/src/training/scheduler.h
@@ -4,53 +4,159 @@
#include "training/training_state.h"
#include "training/validator.h"
#include "training/communicator.h"
+#include "layers/loss.h"
namespace marian {
+bool getSigtermFlag();
+void installSignalHandlers();
+
class Scheduler : public TrainingObserver {
private:
Ptr<Options> options_;
+ Ptr<TrainingState> state_;
std::vector<Ptr<ValidatorBase>> validators_;
bool first_{true};
- Ptr<TrainingState> state_;
-
- timer::Timer timer_, heartBeatTimer_;
+ timer::Timer timer_;
+ timer::Timer heartBeatTimer_;
+
+ // determine scheduled LR decay factor (--lr-decay-inv-sqrt option)
+ float getScheduledLRDecayFactor(const TrainingState& state) const {
+ auto args = options_->get<std::vector<std::string>>("lr-decay-inv-sqrt");
+ ABORT_IF(args.empty() || args.size() > 2, "--lr-decay-inv-sqrt argument must be one or two numbers with units");
+ auto decayGoogle = SchedulingParameter::parse(args[0]);
+ size_t progress = state.getProgressIn(decayGoogle.unit);
+ size_t start = decayGoogle.n;
+ if (args.size() > 1) {
+ auto decayStart = SchedulingParameter::parse(args[1]);
+ ABORT_IF(decayStart && decayStart.unit != decayGoogle.unit,
+ "both --lr-decay-inv-sqrt arguments must have the same unit");
+ start = decayStart.n;
+ }
+ if (decayGoogle && progress > start) {
+ progress = progress - start + decayGoogle.n; // shift so that we get 1 at progress==start
+ return (float)(std::sqrt((double)decayGoogle.n / (double)progress));
+ }
+ else
+ return 1.f;
+ }
- float getLearningRate(TrainingState& state) {
+ // update current learning rate in state.eta
+ // This considers
+ // - base LR (--learn-rate)
+ // - LR warm-up (--lr-warmup, --lr=warmup-start-rate)
+ // - scheduled LR decay (--lr-decay-inv-sqrt)
+ // - state-based LR decay (--lr-decay, --lr-decay-strategy)
+ void updateLearningRate(TrainingState& state) const {
float baselr = options_->get<float>("learn-rate");
- float mult1 = 1.f;
- auto warmup = SchedulingParameter::parse(options_->get<std::string>("lr-warmup"));
- if(warmup) {
- ABORT_IF(state.warmupStart && state.warmupStart.unit != warmup.unit, "lr-warmup and warmup-start must have the same unit");
- auto bno = state.getProgressIn(warmup.unit) - state.warmupStart.n;
- mult1 = std::min(1.f, (float)bno / (float)warmup.n);
+ // warm-up factor
+ float warmupFactor = 1.f;
+ auto warmupParam = SchedulingParameter::parse(options_->get<std::string>("lr-warmup"));
+ if(warmupParam) {
+ ABORT_IF(state.warmupStart && state.warmupStart.unit != warmupParam.unit,
+ "lr-warmup and warmup-start must have the same unit");
+ auto bno = state.getProgressIn(warmupParam.unit) - state.warmupStart.n;
+ warmupFactor = std::min(1.f, (float)bno / (float)warmupParam.n);
}
- float mult2 = 1.f;
- auto decayGoogle = SchedulingParameter::parse(options_->get<std::string>("lr-decay-inv-sqrt"));
- if(decayGoogle) {
- mult2 = std::min(1.f, (float)(std::sqrt(decayGoogle.n) / std::sqrt(state.getProgressIn(decayGoogle.unit))));
- }
+ // TODO: why lr-warmup-start-rate is extracted from options_ instead of using state.warmupStart?
+ float lrStart = options_->get<float>("lr-warmup-start-rate");
+ baselr = lrStart + (baselr - lrStart) * warmupFactor; // linear interpolation between
+ // lr-warmup-start-rate to learn-rate
- baselr = baselr * mult1 * mult2;
+ // schedule-based decay factor (--lr-decay-inv-sqrt)
+ float scheduledDecayFactor = getScheduledLRDecayFactor(state);
+ baselr = baselr * scheduledDecayFactor;
- float lrStart = options_->get<float>("lr-warmup-start-rate");
- if(lrStart > 0)
- baselr = baselr - lrStart * mult1 * mult2 + lrStart * mult2;
+ // factor in state-based decay and set final LR as state.eta
+ state.updateEta(baselr);
+ }
- return baselr;
+ std::string formatLoss(std::string lossType,
+ bool dispLabelCounts,
+ size_t batchLabels,
+ Ptr<TrainingState> state) {
+ std::stringstream ss;
+ ss << "Cost ";
+ ss << std::setprecision(8) << std::fixed;
+
+ // @TODO: put a single loss formatting function into loss.h and reuse here to avoid code duplication
+ // @TODO: use dispLabelCounts with any display type?
+ // @TODO: bugbug cost-type ce-mean-words with multi-loss-type mean divides too much in display
+ if(lossType == "ce-mean-words") {
+ ss << state->costSum / state->costCount;
+ } else if(lossType == "ce-sum" && dispLabelCounts) {
+ ss << state->costSum / state->costCount
+ << " * " << utils::withCommas(state->costCount);
+ if(batchLabels > 0)
+ ss << " @ " << utils::withCommas(batchLabels);
+ ss << " after " << utils::withCommas(state->labelsTotal);
+ } else if(lossType == "ce-sum" && !dispLabelCounts) {
+ ss << state->costSum / state->updatesDisp; // average over batches
+ } else if(lossType == "perplexity") {
+ ss << std::exp(state->costSum / state->costCount);
+ } else if(lossType == "cross-entropy" || lossType == "ce-mean") { // backwards-compat, @TODO: get rid of this?
+ ss << state->costSum / state->samplesDisp;
+ } else {
+ ABORT("Unknown loss type {}", lossType);
+ }
+
+ return ss.str();
}
public:
+ // test if any parameters specify dynamic MB scaling
+ bool isDynamicMBSizeScaling() const {
+ auto mbWarmup = SchedulingParameter::parse(options_->get<std::string>("mini-batch-warmup"));
+ auto mbTracking = options_->get<bool>("mini-batch-track-lr");
+ return mbWarmup || mbTracking;
+ }
+
+ // determine dynamic MB scaling factor
+ double getDynamicMBSizeMultiplier() const {
+ double ratio = 1.0;
+
+ auto mbWarmup = SchedulingParameter::parse(options_->get<std::string>("mini-batch-warmup"));
+ if (mbWarmup) {
+ // mini-batch-warmup
+ LOG_ONCE(info, "[scheduler] Mini-batch size warmup {}", std::string(mbWarmup));
+ // This ramps up MB size at start, relative to progress within warm-up period.
+ size_t progress = state_->getProgressIn(mbWarmup.unit); // number of updates/labels processed
+ auto progressRatio = (double)progress / (double)mbWarmup.n; // where are we relatively within target warm-up period
+ // if unit is labels, then account for the fact that our increment itself is not constant
+ if (mbWarmup.unit == SchedulingUnit::trgLabels)
+ progressRatio = std::sqrt(progressRatio);
+ if (progressRatio < 1)
+ ratio *= progressRatio;
+ }
+
+ // dynamic MB-size tracking with learning rate
+ // As LR goes down, MB gets ramped up by the same ratio, which has been found to be safe.
+ auto mbTracking = options_->get<bool>("mini-batch-track-lr");
+ if (mbTracking) {
+ auto lrFactor = getScheduledLRDecayFactor(*state_) * state_->factor; // (don't include lr-warmup)
+ if (lrFactor != 1)
+ LOG_ONCE(info, "[scheduler] Dynamic mini-batch size adjustment enabled and kicking in");
+ ratio /= lrFactor;
+ }
+ return ratio;
+ }
+
Scheduler(Ptr<Options> options, Ptr<TrainingState> state)
: options_(options), state_(state) {
- state_->eta = getLearningRate(*state);
+ ABORT_IF(state_->factor != 1, "state.factor unexpectedly not 1 at this point??");
+ updateLearningRate(*state);
+ installSignalHandlers();
}
bool keepGoing() {
+
+ if(getSigtermFlag()) // received signal SIGERM => exit gracefully
+ return false;
+
// stop if it reached the maximum number of epochs
size_t stopAfterEpochs = options_->get<size_t>("after-epochs");
if(stopAfterEpochs > 0 && state_->epochs > stopAfterEpochs)
@@ -77,15 +183,20 @@ public:
}
void started() { LOG(info, "Training started"); }
- void finished() { LOG(info, "Training finished"); }
+ void finished() {
+ if (getSigtermFlag())
+ LOG(info, "Training interrupted (SIGTERM).");
+ else
+ LOG(info, "Training finished");
+ }
+
void addValidator(Ptr<ValidatorBase> validator) {
validators_.push_back(validator);
registerTrainingObserver(validators_.back());
if(!state_->loaded) {
- state_->validators[validator->type()]["last-best"]
- = validator->initScore();
+ state_->validators[validator->type()]["last-best"] = validator->initScore();
state_->validators[validator->type()]["stalled"] = 0;
}
if(validators_.size() == 1)
@@ -103,12 +214,12 @@ public:
}
void validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
- bool final = false) {
+ bool isFinal = false) {
// Do not validate if already validated (for instance, after the model is
- // loaded) or if validation is scheduled for another update
- if(state_->validated
- || (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq"))
- && !final))
+ // loaded) or if validation is scheduled for another update, or when signal SIGTERM was received
+ if(getSigtermFlag() // SIGTERM was received
+ || state_->validated // already validated (in resumed training, for example)
+ || (!state_->enteredNewPeriodOf(options_->get<std::string>("valid-freq")) && !isFinal)) // not now
return;
bool firstValidator = true;
@@ -117,15 +228,15 @@ public:
continue;
size_t stalledPrev = validator->stalled();
- float value = validator->validate(graphs);
+ float value = validator->validate(graphs, state_);
if(validator->stalled() > 0) {
LOG_VALID(info,
- "Ep. {} : Up. {} : {} : {} : stalled {} times",
+ "Ep. {} : Up. {} : {} : {} : stalled {} times (last best: {})",
state_->epochs,
state_->batches,
validator->type(),
value,
- validator->stalled());
+ validator->stalled(), validator->lastBest());
} else {
LOG_VALID(info,
"Ep. {} : Up. {} : {} : {} : new best",
@@ -158,126 +269,97 @@ public:
return 0;
}
- void update(float cost, Ptr<data::Batch> batch) {
- update(cost, std::vector<Ptr<data::Batch>>({batch}));
+ void update(StaticLoss rationalLoss, Ptr<data::Batch> batch) {
+ update(rationalLoss, /*numReadBatches=*/1, /*batchSize=*/batch->size(), /*batchLabels=*/batch->wordsTrg());
}
- void update(float cost, const std::vector<Ptr<data::Batch>>& batches, Ptr<IMPIWrapper> mpi = nullptr) {
- state_->rememberPreviousProgress(); // note: epoch increases happen at the wrong place, hence -freq parameters do not support epoch units
+ // @TODO: go back to function which takes batch as an argument? The current arguments make it hard
+ // to choose which subbatch should be used for speed display. For sequence-classifiers it's more interesting
+ // to see the source-words consumed rather than the labels.
+ void update(StaticLoss rationalLoss,
+ size_t numReadBatches, // number of batches read by the reader (for seeking in case of restart)
+ size_t batchSize, // total number of sentences in batch
+ size_t batchLabels, // total number of target words in batch
+ Ptr<IMPIWrapper> mpi = nullptr) {
+ state_->rememberPreviousProgress(); // note: epoch increases happen at the wrong place, hence
+ // -freq parameters do not support epoch units
state_->validated = false;
- size_t batchSize = 0; // number of sentences in batch
- size_t batchLabels = 0; // number of target words in batch
+ // Since batchLabels is counted across all MPI processes, we also should temporarily
+ // extrapolate cost across MPI processes, to have numbers in the right range.
+ // When doing the actual log, we then aggregate across MPI processes to get the accurate number.
+ if(mpi)
+ rationalLoss.loss *= mpi->numMPIProcesses();
- for(const auto& batch : batches) {
- if (batch) { // (nullptr is allowed as result of split)
- batchSize += batch->size();
- batchLabels += batch->words(-1);
- }
- }
+ // @BUGBUG: rationalLoss.count is float, not a count. Possible solution: make (costSum, costCount) a StaticLoss object as well
+ state_->costSum += rationalLoss.loss; // aggregate sum cost since last display
+ state_->costCount += (size_t)rationalLoss.count; // cost gets normalized w.r.t. this in display
- // extrapolate cost across MPI processes, so that we have numbers in the right range
- // When doing the actual log, we then aggregate across MPI processes to get the accurate number.
- if (mpi)
- cost *= mpi->numMPIProcesses(); // @BUGBUG: this is presently correct for ce-sum, but possibly not the av-based losses
+ state_->updatesDisp += 1;
+ state_->samplesDisp += batchSize;
+ state_->wordsDisp += batchLabels; //@TODO: this is wrong // words at given input processed since last display, for speed display
- // reconstruct sum cost, for displaying epoch-level averages instead of minibatch-level
- auto costType = options_->get<std::string>("cost-type");
- auto dispLabelCounts = options_->get<bool>(
- "disp-label-counts"); // if true then show as "cost per label * number of labels"
- if(dispLabelCounts) {
- auto count = // what was cost normalized with originally?
- /*if*/ (costType == "ce-sum") ?
- 1
- /*else if*/ : ((costType == "ce-mean-words") ?
- batchLabels
- /*else*/ : // all others: treat like ce-mean (not correct for some)
- batchSize);
- state_->costSum += cost * count; // aggregate sum cost since last display
- state_->costCount += batchLabels; // cost gets normalized w.r.t. this in display
- } else { // (back compat)
- state_->costSum += cost * batchSize;
- state_->costCount += batchSize;
- }
- state_->wordsDisp += batchLabels; // target words processed since last display, for speed display
- state_->samplesEpoch += batchSize; // sentences processed in this epoch
- state_->labelsTotal += batchLabels; // total labels processed
+ state_->samplesEpoch += batchSize; // sentences processed in this epoch
+ // @BUGBUG: rationalLoss.count is float, not a count
+ state_->labelsTotal += (size_t)rationalLoss.count; // total labels processed
- state_->newBatch();
+ state_->newUpdate(numReadBatches);
+
+ // reconstruct sum cost, for displaying epoch-level averages instead of minibatch-level
+ auto lossType = options_->get<std::string>("cost-type");
+ auto dispLabelCounts = options_->get<bool>("disp-label-counts"); // if true then show as "cost per label * number of labels"
if(state_->enteredNewPeriodOf(options_->get<std::string>("disp-freq")) ||
state_->batches <= options_->get<size_t>("disp-first")) {
// if MPI then aggregate precise cost across workers
- if (mpi) {
- //LOG(info, "all-reducing cost from {}", state_->costSum);
+ if(mpi) {
state_->costSum /= mpi->numMPIProcesses(); // undo the extra scaling
mpi->allReduce(&state_->costSum, &state_->costSum, 1, MPI_FLOAT, MPI_SUM);
- //LOG(info, "all-reduced cost to {}", state_->costSum);
}
- if (mpi && mpi->myMPIRank() != 0)
- ; // skip the report on alternate worker processes
- else if(dispLabelCounts) {
- if(options_->get<bool>("lr-report")) { // if true then show the learning rate
- LOG(info,
- "Ep. {} : Up. {} : Sen. {} : Cost {:.8f} * {} after {} : Time {:.2f}s : {:.2f} "
- "words/s : L.r. {:.4e}",
- state_->epochs,
- state_->batches,
- utils::withCommas(state_->samplesEpoch),
- state_->costSum / state_->costCount,
- utils::withCommas(state_->costCount), // show cost as "av * count"
- utils::withCommas(state_->labelsTotal),
- timer_.elapsed(),
- state_->wordsDisp / timer_.elapsed(),
- state_->eta);
- } else {
- LOG(info,
- "Ep. {} : Up. {} : Sen. {} : Cost {:.8f} * {} after {} : Time {:.2f}s : {:.2f} "
- "words/s",
- state_->epochs,
- state_->batches,
- utils::withCommas(state_->samplesEpoch),
- state_->costSum / state_->costCount,
- utils::withCommas(state_->costCount),
- utils::withCommas(state_->labelsTotal),
- timer_.elapsed(),
- state_->wordsDisp / timer_.elapsed());
- }
+
+ if(mpi && mpi->myMPIRank() != 0) {
+ // skip the report on alternate worker processes
+ } else if(options_->get<bool>("lr-report")) {
+ LOG(info,
+ "Ep. {} : Up. {} : Sen. {} : {} : Time {:.2f}s : {:.2f} words/s : L.r. {:.4e}",
+ state_->epochs,
+ state_->batches,
+ utils::withCommas(state_->samplesEpoch),
+ formatLoss(lossType, dispLabelCounts, batchLabels, state_),
+ timer_.elapsed(),
+ state_->wordsDisp / timer_.elapsed(),
+ state_->eta);
} else {
- if(options_->get<bool>("lr-report")) {
- LOG(info,
- "Ep. {} : Up. {} : Sen. {} : Cost {:.8f} : Time {:.2f}s : {:.2f} words/s : L.r. {:.4e}",
- state_->epochs,
- state_->batches,
- utils::withCommas(state_->samplesEpoch),
- state_->costSum / state_->costCount,
- timer_.elapsed(),
- state_->wordsDisp / timer_.elapsed(),
- state_->eta);
- } else {
- LOG(info,
- "Ep. {} : Up. {} : Sen. {} : Cost {:.8f} : Time {:.2f}s : {:.2f} words/s",
- state_->epochs,
- state_->batches,
- utils::withCommas(state_->samplesEpoch),
- state_->costSum / state_->costCount,
- timer_.elapsed(),
- state_->wordsDisp / timer_.elapsed());
- }
+ LOG(info,
+ "Ep. {} : Up. {} : Sen. {} : {} : Time {:.2f}s : {:.2f} words/s",
+ state_->epochs,
+ state_->batches,
+ utils::withCommas(state_->samplesEpoch),
+ formatLoss(lossType, dispLabelCounts, 0, state_), // ignore batchLabels
+ timer_.elapsed(),
+ state_->wordsDisp / timer_.elapsed());
}
+
+
timer_.start();
- state_->costSum = 0;
- state_->costCount = 0;
- state_->wordsDisp = 0;
+ state_->costSum = 0;
+ state_->costCount = 0;
+
+ state_->updatesDisp = 0;
+ state_->samplesDisp = 0;
+ state_->wordsDisp = 0;
}
+
// progress heartbeat for MS-internal Philly compute cluster
// This environment variable exists when running on the cluster.
+ using namespace std::chrono;
if((!mpi || mpi->myMPIRank() == 0) && getenv("PHILLY_JOB_ID")
&& heartBeatTimer_.elapsed<std::chrono::minutes>() >= 10) {
- printf("PROGRESS: %.2f%%\nEVALERR: %.7f\n", (double)state_->epochs, state_->costSum / state_->costCount), fflush(stdout);
-#if 0
- LOG(info, "heart beat after {} updates", state_->batches);
-#endif
+ printf("PROGRESS: %.2f%%\nEVALERR: %.7f%%\n",
+ (double)state_->epochs,
+ state_->costSum / (state_->costCount ? state_->costCount : 1) / (mpi ? mpi->numMPIProcesses() : 1));
+ fflush(stdout);
+ std::cout << "MBSIZE: " << batchLabels << " after " << state_->batches << " updates = " << state_->labelsTotal << " labels" << std::endl << std::flush;
heartBeatTimer_.start();
}
}
@@ -289,9 +371,20 @@ public:
if(options_->get<bool>("no-restore-corpus")) {
state_->samplesEpoch = 0;
- state_->costSum = 0;
- state_->costCount = 0;
- state_->wordsDisp = 0;
+ state_->costSum = 0;
+ state_->costCount = 0;
+
+ state_->updatesDisp = 0;
+ state_->samplesDisp = 0;
+ state_->wordsDisp = 0;
+ }
+
+ if(options_->get<bool>("valid-reset-stalled")) {
+ state_->stalled = 0;
+ state_->maxStalled = 0;
+ for(const auto& validator : validators_) {
+ state_->validators[validator->type()]["stalled"] = 0;
+ }
}
state_->newLoad();
@@ -299,9 +392,8 @@ public:
void save(const std::string& name) {
// Save config options
- YAML::Node yaml = options_->getYaml();
std::ofstream fout(name + ".yml");
- fout << yaml;
+ fout << options_->asYamlString();
// Save training progress
state_->save(name + ".progress.yml");
}
@@ -313,10 +405,9 @@ public:
}
void actAfterEpoch(TrainingState& state) override {
- float factor = (float)options_->get<double>("lr-decay"); // @TODO: <float>?
+ float factor = options_->get<float>("lr-decay");
- float baselr = getLearningRate(state);
- state.eta = baselr * state.factor;
+ updateLearningRate(state);
if(factor > 0.0) {
bool decay = false;
@@ -346,11 +437,8 @@ public:
if(decay) {
state.factor *= factor;
- state.eta = baselr * state.factor;
- LOG(info,
- "Decaying learning rate to {} in epoch {}",
- state.eta,
- state.epochs);
+ updateLearningRate(state);
+ LOG(info, "Decaying learning rate to {} in epoch {}", state.eta, state.epochs);
state.reset = options_->get<bool>("lr-decay-reset-optimizer");
if(state.reset)
@@ -365,11 +453,10 @@ public:
}
void actAfterBatches(TrainingState& state) override {
- float factor = (float)options_->get<double>("lr-decay"); // @TODO: <float>?
+ float factor = options_->get<float>("lr-decay");
state.reset = false;
- float baselr = getLearningRate(state);
- state.eta = baselr * state.factor;
+ updateLearningRate(state);
if(factor > 0.0) {
if(options_->get<std::string>("lr-decay-strategy") == "batches") {
@@ -379,11 +466,8 @@ public:
if(start > 0 && freq > 0 && state.batches >= start
&& ((state.batches - start) % freq == 0)) {
state.factor *= factor;
- state.eta = baselr * state.factor;
- LOG(info,
- "Decaying learning rate to {} after {} batches",
- state.eta,
- state.batches);
+ updateLearningRate(state);
+ LOG(info, "Decaying learning rate to {} after {} batches", state.eta, state.batches);
state.reset = options_->get<bool>("lr-decay-reset-optimizer");
if(state.reset)
@@ -391,6 +475,7 @@ public:
if(options_->get<bool>("lr-decay-repeat-warmup")) {
LOG(info, "Restarting learning rate warmup");
+ // TODO: avoid repeating this many times and minimize calls to options_->get
state.warmupStart.n = state.getProgressIn(SchedulingParameter::parse(options_->get<std::string>("lr-warmup")).unit);
}
}
@@ -411,21 +496,19 @@ public:
}
void actAfterStalled(TrainingState& state) override {
- float factor = (float)options_->get<double>("lr-decay"); // @TODO: <float>?
+ float factor = options_->get<float>("lr-decay");
state.reset = false;
- float baselr = getLearningRate(state);
- state.eta = baselr * state.factor;
+ updateLearningRate(state);
if(factor > 0.0) {
if(options_->get<std::string>("lr-decay-strategy") == "stalled") {
- size_t startStalled
- = options_->get<std::vector<size_t>>("lr-decay-start").front();
+ size_t startStalled = options_->get<std::vector<size_t>>("lr-decay-start").front();
if(startStalled && state.stalled && state.stalled % startStalled == 0) {
state.factor *= factor;
- state.eta = baselr * state.factor;
+ updateLearningRate(state);
LOG(info,
- "Decaying learning rate to {} after stalled {} time(s)",
+ "Decaying learning rate to {} after having stalled {} time(s)",
state.eta,
state.stalled);
diff --git a/src/training/training.h b/src/training/training.h
index 2209b96e..d4733963 100755..100644
--- a/src/training/training.h
+++ b/src/training/training.h
@@ -34,50 +34,60 @@ public:
dataset->prepare();
+ auto mpi = initMPI(/*multiThreaded=*/!options_->get<bool>("sync-sgd")); // @TODO: do we need the multiThreaded distinction at all?
+
Ptr<BatchStats> stats;
if(options_->get<bool>("mini-batch-fit")) {
LOG(info,
- "[batching] Collecting statistics for batch fitting with step size "
- "{}",
+ "[batching] Collecting statistics for batch fitting with step size {}",
options_->get<size_t>("mini-batch-fit-step"));
- // @TODO, better fake batch with vocabulary
- auto model = New<ModelWrapper>(options_);
- THREAD_GUARD(stats = model->collectStats());
- LOG(info, "[batching] Done");
+ // @TODO this should receive a function object that can generate a fake batch;
+ // that way vocabs would not be exposed.
+ auto model = New<ModelWrapper>(options_, mpi);
+
+ // use temporary scheduler to make sure everything gets destroyed properly
+ // otherwise the scheduler believes that registered objects still exist
+ auto tempTrainState = New<TrainingState>(options_->get<float>("learn-rate"));
+ auto tempScheduler = New<Scheduler>(options_, tempTrainState);
+
+ model->setScheduler(tempScheduler); // collectStats() needs to know about dynamic MB scaling
+ stats = model->collectStats(dataset->getVocabs());
+ LOG(info, "[batching] Done. Typical MB size is {} target words", stats->estimateTypicalTrgWords());
}
auto trainState = New<TrainingState>(options_->get<float>("learn-rate"));
auto scheduler = New<Scheduler>(options_, trainState);
- if((options_->has("valid-sets") || options_->has("valid-script-path"))
+ if((options_->hasAndNotEmpty("valid-sets") || options_->hasAndNotEmpty("valid-script-path"))
&& SchedulingParameter::parse(options_->get<std::string>("valid-freq"))) {
for(auto validator : Validators(dataset->getVocabs(), options_))
scheduler->addValidator(validator);
}
auto batchGenerator = New<CorpusBatchGenerator>(dataset, options_, stats);
+
scheduler->registerTrainingObserver(batchGenerator);
- auto model = New<ModelWrapper>(options_);
+ auto model = New<ModelWrapper>(options_, mpi);
model->setScheduler(scheduler);
+ model->setTypicalTrgBatchWords(batchGenerator->estimateTypicalTrgBatchWords()); // needed for dynamic MB scaling
model->load();
- // @TODO: shuffle_ as a private attribute in BG
- auto shuffle = !options_->get<bool>("no-shuffle");
bool restored = !options_->get<bool>("no-restore-corpus")
- && batchGenerator->restore(trainState, shuffle);
+ && batchGenerator->restore(trainState);
+ // -- main training loop
scheduler->started();
while(scheduler->keepGoing()) {
if(!restored)
- batchGenerator->prepare(shuffle);
+ batchGenerator->prepare();
restored = false;
- // @TODO: try to use for(auto ...)
- for(auto batchIt = std::begin(*batchGenerator);
- batchIt != std::end(*batchGenerator) && scheduler->keepGoing();
- batchIt++) {
- model->update(*batchIt);
+ // main training loop for one epoch
+ for(auto batch : *batchGenerator) {
+ if (!scheduler->keepGoing())
+ break;
+ model->update(batch);
}
if(scheduler->keepGoing())
@@ -85,12 +95,15 @@ public:
}
scheduler->finished();
- model->finalize();
+ model->finalize(); // allow async to sync before final save --@TODO: rename, or move into save()
- // Avoid saving the model twice if it has been loaded and training did not
- // progress
+ // Avoid saving the model twice if it has been loaded and training did not progress
if(!trainState->loaded)
model->save(true);
+
+ // Signal success to a potential MPI runner
+ model = nullptr; // release any reference to MPI that model may hold
+ finalizeMPI(std::move(mpi));
}
};
} // namespace marian
diff --git a/src/training/training_state.h b/src/training/training_state.h
index 2deefe47..c1ddfc83 100755..100644
--- a/src/training/training_state.h
+++ b/src/training/training_state.h
@@ -2,6 +2,7 @@
#include "common/definitions.h"
#include "common/filesystem.h"
+#include "common/utils.h"
#include <fstream>
#include <vector>
@@ -12,6 +13,8 @@ class TrainingState;
class TrainingObserver {
public:
+ virtual ~TrainingObserver() {}
+
virtual void init(TrainingState&) {}
virtual void actAfterEpoch(TrainingState&) {}
virtual void actAfterBatches(TrainingState&) {}
@@ -25,35 +28,39 @@ enum class SchedulingUnit {
updates, // "u": number of updates so far (batches)
epochs // "e": number of epochs begun so far (very first epoch is 1)
};
+
struct SchedulingParameter {
size_t n{0}; // number of steps measured in 'unit'
SchedulingUnit unit{SchedulingUnit::updates}; // unit of value
// parses scheduling parameters of the form NU where N=unsigned int and U=unit
- // Examples of valid inputs: "16000u" (16000 updates), "32000000t" (32 million target labels), "100e" (100 epochs).
+ // Examples of valid inputs: "16000u" (16000 updates), "32000000t" (32 million target labels),
+ // "100e" (100 epochs).
static SchedulingParameter parse(std::string param) {
SchedulingParameter res;
- if (!param.empty() && param.back() >= 'a') {
- switch (param.back()) {
- case 't': res.unit = SchedulingUnit::trgLabels; break;
- case 'u': res.unit = SchedulingUnit::updates; break;
- case 'e': res.unit = SchedulingUnit::epochs; break;
- default: ABORT("invalid unit '{}' in {}", param.back(), param);
+ if(!param.empty() && param.back() >= 'a') {
+ switch(param.back()) {
+ case 't': res.unit = SchedulingUnit::trgLabels; break;
+ case 'u': res.unit = SchedulingUnit::updates; break;
+ case 'e': res.unit = SchedulingUnit::epochs; break;
+ default: ABORT("invalid unit '{}' in {}", param.back(), param);
}
param.pop_back();
}
- res.n = (size_t)std::stoull(param);
+ double number = utils::parseNumber(param);
+ res.n = (size_t)number;
+ ABORT_IF(number != (double)res.n, "Scheduling parameters must be whole numbers");
return res;
}
operator bool() const { return n > 0; } // check whether it is specified
operator std::string() const { // convert back for storing in config
- switch (unit) {
- case SchedulingUnit::trgLabels: return std::to_string(n) + "t";
- case SchedulingUnit::updates : return std::to_string(n) + "u";
- case SchedulingUnit::epochs : return std::to_string(n) + "e";
- default: ABORT("corrupt enum value");
+ switch(unit) {
+ case SchedulingUnit::trgLabels: return std::to_string(n) + "t";
+ case SchedulingUnit::updates : return std::to_string(n) + "u";
+ case SchedulingUnit::epochs : return std::to_string(n) + "e";
+ default: ABORT("corrupt enum value for scheduling unit");
}
}
};
@@ -62,9 +69,9 @@ class TrainingState {
public:
// Current epoch
size_t epochs{1};
- // The total number of batches (=updates) processed since beginning of training --@TODO: rename to 'updates'
+ // The total number of updates since beginning of training --@TODO: rename to 'updates'
size_t batches{0};
- // The number of batches seen in this epoch --@TODO: rename to 'updatesEpoch' or 'updatesInCurrentEpoch'
+ // The number of batches seen in this epoch --note: not updates; an update can consist of multiple batches
size_t batchesEpoch{0};
// The number of sentences seen in this epoch --@TODO: rename to 'sentencesEpoch'
size_t samplesEpoch{0};
@@ -86,19 +93,27 @@ public:
// Reset optimizer parameters
bool reset{false};
- // Learning rate
+ // Current learning rate, representing all adjustment processes and factors
float eta;
- // Multiplication factor for learning rate
+ void updateEta(float dynamicBaseLR) { // note: no other function may write to 'eta' (besides load())
+ eta = dynamicBaseLR * factor;
+ }
+ // State-based multiplication factor for learning rate
float factor{1.f};
SchedulingParameter warmupStart; // has same unit as lr-warmup
// Sum of costs since last display
float costSum{0};
- // Number of words/labels/samples (depending on cost-type) aggregated in
+ // Number of labels aggregated in
// costSum since last display
size_t costCount{0};
- // Number of words seen since last display, for speed measurement
+
+ // Number of words seen since last display
size_t wordsDisp{0};
+ // Number of samples/sentences seen since last display
+ size_t samplesDisp{0};
+ // Number of updates seen since last display
+ size_t updatesDisp{0};
// The state of the random number generator from a batch generator
std::string seedBatch;
@@ -111,20 +126,22 @@ public:
// Set flag if the model was validated in the current batch
bool validated{false};
- TrainingState(float learnRate) : eta(learnRate) {}
+ TrainingState(float learnRate) {
+ updateEta(learnRate);
+ }
void registerObserver(Ptr<TrainingObserver> observer) {
- observers_.push_back(observer);
- observers_.back()->init(*this);
+ observer->init(*this);
+ wObservers_.push_back(observer);
}
// return the totals count that corresponds to the given unit (batches, labels, or epochs)
size_t getProgressIn(SchedulingUnit u) const {
- switch (u) {
- case SchedulingUnit::trgLabels: return labelsTotal;
- case SchedulingUnit::updates : return batches;
- case SchedulingUnit::epochs : return epochs;
- default: ABORT("corrupt enum value");
+ switch(u) {
+ case SchedulingUnit::trgLabels: return labelsTotal;
+ case SchedulingUnit::updates : return batches;
+ case SchedulingUnit::epochs : return epochs;
+ default: ABORT("corrupt enum value");
}
}
@@ -137,11 +154,11 @@ public:
}
size_t getPreviousProgressIn(SchedulingUnit u) const {
- switch (u) {
- case SchedulingUnit::trgLabels: return prevLabelsTotal;
- case SchedulingUnit::updates : return prevBatches;
- case SchedulingUnit::epochs : return prevEpochs;
- default: ABORT("corrupt enum value");
+ switch(u) {
+ case SchedulingUnit::trgLabels: return prevLabelsTotal;
+ case SchedulingUnit::updates : return prevBatches;
+ case SchedulingUnit::epochs : return prevEpochs;
+ default: ABORT("corrupt enum value");
}
}
@@ -149,6 +166,7 @@ public:
// unit in which that parameter is given. There are a few edge cases:
// - this function will be called many times within the same epoch
// - labelsTotal does not increment by 1, so simple modulus does not work
+ //
// So instead of modulus==0, this function compares the previous progress/period
// to the current, and triggers if they differ (i.e. the border between two
// periods was crossed). This requires that rememberPreviousProgress() is called
@@ -158,7 +176,8 @@ public:
bool enteredNewPeriodOf(std::string schedulingParam) const {
auto period = SchedulingParameter::parse(schedulingParam);
ABORT_IF(period.unit == SchedulingUnit::epochs,
- "Unit {} is not supported for frequency parameters (the one(s) with value {})", schedulingParam);
+ "Unit {} is not supported for frequency parameters (the one(s) with value {})",
+ schedulingParam);
auto previousProgress = getPreviousProgressIn(period.unit);
auto progress = getProgressIn(period.unit);
return period && progress / period.n != previousProgress / period.n;
@@ -166,33 +185,45 @@ public:
void newEpoch() {
++epochs;
- for(auto observer : observers_)
+ for(auto wObserver : wObservers_) {
+ auto observer = wObserver.lock();
+ ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
observer->actAfterEpoch(*this);
+ }
samplesEpoch = 0;
batchesEpoch = 0;
}
- void newBatch() {
+ void newUpdate(size_t batchesInUpdate) {
++batches;
- ++batchesEpoch;
+ batchesEpoch += batchesInUpdate;
loaded = false;
validated = false;
- for(auto observer : observers_)
+ for(auto wObserver : wObservers_) {
+ auto observer = wObserver.lock();
+ ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
observer->actAfterBatches(*this);
+ }
}
void newStalled(size_t num) {
stalled = num;
if(num > maxStalled)
++maxStalled;
- for(auto observer : observers_)
+ for(auto wObserver : wObservers_) {
+ auto observer = wObserver.lock();
+ ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
observer->actAfterStalled(*this);
+ }
}
void newLoad() {
loaded = true;
- for(auto observer : observers_)
+ for(auto wObserver : wObservers_) {
+ auto observer = wObserver.lock();
+ ABORT_IF(!observer, "Training observer object expired. Make sure all registered observers exist during scheduler life time");
observer->actAfterLoaded(*this);
+ }
}
void load(const std::string& name) {
@@ -204,13 +235,16 @@ public:
epochs = config["epochs"].as<size_t>();
batches = config["batches"].as<size_t>();
batchesEpoch = config["batches-epoch"].as<size_t>();
- // (different serialization name for back compat)
+ // different serialization name for backward compatibility
samplesEpoch = config["samples"].as<size_t>();
- // (optional for back compat)
+
+ // clang-format off
+ // optional for backward compatibility
labelsTotal = config["labels-total"] ? config["labels-total"].as<size_t>() : 0;
prevLabelsTotal = config["prev-labels-total"] ? config["prev-labels-total"].as<size_t>() : 0;
prevBatches = config["prev-batches"] ? config["prev-batches"].as<size_t>() : 0;
prevEpochs = config["prev-epochs"] ? config["prev-epochs"].as<size_t>() : 0;
+ // clang-format on
stalled = config["stalled"].as<size_t>();
maxStalled = config["stalled-max"].as<size_t>();
@@ -224,15 +258,17 @@ public:
warmupStart = SchedulingParameter::parse(config["warmup-start"].as<std::string>());
costSum = config["cost-sum"].as<float>();
- // (different serialization name for back compat)
- costCount = config["disp-samples"].as<size_t>();
+ costCount = config["cost-count"].as<size_t>();
+
wordsDisp = config["disp-words"].as<size_t>();
+ samplesDisp = config["disp-samples"].as<size_t>();
+ updatesDisp = config["disp-updates"].as<size_t>();
seedBatch = config["seed-batch"].as<std::string>();
seedCorpus = config["seed-corpus"].as<std::string>();
}
- void save(const std::string& name) {
+ void save(const std::string& name) const {
std::ofstream fout(name);
YAML::Node config;
@@ -257,7 +293,10 @@ public:
config["warmup-start"] = std::string(warmupStart);
config["cost-sum"] = costSum;
- config["disp-samples"] = costCount;
+ config["cost-count"] = costCount;
+
+ config["disp-updates"] = updatesDisp;
+ config["disp-samples"] = samplesDisp;
config["disp-words"] = wordsDisp;
config["seed-batch"] = seedBatch;
@@ -266,7 +305,20 @@ public:
fout << config;
}
+ std::string fillTemplate(const std::string& templ) const {
+ // The formatting below uses fmtlib, which is included with spdlog
+ // and is included via the logger.
+ return fmt::format(templ.c_str(),
+ fmt::arg("E", epochs),
+ fmt::arg("U", batches),
+ fmt::arg("B", batchesEpoch),
+ fmt::arg("T", labelsTotal));
+ }
+
private:
- std::vector<Ptr<TrainingObserver>> observers_;
+ // this needs to be a vector of weak pointers, otherwise
+ // it is likely to cause circular dependencies.
+ std::vector<Weak<TrainingObserver>> wObservers_;
+
};
} // namespace marian
diff --git a/src/training/validator.cpp b/src/training/validator.cpp
index a3f55c71..79829e3a 100644
--- a/src/training/validator.cpp
+++ b/src/training/validator.cpp
@@ -2,10 +2,10 @@
namespace marian {
-std::vector<Ptr<Validator<data::Corpus>>> Validators(
+std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> Validators(
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> config) {
- std::vector<Ptr<Validator<data::Corpus>>> validators;
+ std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> validators;
auto validMetrics = config->get<std::vector<std::string>>("valid-metrics");
@@ -31,6 +31,15 @@ std::vector<Ptr<Validator<data::Corpus>>> Validators(
} else if(metric == "bleu-detok") {
auto validator = New<BleuValidator>(vocabs, config, true);
validators.push_back(validator);
+ } else if(metric == "accuracy") {
+ auto validator = New<AccuracyValidator>(vocabs, config);
+ validators.push_back(validator);
+ } else if(metric == "bert-lm-accuracy") {
+ auto validator = New<BertAccuracyValidator>(vocabs, config, true);
+ validators.push_back(validator);
+ } else if(metric == "bert-sentence-accuracy") {
+ auto validator = New<BertAccuracyValidator>(vocabs, config, false);
+ validators.push_back(validator);
} else {
LOG_VALID(warn, "Unrecognized validation metric: {}", metric);
}
@@ -38,4 +47,595 @@ std::vector<Ptr<Validator<data::Corpus>>> Validators(
return validators;
}
+
+
+///////////////////////////////////////////////////////////////////////////////////////
+float ValidatorBase::initScore() {
+ return lowerIsBetter_ ? std::numeric_limits<float>::max() : std::numeric_limits<float>::lowest();
+}
+
+void ValidatorBase::actAfterLoaded(TrainingState& state) {
+ if(state.validators[type()]) {
+ lastBest_ = state.validators[type()]["last-best"].as<float>();
+ stalled_ = state.validators[type()]["stalled"].as<size_t>();
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////////////////
+CrossEntropyValidator::CrossEntropyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
+ : Validator(vocabs, options) {
+ createBatchGenerator(/*isTranslating=*/false);
+
+ auto opts = options_->with("inference",
+ true, // @TODO: check if required
+ "cost-type",
+ "ce-sum");
+ // @TODO: remove, only used for saving?
+ builder_ = models::createCriterionFunctionFromOptions(opts, models::usage::scoring);
+}
+
+float CrossEntropyValidator::validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) {
+ auto ctype = options_->get<std::string>("cost-type");
+
+ // @TODO: use with(...) everywhere, this will help with creating immutable options.
+ // Make options const everywhere and get rid of "set"?
+ auto opts = options_->with("inference", true, "cost-type", "ce-sum");
+
+ StaticLoss loss;
+ size_t samples = 0;
+ std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
+
+ auto task = [=, &loss, &samples, &graphQueue](BatchPtr batch) {
+ thread_local Ptr<ExpressionGraph> graph;
+
+ if(!graph) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
+ graph = graphQueue.front();
+ graphQueue.pop_front();
+ }
+
+ auto builder = models::createCriterionFunctionFromOptions(options_, models::usage::scoring);
+
+ builder->clear(graph);
+ auto dynamicLoss = builder->build(graph, batch);
+ graph->forward();
+
+ std::unique_lock<std::mutex> lock(mutex_);
+ loss += *dynamicLoss;
+ samples += batch->size();
+ };
+
+ {
+ threadPool_.reserve(graphs.size());
+ TaskBarrier taskBarrier;
+ for(auto batch : *batchGenerator_)
+ taskBarrier.push_back(threadPool_.enqueue(task, batch));
+ // ~TaskBarrier waits until all are done
+ }
+
+ if(ctype == "perplexity")
+ return std::exp(loss.loss / loss.count);
+ if(ctype == "ce-mean-words")
+ return loss.loss / loss.count;
+ if(ctype == "ce-sum")
+ return loss.loss;
+ else
+ return loss.loss / samples; // @TODO: back-compat, to be removed
+}
+
+///////////////////////////////////////////////////////////////////////////////////////
+AccuracyValidator::AccuracyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
+ : Validator(vocabs, options, /*lowerIsBetter=*/false) {
+ createBatchGenerator(/*isTranslating=*/false);
+
+ // @TODO: remove, only used for saving?
+ builder_ = models::createModelFromOptions(options_, models::usage::raw);
+}
+
+float AccuracyValidator::validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) {
+ size_t correct = 0;
+ size_t totalLabels = 0;
+ std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
+
+ auto task = [=, &correct, &totalLabels, &graphQueue](BatchPtr batch) {
+ thread_local Ptr<ExpressionGraph> graph;
+
+ if(!graph) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
+ graph = graphQueue.front();
+ graphQueue.pop_front();
+ }
+
+ auto builder = models::createModelFromOptions(options_, models::usage::raw);
+
+ builder->clear(graph);
+ Expr logits = builder->build(graph, batch).getLogits();
+ graph->forward();
+
+ std::vector<float> vLogits;
+ logits->val()->get(vLogits);
+
+ const auto& groundTruth = batch->back()->data();
+
+ IndexType cols = logits->shape()[-1];
+
+ size_t thisCorrect = 0;
+ size_t thisLabels = groundTruth.size();
+
+ for(int i = 0; i < thisLabels; ++i) {
+ // CPU-side Argmax
+ Word bestWord = Word::NONE;
+ float bestValue = std::numeric_limits<float>::lowest();
+ for(IndexType j = 0; j < cols; ++j) {
+ float currValue = vLogits[i * cols + j];
+ if(currValue > bestValue) {
+ bestValue = currValue;
+ bestWord = Word::fromWordIndex(j);
+ }
+ }
+ thisCorrect += (size_t)(bestWord == groundTruth[i]);
+ }
+
+ std::unique_lock<std::mutex> lock(mutex_);
+ totalLabels += thisLabels;
+ correct += thisCorrect;
+ };
+
+ {
+ threadPool_.reserve(graphs.size());
+
+ TaskBarrier taskBarrier;
+ for(auto batch : *batchGenerator_)
+ taskBarrier.push_back(threadPool_.enqueue(task, batch));
+
+ // ~TaskBarrier waits until all are done
+ }
+
+ return (float)correct / (float)totalLabels;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////
+BertAccuracyValidator::BertAccuracyValidator(std::vector<Ptr<Vocab>> vocabs,
+ Ptr<Options> options,
+ bool evalMaskedLM)
+ : Validator(vocabs, options, /*lowerIsBetter=*/false), evalMaskedLM_(evalMaskedLM) {
+ createBatchGenerator(/*isTranslating=*/false);
+ // @TODO: remove, only used for saving?
+ builder_ = models::createModelFromOptions(options_, models::usage::raw);
+}
+
+float BertAccuracyValidator::validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) {
+ size_t correct = 0;
+ size_t totalLabels = 0;
+ size_t batchId = 0;
+ std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
+
+ auto task = [=, &correct, &totalLabels, &graphQueue](BatchPtr batch, size_t batchId) {
+ thread_local Ptr<ExpressionGraph> graph;
+
+ if(!graph) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
+ graph = graphQueue.front();
+ graphQueue.pop_front();
+ }
+
+ auto builder = models::createModelFromOptions(options_, models::usage::raw);
+
+ thread_local std::unique_ptr<std::mt19937> engine;
+ if(!engine)
+ engine.reset(new std::mt19937((unsigned int)(Config::seed + batchId)));
+
+ auto bertBatch = New<data::BertBatch>(batch,
+ *engine,
+ options_->get<float>("bert-masking-fraction"),
+ options_->get<std::string>("bert-mask-symbol"),
+ options_->get<std::string>("bert-sep-symbol"),
+ options_->get<std::string>("bert-class-symbol"),
+ options_->get<int>("bert-type-vocab-size"));
+
+ builder->clear(graph);
+ auto classifierStates
+ = std::dynamic_pointer_cast<BertEncoderClassifier>(builder)->apply(graph, bertBatch, true);
+ graph->forward();
+
+ auto maskedLMLogits = classifierStates[0]->getLogProbs();
+ const auto& maskedLMLabels = bertBatch->bertMaskedWords();
+
+ auto sentenceLogits = classifierStates[1]->getLogProbs();
+ const auto& sentenceLabels = bertBatch->back()->data();
+
+ auto count = [=, &correct, &totalLabels](Expr logits, const Words& labels) {
+ IndexType cols = logits->shape()[-1];
+ size_t thisCorrect = 0;
+ size_t thisLabels = labels.size();
+
+ std::vector<float> vLogits;
+ logits->val()->get(vLogits);
+
+ for(int i = 0; i < thisLabels; ++i) {
+ // CPU-side Argmax
+ IndexType bestIndex = 0;
+ float bestValue = std::numeric_limits<float>::lowest();
+ for(IndexType j = 0; j < cols; ++j) {
+ float currValue = vLogits[i * cols + j];
+ if(currValue > bestValue) {
+ bestValue = currValue;
+ bestIndex = j;
+ }
+ }
+ thisCorrect += (size_t)(bestIndex == labels[i].toWordIndex());
+ }
+
+ std::unique_lock<std::mutex> lock(mutex_);
+ totalLabels += thisLabels;
+ correct += thisCorrect;
+ };
+
+ if(evalMaskedLM_)
+ count(maskedLMLogits, maskedLMLabels);
+ else
+ count(sentenceLogits, sentenceLabels);
+ };
+
+ {
+ threadPool_.reserve(graphs.size());
+ TaskBarrier taskBarrier;
+ for(auto batch : *batchGenerator_) {
+ taskBarrier.push_back(threadPool_.enqueue(task, batch, batchId));
+ batchId++;
+ }
+ // ~TaskBarrier waits until all are done
+ }
+
+ return (float)correct / (float)totalLabels;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////
+ScriptValidator::ScriptValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
+ : Validator(vocabs, options, false) {
+ // @TODO: remove, only used for saving?
+ builder_ = models::createModelFromOptions(options_, models::usage::raw);
+
+ ABORT_IF(!options_->hasAndNotEmpty("valid-script-path"),
+ "valid-script metric but no script given");
+}
+
+float ScriptValidator::validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
+ Ptr<const TrainingState> /*ignored*/) {
+ using namespace data;
+ auto model = options_->get<std::string>("model");
+ std::string suffix = model.substr(model.size() - 4);
+ ABORT_IF(suffix != ".npz" && suffix != ".bin", "Unknown model suffix {}", suffix);
+
+ builder_->save(graphs[0], model + ".dev" + suffix, true);
+
+ auto valStr = utils::exec(options_->get<std::string>("valid-script-path"),
+ options_->get<std::vector<std::string>>("valid-script-args"));
+ float val = (float)std::atof(valStr.c_str());
+ updateStalled(graphs, val);
+
+ return val;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////
+TranslationValidator::TranslationValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
+ : Validator(vocabs, options, false), quiet_(options_->get<bool>("quiet-translation")) {
+ // @TODO: remove, only used for saving?
+ builder_ = models::createModelFromOptions(options_, models::usage::translation);
+
+ if(!options_->hasAndNotEmpty("valid-script-path"))
+ LOG_VALID(warn, "No post-processing script given for validating translator");
+
+ createBatchGenerator(/*isTranslating=*/true);
+}
+
+float TranslationValidator::validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
+ Ptr<const TrainingState> state) {
+ using namespace data;
+
+ // Generate batches
+ batchGenerator_->prepare();
+
+ // Create scorer
+ auto model = options_->get<std::string>("model");
+
+ std::vector<Ptr<Scorer>> scorers;
+ for(auto graph : graphs) {
+ auto builder = models::createModelFromOptions(options_, models::usage::translation);
+ Ptr<Scorer> scorer = New<ScorerWrapper>(builder, "", 1.0f, model);
+ scorers.push_back(scorer); // @TODO: should this be done in the contructor?
+ }
+
+ // Set up output file
+ std::string fileName;
+ Ptr<io::TemporaryFile> tempFile;
+
+ if(options_->hasAndNotEmpty("valid-translation-output")) {
+ fileName = options_->get<std::string>("valid-translation-output");
+ // fileName can be a template with fields for training state parameters:
+ fileName = state->fillTemplate(fileName);
+ } else {
+ tempFile.reset(new io::TemporaryFile(options_->get<std::string>("tempdir"), false));
+ fileName = tempFile->getFileName();
+ }
+
+ for(auto graph : graphs)
+ graph->setInference(true);
+
+ if(!quiet_)
+ LOG(info, "Translating validation set...");
+
+ timer::Timer timer;
+ {
+ auto printer = New<OutputPrinter>(options_, vocabs_.back());
+ // @TODO: This can be simplified. If there is no "valid-translation-output", fileName already
+ // contains the name of temporary file that should be used?
+ auto collector = options_->hasAndNotEmpty("valid-translation-output")
+ ? New<OutputCollector>(fileName)
+ : New<OutputCollector>(tempFile->getFileName());
+
+ if(quiet_)
+ collector->setPrintingStrategy(New<QuietPrinting>());
+ else
+ collector->setPrintingStrategy(New<GeometricPrinting>());
+
+ std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
+ std::deque<Ptr<Scorer>> scorerQueue(scorers.begin(), scorers.end());
+ auto task = [=, &graphQueue, &scorerQueue](BatchPtr batch) {
+ thread_local Ptr<ExpressionGraph> graph;
+ thread_local Ptr<Scorer> scorer;
+
+ if(!graph) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
+ graph = graphQueue.front();
+ graphQueue.pop_front();
+
+ ABORT_IF(scorerQueue.empty(), "Asking for scorer, but none left on queue");
+ scorer = scorerQueue.front();
+ scorerQueue.pop_front();
+ }
+
+ auto search = New<BeamSearch>(options_, std::vector<Ptr<Scorer>>{scorer}, vocabs_.back());
+ auto histories = search->search(graph, batch);
+
+ for(auto history : histories) {
+ std::stringstream best1;
+ std::stringstream bestn;
+ printer->print(history, best1, bestn);
+ collector->Write(
+ (long)history->getLineNum(), best1.str(), bestn.str(), options_->get<bool>("n-best"));
+ }
+ };
+
+ threadPool_.reserve(graphs.size());
+ TaskBarrier taskBarrier;
+ for(auto batch : *batchGenerator_)
+ taskBarrier.push_back(threadPool_.enqueue(task, batch));
+ // ~TaskBarrier waits until all are done
+ }
+
+ if(!quiet_)
+ LOG(info, "Total translation time: {:.5f}s", timer.elapsed());
+
+ for(auto graph : graphs)
+ graph->setInference(false);
+
+ float val = 0.0f;
+
+ // Run post-processing script if given
+ if(options_->hasAndNotEmpty("valid-script-path")) {
+ // auto command = options_->get<std::string>("valid-script-path") + " " + fileName;
+ // auto valStr = utils::exec(command);
+ auto valStr = utils::exec(options_->get<std::string>("valid-script-path"),
+ options_->get<std::vector<std::string>>("valid-script-args"),
+ fileName);
+ val = (float)std::atof(valStr.c_str());
+ updateStalled(graphs, val);
+ }
+
+ return val;
+};
+
+///////////////////////////////////////////////////////////////////////////////////////
+BleuValidator::BleuValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, bool detok)
+ : Validator(vocabs, options, false),
+ detok_(detok),
+ quiet_(options_->get<bool>("quiet-translation")) {
+ // @TODO: remove, only used for saving?
+ builder_ = models::createModelFromOptions(options_, models::usage::translation);
+
+ // @TODO: replace bleu-detok by a separate parameter to enable (various forms of) detok
+ auto vocab = vocabs_.back();
+ ABORT_IF(detok_ && vocab->type() != "SentencePieceVocab" && vocab->type() != "FactoredVocab",
+ "Detokenizing BLEU validator expects the target vocabulary to be SentencePieceVocab or "
+ "FactoredVocab. "
+ "Current vocabulary type is {}",
+ vocab->type());
+
+ createBatchGenerator(/*isTranslating=*/true);
+}
+
+float BleuValidator::validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
+ Ptr<const TrainingState> state) {
+ using namespace data;
+
+ // Generate batches
+ batchGenerator_->prepare();
+
+ // Create scorer
+ auto model = options_->get<std::string>("model");
+
+ // @TODO: check if required - Temporary options for translation
+ auto mopts = New<Options>();
+ mopts->merge(options_);
+ mopts->set("inference", true);
+
+ std::vector<Ptr<Scorer>> scorers;
+ for(auto graph : graphs) {
+ auto builder = models::createModelFromOptions(options_, models::usage::translation);
+ Ptr<Scorer> scorer = New<ScorerWrapper>(builder, "", 1.0f, model);
+ scorers.push_back(scorer);
+ }
+
+ for(auto graph : graphs)
+ graph->setInference(true);
+
+ if(!quiet_)
+ LOG(info, "Translating validation set...");
+
+ // 0: 1-grams matched, 1: 1-grams total,
+ // ...,
+ // 6: 4-grams matched, 7: 4-grams total,
+ // 8: reference length
+ std::vector<float> stats(9, 0.f);
+
+ timer::Timer timer;
+ {
+ auto printer = New<OutputPrinter>(options_, vocabs_.back());
+
+ Ptr<OutputCollector> collector;
+ if(options_->hasAndNotEmpty("valid-translation-output")) {
+ auto fileName = options_->get<std::string>("valid-translation-output");
+ // fileName can be a template with fields for training state parameters:
+ fileName = state->fillTemplate(fileName);
+ collector = New<OutputCollector>(fileName); // for debugging
+ } else {
+ collector = New<OutputCollector>(/* null */); // don't print, but log
+ }
+
+ if(quiet_)
+ collector->setPrintingStrategy(New<QuietPrinting>());
+ else
+ collector->setPrintingStrategy(New<GeometricPrinting>());
+
+ std::deque<Ptr<ExpressionGraph>> graphQueue(graphs.begin(), graphs.end());
+ std::deque<Ptr<Scorer>> scorerQueue(scorers.begin(), scorers.end());
+ auto task = [=, &stats, &graphQueue, &scorerQueue](BatchPtr batch) {
+ thread_local Ptr<ExpressionGraph> graph;
+ thread_local Ptr<Scorer> scorer;
+
+ if(!graph) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ ABORT_IF(graphQueue.empty(), "Asking for graph, but none left on queue");
+ graph = graphQueue.front();
+ graphQueue.pop_front();
+
+ ABORT_IF(scorerQueue.empty(), "Asking for scorer, but none left on queue");
+ scorer = scorerQueue.front();
+ scorerQueue.pop_front();
+ }
+
+ auto search = New<BeamSearch>(options_, std::vector<Ptr<Scorer>>{scorer}, vocabs_.back());
+ auto histories = search->search(graph, batch);
+
+ size_t no = 0;
+ std::lock_guard<std::mutex> statsLock(mutex_);
+ for(auto history : histories) {
+ auto result = history->top();
+ const auto& words = std::get<0>(result);
+ updateStats(stats, words, batch, no, vocabs_.back()->getEosId());
+
+ std::stringstream best1;
+ std::stringstream bestn;
+ printer->print(history, best1, bestn);
+ collector->Write((long)history->getLineNum(),
+ best1.str(),
+ bestn.str(),
+ /*nbest=*/false);
+ no++;
+ }
+ };
+
+ threadPool_.reserve(graphs.size());
+ TaskBarrier taskBarrier;
+ for(auto batch : *batchGenerator_)
+ taskBarrier.push_back(threadPool_.enqueue(task, batch));
+ // ~TaskBarrier waits until all are done
+ }
+
+ if(!quiet_)
+ LOG(info, "Total translation time: {:.5f}s", timer.elapsed());
+
+ for(auto graph : graphs)
+ graph->setInference(false);
+
+ float val = calcBLEU(stats);
+ updateStalled(graphs, val);
+
+ return val;
+}
+
+std::vector<std::string> BleuValidator::decode(const Words& words, bool addEOS) {
+ auto vocab = vocabs_.back();
+ auto tokenString = vocab->surfaceForm(words); // detokenize to surface form
+ tokenString = tokenize(tokenString); // tokenize according to SacreBLEU rules
+ tokenString
+ = tokenizeContinuousScript(tokenString); // CJT scripts only: further break into characters
+ auto tokens = utils::splitAny(tokenString, " ");
+ if(addEOS)
+ tokens.push_back("</s>");
+ return tokens;
+}
+
+void BleuValidator::updateStats(std::vector<float>& stats,
+ const Words& cand,
+ const Ptr<data::Batch> batch,
+ size_t no,
+ Word eos) {
+ auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
+ auto subBatch = corpusBatch->back();
+
+ size_t size = subBatch->batchSize();
+ size_t width = subBatch->batchWidth();
+
+ Words ref; // fill ref
+ for(size_t i = 0; i < width; ++i) {
+ Word w = subBatch->data()[i * size + no];
+ if(w == eos)
+ break;
+ ref.push_back(w);
+ }
+
+ bool detok = detok_;
+#if 1 // hack for now, to get this feature when running under Flo
+ // Problem is that Flo pieces that pass 'bleu' do not know whether vocab is factored,
+ // hence cannot select 'bleu-detok'.
+ // @TODO: We agreed that we will replace bleu-detok by bleu with an additional
+ // parameter to select the detokenization method, which will default to detok for
+ // FactoredSegmenter, and no-op for base vocab.
+ if(vocabs_.back()->type() == "FactoredVocab") {
+ if(!quiet_)
+ LOG_ONCE(info, "[valid] FactoredVocab implies using detokenized BLEU");
+ detok = true; // always use bleu-detok
+ }
+#endif
+ if(detok) { // log the first detokenized string
+ LOG_ONCE(info, "[valid] First sentence's tokens after detokenization, as scored:");
+ LOG_ONCE(info, "[valid] Hyp: {}", utils::join(decode(cand, /*addEOS=*/true)));
+ LOG_ONCE(info, "[valid] Ref: {}", utils::join(decode(ref)));
+ }
+ if(detok)
+ updateStats(stats, decode(cand, /*addEOS=*/true), decode(ref));
+ else
+ updateStats(stats, cand, ref);
+}
+
+float BleuValidator::calcBLEU(const std::vector<float>& stats) {
+ float logbleu = 0;
+ for(int i = 0; i < 8; i += 2) {
+ if(stats[i] == 0.f)
+ return 0.f;
+ logbleu += std::log(stats[i] / stats[i + 1]);
+ }
+
+ logbleu /= 4.f;
+
+ float brev_penalty = 1.f - std::max(stats[8] / stats[1], 1.f);
+ return std::exp(logbleu + brev_penalty) * 100;
+}
+
} // namespace marian
diff --git a/src/training/validator.h b/src/training/validator.h
index efed50cb..1658dff3 100755
--- a/src/training/validator.h
+++ b/src/training/validator.h
@@ -15,6 +15,7 @@
#include "translator/output_collector.h"
#include "translator/output_printer.h"
#include "translator/scorers.h"
+#include "models/bert.h"
#include <cstdio>
#include <cstdlib>
@@ -35,29 +36,23 @@ protected:
public:
ValidatorBase(bool lowerIsBetter) : lowerIsBetter_(lowerIsBetter), lastBest_{initScore()} {}
+ virtual ~ValidatorBase() {}
- virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs) = 0;
+ virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
+ Ptr<const TrainingState> state) = 0;
virtual std::string type() = 0;
float lastBest() { return lastBest_; }
size_t stalled() { return stalled_; }
- virtual float initScore() {
- return lowerIsBetter_ ? std::numeric_limits<float>::max()
- : std::numeric_limits<float>::lowest();
- }
-
- virtual void actAfterLoaded(TrainingState& state) override {
- if(state.validators[type()]) {
- lastBest_ = state.validators[type()]["last-best"].as<float>();
- stalled_ = state.validators[type()]["stalled"].as<size_t>();
- }
- }
+ virtual float initScore();
+ virtual void actAfterLoaded(TrainingState& state) override;
};
-template <class DataSet>
+template <class DataSet, class BuilderType> // @TODO: BuilderType doesn't really serve a purpose here? Review and remove.
class Validator : public ValidatorBase {
public:
+ virtual ~Validator() {}
Validator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, bool lowerIsBetter = true)
: ValidatorBase(lowerIsBetter),
vocabs_(vocabs),
@@ -65,50 +60,38 @@ public:
options_(New<Options>(options->clone())) {
// set options common for all validators
options_->set("inference", true);
- if(options_->has("valid-max-length"))
+ options_->set("shuffle", "none"); // don't shuffle validation sets
+
+ if(options_->has("valid-max-length")) {
options_->set("max-length", options_->get<size_t>("valid-max-length"));
+ options_->set("max-length-crop", true); // @TODO: make this configureable
+ }
if(options_->has("valid-mini-batch"))
options_->set("mini-batch", options_->get<size_t>("valid-mini-batch"));
options_->set("mini-batch-sort", "src");
options_->set("maxi-batch", 10);
}
-protected:
- void createBatchGenerator(bool isTranslating) {
- // Create the BatchGenerator. Note that ScriptValidator does not use batchGenerator_.
-
- // Update validation options
- auto opts = New<Options>();
- opts->merge(options_);
- opts->set("inference", true);
-
- if (isTranslating) { // TranslationValidator and BleuValidator
- opts->set("max-length", 1000);
- opts->set("mini-batch", options_->get<int>("valid-mini-batch"));
- opts->set("maxi-batch", 10);
- }
- else { // CrossEntropyValidator
- opts->set("max-length", options_->get<size_t>("valid-max-length"));
- if(options_->has("valid-mini-batch"))
- opts->set("mini-batch", options_->get<size_t>("valid-mini-batch"));
- opts->set("mini-batch-sort", "src");
- }
+ typedef typename DataSet::batch_ptr BatchPtr;
+protected:
+ // Create the BatchGenerator. Note that ScriptValidator does not use batchGenerator_.
+ void createBatchGenerator(bool /*isTranslating*/) {
// Create corpus
auto validPaths = options_->get<std::vector<std::string>>("valid-sets");
auto corpus = New<DataSet>(validPaths, vocabs_, options_);
// Create batch generator
- batchGenerator_ = New<data::BatchGenerator<DataSet>>(corpus, opts);
+ batchGenerator_ = New<data::BatchGenerator<DataSet>>(corpus, options_);
}
public:
- virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
-
+ virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
+ Ptr<const TrainingState> /*ignored*/) override {
for(auto graph : graphs)
graph->setInference(true);
- batchGenerator_->prepare(false);
+ batchGenerator_->prepare();
// Validate on batches
float val = validateBG(graphs);
@@ -123,7 +106,7 @@ public:
protected:
std::vector<Ptr<Vocab>> vocabs_;
Ptr<Options> options_;
- Ptr<models::ModelBase> builder_;
+ Ptr<BuilderType> builder_; // @TODO: remove, this is not guaranteed to be state-free, hence not thread-safe, but we are using validators with multi-threading.
Ptr<data::BatchGenerator<DataSet>> batchGenerator_;
virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>&)
@@ -137,108 +120,72 @@ protected:
lastBest_ = val;
if(options_->get<bool>("keep-best"))
keepBest(graphs);
- } else {
+ } else /* if (lastBest_ != val) */ { // (special case 0 at start) @TODO: needed? Seems stall count gets reset each time it does improve. If not needed, remove "if(...)" again.
stalled_++;
}
}
virtual void keepBest(const std::vector<Ptr<ExpressionGraph>>& graphs) {
auto model = options_->get<std::string>("model");
- builder_->save(graphs[0], model + ".best-" + type() + ".npz", true);
+ std::string suffix = model.substr(model.size() - 4);
+ ABORT_IF(suffix != ".npz" && suffix != ".bin", "Unknown model suffix {}", suffix);
+
+ builder_->save(graphs[0], model + ".best-" + type() + suffix, true);
}
};
-class CrossEntropyValidator : public Validator<data::Corpus> {
+class CrossEntropyValidator : public Validator<data::Corpus, models::ICriterionFunction> {
+ using Validator::BatchPtr;
+
public:
- CrossEntropyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
- : Validator(vocabs, options) {
- createBatchGenerator(/*isTranslating=*/false);
-
- // @TODO: check if this is required.
- Ptr<Options> opts = New<Options>();
- opts->merge(options);
- opts->set("inference", true);
- opts->set("cost-type", "ce-sum");
- builder_ = models::from_options(opts, models::usage::scoring);
- }
+ CrossEntropyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options);
+ virtual ~CrossEntropyValidator() {}
std::string type() override { return options_->get<std::string>("cost-type"); }
protected:
- virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
-
- auto ctype = options_->get<std::string>("cost-type");
- options_->set("cost-type", "ce-sum");
-
- float cost = 0;
- size_t samples = 0;
- size_t words = 0;
- size_t batchId = 0;
-
- {
- threadPool_.reserve(graphs.size());
-
- TaskBarrier taskBarrier;
- for(auto batch : *batchGenerator_) {
- auto task = [=, &cost, &samples, &words](size_t id) {
- thread_local Ptr<ExpressionGraph> graph;
- thread_local auto builder = models::from_options(options_, models::usage::scoring);
+ virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) override;
+};
- if(!graph) {
- graph = graphs[id % graphs.size()];
- }
+// Used for validating with classifiers. Compute prediction accuracy versus ground truth for a set of classes
+class AccuracyValidator : public Validator<data::Corpus, models::IModel> {
+public:
+ AccuracyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options);
+ virtual ~AccuracyValidator() {}
- builder->clear(graph);
- auto costNode = builder->build(graph, batch);
- graph->forward();
+ std::string type() override { return "accuracy"; }
- std::unique_lock<std::mutex> lock(mutex_);
- cost += costNode->scalar();
- samples += batch->size();
- words += batch->back()->batchWords();
- };
+protected:
+ virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) override;
+};
- taskBarrier.push_back(threadPool_.enqueue(task, batchId));
- batchId++;
- }
- // ~TaskBarrier waits until all are done
- }
+class BertAccuracyValidator : public Validator<data::Corpus, models::IModel> {
+private:
+ bool evalMaskedLM_{true};
- // get back to the original cost type
- options_->set("cost-type", ctype);
+public:
+ BertAccuracyValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, bool evalMaskedLM);
+ virtual ~BertAccuracyValidator() {}
- if(ctype == "perplexity")
- return std::exp(cost / words);
- if(ctype == "ce-mean-words")
- return cost / words;
- if(ctype == "ce-sum")
- return cost;
+ std::string type() override {
+ if(evalMaskedLM_)
+ return "bert-lm-accuracy";
else
- return cost / samples;
+ return "bert-sentence-accuracy";
}
-};
-
-class ScriptValidator : public Validator<data::Corpus> {
-public:
- ScriptValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
- : Validator(vocabs, options, false) {
- builder_ = models::from_options(options_, models::usage::raw);
- ABORT_IF(!options_->has("valid-script-path"), "valid-script metric but no script given");
- }
+protected:
+ virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>& graphs) override;
+};
- virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
- using namespace data;
- auto model = options_->get<std::string>("model");
- builder_->save(graphs[0], model + ".dev.npz", true);
- auto command = options_->get<std::string>("valid-script-path");
- auto valStr = utils::exec(command);
- float val = (float)std::atof(valStr.c_str());
- updateStalled(graphs, val);
+class ScriptValidator : public Validator<data::Corpus, models::IModel> {
+public:
+ ScriptValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options);
+ virtual ~ScriptValidator() {}
- return val;
- };
+ virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
+ Ptr<const TrainingState> /*ignored*/) override;
std::string type() override { return "valid-script"; }
@@ -248,125 +195,14 @@ protected:
}
};
-class TranslationValidator : public Validator<data::Corpus> {
+// validator that translates and computes BLEU (or any metric) with an external script
+class TranslationValidator : public Validator<data::Corpus, models::IModel> {
public:
- TranslationValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options)
- : Validator(vocabs, options, false),
- quiet_(options_->get<bool>("quiet-translation")) {
- builder_ = models::from_options(options_, models::usage::translation);
-
- if(!options_->has("valid-script-path"))
- LOG_VALID(warn,
- "No post-processing script given for validating translator");
-
- createBatchGenerator(/*isTranslating=*/true);
- }
-
- virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
- using namespace data;
-
- // Generate batches
- batchGenerator_->prepare(false);
-
- // Create scorer
- auto model = options_->get<std::string>("model");
-
- // Temporary options for translation
- auto mopts = New<Options>();
- mopts->merge(options_);
- mopts->set("inference", true);
-
- std::vector<Ptr<Scorer>> scorers;
- for(auto graph : graphs) {
- auto builder = models::from_options(options_, models::usage::translation);
- Ptr<Scorer> scorer = New<ScorerWrapper>(builder, "", 1.0f, model);
- scorers.push_back(scorer);
- }
-
- // Set up output file
- std::string fileName;
- Ptr<io::TemporaryFile> tempFile;
-
- if(options_->has("valid-translation-output")) {
- fileName = options_->get<std::string>("valid-translation-output");
- } else {
- tempFile.reset(new io::TemporaryFile(options_->get<std::string>("tempdir"), false));
- fileName = tempFile->getFileName();
- }
-
- for(auto graph : graphs)
- graph->setInference(true);
-
- if(!quiet_)
- LOG(info, "Translating validation set...");
-
- timer::Timer timer;
- {
- auto printer = New<OutputPrinter>(options_, vocabs_.back());
- auto collector = options_->has("valid-translation-output")
- ? New<OutputCollector>(fileName)
- : New<OutputCollector>(*tempFile);
-
- if(quiet_)
- collector->setPrintingStrategy(New<QuietPrinting>());
- else
- collector->setPrintingStrategy(New<GeometricPrinting>());
-
- threadPool_.reserve(graphs.size());
-
- size_t sentenceId = 0;
- TaskBarrier taskBarrier;
- for(auto batch : *batchGenerator_) {
- auto task = [=](size_t id) {
- thread_local Ptr<ExpressionGraph> graph;
- thread_local Ptr<Scorer> scorer;
-
- if(!graph) {
- graph = graphs[id % graphs.size()];
- scorer = scorers[id % graphs.size()];
- }
-
- auto search = New<BeamSearch>(options_,
- std::vector<Ptr<Scorer>>{scorer},
- vocabs_.back()->getEosId(),
- vocabs_.back()->getUnkId());
- auto histories = search->search(graph, batch);
-
- for(auto history : histories) {
- std::stringstream best1;
- std::stringstream bestn;
- printer->print(history, best1, bestn);
- collector->Write((long)history->GetLineNum(),
- best1.str(),
- bestn.str(),
- options_->get<bool>("n-best"));
- }
- };
-
- taskBarrier.push_back(threadPool_.enqueue(task, sentenceId));
- sentenceId++;
- }
- // ~TaskBarrier waits until all are done
- }
+ TranslationValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options);
+ virtual ~TranslationValidator() {}
- if(!quiet_)
- LOG(info, "Total translation time: {:.5f}s", timer.elapsed());
-
- for(auto graph : graphs)
- graph->setInference(false);
-
- float val = 0.0f;
-
- // Run post-processing script if given
- if(options_->has("valid-script-path")) {
- auto command = options_->get<std::string>("valid-script-path") + " " + fileName;
- auto valStr = utils::exec(command);
- val = (float)std::atof(valStr.c_str());
- updateStalled(graphs, val);
- }
-
- return val;
- };
+ virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
+ Ptr<const TrainingState> state) override;
std::string type() override { return "translation"; }
@@ -378,146 +214,22 @@ protected:
}
};
+// validator that translates and computes BLEU internally, with or without decoding
// @TODO: combine with TranslationValidator (above) to avoid code duplication
-class BleuValidator : public Validator<data::Corpus> {
-private:
- bool detok_{false};
-
+class BleuValidator : public Validator<data::Corpus, models::IModel> {
public:
- BleuValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, bool detok = false)
- : Validator(vocabs, options, false),
- detok_(detok),
- quiet_(options_->get<bool>("quiet-translation")) {
- builder_ = models::from_options(options_, models::usage::translation);
-
-#ifdef USE_SENTENCEPIECE
- auto vocab = vocabs_.back();
- ABORT_IF(detok_ && vocab->type() != "SentencePieceVocab",
- "Detokenizing BLEU validator expects the target vocabulary to be SentencePieceVocab. "
- "Current vocabulary type is {}", vocab->type());
-#else
- ABORT_IF(detok_,
- "Detokenizing BLEU validator expects the target vocabulary to be SentencePieceVocab. "
- "Marian has not been compiled with SentencePieceVocab support");
-#endif
-
- createBatchGenerator(/*isTranslating=*/true);
- }
-
- virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs) override {
- using namespace data;
-
- // Generate batches
- batchGenerator_->prepare(false);
-
- // Create scorer
- auto model = options_->get<std::string>("model");
-
- // @TODO: check if required - Temporary options for translation
- auto mopts = New<Options>();
- mopts->merge(options_);
- mopts->set("inference", true);
-
- std::vector<Ptr<Scorer>> scorers;
- for(auto graph : graphs) {
- auto builder = models::from_options(options_, models::usage::translation);
- Ptr<Scorer> scorer = New<ScorerWrapper>(builder, "", 1.0f, model);
- scorers.push_back(scorer);
- }
-
- for(auto graph : graphs)
- graph->setInference(true);
-
- if(!quiet_)
- LOG(info, "Translating validation set...");
-
- // 0: 1-grams matched, 1: 1-grams total,
- // ...,
- // 6: 4-grams matched, 7: 4-grams total,
- // 8: reference length
- std::vector<float> stats(9, 0.f);
-
- timer::Timer timer;
- {
- auto printer = New<OutputPrinter>(options_, vocabs_.back());
-
- Ptr<OutputCollector> collector;
- if(options_->has("valid-translation-output")) {
- auto fileName = options_->get<std::string>("valid-translation-output");
- collector = New<OutputCollector>(fileName); // for debugging
- }
- else {
- collector = New<OutputCollector>(/* null */); // don't print, but log
- }
+ BleuValidator(std::vector<Ptr<Vocab>> vocabs, Ptr<Options> options, bool detok = false);
+ virtual ~BleuValidator() {}
- if(quiet_)
- collector->setPrintingStrategy(New<QuietPrinting>());
- else
- collector->setPrintingStrategy(New<GeometricPrinting>());
-
- threadPool_.reserve(graphs.size());
-
- size_t sentenceId = 0;
- TaskBarrier taskBarrier;
- for(auto batch : *batchGenerator_) {
- auto task = [=, &stats](size_t id) {
- thread_local Ptr<ExpressionGraph> graph;
- thread_local Ptr<Scorer> scorer;
-
- if(!graph) {
- graph = graphs[id % graphs.size()];
- scorer = scorers[id % graphs.size()];
- }
-
- auto search = New<BeamSearch>(options_,
- std::vector<Ptr<Scorer>>{scorer},
- vocabs_.back()->getEosId(),
- vocabs_.back()->getUnkId());
- auto histories = search->search(graph, batch);
-
- size_t no = 0;
- std::lock_guard<std::mutex> statsLock(mutex_);
- for(auto history : histories) {
- auto result = history->Top();
- const auto& words = std::get<0>(result);
- updateStats(stats, words, batch, no, vocabs_.back()->getEosId());
-
- std::stringstream best1;
- std::stringstream bestn;
- printer->print(history, best1, bestn);
- collector->Write((long)history->GetLineNum(),
- best1.str(),
- bestn.str(),
- /*nbest=*/ false);
- no++;
- }
- };
-
- taskBarrier.push_back(threadPool_.enqueue(task, sentenceId));
- sentenceId++;
- }
- // ~TaskBarrier waits until all are done
- }
-
- if(!quiet_)
- LOG(info, "Total translation time: {:.5f}s", timer.elapsed());
-
- for(auto graph : graphs)
- graph->setInference(false);
-
- float val = calcBLEU(stats);
- updateStalled(graphs, val);
-
- return val;
- };
+ virtual float validate(const std::vector<Ptr<ExpressionGraph>>& graphs,
+ Ptr<const TrainingState> state) override;
+ // @TODO: why do we return this string, but not pass it to the constructor?
std::string type() override { return detok_ ? "bleu-detok" : "bleu"; }
protected:
- bool quiet_{false};
-
// Tokenizer function adapted from multi-bleu-detok.pl, corresponds to sacreBLEU.py
- std::string tokenize(const std::string& text) {
+ static std::string tokenize(const std::string& text) {
std::string normText = text;
// language-independent part:
@@ -542,14 +254,30 @@ protected:
return normText;
}
- std::vector<std::string> decode(const Words& words, bool addEOS = false) {
- auto vocab = vocabs_.back();
- auto tokens = utils::splitAny(tokenize(vocab->decode(words)), " ");
- if(addEOS)
- tokens.push_back("</s>");
- return tokens;
+ static std::string tokenizeContinuousScript(const std::string& sUTF8) {
+ // We want BLEU-like scores that are comparable across different tokenization schemes.
+ // For continuous scripts (Chinese, Japanese, Thai), we would need a language-specific
+ // statistical word segmenter, which is outside the scope of Marian. As a practical
+ // compromise, we segment continuous-script sequences into individual characters, while
+ // leaving Western scripts as words. This way we can use the same settings for Western
+ // languages, where Marian would report SacreBLEU scores, and Asian languages, where
+ // scores are not standard but internally comparable across tokenization schemes.
+ // @TODO: Check what sacrebleu.py is doing, and whether we can replicate that here faithfully.
+ auto in = utils::utf8ToUnicodeString(sUTF8);
+ auto out = in.substr(0, 0); // (out should be same type as in, don't want to bother with exact type)
+ for (auto c : in) {
+ auto isCS = utils::isContinuousScript(c);
+ if (isCS) // surround continuous-script chars by spaces on each side
+ out.push_back(' '); // (duplicate spaces are ignored when splitting later)
+ out.push_back(c);
+ if (isCS)
+ out.push_back(' ');
+ }
+ return utils::utf8FromUnicodeString(out);
}
+ std::vector<std::string> decode(const Words& words, bool addEOS = false);
+
// Update document-wide sufficient statistics for BLEU with single sentence n-gram stats.
template <typename T>
void updateStats(std::vector<float>& stats,
@@ -593,45 +321,17 @@ protected:
const Words& cand,
const Ptr<data::Batch> batch,
size_t no,
- Word eos) {
-
- auto corpusBatch = std::static_pointer_cast<data::CorpusBatch>(batch);
- auto subBatch = corpusBatch->back();
-
- size_t size = subBatch->batchSize();
- size_t width = subBatch->batchWidth();
-
- Words ref; // fill ref
- for(size_t i = 0; i < width; ++i) {
- Word w = subBatch->data()[i * size + no];
- if(w == eos)
- break;
- ref.push_back(w);
- }
-
- if(detok_)
- updateStats(stats, decode(cand, /*addEOS=*/ true), decode(ref));
- else
- updateStats(stats, cand, ref);
- }
-
- float calcBLEU(const std::vector<float>& stats) {
- float logbleu = 0;
- for(int i = 0; i < 8; i += 2) {
- if(stats[i] == 0.f)
- return 0.f;
- logbleu += std::log(stats[i] / stats[i + 1]);
- }
+ Word eos);
- logbleu /= 4.f;
-
- float brev_penalty = 1.f - std::max(stats[8] / stats[1], 1.f);
- return std::exp(logbleu + brev_penalty) * 100;
- }
+ float calcBLEU(const std::vector<float>& stats);
virtual float validateBG(const std::vector<Ptr<ExpressionGraph>>& /*graphs*/) override {
return 0;
}
+
+private:
+ bool detok_;
+ bool quiet_{ false };
};
/**
@@ -645,7 +345,7 @@ protected:
*
* @return Vector of validator objects
*/
-std::vector<Ptr<Validator<data::Corpus>>> Validators(
+std::vector<Ptr<ValidatorBase/*<data::Corpus>*/>> Validators(
std::vector<Ptr<Vocab>> vocabs,
Ptr<Options> config);
} // namespace marian
diff --git a/src/translator/beam_search.h b/src/translator/beam_search.h
index 798312fe..4ea17d8b 100755
--- a/src/translator/beam_search.h
+++ b/src/translator/beam_search.h
@@ -4,6 +4,7 @@
#include "marian.h"
#include "translator/history.h"
#include "translator/scorers.h"
+#include "data/factored_vocab.h"
#include "translator/helpers.h"
#include "translator/nth_element.h"
@@ -15,100 +16,198 @@ private:
Ptr<Options> options_;
std::vector<Ptr<Scorer>> scorers_;
size_t beamSize_;
- Word trgEosId_ = (Word)-1;
- Word trgUnkId_ = (Word)-1;
+ Ptr<const Vocab> trgVocab_;
+
+ const float INVALID_PATH_SCORE = std::numeric_limits<float>::lowest(); // @TODO: observe this closely
+ const bool PURGE_BATCH = true; // @TODO: diagnostic, to-be-removed once confirmed there are no issues.
public:
BeamSearch(Ptr<Options> options,
const std::vector<Ptr<Scorer>>& scorers,
- Word trgEosId,
- Word trgUnkId = -1)
+ const Ptr<const Vocab> trgVocab)
: options_(options),
scorers_(scorers),
- beamSize_(options_->has("beam-size")
- ? options_->get<size_t>("beam-size")
- : 3),
- trgEosId_(trgEosId),
- trgUnkId_(trgUnkId) {}
-
- Beams toHyps(const std::vector<unsigned int> keys,
- const std::vector<float> pathScores,
- size_t vocabSize,
+ beamSize_(options_->get<size_t>("beam-size")),
+ trgVocab_(trgVocab) {}
+
+ // combine new expandedPathScores and previous beams into new set of beams
+ Beams toHyps(const std::vector<unsigned int>& nBestKeys, // [currentDimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened
+ const std::vector<float>& nBestPathScores, // [currentDimBatch, beamSize] flattened
+ const size_t nBestBeamSize, // for interpretation of nBestKeys
+ const size_t vocabSize, // ditto.
const Beams& beams,
- std::vector<Ptr<ScorerState>>& states,
- size_t beamSize,
- bool first,
- Ptr<data::CorpusBatch> batch) {
- Beams newBeams(beams.size());
+ const std::vector<Ptr<ScorerState /*const*/>>& states,
+ Ptr<data::CorpusBatch /*const*/> batch, // for alignments only
+ Ptr<FactoredVocab/*const*/> factoredVocab, size_t factorGroup,
+ const std::vector<bool>& dropBatchEntries, // [origDimBatch] - empty source batch entries are marked with true, should be cleared after first use.
+ const std::vector<IndexType>& batchIdxMap) const { // [origBatchIdx -> currentBatchIdx]
+ std::vector<float> align; // collects alignment information from the last executed time step
+ if(options_->hasAndNotEmpty("alignment") && factorGroup == 0)
+ align = scorers_[0]->getAlignment(); // [beam depth * max src length * current batch size] -> P(s|t); use alignments from the first scorer, even if ensemble,
+
+ const auto origDimBatch = beams.size(); // see function search for definition of origDimBatch and currentDimBatch etc.
+ Beams newBeams(origDimBatch); // return value of this function goes here. There are always origDimBatch beams.
+
+ // create a reverse batchMap to obtain original batchIdx in the starting batch size
+ // and calculate the current batch size based on non-empty beams
+ std::vector<IndexType> reverseBatchIdxMap; // empty if not purging batch entries
+ size_t currentDimBatch = beams.size();
+ if(PURGE_BATCH) {
+ reverseBatchIdxMap.resize(batchIdxMap.size()); // adjust size if doing batch purging.
+ currentDimBatch = 0;
+ for(int i = 0; i < batchIdxMap.size(); ++i) {
+ reverseBatchIdxMap[batchIdxMap[i]] = i; // reverse batch index mapping, multiple occurences get overwritten with the last one,
+ // which is expected due to down-shifting
+ if(!beams[i].empty())
+ currentDimBatch++;
+ }
+ }
- std::vector<float> align;
- if(options_->has("alignment"))
- // Use alignments from the first scorer, even if ensemble
- align = scorers_[0]->getAlignment();
-
- for(size_t i = 0; i < keys.size(); ++i) {
- // Keys contains indices to vocab items in the entire beam.
- // Values can be between 0 and beamSize * vocabSize.
- Word embIdx = (Word)(keys[i] % vocabSize);
- auto beamIdx = i / beamSize;
-
- // Retrieve short list for final softmax (based on words aligned
- // to source sentences). If short list has been set, map the indices
- // in the sub-selected vocabulary matrix back to their original positions.
- auto shortlist = scorers_[0]->getShortlist();
- if(shortlist)
- embIdx = shortlist->reverseMap(embIdx); // @TODO: should reverseMap accept a size_t or a Word?
-
- if(newBeams[beamIdx].size() < beams[beamIdx].size()) {
- auto& beam = beams[beamIdx];
- auto& newBeam = newBeams[beamIdx];
-
- auto hypIdx = (IndexType)(keys[i] / vocabSize);
- float pathScore = pathScores[i];
-
- auto hypIdxTrans
- = IndexType((hypIdx / beamSize) + (hypIdx % beamSize) * beams.size());
- if(first)
- hypIdxTrans = hypIdx;
-
- size_t beamHypIdx = hypIdx % beamSize;
- if(beamHypIdx >= (int)beam.size())
- beamHypIdx = beamHypIdx % beam.size();
-
- if(first)
- beamHypIdx = 0;
-
- auto hyp = New<Hypothesis>(beam[beamHypIdx], embIdx, hypIdxTrans, pathScore);
-
- // Set score breakdown for n-best lists
- if(options_->get<bool>("n-best")) {
- std::vector<float> breakDown(states.size(), 0);
- beam[beamHypIdx]->GetScoreBreakdown().resize(states.size(), 0);
- for(size_t j = 0; j < states.size(); ++j) {
- size_t key = embIdx + hypIdxTrans * vocabSize;
- breakDown[j] = states[j]->breakDown(key)
- + beam[beamHypIdx]->GetScoreBreakdown()[j];
- }
- hyp->GetScoreBreakdown() = breakDown;
+ for(size_t i = 0; i < nBestKeys.size(); ++i) { // [currentDimBatch, beamSize] flattened
+ // Keys encode batchIdx, beamHypIdx, and word index in the entire beam.
+ // They can be between 0 and (vocabSize * nBestBeamSize * batchSize)-1.
+ // (beamHypIdx refers to the GPU tensors, *not* the beams[] array; they are not the same in case of purging)
+ const auto key = nBestKeys[i];
+
+ // decompose key into individual indices (batchIdx, beamHypIdx, wordIdx)
+ const auto beamHypIdx = (key / vocabSize) % nBestBeamSize;
+ const auto currentBatchIdx = (key / vocabSize) / nBestBeamSize;
+ const auto origBatchIdx = reverseBatchIdxMap.empty() ? currentBatchIdx : reverseBatchIdxMap[currentBatchIdx]; // map currentBatchIdx back into original position within starting maximal batch size, required to find correct beam
+
+ bool dropHyp = !dropBatchEntries.empty() && dropBatchEntries[origBatchIdx] && factorGroup == 0;
+
+ WordIndex wordIdx;
+ if(dropHyp) { // if we force=drop the hypothesis, assign EOS, otherwise the expected word id.
+ if(factoredVocab) { // when using factoredVocab, extract the EOS lemma index from the word id, we predicting factors one by one here, hence lemma only
+ std::vector<size_t> eosFactors;
+ factoredVocab->word2factors(factoredVocab->getEosId(), eosFactors);
+ wordIdx = (WordIndex)eosFactors[0];
+ } else { // without factoredVocab lemma index and word index are the same. Safe cruising.
+ wordIdx = trgVocab_->getEosId().toWordIndex();
}
+ } else { // we are not dropping anything, just assign the normal index
+ wordIdx = (WordIndex)(key % vocabSize);
+ }
- // Set alignments
- if(!align.empty()) {
- hyp->SetAlignment(
- getAlignmentsForHypothesis(align, batch, (int)beamHypIdx, (int)beamIdx));
+ // @TODO: We currently assign a log probability of 0 to all beam entries of the dropped batch entry, instead it might be a good idea to use
+ // the per Hyp pathScore without the current expansion (a bit hard to obtain).
+ // For the case where we drop empty inputs, 0 is fine. For other use cases like a forced stop, the penultimate pathScore might be better.
+ // For the empty hyp this would naturally result in 0, too.
+ const float pathScore = dropHyp ? 0.f : nBestPathScores[i]; // 0 (Prob = 1, maximum score) if dropped or expanded path score for (batchIdx, beamHypIdx, word)
+
+ const auto& beam = beams[origBatchIdx];
+ auto& newBeam = newBeams[origBatchIdx]; // extended hypotheses are going to be placed in this new beam
+
+ if(newBeam.size() >= beam.size()) // getNBestList() generates N for all batch entries incl. those that already have a narrower beam
+ continue;
+ if(pathScore == INVALID_PATH_SCORE) // (dummy slot or word that cannot be expanded by current factor)
+ continue;
+
+ ABORT_IF(pathScore < INVALID_PATH_SCORE, "Actual pathScore ({}) is lower than INVALID_PATH_SCORE ({})??", pathScore, INVALID_PATH_SCORE); // This should not happen in valid situations. Currently the only smaller value would be -inf (effect of overflow in summation?)
+ ABORT_IF(beamHypIdx >= beam.size(), "Out of bounds beamHypIdx??"); // effectively this is equivalent to ABORT_IF(beams[origBatchIdx].empty(), ...)
+
+ // map wordIdx to word
+ auto prevBeamHypIdx = beamHypIdx; // back pointer
+ auto prevHyp = beam[prevBeamHypIdx];
+ Word word;
+ // If short list has been set, then wordIdx is an index into the short-listed word set,
+ // rather than the true word index.
+ auto shortlist = scorers_[0]->getShortlist();
+ if (factoredVocab) {
+ // For factored decoding, the word is built over multiple decoding steps,
+ // starting with the lemma, then adding factors one by one.
+ if (factorGroup == 0) {
+ word = factoredVocab->lemma2Word(shortlist ? shortlist->reverseMap(wordIdx) : wordIdx); // @BUGBUG: reverseMap is only correct if factoredVocab_->getGroupRange(0).first == 0
+ std::vector<size_t> factorIndices; factoredVocab->word2factors(word, factorIndices);
+ //LOG(info, "{} + {} ({}) -> {} -> {}",
+ // factoredVocab->decode(prevHyp->tracebackWords()),
+ // factoredVocab->word2string(word), factorIndices[0], prevHyp->getPathScore(), pathScore);
+ }
+ else {
+ //LOG(info, "{} |{} ({}) = {} ({}) -> {} -> {}",
+ // factoredVocab->decodeForDiagnostics(beam[beamHypIdx]->tracebackWords()),
+ // factoredVocab->getFactorGroupPrefix(factorGroup), factorGroup,
+ // factoredVocab->getFactorName(factorGroup, wordIdx), wordIdx,
+ // prevHyp->getPathScore(), pathScore);
+ word = beam[beamHypIdx]->getWord();
+ ABORT_IF(!factoredVocab->canExpandFactoredWord(word, factorGroup),
+ "A word without this factor snuck through to here??");
+ word = factoredVocab->expandFactoredWord(word, factorGroup, wordIdx);
+ prevBeamHypIdx = prevHyp->getPrevStateIndex();
+ prevHyp = prevHyp->getPrevHyp(); // short-circuit the backpointer, so that the traceback does not contain partially factored words
+ }
+ }
+ else if (shortlist)
+ word = Word::fromWordIndex(shortlist->reverseMap(wordIdx));
+ else
+ word = Word::fromWordIndex(wordIdx);
+
+ auto hyp = Hypothesis::New(prevHyp, word, prevBeamHypIdx, pathScore);
+
+ // Set score breakdown for n-best lists
+ if(options_->get<bool>("n-best")) {
+ auto breakDown = beam[beamHypIdx]->getScoreBreakdown();
+ ABORT_IF(factoredVocab && factorGroup > 0 && !factoredVocab->canExpandFactoredWord(word, factorGroup),
+ "A word without this factor snuck through to here??");
+ breakDown.resize(states.size(), 0); // at start, this is empty, so this will set the initial score to 0
+ for(size_t j = 0; j < states.size(); ++j) {
+ auto lval = states[j]->getLogProbs().getFactoredLogitsTensor(factorGroup); // [maxBeamSize, 1, currentDimBatch, dimFactorVocab]
+ // The flatting happens based on actual (current) batch size and batch index computed with batch-pruning as we are looking into the pruned tensor
+ size_t flattenedLogitIndex = (beamHypIdx * currentDimBatch + currentBatchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key'
+
+ // @TODO: use a function on shape() to index, or new method val->at({i1, i2, i3, i4}) with broadcasting
+ ABORT_IF(lval->shape() != Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}) &&
+ (beamHypIdx == 0 && lval->shape() != Shape({1, 1, (int)currentDimBatch, (int)vocabSize})),
+ "Unexpected shape of logits?? {} != {}", lval->shape(), Shape({(int)nBestBeamSize, 1, (int)currentDimBatch, (int)vocabSize}));
+
+ breakDown[j] += lval->get(flattenedLogitIndex);
}
+ hyp->setScoreBreakdown(breakDown);
+ }
- newBeam.push_back(hyp);
+ // Set alignments
+ if(!align.empty())
+ hyp->setAlignment(getAlignmentsForHypothesis(align, batch, (int)beamHypIdx, (int)currentBatchIdx, (int)origBatchIdx, (int)currentDimBatch));
+ else // not first factor: just copy
+ hyp->setAlignment(beam[beamHypIdx]->getAlignment());
+
+ newBeam.push_back(hyp);
+ }
+
+ // if factored vocab and this is not the first factor, we need to
+ // also propagate factored hypotheses that do not get expanded in this step because they don't have this factor
+ if (factorGroup > 0) {
+ for (size_t batchIdx = 0; batchIdx < beams.size(); batchIdx++) {
+ const auto& beam = beams[batchIdx];
+ auto& newBeam = newBeams[batchIdx];
+ for (const auto& beamHyp : beam) {
+ auto word = beamHyp->getWord();
+ //LOG(info, "Checking {}", factoredVocab->word2string(word));
+ if (factoredVocab->canExpandFactoredWord(word, factorGroup)) // handled above
+ continue;
+ //LOG(info, "Forwarded {}", factoredVocab->word2string(word));
+ newBeam.push_back(beamHyp);
+ }
+ if (newBeam.size() > beam.size()) {
+ //LOG(info, "Size {}, sorting...", newBeam.size());
+ std::nth_element(newBeam.begin(), newBeam.begin() + beam.size(), newBeam.end(), [](Hypothesis::PtrType a, Hypothesis::PtrType b) {
+ return a->getPathScore() > b->getPathScore(); // (sort highest score first)
+ });
+ //LOG(info, "Size {}, sorted...", newBeam.size());
+ newBeam.resize(beam.size());
+ }
}
}
return newBeams;
}
- std::vector<float> getAlignmentsForHypothesis(
- const std::vector<float> alignAll,
+ std::vector<float> getAlignmentsForHypothesis( // -> P(s|t) for current t and given beam and batch dim
+ const std::vector<float> alignAll, // [beam depth, max src length, batch size, 1], flattened vector of all attention probablities
Ptr<data::CorpusBatch> batch,
int beamHypIdx,
- int beamIdx) {
+ int currentBatchIdx,
+ int origBatchIdx,
+ int currentDimBatch) const {
// Let's B be the beam size, N be the number of batched sentences,
// and L the number of words in the longest sentence in the batch.
// The alignment vector:
@@ -126,182 +225,316 @@ public:
// in a single beam, i.e.:
// * [word1-batch1, word1-batch2, ..., word2-batch1, ...]
//
- size_t batchSize = batch->size();
- size_t batchWidth = batch->width() * batchSize;
+
+ size_t origDimBatch = batch->size(); // number of sentences in batch
+ size_t batchWidth = batch->width(); // max src length
+
+ // loop over words of batch entry 'currentBatchIdx' and beam entry 'beamHypIdx'
std::vector<float> align;
+ for(size_t srcPos = 0; srcPos < batchWidth; ++srcPos) { // loop over source positions
+ // We are looking into the probabilites from an actual tensor, hence we need to use currentDimBatch and currentBatchIdx.
+ size_t currentAttIdx = (batchWidth * beamHypIdx + srcPos) * currentDimBatch + currentBatchIdx; // = flatten [beam index, s, batch index, 0]
- for(size_t w = 0; w < batchWidth / batchSize; ++w) {
- size_t a = ((batchWidth * beamHypIdx) + beamIdx) + (batchSize * w);
- size_t m = a % batchWidth;
- if(batch->front()->mask()[m] != 0)
- align.emplace_back(alignAll[a]);
- }
+ // We are looking into the mask from the orginal batch, hence we need to use origDmBatch and origBatchIdx.
+ size_t origAttIdx = (batchWidth * beamHypIdx + srcPos) * origDimBatch + origBatchIdx;; // = flatten [beam index, s, batch index, 0]
+ size_t origMaskIdx = origAttIdx % (batchWidth * origDimBatch); // == batchIdx + (batchSize * srcPos) = flatten [0, s, batch index, 0]
+ // If the original position is not masked out used the corresponding current attention score.
+ if(batch->front()->mask()[origMaskIdx] != 0)
+ align.emplace_back(alignAll[currentAttIdx]);
+ }
return align;
}
- Beams pruneBeam(const Beams& beams) {
+ // remove all beam entries that have reached EOS
+ Beams purgeBeams(const Beams& beams, /*in/out=*/std::vector<IndexType>& batchIdxMap) {
+ const auto trgEosId = trgVocab_->getEosId();
Beams newBeams;
+ size_t beamIdx = 0; // beam index
for(auto beam : beams) {
- Beam newBeam;
- for(auto hyp : beam) {
- if(hyp->GetWord() != trgEosId_) {
- newBeam.push_back(hyp);
- }
+ Beam newBeam; // a beam of surviving hyps
+ for(auto hyp : beam)
+ if(hyp->getWord() != trgEosId) // if this hyp is not finished,
+ newBeam.push_back(hyp); // move over to beam of surviving hyps
+
+ if(PURGE_BATCH)
+ if(newBeam.empty() && !beam.empty()) { // previous beam had hyps, but all were finished in this step, newBeam will now stay empty
+ for(size_t i = beamIdx + 1; i < beams.size(); ++i) // for all entries above this beam
+ batchIdxMap[i] = batchIdxMap[i] - 1; // make them look at one batch index below, as the current entry will be removed from the batch.
}
+
newBeams.push_back(newBeam);
+ beamIdx++; // move to next beam index
}
return newBeams;
}
+ //**********************************************************************
// main decoding function
Histories search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
- int dimBatch = (int)batch->size();
+ auto factoredVocab = trgVocab_->tryAs<FactoredVocab>();
+#if 0 // use '1' here to disable factored decoding, e.g. for comparisons
+ factoredVocab.reset();
+#endif
+ size_t numFactorGroups = factoredVocab ? factoredVocab->getNumGroups() : 1;
+ if (numFactorGroups == 1) // if no factors then we didn't need this object in the first place
+ factoredVocab.reset();
+
+ // We will use the prefix "origBatch..." whenever we refer to batch dimensions of the original batch. These do not change during search.
+ // We will use the prefix "currentBatch.." whenever we refer to batch dimension that can change due to batch-pruning.
+ const int origDimBatch = (int)batch->size();
+ const auto trgEosId = trgVocab_->getEosId();
+ const auto trgUnkId = trgVocab_->getUnkId();
+
+ auto getNBestList = createGetNBestListFn(beamSize_, origDimBatch, graph->getDeviceId());
+
+ for(auto scorer : scorers_) {
+ scorer->clear(graph);
+ }
- Histories histories;
- for(int i = 0; i < dimBatch; ++i) {
+ Histories histories(origDimBatch);
+ for(int i = 0; i < origDimBatch; ++i) {
size_t sentId = batch->getSentenceIds()[i];
- auto history = New<History>(sentId,
+ histories[i] = New<History>(sentId,
options_->get<float>("normalize"),
options_->get<float>("word-penalty"));
- histories.push_back(history);
}
- size_t localBeamSize = beamSize_; // max over beam sizes of active sentence hypotheses
-
- auto getNBestList = createGetNBestListFn(localBeamSize, dimBatch, graph->getDeviceId());
-
- Beams beams(dimBatch); // [batchIndex][beamIndex] is one sentence hypothesis
- for(auto& beam : beams)
- beam.resize(localBeamSize, New<Hypothesis>());
-
- bool first = true;
- bool final = false;
-
- for(int i = 0; i < dimBatch; ++i)
- histories[i]->Add(beams[i], trgEosId_);
-
+ // start states
std::vector<Ptr<ScorerState>> states;
-
for(auto scorer : scorers_) {
- scorer->clear(graph);
+ states.push_back(scorer->startState(graph, batch));
}
- for(auto scorer : scorers_) {
- states.push_back(scorer->startState(graph, batch));
+ // create one beam per batch entry with sentence-start hypothesis
+ Beams beams(origDimBatch, Beam(beamSize_, Hypothesis::New())); // array [origDimBatch] of array [maxBeamSize] of Hypothesis, keeps full size through search.
+ // batch purging is determined from an empty sub-beam.
+ std::vector<IndexType> batchIdxMap(origDimBatch); // Record at which batch entry a beam is looking.
+ // By default that corresponds to position in array,
+ // but shifts in the course of removing batch entries when they are finished.
+
+ const std::vector<bool> emptyBatchEntries; // used for recording if there are empty input batch entries
+ for(int origBatchIdx = 0; origBatchIdx < origDimBatch; ++origBatchIdx) {
+ batchIdxMap[origBatchIdx] = origBatchIdx; // map to same position on initialization
+ auto& beam = beams[origBatchIdx];
+ histories[origBatchIdx]->add(beam, trgEosId); // add beams with start-hypotheses to traceback grid
+
+ // Mark batch entries that consist only of source <EOS> i.e. these are empty lines. They will be forced to EOS and purged from batch
+ const auto& srcEosId = batch->front()->vocab()->getEosId();
+ const_cast<std::vector<bool>&>(emptyBatchEntries).push_back(batch->front()->data()[origBatchIdx] == srcEosId); // const_cast during construction
+ }
+
+ // determine index of UNK in the log prob vectors if we want to suppress it in the decoding process
+ int unkColId = -1;
+ if (trgUnkId != Word::NONE && !options_->get<bool>("allow-unk", false)) { // do we need to suppress unk?
+ unkColId = factoredVocab ? factoredVocab->getUnkIndex() : trgUnkId.toWordIndex(); // what's the raw index of unk in the log prob vector?
+ auto shortlist = scorers_[0]->getShortlist(); // first shortlist is generally ok, @TODO: make sure they are the same across scorers?
+ if (shortlist)
+ unkColId = shortlist->tryForwardMap(unkColId); // use shifted postion of unk in case of using a shortlist, shortlist may have removed unk which results in -1
}
- // main loop over output tokens
- do {
- //**********************************************************************
- // create constant containing previous path scores for current beam
- // also create mapping of hyp indices, which are not 1:1 if sentences complete
- std::vector<IndexType> hypIndices; // [beamIndex * activeBatchSize + batchIndex] backpointers, concatenated over beam positions. Used for reordering hypotheses
- std::vector<IndexType> embIndices;
- Expr prevPathScores; // [beam, 1, 1, 1]
- if(first) {
- // no scores yet
- prevPathScores = graph->constant({1, 1, 1, 1}, inits::from_value(0));
- } else {
- std::vector<float> beamScores;
-
- dimBatch = (int)batch->size();
-
- for(size_t i = 0; i < localBeamSize; ++i) {
- for(size_t j = 0; j < beams.size(); ++j) { // loop over batch entries (active sentences)
- auto& beam = beams[j];
- if(i < beam.size()) {
- auto hyp = beam[i];
- hypIndices.push_back((IndexType)hyp->GetPrevStateIndex()); // backpointer
- embIndices.push_back(hyp->GetWord());
- beamScores.push_back(hyp->GetPathScore());
- } else { // dummy hypothesis
- hypIndices.push_back(0);
- embIndices.push_back(0); // (unused)
- beamScores.push_back(-9999);
+ // the decoding process updates the following state information in each output time step:
+ // - beams: array [origDimBatch] of array [maxBeamSize] of Hypothesis
+ // - current output time step's set of active hypotheses, aka active search space
+ // - states[.]: ScorerState
+ // - NN state; one per scorer, e.g. 2 for ensemble of 2
+ // and it forms the following return value
+ // - histories: array [origDimBatch] of History
+ // with History: vector [t] of array [maxBeamSize] of Hypothesis
+ // with Hypothesis: (last word, aggregate score, prev Hypothesis)
+
+ IndexType currentDimBatch = origDimBatch;
+ auto prevBatchIdxMap = batchIdxMap; // [origBatchIdx -> currentBatchIdx] but shifted by one time step
+ // main loop over output time steps
+ for (size_t t = 0; ; t++) {
+ ABORT_IF(origDimBatch != beams.size(), "Lost a batch entry??");
+ // determine beam size for next output time step, as max over still-active sentences
+ // E.g. if all batch entries are down from beam 5 to no more than 4 surviving hyps, then
+ // switch to beam of 4 for all. If all are done, then beam ends up being 0, and we are done.
+ size_t maxBeamSize = 0; // @TODO: is there some std::algorithm for this?
+ for(auto& beam : beams)
+ if(beam.size() > maxBeamSize)
+ maxBeamSize = beam.size();
+
+ // done if all batch entries have reached EOS on all beam entries
+ if (maxBeamSize == 0)
+ break;
+
+ for (size_t factorGroup = 0; factorGroup < numFactorGroups; factorGroup++) {
+ // for factored vocabs, we do one factor at a time, but without updating the scorer for secondary factors
+
+ //**********************************************************************
+ // create constant containing previous path scores for current beam
+ // Also create mapping of hyp indices, for reordering the decoder-state tensors.
+ std::vector<IndexType> batchIndices; // [1, 1, currentDimBatch, 1] indices of currently used batch indices with regard to current, actual tensors
+ std::vector<IndexType> hypIndices; // [maxBeamSize, 1, currentDimBatch, 1] (flattened) tensor index ((beamHypIdx, batchIdx), flattened) of prev hyp that a hyp originated from
+ std::vector<Word> prevWords; // [maxBeamSize, 1, currentDimBatch, 1] (flattened) word that a hyp ended in, for advancing the decoder-model's history
+ Expr prevPathScores; // [maxBeamSize, 1, currentDimBatch, 1], path score that a hyp ended in (last axis will broadcast into vocab size when adding expandedPathScores)
+
+ bool anyCanExpand = false; // stays false if all hyps are invalid factor expansions
+ if(t == 0 && factorGroup == 0) { // no scores yet
+ prevPathScores = graph->constant({1, 1, 1, 1}, inits::fromValue(0));
+ anyCanExpand = true;
+
+ // at the beginning all batch entries are used
+ batchIndices.resize(origDimBatch);
+ std::iota(batchIndices.begin(), batchIndices.end(), 0);
+ } else {
+ if(factorGroup == 0) // only factorGroup==0 can subselect neural state
+ for(int currentBatchIdx = 0; currentBatchIdx < beams.size(); ++currentBatchIdx) // loop over batch entries (active sentences)
+ if(!beams[currentBatchIdx].empty() || !PURGE_BATCH) // for each beam check
+ batchIndices.push_back(prevBatchIdxMap[currentBatchIdx]); // which batch entries were active in previous step
+
+ std::vector<float> prevScores;
+ for(size_t beamHypIdx = 0; beamHypIdx < maxBeamSize; ++beamHypIdx) { // loop over globally maximal beam-size (maxBeamSize)
+ for(int origBatchIdx = 0; origBatchIdx < origDimBatch; ++origBatchIdx) { // loop over all batch entries (active and inactive)
+ auto& beam = beams[origBatchIdx];
+ if(beamHypIdx < beam.size()) {
+ auto hyp = beam[beamHypIdx];
+ auto word = hyp->getWord();
+ auto canExpand = (!factoredVocab || factoredVocab->canExpandFactoredWord(hyp->getWord(), factorGroup));
+ //LOG(info, "[{}, {}] Can expand {} with {} -> {}", batchIdx, beamHypIdx, (*batch->back()->vocab())[hyp->getWord()], factorGroup, canExpand);
+ anyCanExpand |= canExpand;
+
+ auto currentBatchIdx = origBatchIdx;
+ if(PURGE_BATCH) {
+ if(factorGroup == 0)
+ currentBatchIdx = prevBatchIdxMap[origBatchIdx]; // subselection may happen for factorGroup == 0
+ else
+ currentBatchIdx = batchIdxMap[origBatchIdx]; // no subselection happens for factorGroup > 0,
+ // but we treat it like a next step, since a step
+ // happened for factorGroup == 0
+ }
+
+ auto hypIndex = (IndexType)(hyp->getPrevStateIndex() * currentDimBatch + currentBatchIdx); // (beamHypIdx, batchIdx), flattened, for index_select() operation
+
+ hypIndices.push_back(hypIndex); // (beamHypIdx, batchIdx), flattened as said above.
+ prevWords .push_back(word);
+ prevScores.push_back(canExpand ? hyp->getPathScore() : INVALID_PATH_SCORE);
+ } else { // pad to maxBeamSize (dummy hypothesis)
+ if(!PURGE_BATCH || !beam.empty()) { // but only if we are not pruning and the beam is not deactivated yet
+ hypIndices.push_back(0);
+ prevWords.push_back(trgEosId); // (unused, but must be valid)
+ prevScores.push_back((float)INVALID_PATH_SCORE);
+ }
+ }
}
}
+ if(factorGroup == 0)
+ currentDimBatch = (IndexType) batchIndices.size(); // keep batch size constant for all factor groups in a time step
+ prevPathScores = graph->constant({(int)maxBeamSize, 1, (int)currentDimBatch, 1}, inits::fromVector(prevScores));
+ }
+ if (!anyCanExpand) // all words cannot expand this factor: skip
+ continue;
+
+ //**********************************************************************
+ // compute expanded path scores with word prediction probs from all scorers
+ auto expandedPathScores = prevPathScores; // will become [maxBeamSize, 1, currDimBatch, dimVocab]
+ Expr logProbs;
+ for(size_t i = 0; i < scorers_.size(); ++i) {
+ if (factorGroup == 0) {
+ // compute output probabilities for current output time step
+ // - uses hypIndices[index in beam, 1, batch index, 1] to reorder scorer state to reflect the top-N in beams[][]
+ // - adds prevWords [index in beam, 1, batch index, 1] to the scorer's target history
+ // - performs one step of the scorer
+ // - returns new NN state for use in next output time step
+ // - returns vector of prediction probabilities over output vocab via newState
+ // update state in-place for next output time step
+ //if (t > 0) for (size_t kk = 0; kk < prevWords.size(); kk++)
+ // LOG(info, "prevWords[{},{}]={} -> {}", t/numFactorGroups, factorGroup,
+ // factoredVocab ? factoredVocab->word2string(prevWords[kk]) : (*batch->back()->vocab())[prevWords[kk]],
+ // prevScores[kk]);
+ states[i] = scorers_[i]->step(graph, states[i], hypIndices, prevWords, batchIndices, (int)maxBeamSize);
+ if (numFactorGroups == 1) // @TODO: this branch can go away
+ logProbs = states[i]->getLogProbs().getLogits(); // [maxBeamSize, 1, currentDimBatch, dimVocab]
+ else
+ {
+ auto shortlist = scorers_[i]->getShortlist();
+ logProbs = states[i]->getLogProbs().getFactoredLogits(factorGroup, shortlist); // [maxBeamSize, 1, currentDimBatch, dimVocab]
+ }
+ }
+ else {
+ // add secondary factors
+ // For those, we don't update the decoder-model state in any way.
+ // Instead, we just keep expanding with the factors.
+ // We will have temporary Word entries in hyps with some factors set to FACTOR_NOT_SPECIFIED.
+ // For some lemmas, a factor is not applicable. For those, the factor score is the same (zero)
+ // for all factor values. This would thus unnecessarily pollute the beam with identical copies,
+ // and push out other hypotheses. Hence, we exclude those here by setting the path score to
+ // INVALID_PATH_SCORE. Instead, toHyps() explicitly propagates those hyps by simply copying the
+ // previous hypothesis.
+ logProbs = states[i]->getLogProbs().getFactoredLogits(factorGroup, /*shortlist=*/ nullptr, hypIndices, maxBeamSize); // [maxBeamSize, 1, currentDimBatch, dimVocab]
+ }
+ // expand all hypotheses, [maxBeamSize, 1, currentDimBatch, 1] -> [maxBeamSize, 1, currentDimBatch, dimVocab]
+ expandedPathScores = expandedPathScores + scorers_[i]->getWeight() * logProbs;
}
- prevPathScores = graph->constant({(int)localBeamSize, 1, dimBatch, 1},
- inits::from_vector(beamScores));
- }
-
- //**********************************************************************
- // prepare scores for beam search
- auto pathScores = prevPathScores;
-
- for(size_t i = 0; i < scorers_.size(); ++i) {
- states[i] = scorers_[i]->step(
- graph, states[i], hypIndices, embIndices, dimBatch, (int)localBeamSize);
+ // make beams continuous
+ expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [currentDimBatch, 1, maxBeamSize, dimVocab]
- if(scorers_[i]->getWeight() != 1.f)
- pathScores = pathScores + scorers_[i]->getWeight() * states[i]->getLogProbs();
+ // perform NN computation
+ if(t == 0 && factorGroup == 0)
+ graph->forward();
else
- pathScores = pathScores + states[i]->getLogProbs();
- }
-
- // make beams continuous
- if(dimBatch > 1 && localBeamSize > 1)
- pathScores = transpose(pathScores, {2, 1, 0, 3});
-
- if(first)
- graph->forward();
- else
- graph->forwardNext();
-
- //**********************************************************************
- // suppress specific symbols if not at right positions
- if(trgUnkId_ != -1 && options_->has("allow-unk")
- && !options_->get<bool>("allow-unk"))
- suppressWord(pathScores, trgUnkId_);
- for(auto state : states)
- state->blacklist(pathScores, batch);
-
- //**********************************************************************
- // perform beam search and pruning
- std::vector<unsigned int> outKeys;
- std::vector<float> outPathScores;
-
- std::vector<size_t> beamSizes(dimBatch, localBeamSize);
- getNBestList(beamSizes, pathScores->val(), outPathScores, outKeys, first);
-
- int dimTrgVoc = pathScores->shape()[-1];
- beams = toHyps(outKeys,
- outPathScores,
- dimTrgVoc,
- beams,
- states,
- localBeamSize,
- first,
- batch);
-
- auto prunedBeams = pruneBeam(beams);
- for(int i = 0; i < dimBatch; ++i) {
- if(!beams[i].empty()) {
- final = final
- || histories[i]->size()
- >= options_->get<float>("max-length-factor")
- * batch->front()->batchWidth();
- histories[i]->Add(
- beams[i], trgEosId_, prunedBeams[i].empty() || final);
+ graph->forwardNext();
+
+ //**********************************************************************
+ // suppress specific symbols if not at right positions
+ if(unkColId != -1 && factorGroup == 0)
+ suppressWord(expandedPathScores, unkColId);
+ for(auto state : states)
+ state->blacklist(expandedPathScores, batch);
+
+ //**********************************************************************
+ // perform beam search
+
+ // find N best amongst the (maxBeamSize * dimVocab) hypotheses
+ std::vector<unsigned int> nBestKeys; // [currentDimBatch, maxBeamSize] flattened -> (batchIdx, beamHypIdx, word idx) flattened
+ std::vector<float> nBestPathScores; // [currentDimBatch, maxBeamSize] flattened
+ getNBestList(/*in*/ expandedPathScores->val(), // [currentDimBatch, 1, maxBeamSize, dimVocab or dimShortlist]
+ /*N=*/ maxBeamSize, // desired beam size
+ /*out*/ nBestPathScores, /*out*/ nBestKeys,
+ /*first=*/t == 0 && factorGroup == 0); // @TODO: this is only used for checking presently, and should be removed altogether
+ // Now, nBestPathScores contain N-best expandedPathScores for each batch and beam,
+ // and nBestKeys for each their original location (batchIdx, beamHypIdx, word).
+
+ // combine N-best sets with existing search space (beams) to updated search space
+ beams = toHyps(nBestKeys, nBestPathScores,
+ /*nBestBeamSize*/expandedPathScores->shape()[-2], // used for interpretation of keys
+ /*vocabSize=*/expandedPathScores->shape()[-1], // used for interpretation of keys
+ beams,
+ states, // used for keeping track of per-ensemble-member path score
+ batch, // only used for propagating alignment info
+ factoredVocab, factorGroup,
+ emptyBatchEntries, // [origDimBatch] - empty source batch entries are marked with true
+ batchIdxMap); // used to create a reverse batch index map to recover original batch indices for this step
+ } // END FOR factorGroup = 0 .. numFactorGroups-1
+
+ prevBatchIdxMap = batchIdxMap; // save current batchIdx map to be used in next step; we are then going to look one step back
+
+ // remove all hyps that end in EOS
+ // The position of a hyp in the beam may change.
+ // in/out = shifts the batch index map if a beam gets fully purged
+ const auto purgedNewBeams = purgeBeams(beams, /*in/out=*/batchIdxMap);
+
+ // add updated search space (beams) to our return value
+ bool maxLengthReached = false;
+ for(int batchIdx = 0; batchIdx < origDimBatch; ++batchIdx) {
+ // if this batch entry has surviving hyps then add them to the traceback grid
+ if(!beams[batchIdx].empty()) { // if the beam is not empty expand the history object associated with the beam
+ if (histories[batchIdx]->size() >= options_->get<float>("max-length-factor") * batch->front()->batchWidth())
+ maxLengthReached = true;
+ histories[batchIdx]->add(beams[batchIdx], trgEosId, purgedNewBeams[batchIdx].empty() || maxLengthReached);
}
}
- beams = prunedBeams;
-
- // determine beam size for next sentence, as max over still-active sentences
- if(!first) {
- size_t maxBeam = 0;
- for(auto& beam : beams)
- if(beam.size() > maxBeam)
- maxBeam = beam.size();
- localBeamSize = maxBeam;
- }
- first = false;
+ if (maxLengthReached) // early exit if max length limit was reached
+ break;
- } while(localBeamSize != 0 && !final); // end of main loop over output tokens
+ // this is the search space for the next output time step
+ beams = purgedNewBeams;
+ } // end of main loop over output time steps
- return histories;
+ return histories; // [origDimBatch][t][N best hyps]
}
};
} // namespace marian
diff --git a/src/translator/helpers.cpp b/src/translator/helpers.cpp
index ccc45040..f4b75da0 100755..100644
--- a/src/translator/helpers.cpp
+++ b/src/translator/helpers.cpp
@@ -24,18 +24,18 @@ void SetColumn(Tensor in_, size_t col, float value) {
}
}
-void suppressWord(Expr logProbs, Word id) {
- SetColumn(logProbs->val(), id, std::numeric_limits<float>::lowest());
+void suppressWord(Expr logProbs, WordIndex wordIndex) {
+ SetColumn(logProbs->val(), wordIndex, std::numeric_limits<float>::lowest());
}
} // namespace cpu
-void suppressWord(Expr logProbs, Word id) {
+void suppressWord(Expr logProbs, WordIndex wordIndex) {
if(logProbs->val()->getBackend()->getDeviceId().type == DeviceType::cpu) {
- cpu::suppressWord(logProbs, id);
+ cpu::suppressWord(logProbs, wordIndex);
}
#ifdef CUDA_FOUND
else {
- gpu::suppressWord(logProbs, id);
+ gpu::suppressWord(logProbs, wordIndex);
}
#endif
}
diff --git a/src/translator/helpers.cu b/src/translator/helpers.cu
index 8b40d8db..48a079fc 100644
--- a/src/translator/helpers.cu
+++ b/src/translator/helpers.cu
@@ -10,15 +10,18 @@
#include "tensors/tensor.h"
#include "translator/helpers.h"
+#include "tensors/gpu/cuda_helpers.h"
+
namespace marian {
namespace gpu {
-__global__ void gSetColumn(float* d_in,
+template <typename T>
+__global__ void gSetColumn(T* d_in,
size_t n_columns,
size_t n_rows,
size_t noColumn,
- float value) {
+ T value) {
size_t rowNumber = threadIdx.x + blockDim.x * blockIdx.x;
size_t index = noColumn + rowNumber * n_columns;
@@ -27,18 +30,26 @@ __global__ void gSetColumn(float* d_in,
}
}
-void SetColumn(Tensor in_, size_t col, float value) {
- int nRows = in_->shape().elements() / in_->shape()[-1];
- int nColumns = in_->shape()[-1];
+void SetColumn(Tensor in, size_t col, float value) {
+ int nRows = in->shape().elements() / in->shape()[-1];
+ int nColumns = in->shape()[-1];
int nBlocks = nRows / 512 + ((nRows % 512 == 0) ? 0 : 1);
int nThreads = std::min(512, nRows);
- gSetColumn<<<nBlocks, nThreads>>>(in_->data(), nColumns, nRows, col, value);
+ if(in->type() == Type::float32) {
+ gSetColumn<<<nBlocks, nThreads>>>(in->data<float>(), nColumns, nRows, col, value);
+#if COMPILE_FP16
+ } else if(in->type() == Type::float16) {
+ gSetColumn<<<nBlocks, nThreads>>>(in->data<half>(), nColumns, nRows, col, (half)value);
+#endif
+ } else {
+ ABORT("suppressWord not implemented for type {}", in->type());
+ }
}
-void suppressWord(Expr probs, Word id) {
- SetColumn(probs->val(), id, std::numeric_limits<float>::lowest());
+void suppressWord(Expr probs, WordIndex wordIndex) {
+ SetColumn(probs->val(), wordIndex, NumericLimits<float>(probs->value_type()).lowest);
}
} // namespace gpu
} // namespace marian
diff --git a/src/translator/helpers.h b/src/translator/helpers.h
index d4ff3a94..71b1eb20 100755..100644
--- a/src/translator/helpers.h
+++ b/src/translator/helpers.h
@@ -11,13 +11,13 @@ namespace marian {
namespace cpu {
-void suppressWord(Expr logProbs, Word id);
+void suppressWord(Expr logProbs, WordIndex wordIndex);
}
namespace gpu {
-void suppressWord(Expr logProbs, Word id);
+void suppressWord(Expr logProbs, WordIndex wordIndex);
}
-void suppressWord(Expr logProbs, Word id);
+void suppressWord(Expr logProbs, WordIndex wordIndex);
} // namespace marian
diff --git a/src/translator/history.h b/src/translator/history.h
index e70b80c6..463c75a1 100755..100644
--- a/src/translator/history.h
+++ b/src/translator/history.h
@@ -14,26 +14,22 @@ private:
struct SentenceHypothesisCoord {
bool operator<(const SentenceHypothesisCoord& hc) const { return normalizedPathScore < hc.normalizedPathScore; }
- size_t i; // last time step of this sentence hypothesis
- size_t j; // which beam entry
+ size_t timeStepIdx; // last time step of this sentence hypothesis
+ size_t beamIdx; // which beam entry
float normalizedPathScore; // length-normalized sentence score
};
+ float lengthPenalty(size_t length) { return std::pow((float)length, alpha_); }
+ float wordPenalty(size_t length) { return wp_ * (float)length; }
public:
History(size_t lineNo, float alpha = 1.f, float wp_ = 0.f);
- float LengthPenalty(size_t length) { return std::pow((float)length, alpha_); }
- float WordPenalty(size_t length) { return wp_ * (float)length; }
-
- void Add(const Beam& beam, Word trgEosId, bool last = false) {
- if(beam.back()->GetPrevHyp() != nullptr) {
- for(size_t j = 0; j < beam.size(); ++j)
- if(beam[j]->GetWord() == trgEosId || last) {
- float pathScore = (beam[j]->GetPathScore() - WordPenalty(history_.size()))
- / LengthPenalty(history_.size());
- topHyps_.push({history_.size(), j, pathScore});
- // std::cerr << "Add " << history_.size() << " " << j << " " << pathScore
- // << std::endl;
+ void add(const Beam& beam, Word trgEosId, bool last = false) {
+ if(beam.back()->getPrevHyp() != nullptr) { // if not start hyp do
+ for(size_t beamIdx = 0; beamIdx < beam.size(); ++beamIdx)
+ if(beam[beamIdx]->getWord() == trgEosId || last) { // if this is a final hyp do
+ float pathScore = (beam[beamIdx]->getPathScore() - wordPenalty(history_.size())) / lengthPenalty(history_.size()); // get and normalize path score
+ topHyps_.push({history_.size(), beamIdx, pathScore}); // push final hyp on queue of scored hyps
}
}
history_.push_back(beam);
@@ -41,37 +37,39 @@ public:
size_t size() const { return history_.size(); } // number of time steps
- NBestList NBest(size_t n) const {
+ NBestList nBest(size_t n) const {
NBestList nbest;
for (auto topHypsCopy = topHyps_; nbest.size() < n && !topHypsCopy.empty(); topHypsCopy.pop()) {
auto bestHypCoord = topHypsCopy.top();
- const size_t start = bestHypCoord.i; // last time step of this hypothesis
- const size_t j = bestHypCoord.j; // which beam entry
- Ptr<Hypothesis> bestHyp = history_[start][j];
- // float c = bestHypCoord.normalizedPathScore;
- // std::cerr << "h: " << start << " " << j << " " << c << std::endl;
+ const size_t timeStepIdx = bestHypCoord.timeStepIdx; // last time step of this hypothesis
+ const size_t beamIdx = bestHypCoord.beamIdx; // which beam entry
+ Hypothesis::PtrType bestHyp = history_[timeStepIdx][beamIdx];
// trace back best path
- Words targetWords = bestHyp->TracebackWords();
+ Words targetWords = bestHyp->tracebackWords();
- // note: bestHyp->GetPathScore() is not normalized, while bestHypCoord.normalizedPathScore is
+ // note: bestHyp->getPathScore() is not normalized, while bestHypCoord.normalizedPathScore is
nbest.emplace_back(targetWords, bestHyp, bestHypCoord.normalizedPathScore);
}
return nbest;
}
- Result Top() const { return NBest(1)[0]; }
+ Result top() const {
+ const NBestList& nbest = nBest(1);
+ ABORT_IF(nbest.empty(), "No hypotheses in n-best list??");
+ return nbest[0];
+ }
- size_t GetLineNum() const { return lineNo_; }
+ size_t getLineNum() const { return lineNo_; }
private:
- std::vector<Beam> history_; // [time step][index into beam] search grid
+ std::vector<Beam> history_; // [time step][index into beam] search grid @TODO: simplify as this is currently an expensive length count
std::priority_queue<SentenceHypothesisCoord> topHyps_; // all sentence hypotheses (those that reached eos), sorted by score
size_t lineNo_;
float alpha_;
float wp_;
};
-typedef std::vector<Ptr<History>> Histories;
+typedef std::vector<Ptr<History>> Histories; // [batchDim]
} // namespace marian
diff --git a/src/translator/hypothesis.h b/src/translator/hypothesis.h
index c13ac716..9bd21dcd 100755..100644
--- a/src/translator/hypothesis.h
+++ b/src/translator/hypothesis.h
@@ -6,65 +6,96 @@
namespace marian {
+// one single (partial or full) hypothesis in beam search
+// key elements:
+// - the word that this hyp ends with
+// - the aggregate score up to and including the word
+// - back pointer to previous hypothesis for traceback
class Hypothesis {
public:
- Hypothesis() : prevHyp_(nullptr), prevIndex_(0), word_(0), pathScore_(0.0) {}
+ typedef IPtr<Hypothesis> PtrType;
- Hypothesis(const Ptr<Hypothesis> prevHyp,
+private:
+ // Constructors are private, use Hypothesis::New(...)
+
+ Hypothesis() : prevHyp_(nullptr), prevBeamHypIdx_(0), word_(Word::ZERO), pathScore_(0.0) {}
+
+ Hypothesis(const PtrType prevHyp,
Word word,
- IndexType prevIndex,
+ size_t prevBeamHypIdx, // beam-hyp index that this hypothesis originated from
float pathScore)
- : prevHyp_(prevHyp), prevIndex_(prevIndex), word_(word), pathScore_(pathScore) {}
+ : prevHyp_(prevHyp), prevBeamHypIdx_(prevBeamHypIdx), word_(word), pathScore_(pathScore) {}
+
+public:
+ // Use this whenever creating a pointer to MemoryPiece
+ template <class ...Args>
+ static PtrType New(Args&& ...args) {
+ return PtrType(new Hypothesis(std::forward<Args>(args)...));
+ }
+
+ const PtrType getPrevHyp() const { return prevHyp_; }
- const Ptr<Hypothesis> GetPrevHyp() const { return prevHyp_; }
+ Word getWord() const { return word_; }
- Word GetWord() const { return word_; }
+ size_t getPrevStateIndex() const { return prevBeamHypIdx_; }
- IndexType GetPrevStateIndex() const { return prevIndex_; }
+ float getPathScore() const { return pathScore_; }
- float GetPathScore() const { return pathScore_; }
+ const std::vector<float>& getScoreBreakdown() { return scoreBreakdown_; }
+ void setScoreBreakdown(const std::vector<float>& scoreBreakdown) { scoreBreakdown_ = scoreBreakdown; }
- std::vector<float>& GetScoreBreakdown() { return scoreBreakdown_; }
- std::vector<float>& GetAlignment() { return alignment_; }
+ const std::vector<float>& getAlignment() { return alignment_; }
+ void setAlignment(const std::vector<float>& align) { alignment_ = align; };
- void SetAlignment(const std::vector<float>& align) { alignment_ = align; };
+ // trace back paths referenced from this hypothesis
+ Words tracebackWords() {
+ Words targetWords;
+ for(auto hyp = this; hyp->getPrevHyp(); hyp = hyp->getPrevHyp().get()) {
+ targetWords.push_back(hyp->getWord());
+ }
+ std::reverse(targetWords.begin(), targetWords.end());
+ return targetWords;
+ }
- // helpers to trace back paths referenced from this hypothesis
- Words TracebackWords()
- {
- Words targetWords;
- for (auto hyp = this; hyp->GetPrevHyp(); hyp = hyp->GetPrevHyp().get()) {
- targetWords.push_back(hyp->GetWord());
- // std::cerr << hyp->GetWord() << " " << hyp << std::endl;
- }
- std::reverse(targetWords.begin(), targetWords.end());
- return targetWords;
+ // calculate word-level scores for each target word by de-aggregating the path score
+ std::vector<float> tracebackWordScores() {
+ std::vector<float> scores;
+ // traverse hypotheses backward
+ for(auto hyp = this; hyp->getPrevHyp(); hyp = hyp->getPrevHyp().get()) {
+ // a path score is a cumulative score including scores from all preceding hypotheses (words),
+ // so calculate a word-level score by subtracting the previous path score from the current path score
+ auto prevPathScore = hyp->getPrevHyp() ? hyp->getPrevHyp().get()->pathScore_ : 0.f;
+ scores.push_back(hyp->pathScore_ - prevPathScore);
+ }
+ std::reverse(scores.begin(), scores.end());
+ return scores;
}
- // get soft alignments for each target word starting from the hyp one
+ // get soft alignments [t][s] -> P(s|t) for each target word starting from the hyp one
typedef data::SoftAlignment SoftAlignment;
- SoftAlignment TracebackAlignment()
- {
- SoftAlignment align;
- for (auto hyp = this; hyp->GetPrevHyp(); hyp = hyp->GetPrevHyp().get()) {
- align.push_back(hyp->GetAlignment());
- }
- std::reverse(align.begin(), align.end());
- return align;
+ SoftAlignment tracebackAlignment() {
+ SoftAlignment align;
+ for(auto hyp = this; hyp->getPrevHyp(); hyp = hyp->getPrevHyp().get()) {
+ align.push_back(hyp->getAlignment());
+ }
+ std::reverse(align.begin(), align.end());
+ return align; // [t][s] -> P(s|t)
}
private:
- const Ptr<Hypothesis> prevHyp_;
- const IndexType prevIndex_;
+ const PtrType prevHyp_;
+ const size_t prevBeamHypIdx_;
const Word word_;
const float pathScore_;
- std::vector<float> scoreBreakdown_;
+ std::vector<float> scoreBreakdown_; // [num scorers]
std::vector<float> alignment_;
+
+ ENABLE_INTRUSIVE_PTR(Hypothesis)
};
-typedef std::vector<Ptr<Hypothesis>> Beam; // Beam = vector of hypotheses
-typedef std::vector<Beam> Beams; // Beams = vector of vector of hypotheses
-typedef std::tuple<Words, Ptr<Hypothesis>, float> Result; // (word ids for hyp, hyp, normalized sentence score for hyp)
+typedef std::vector<IPtr<Hypothesis>> Beam; // Beam = vector [beamSize] of hypotheses
+typedef std::vector<Beam> Beams; // Beams = vector [batchDim] of vector [beamSize] of hypotheses
+typedef std::tuple<Words, IPtr<Hypothesis>, float> Result; // (word ids for hyp, hyp, normalized sentence score for hyp)
typedef std::vector<Result> NBestList; // sorted vector of (word ids, hyp, sent score) tuples
} // namespace marian
diff --git a/src/translator/nth_element.cpp b/src/translator/nth_element.cpp
index 54983323..8b2f8947 100755..100644
--- a/src/translator/nth_element.cpp
+++ b/src/translator/nth_element.cpp
@@ -14,87 +14,73 @@ namespace marian {
class NthElementCPU {
std::vector<int> h_res_idx;
std::vector<float> h_res;
- size_t lastN;
+ //size_t lastN_;
public:
- NthElementCPU() = delete;
+ NthElementCPU() {}
NthElementCPU(const NthElementCPU& copy) = delete;
- NthElementCPU(size_t maxBeamSize, size_t maxBatchSize) {
- size_t maxSize = maxBeamSize * maxBatchSize;
+
+public:
+ void getNBestList(Tensor scores, // [dimBatch, 1, beamSize, dimVocab or dimShortlist]
+ size_t N,
+ std::vector<float>& outPathScores,
+ std::vector<unsigned>& outKeys,
+ const bool isFirst) {
+ const auto vocabSize = scores->shape()[-1];
+ const auto inputN = scores->shape()[-2];
+ const auto dimBatch = scores->shape()[-4];
+ ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??"); // @TODO: Remove isFirst argument altogether
+ const float* scoresData = scores->data();
+
+ size_t maxSize = N * dimBatch;
h_res.resize(maxSize);
h_res_idx.resize(maxSize);
- }
+ size_t pos = 0; // iterates through h_res and h_res_idx
-private:
- void getNBestList(float* scores,
- const std::vector<int>& batchFirstElementIdxs,
- const std::vector<int>& cumulativeBeamSizes) {
- /* For each batch, select the max N elements, where N is the beam size for
- * this batch. Locally record these elements (their current value and index
- * in 'scores') before updating each element to a large negative value, such
- * that they won't be a maximum if we're called again on the same input.
- */
-
- int numProbs = batchFirstElementIdxs.back();
- std::vector<int> idxs(numProbs);
+ size_t batchOffset = inputN * vocabSize;
+ std::vector<int> idxs(batchOffset); // re-used for each batch
std::iota(idxs.begin(), idxs.end(), 0);
- size_t numBatches = batchFirstElementIdxs.size() - 1;
- for(size_t batchIdx = 0; batchIdx < numBatches; ++batchIdx) {
- int pos = cumulativeBeamSizes[batchIdx];
- int beamSize = cumulativeBeamSizes[batchIdx + 1] - pos;
-
- std::vector<int>::iterator begin = idxs.begin() + batchFirstElementIdxs[batchIdx];
- std::vector<int>::iterator middle = begin + beamSize;
- std::vector<int>::iterator end = idxs.begin() + batchFirstElementIdxs[batchIdx + 1];
- std::partial_sort(
- begin, middle, end, [=](int a, int b) { return scores[a] > scores[b]; });
-
- while(begin != middle) {
- int idx = *begin++;
- h_res_idx[pos] = idx;
- h_res[pos] = scores[idx];
- scores[idx] = std::numeric_limits<float>::lowest();
+ for(size_t batchIdx = 0; batchIdx < dimBatch; ++batchIdx) {
+
+ std::partial_sort(
+ // sorts the top N (beam size) idxs by score to the front
+ idxs.begin(),
+ idxs.begin() + N,
+ idxs.end(),
+ [&](int a, int b) { return scoresData[a] > scoresData[b]; }
+ );
+
+ // copy top N idxs and scores to return vectors
+ for(size_t i = 0; i < N; ++i) {
+ int idx = idxs[i];
+ // since idxs is re-used for each batch, add batch offset to each idx to get absolute position
+ h_res_idx[pos] = idx + batchIdx * batchOffset;
+ h_res[pos] = scoresData[idx];
++pos;
}
- }
- }
-public:
- void getNBestList(const std::vector<size_t>& beamSizes,
- Tensor scores,
- std::vector<float>& outPathScores,
- std::vector<unsigned>& outKeys,
- const bool isFirst) {
- std::vector<int> cumulativeBeamSizes(beamSizes.size() + 1, 0);
- std::vector<int> batchFirstElementIdxs(beamSizes.size() + 1, 0);
-
- auto vocabSize = scores->shape()[-1];
- for(int i = 0; i < beamSizes.size(); ++i) {
- cumulativeBeamSizes[i + 1] = cumulativeBeamSizes[i] + (int)beamSizes[i];
- batchFirstElementIdxs[i + 1]
- += (isFirst ? i + 1 : cumulativeBeamSizes[i + 1]) * vocabSize;
+ // advance pointer to next batch's beginning
+ scoresData += batchOffset;
}
-
- getNBestList(scores->data(), batchFirstElementIdxs, cumulativeBeamSizes);
- getPairs(cumulativeBeamSizes.back(), outKeys, outPathScores);
+ getPairs(/*cumulativeBeamSizes.back(),*/ outKeys, outPathScores);
}
private:
- void getPairs(size_t number,
+ void getPairs(/*size_t number,*/
std::vector<unsigned>& outKeys,
std::vector<float>& outValues) {
- std::copy(h_res_idx.begin(), h_res_idx.begin() + number, std::back_inserter(outKeys));
- std::copy(h_res .begin(), h_res .begin() + number, std::back_inserter(outValues));
- lastN = number;
+ std::copy(h_res_idx.begin(), h_res_idx.end(), std::back_inserter(outKeys));
+ std::copy(h_res .begin(), h_res .end(), std::back_inserter(outValues));
+ //lastN_ = number;
}
- void getValueByKey(std::vector<float>& out, float* d_in) {
- for(size_t i = 0; i < lastN; ++i) {
- out[i] = d_in[h_res_idx[i]];
- }
- }
+ //void getValueByKey(std::vector<float>& out, float* d_in) {
+ // for(size_t i = 0; i < lastN_; ++i) {
+ // out[i] = d_in[h_res_idx[i]];
+ // }
+ //}
};
#ifdef CUDA_FOUND
@@ -108,15 +94,11 @@ GetNBestListFn createGetNBestListFn(size_t beamSize, size_t dimBatch, DeviceId d
if(deviceId.type == DeviceType::gpu)
return createGetNBestListGPUFn(beamSize, dimBatch, deviceId);
#else
- deviceId; // (unused)
+ deviceId; beamSize; dimBatch; // (unused)
#endif
- auto nth = New<NthElementCPU>(beamSize, dimBatch);
- return [nth](const std::vector<size_t>& beamSizes,
- Tensor logProbs,
- std::vector<float>& outCosts,
- std::vector<unsigned>& outKeys,
- const bool isFirst) {
- return nth->getNBestList(beamSizes, logProbs, outCosts, outKeys, isFirst);
+ auto nth = New<NthElementCPU>();
+ return [nth](Tensor logProbs, size_t N, std::vector<float>& outCosts, std::vector<unsigned>& outKeys, const bool isFirst) {
+ return nth->getNBestList(logProbs, N, outCosts, outKeys, isFirst);
};
}
diff --git a/src/translator/nth_element.cu b/src/translator/nth_element.cu
index d5377f90..e8786ee7 100755..100644
--- a/src/translator/nth_element.cu
+++ b/src/translator/nth_element.cu
@@ -20,11 +20,14 @@ namespace marian {
} \
}
+template <typename T>
__global__ void gMaxElement(float* d_out,
int* d_ind,
- float* d_in,
+ T* d_in, // this is the probs array, only one with type float or half
int numBatches,
- int* batchFirstElementIdxs) {
+ int* batchFirstElementIdxs,
+ float disabledPathScore) // disabledPathScore is used to blank out found values, type-dependent
+{
extern __shared__ float sdata[];
__shared__ int indices[512];
@@ -36,16 +39,16 @@ __global__ void gMaxElement(float* d_out,
int i = begin + blockIdx.x * (blockDim.x * 2) + tid;
- sdata[tid] = -3.40282e+38f;
+ sdata[tid] = disabledPathScore;
if(i < end) {
- sdata[tid] = d_in[i];
+ sdata[tid] = (float)d_in[i];
indices[tid] = i;
}
if(i + blockDim.x < end) {
- float a = d_in[i];
- float b = d_in[i + blockDim.x];
+ float a = (float)d_in[i];
+ float b = (float)d_in[i + blockDim.x];
if(a > b) {
sdata[tid] = a;
indices[tid] = i;
@@ -58,14 +61,14 @@ __global__ void gMaxElement(float* d_out,
while(i + 2 * gridDim.x * blockDim.x < end) {
i += 2 * gridDim.x * blockDim.x;
- float a = d_in[i];
+ float a = (float)d_in[i];
if(a > sdata[tid]) {
sdata[tid] = a;
indices[tid] = i;
}
if(i + blockDim.x < end) {
- float b = d_in[i + blockDim.x];
+ float b = (float)d_in[i + blockDim.x];
if(b > sdata[tid]) {
sdata[tid] = b;
indices[tid] = i + blockDim.x;
@@ -100,14 +103,16 @@ __global__ void gMaxElement(float* d_out,
}
}
+template <typename T>
__global__ void gMaxElementUpdate(float* binCosts,
int* binIdxs,
- float* probs,
+ T* probs, // should work well enough with half, uses float everywhere else
int* batchFirstElements,
float* outCosts,
int* outIdxs,
- int* cummulatedBeamSizes,
- int NUM_BLOCKS) {
+ int* cumulativeBeamSizes,
+ int NUM_BLOCKS,
+ float disabledPathScore) {
extern __shared__ float sdata[];
__shared__ int indices[512];
__shared__ float bestBinCost;
@@ -121,12 +126,12 @@ __global__ void gMaxElementUpdate(float* binCosts,
num_bins = 500;
}
- for(int pos = cummulatedBeamSizes[batchIdx];
- pos < cummulatedBeamSizes[batchIdx + 1];
+ for(int pos = cumulativeBeamSizes[batchIdx];
+ pos < cumulativeBeamSizes[batchIdx + 1];
++pos) {
int i = tid;
- sdata[tid] = -3.40282e+38f;
+ sdata[tid] = disabledPathScore;
if(i < num_bins) {
sdata[tid] = binCosts[batchIdx * NUM_BLOCKS + i];
@@ -186,7 +191,7 @@ __global__ void gMaxElementUpdate(float* binCosts,
bestBinCost = sdata[0];
bestBinCostIdx = batchIdx * NUM_BLOCKS + indices[0];
- probs[binIdxs[bestBinCostIdx]] = -3.40282e+38f;
+ probs[binIdxs[bestBinCostIdx]] = disabledPathScore;
outIdxs[pos] = binIdxs[bestBinCostIdx];
outCosts[pos] = bestBinCost;
@@ -198,16 +203,16 @@ __global__ void gMaxElementUpdate(float* binCosts,
+ (bestBinCostIdx - batchIdx * NUM_BLOCKS) * (blockDim.x * 2) + tid;
const int dist = num_bins * 2 * blockDim.x;
- sdata[tid] = -3.40282e+38f;
+ sdata[tid] = disabledPathScore;
if(i < batchFirstElements[batchIdx + 1]) {
- sdata[tid] = probs[i];
+ sdata[tid] = (float)probs[i];
indices[tid] = i;
}
if(i + blockDim.x < batchFirstElements[batchIdx + 1]) {
- float a = probs[i];
- float b = probs[i + blockDim.x];
+ float a = (float)probs[i];
+ float b = (float)probs[i + blockDim.x];
if(a > b) {
sdata[tid] = a;
indices[tid] = i;
@@ -220,14 +225,14 @@ __global__ void gMaxElementUpdate(float* binCosts,
while(i + dist < batchFirstElements[batchIdx + 1]) {
i += dist;
- float a = probs[i];
+ float a = (float)probs[i];
if(a > sdata[tid]) {
sdata[tid] = a;
indices[tid] = i;
}
if(i + blockDim.x < batchFirstElements[batchIdx + 1]) {
- float b = probs[i + blockDim.x];
+ float b = (float)probs[i + blockDim.x];
if(b > sdata[tid]) {
sdata[tid] = b;
indices[tid] = i + blockDim.x;
@@ -279,6 +284,7 @@ public:
size_t maxBatchSize,
DeviceId deviceId)
: deviceId_(deviceId),
+ maxBeamSize_(maxBeamSize), maxBatchSize_(maxBatchSize),
NUM_BLOCKS(std::min(
500,
int(maxBeamSize* MAX_VOCAB_SIZE / (2 * BLOCK_SIZE))
@@ -302,23 +308,26 @@ public:
}
~NthElementGPU() {
+ // No CUDA error checking as this is a destructor and we cannot do anything about errors anyway.
cudaSetDevice(deviceId_.no);
-
- CUDA_CHECK(cudaFree(d_cumBeamSizes));
- CUDA_CHECK(cudaFree(d_batchPosition));
- CUDA_CHECK(cudaFree(d_breakdown));
- CUDA_CHECK(cudaFreeHost(h_res_idx));
- CUDA_CHECK(cudaFreeHost(h_res));
- CUDA_CHECK(cudaFree(d_res));
- CUDA_CHECK(cudaFree(d_res_idx));
- CUDA_CHECK(cudaFree(d_out));
- CUDA_CHECK(cudaFree(d_ind));
+ cudaFree(d_cumBeamSizes);
+ cudaFree(d_batchPosition);
+ cudaFree(d_breakdown);
+ cudaFreeHost(h_res_idx);
+ cudaFreeHost(h_res);
+ cudaFree(d_res);
+ cudaFree(d_res_idx);
+ cudaFree(d_out);
+ cudaFree(d_ind);
}
private:
- void getNBestList(float* probs,
- const std::vector<int>& batchFirstElementIdxs,
- const std::vector<int>& cummulatedBeamSizes) {
+ template <typename T>
+ void selectNBest(T* probs,
+ const std::vector<int>& batchFirstElementIdxs,
+ const std::vector<int>& cumulativeBeamSizes,
+ float disabledPathScore) {
+
cudaSetDevice(deviceId_.no);
CUDA_CHECK(cudaMemcpyAsync(d_batchPosition,
batchFirstElementIdxs.data(),
@@ -326,8 +335,8 @@ private:
cudaMemcpyHostToDevice,
/* stream_ */ 0));
CUDA_CHECK(cudaMemcpyAsync(d_cumBeamSizes,
- cummulatedBeamSizes.data(),
- cummulatedBeamSizes.size() * sizeof(int),
+ cumulativeBeamSizes.data(),
+ cumulativeBeamSizes.size() * sizeof(int),
cudaMemcpyHostToDevice,
/* stream_ */ 0));
@@ -335,13 +344,13 @@ private:
gMaxElement<<<NUM_BLOCKS,
BLOCK_SIZE,
- BLOCK_SIZE * sizeof(float),
+ BLOCK_SIZE * sizeof(float), // shared memory size
/* stream_ */ 0>>>(
- d_out, d_ind, probs, numBatches, d_batchPosition);
+ d_out, d_ind, probs, numBatches, d_batchPosition, disabledPathScore);
gMaxElementUpdate<<<numBatches,
BLOCK_SIZE,
- BLOCK_SIZE * sizeof(float),
+ BLOCK_SIZE * sizeof(float), // shared memory size
/* stream_ */ 0>>>(d_out,
d_ind,
probs,
@@ -349,30 +358,58 @@ private:
d_res,
d_res_idx,
d_cumBeamSizes,
- NUM_BLOCKS);
+ NUM_BLOCKS,
+ disabledPathScore);
}
public:
- void getNBestList(const std::vector<size_t>& beamSizes,
- Tensor Probs,
+ void getNBestList(Tensor scores,
+ size_t N,
std::vector<float>& outCosts,
std::vector<unsigned>& outKeys,
const bool isFirst) {
cudaSetDevice(deviceId_.no);
- std::vector<int> cummulatedBeamSizes(beamSizes.size() + 1, 0);
- std::vector<int> batchFirstElementIdxs(beamSizes.size() + 1, 0);
+ const auto vocabSize = scores->shape()[-1];
+ const auto inputN = scores->shape()[-2];
+ const auto dimBatch = scores->shape()[-4];
+ ABORT_IF(inputN != (isFirst ? 1 : N), "Input tensor has wrong beam dim??"); // @TODO: Remove isFirst argument altogether
+ ABORT_IF(vocabSize > MAX_VOCAB_SIZE, "GetNBestList(): actual vocab size {} exceeds MAX_VOCAB_SIZE of {}", vocabSize, MAX_VOCAB_SIZE);
+ ABORT_IF(dimBatch > maxBatchSize_, "GetNBestList(): actual batch size {} exceeds initialization parameter {}", dimBatch, maxBatchSize_);
+ ABORT_IF(std::max(N, (size_t)inputN) > maxBeamSize_, "GetNBestList(): actual beam size {} exceeds initialization parameter {}", N, maxBeamSize_);
- const size_t vocabSize = Probs->shape()[-1];
+ const std::vector<size_t> beamSizes(dimBatch, N);
+ std::vector<int> cumulativeBeamSizes(beamSizes.size() + 1, 0);
+ std::vector<int> batchFirstElementIdxs(beamSizes.size() + 1, 0);
- for(size_t i = 0; i < beamSizes.size(); ++i) {
- cummulatedBeamSizes[i + 1] = cummulatedBeamSizes[i] + beamSizes[i];
- batchFirstElementIdxs[i + 1]
- += ((isFirst) ? (i + 1) : cummulatedBeamSizes[i + 1]) * vocabSize;
+ for(size_t batchIdx = 0; batchIdx < beamSizes.size(); ++batchIdx) {
+#if 1
+ cumulativeBeamSizes[batchIdx + 1] = (batchIdx + 1) * (int)N;
+ batchFirstElementIdxs[batchIdx + 1] += (batchIdx + 1) * inputN * vocabSize;
+ ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != cumulativeBeamSizes[batchIdx] + (int)N, "cumulativeBeamSizes wrong??");
+ ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??");
+#else
+ cumulativeBeamSizes[batchIdx + 1] = cumulativeBeamSizes[batchIdx] + beamSizes[batchIdx];
+ ABORT_IF(cumulativeBeamSizes[batchIdx + 1] != (batchIdx + 1) * N, "cumulativeBeamSizes wrong??");
+ batchFirstElementIdxs[batchIdx + 1]
+ += ((isFirst) ? (batchIdx + 1) : cumulativeBeamSizes[batchIdx + 1]) * vocabSize;
+ ABORT_IF((isFirst ? batchIdx + 1 : cumulativeBeamSizes[batchIdx + 1]) != (batchIdx + 1) * inputN, "inputN wrong??");
+#endif
}
- getNBestList(Probs->data(), batchFirstElementIdxs, cummulatedBeamSizes);
- getPairs(cummulatedBeamSizes.back(), outKeys, outCosts);
+ if(scores->type() == Type::float32) {
+ float disabledPathScore = NumericLimits<float>(scores->type()).lowest;
+ selectNBest(scores->data<float>(), batchFirstElementIdxs, cumulativeBeamSizes, disabledPathScore);
+#if COMPILE_FP16
+ } else if(scores->type() == Type::float16) {
+ float disabledPathScore = NumericLimits<float>(scores->type()).lowest;
+ selectNBest(scores->data<half>(), batchFirstElementIdxs, cumulativeBeamSizes, disabledPathScore);
+#endif
+ } else {
+ ABORT("getNBestList not implemented for type {}", scores->type());
+ }
+ getPairs(dimBatch * N, outKeys, outCosts);
+ ABORT_IF(cumulativeBeamSizes.back() != dimBatch * N, "cumulativeBeamSizes.back() wrong??");
}
private:
@@ -397,55 +434,39 @@ private:
outValues.push_back(h_res[i]);
}
- lastN = number;
- }
-
- void getValueByKey(std::vector<float>& out, float* d_in) {
- cudaSetDevice(deviceId_.no);
-
- gGetValueByKey<<<1, lastN, 0, /* stream_ */ 0>>>(
- d_in, d_breakdown, h_res_idx, lastN);
-
- CUDA_CHECK(cudaMemcpyAsync(out.data(),
- d_breakdown,
- lastN * sizeof(float),
- cudaMemcpyDeviceToHost,
- /* stream_ */ 0));
- CUDA_CHECK(cudaStreamSynchronize(/* stream_ */ 0));
+ //lastN = number;
}
DeviceId deviceId_;
- const int MAX_VOCAB_SIZE = 100000;
+ const int MAX_VOCAB_SIZE = 500000;
+ size_t maxBeamSize_;
+ size_t maxBatchSize_;
const int BLOCK_SIZE = 512;
const int NUM_BLOCKS;
- int* d_ind;
- float* d_out;
+ int* d_ind; // [maxBatchSize * NUM_BLOCKS]
+ float* d_out; // [maxBatchSize * NUM_BLOCKS]
- int* d_res_idx;
- float* d_res;
+ int* d_res_idx; // [maxBatchSize * maxBeamSize]
+ float* d_res; // [maxBatchSize * maxBeamSize]
- int* h_res_idx;
- float* h_res;
+ int* h_res_idx; // [maxBeamSize * maxBatchSize]
+ float* h_res; // [maxBeamSize * maxBatchSize]
- float* d_breakdown;
- int* d_batchPosition;
- int* d_cumBeamSizes;
- size_t lastN;
+ float* d_breakdown; // [maxBeamSize]
+ int* d_batchPosition; // [maxBatchSize + 1]
+ int* d_cumBeamSizes; // [maxBatchSize + 1]
+ //size_t lastN;
};
// factory function
// Returns a lambda with the same signature as the getNBestList() function.
GetNBestListFn createGetNBestListGPUFn(size_t beamSize, size_t dimBatch, DeviceId deviceId) {
auto nth = New<NthElementGPU>(beamSize, dimBatch, deviceId);
- return [nth](const std::vector<size_t>& beamSizes,
- Tensor logProbs,
- std::vector<float>& outCosts,
- std::vector<unsigned>& outKeys,
- const bool isFirst) {
- return nth->getNBestList(beamSizes, logProbs, outCosts, outKeys, isFirst);
+ return [nth](Tensor logProbs, size_t N, std::vector<float>& outCosts, std::vector<unsigned>& outKeys, const bool isFirst) {
+ return nth->getNBestList(logProbs, N, outCosts, outKeys, isFirst);
};
}
diff --git a/src/translator/nth_element.h b/src/translator/nth_element.h
index 91ea3792..ca325ed0 100755..100644
--- a/src/translator/nth_element.h
+++ b/src/translator/nth_element.h
@@ -10,8 +10,8 @@
namespace marian {
-typedef std::function<void(const std::vector<size_t>& beamSizes,
- Tensor logProbs,
+typedef std::function<void(Tensor logProbs,
+ size_t N,
std::vector<float>& outCosts,
std::vector<unsigned>& outKeys,
const bool isFirst)> GetNBestListFn;
diff --git a/src/translator/output_collector.cpp b/src/translator/output_collector.cpp
index 58fba69b..76bc4cbb 100755..100644
--- a/src/translator/output_collector.cpp
+++ b/src/translator/output_collector.cpp
@@ -12,7 +12,7 @@ OutputCollector::OutputCollector()
OutputCollector::OutputCollector(std::string outFile)
: nextId_(0),
- outStrm_(new io::OutputFileStream(std::cout)),
+ outStrm_(new std::ostream(std::cout.rdbuf())),
printing_(new DefaultPrinting()) {
if (outFile != "stdout")
outStrm_.reset(new io::OutputFileStream(outFile));
diff --git a/src/translator/output_collector.h b/src/translator/output_collector.h
index 51b47159..ffcbd2d5 100755..100644
--- a/src/translator/output_collector.h
+++ b/src/translator/output_collector.h
@@ -11,6 +11,7 @@ namespace marian {
class PrintingStrategy {
public:
+ virtual ~PrintingStrategy() {}
virtual bool shouldBePrinted(long) = 0;
};
@@ -49,9 +50,7 @@ public:
OutputCollector(std::string outFile);
template <class T>
- OutputCollector(T&& arg)
- : nextId_(0),
- outStrm_(new io::OutputFileStream(arg)) {}
+ OutputCollector(T&& arg) : nextId_(0), outStrm_(new io::OutputFileStream(arg)) {}
OutputCollector(const OutputCollector&) = delete;
@@ -68,7 +67,7 @@ protected:
typedef std::map<long, std::pair<std::string, std::string>> Outputs;
Outputs outputs_;
long nextId_;
- UPtr<io::OutputFileStream> outStrm_;
+ UPtr<std::ostream> outStrm_;
Ptr<PrintingStrategy> printing_;
std::mutex mutex_;
};
diff --git a/src/translator/output_printer.cpp b/src/translator/output_printer.cpp
index 000af789..f57ec9da 100755..100644
--- a/src/translator/output_printer.cpp
+++ b/src/translator/output_printer.cpp
@@ -1,14 +1,16 @@
#include "output_printer.h"
+#include <sstream>
+
namespace marian {
-std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
+std::string OutputPrinter::getAlignment(const Hypothesis::PtrType& hyp) {
data::SoftAlignment align;
auto last = hyp;
// get soft alignments for each target word starting from the last one
- while(last->GetPrevHyp().get() != nullptr) {
- align.push_back(last->GetAlignment());
- last = last->GetPrevHyp();
+ while(last->getPrevHyp().get() != nullptr) {
+ align.push_back(last->getAlignment());
+ last = last->getPrevHyp();
}
// reverse alignments
@@ -19,11 +21,18 @@ std::string OutputPrinter::getAlignment(const Ptr<Hypothesis>& hyp) {
} else if(alignment_ == "hard") {
return data::ConvertSoftAlignToHardAlign(align, 1.f).toString();
} else if(alignmentThreshold_ > 0.f) {
- return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_)
- .toString();
+ return data::ConvertSoftAlignToHardAlign(align, alignmentThreshold_).toString();
} else {
ABORT("Unrecognized word alignment type");
}
}
+std::string OutputPrinter::getWordScores(const Hypothesis::PtrType& hyp) {
+ std::ostringstream scores;
+ scores.precision(5);
+ for(const auto& score : hyp->tracebackWordScores())
+ scores << " " << std::fixed << score;
+ return scores.str();
+}
+
} // namespace marian
diff --git a/src/translator/output_printer.h b/src/translator/output_printer.h
index b27b591c..603eedba 100755
--- a/src/translator/output_printer.h
+++ b/src/translator/output_printer.h
@@ -13,19 +13,21 @@ namespace marian {
class OutputPrinter {
public:
- OutputPrinter(Ptr<Options> options, Ptr<Vocab> vocab)
+ OutputPrinter(Ptr<const Options> options, Ptr<const Vocab> vocab)
: vocab_(vocab),
reverse_(options->get<bool>("right-left")),
nbest_(options->get<bool>("n-best", false)
? options->get<size_t>("beam-size")
: 0),
alignment_(options->get<std::string>("alignment", "")),
- alignmentThreshold_(getAlignmentThreshold(alignment_)) {}
+ alignmentThreshold_(getAlignmentThreshold(alignment_)),
+ wordScores_(options->get<bool>("word-scores")) {}
template <class OStream>
- void print(Ptr<History> history, OStream& best1, OStream& bestn) {
- const auto& nbl = history->NBest(nbest_);
+ void print(Ptr<const History> history, OStream& best1, OStream& bestn) {
+ const auto& nbl = history->nBest(nbest_);
+ // prepare n-best list output
for(size_t i = 0; i < nbl.size(); ++i) {
const auto& result = nbl[i];
const auto& hypo = std::get<1>(result);
@@ -35,17 +37,20 @@ public:
std::reverse(words.begin(), words.end());
std::string translation = vocab_->decode(words);
- bestn << history->GetLineNum() << " ||| " << translation;
+ bestn << history->getLineNum() << " ||| " << translation;
if(!alignment_.empty())
bestn << " ||| " << getAlignment(hypo);
+ if(wordScores_)
+ bestn << " ||| WordScores=" << getWordScores(hypo);
+
bestn << " |||";
- if(hypo->GetScoreBreakdown().empty()) {
- bestn << " F0=" << hypo->GetPathScore();
+ if(hypo->getScoreBreakdown().empty()) {
+ bestn << " F0=" << hypo->getPathScore();
} else {
- for(size_t j = 0; j < hypo->GetScoreBreakdown().size(); ++j) {
- bestn << " F" << j << "= " << hypo->GetScoreBreakdown()[j];
+ for(size_t j = 0; j < hypo->getScoreBreakdown().size(); ++j) {
+ bestn << " F" << j << "= " << hypo->getScoreBreakdown()[j];
}
}
@@ -58,7 +63,7 @@ public:
bestn << std::flush;
}
- auto result = history->Top();
+ auto result = history->top();
auto words = std::get<0>(result);
if(reverse_)
@@ -71,17 +76,27 @@ public:
const auto& hypo = std::get<1>(result);
best1 << " ||| " << getAlignment(hypo);
}
+
+ if(wordScores_) {
+ const auto& hypo = std::get<1>(result);
+ best1 << " ||| WordScores=" << getWordScores(hypo);
+ }
+
best1 << std::flush;
}
private:
- Ptr<Vocab> vocab_;
- bool reverse_{false};
- size_t nbest_{0};
- std::string alignment_;
- float alignmentThreshold_{0.f};
-
- std::string getAlignment(const Ptr<Hypothesis>& hyp);
+ Ptr<Vocab const> vocab_;
+ bool reverse_{false}; // If it is a right-to-left model that needs reversed word order
+ size_t nbest_{0}; // Size of the n-best list to print
+ std::string alignment_; // A non-empty string indicates the type of word alignment
+ float alignmentThreshold_{0.f}; // Threshold for converting attention into hard word alignment
+ bool wordScores_{false}; // Whether to print word-level scores or not
+
+ // Get word alignment pairs or soft alignment
+ std::string getAlignment(const Hypothesis::PtrType& hyp);
+ // Get word-level scores
+ std::string getWordScores(const Hypothesis::PtrType& hyp);
float getAlignmentThreshold(const std::string& str) {
try {
diff --git a/src/translator/scorers.cpp b/src/translator/scorers.cpp
index 9afa1d7b..d1c8b160 100755..100644
--- a/src/translator/scorers.cpp
+++ b/src/translator/scorers.cpp
@@ -17,7 +17,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
}
bool skipCost = options->get<bool>("skip-cost");
- auto encdec = models::from_options(
+ auto encdec = models::createModelFromOptions(
options, skipCost ? models::usage::raw : models::usage::translation);
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
@@ -39,7 +39,7 @@ Ptr<Scorer> scorerByType(const std::string& fname,
}
bool skipCost = options->get<bool>("skip-cost");
- auto encdec = models::from_options(
+ auto encdec = models::createModelFromOptions(
options, skipCost ? models::usage::raw : models::usage::translation);
LOG(info, "Loading scorer of type {} as feature {}", type, fname);
@@ -53,9 +53,10 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
auto models = options->get<std::vector<std::string>>("models");
std::vector<float> weights(models.size(), 1.f);
- if(options->has("weights"))
+ if(options->hasAndNotEmpty("weights"))
weights = options->get<std::vector<float>>("weights");
+ bool isPrevRightLeft = false; // if the previous model was a right-to-left model
size_t i = 0;
for(auto model : models) {
std::string fname = "F" + std::to_string(i);
@@ -72,6 +73,18 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options) {
LOG(warn, "No model settings found in model file");
}
+ // l2r and r2l cannot be used in the same ensemble
+ if(models.size() > 1 && modelOptions->has("right-left")) {
+ if(i == 0) {
+ isPrevRightLeft = modelOptions->get<bool>("right-left");
+ } else {
+ // abort as soon as there are two consecutive models with opposite directions
+ ABORT_IF(isPrevRightLeft != modelOptions->get<bool>("right-left"),
+ "Left-to-right and right-to-left models cannot be used together in ensembles");
+ isPrevRightLeft = modelOptions->get<bool>("right-left");
+ }
+ }
+
scorers.push_back(scorerByType(fname, weights[i], model, modelOptions));
i++;
}
@@ -83,7 +96,7 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<c
std::vector<Ptr<Scorer>> scorers;
std::vector<float> weights(ptrs.size(), 1.f);
- if(options->has("weights"))
+ if(options->hasAndNotEmpty("weights"))
weights = options->get<std::vector<float>>("weights");
size_t i = 0;
@@ -109,4 +122,13 @@ std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<c
return scorers;
}
+std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<mio::mmap_source>& mmaps) {
+ std::vector<const void*> ptrs;
+ for(const auto& mmap : mmaps) {
+ ABORT_IF(!mmap.is_mapped(), "Memory mapping did not succeed");
+ ptrs.push_back(mmap.data());
+ }
+ return createScorers(options, ptrs);
+}
+
} // namespace marian
diff --git a/src/translator/scorers.h b/src/translator/scorers.h
index 16a066b0..a5a0be2c 100755
--- a/src/translator/scorers.h
+++ b/src/translator/scorers.h
@@ -4,14 +4,15 @@
#include "data/shortlist.h"
#include "models/model_factory.h"
+#include "3rd_party/mio/mio.hpp"
namespace marian {
class ScorerState {
public:
- virtual Expr getLogProbs() = 0;
+ virtual ~ScorerState(){}
- virtual float breakDown(size_t i) { return getLogProbs()->val()->get(i); }
+ virtual Logits getLogProbs() const = 0;
virtual void blacklist(Expr /*totalCosts*/, Ptr<data::CorpusBatch> /*batch*/){};
};
@@ -25,6 +26,8 @@ public:
Scorer(const std::string& name, float weight)
: name_(name), weight_(weight) {}
+ virtual ~Scorer(){}
+
std::string getName() { return name_; }
float getWeight() { return weight_; }
@@ -35,14 +38,14 @@ public:
virtual Ptr<ScorerState> step(Ptr<ExpressionGraph>,
Ptr<ScorerState>,
const std::vector<IndexType>&,
- const std::vector<IndexType>&,
- int dimBatch,
+ const Words&,
+ const std::vector<IndexType>& batchIndices,
int beamSize)
= 0;
virtual void init(Ptr<ExpressionGraph>) {}
- virtual void setShortlistGenerator(Ptr<data::ShortlistGenerator> /*shortlistGenerator*/){};
+ virtual void setShortlistGenerator(Ptr<const data::ShortlistGenerator> /*shortlistGenerator*/){};
virtual Ptr<data::Shortlist> getShortlist() { return nullptr; };
virtual std::vector<float> getAlignment() { return {}; };
@@ -54,41 +57,44 @@ protected:
public:
ScorerWrapperState(Ptr<DecoderState> state) : state_(state) {}
+ virtual ~ScorerWrapperState() {}
virtual Ptr<DecoderState> getState() { return state_; }
- virtual Expr getLogProbs() override { return state_->getLogProbs(); };
+ virtual Logits getLogProbs() const override { return state_->getLogProbs(); };
virtual void blacklist(Expr totalCosts, Ptr<data::CorpusBatch> batch) override {
state_->blacklist(totalCosts, batch);
}
};
-// class to wrap EncoderDecoderBase in a Scorer interface
+// class to wrap IEncoderDecoder in a Scorer interface
class ScorerWrapper : public Scorer {
private:
- Ptr<EncoderDecoderBase> encdec_;
+ Ptr<IEncoderDecoder> encdec_;
std::string fname_;
const void* ptr_;
public:
- ScorerWrapper(Ptr<models::ModelBase> encdec,
+ ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
const std::string& fname)
: Scorer(name, weight),
- encdec_(std::static_pointer_cast<EncoderDecoderBase>(encdec)),
+ encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
fname_(fname),
ptr_{0} {}
- ScorerWrapper(Ptr<models::ModelBase> encdec,
+ ScorerWrapper(Ptr<models::IModel> encdec,
const std::string& name,
float weight,
const void* ptr)
: Scorer(name, weight),
- encdec_(std::static_pointer_cast<EncoderDecoderBase>(encdec)),
+ encdec_(std::static_pointer_cast<IEncoderDecoder>(encdec)),
ptr_{ptr} {}
+ virtual ~ScorerWrapper() {}
+
virtual void init(Ptr<ExpressionGraph> graph) override {
graph->switchParams(getName());
if(ptr_)
@@ -111,17 +117,17 @@ public:
virtual Ptr<ScorerState> step(Ptr<ExpressionGraph> graph,
Ptr<ScorerState> state,
const std::vector<IndexType>& hypIndices,
- const std::vector<IndexType>& embIndices,
- int dimBatch,
+ const Words& words,
+ const std::vector<IndexType>& batchIndices,
int beamSize) override {
graph->switchParams(getName());
auto wrapperState = std::dynamic_pointer_cast<ScorerWrapperState>(state);
- auto newState = encdec_->step(graph, wrapperState->getState(), hypIndices, embIndices, dimBatch, beamSize);
+ auto newState = encdec_->step(graph, wrapperState->getState(), hypIndices, words, batchIndices, beamSize);
return New<ScorerWrapperState>(newState);
}
virtual void setShortlistGenerator(
- Ptr<data::ShortlistGenerator> shortlistGenerator) override {
+ Ptr<const data::ShortlistGenerator> shortlistGenerator) override {
encdec_->setShortlistGenerator(shortlistGenerator);
};
@@ -130,7 +136,9 @@ public:
};
virtual std::vector<float> getAlignment() override {
- return encdec_->getAlignment().front();
+ // This is called during decoding, where alignments only exist for the last time step. Hence front().
+ // This makes as copy. @TODO: It should be OK to return this as a const&.
+ return encdec_->getAlignment().front(); // [beam depth * max src length * batch size]
}
};
@@ -147,5 +155,6 @@ Ptr<Scorer> scorerByType(const std::string& fname,
Ptr<Options> config);
std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<const void*>& ptrs);
+std::vector<Ptr<Scorer>> createScorers(Ptr<Options> options, const std::vector<mio::mmap_source>& mmaps);
} // namespace marian
diff --git a/src/translator/translator.h b/src/translator/translator.h
index 9f973113..cc68a4f0 100755
--- a/src/translator/translator.h
+++ b/src/translator/translator.h
@@ -6,6 +6,7 @@
#include "data/text_input.h"
#include "3rd_party/threadpool.h"
+
#include "translator/history.h"
#include "translator/output_collector.h"
#include "translator/output_printer.h"
@@ -13,6 +14,14 @@
#include "models/model_task.h"
#include "translator/scorers.h"
+// currently for diagnostics only, will try to mmap files ending in *.bin suffix when enabled.
+// @TODO: add this as an actual feature.
+#define MMAP 0
+
+#if MMAP
+#include "3rd_party/mio/mio.hpp"
+#endif
+
namespace marian {
template <class Search>
@@ -24,15 +33,22 @@ private:
Ptr<data::Corpus> corpus_;
Ptr<Vocab> trgVocab_;
- Ptr<data::ShortlistGenerator> shortlistGenerator_;
+ Ptr<const data::ShortlistGenerator> shortlistGenerator_;
size_t numDevices_;
+#if MMAP
+ std::vector<mio::mmap_source> mmaps_;
+#endif
+
public:
- Translate(Ptr<Options> options) : options_(options) {
+ Translate(Ptr<Options> options)
+ : options_(New<Options>(options->clone())) { // @TODO: clone should return Ptr<Options> same as "with"?
// This is currently safe as the translator is either created stand-alone or
// or config is created anew from Options in the validator
- options_->set("inference", true);
+
+ options_->set("inference", true,
+ "shuffle", "none");
corpus_ = New<data::Corpus>(options_, true);
@@ -41,7 +57,7 @@ public:
trgVocab_->load(vocabs.back());
auto srcVocab = corpus_->getVocabs()[0];
- if(options_->has("shortlist"))
+ if(options_->hasAndNotEmpty("shortlist"))
shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
options_, srcVocab, trgVocab_, 0, 1, vocabs.front() == vocabs.back());
@@ -52,16 +68,35 @@ public:
scorers_.resize(numDevices_);
graphs_.resize(numDevices_);
+#if MMAP
+ auto models = options->get<std::vector<std::string>>("models");
+ for(auto model : models) {
+ marian::filesystem::Path modelPath(model);
+ ABORT_IF(modelPath.extension() != marian::filesystem::Path(".bin"),
+ "Non-binarized models cannot be mmapped");
+ mmaps_.push_back(std::move(mio::mmap_source(model)));
+ }
+#endif
+
size_t id = 0;
for(auto device : devices) {
auto task = [&](DeviceId device, size_t id) {
- auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
+ auto graph = New<ExpressionGraph>(true);
+ auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
+ graph->setDefaultElementType(typeFromString(prec[0]));
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
+ if (device.type == DeviceType::cpu) {
+ graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
+ }
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;
+#if MMAP
+ auto scorers = createScorers(options_, mmaps_);
+#else
auto scorers = createScorers(options_);
+#endif
for(auto scorer : scorers) {
scorer->init(graph);
if(shortlistGenerator_)
@@ -74,6 +109,17 @@ public:
threadPool.enqueue(task, device, id++);
}
+
+ if(options_->get<bool>("output-sampling", false)) {
+ if(options_->get<size_t>("beam-size") > 1)
+ LOG(warn,
+ "[warning] Output sampling and beam search (beam-size > 1) are contradictory methods "
+ "and using them together is not recommended. Set beam-size to 1");
+ if(options_->get<std::vector<std::string>>("models").size() > 1)
+ LOG(warn,
+ "[warning] Output sampling and model ensembling are contradictory methods and using "
+ "them together is not recommended. Use a single model");
+ }
}
void run() override {
@@ -87,8 +133,9 @@ public:
if(options_->get<bool>("quiet-translation"))
collector->setPrintingStrategy(New<QuietPrinting>());
- bg.prepare(false);
+ bg.prepare();
+ bool doNbest = options_->get<bool>("n-best");
for(auto batch : bg) {
auto task = [=](size_t id) {
thread_local Ptr<ExpressionGraph> graph;
@@ -99,17 +146,17 @@ public:
scorers = scorers_[id % numDevices_];
}
- auto search = New<Search>(options_, scorers, trgVocab_->getEosId(), trgVocab_->getUnkId());
+ auto search = New<Search>(options_, scorers, trgVocab_);
auto histories = search->search(graph, batch);
for(auto history : histories) {
std::stringstream best1;
std::stringstream bestn;
printer->print(history, best1, bestn);
- collector->Write((long)history->GetLineNum(),
+ collector->Write((long)history->getLineNum(),
best1.str(),
bestn.str(),
- options_->get<bool>("n-best"));
+ doNbest);
}
@@ -119,8 +166,8 @@ public:
&& id % 1000 == 0) // hard beat once every 1000 batches
{
auto progress = 0.f; //fake progress for now
- fprintf(stdout, "PROGRESS: %.2f%%\n", progress);
- fflush(stdout);
+ fprintf(stderr, "PROGRESS: %.2f%%\n", progress);
+ fflush(stderr);
}
};
@@ -139,17 +186,18 @@ private:
std::vector<Ptr<Vocab>> srcVocabs_;
Ptr<Vocab> trgVocab_;
+ Ptr<const data::ShortlistGenerator> shortlistGenerator_;
size_t numDevices_;
public:
virtual ~TranslateService() {}
- TranslateService(Ptr<Options> options) : options_(options) { init(); }
-
- void init() override {
+ TranslateService(Ptr<Options> options)
+ : options_(New<Options>(options->clone())) {
// initialize vocabs
options_->set("inference", true);
+ options_->set("shuffle", "none");
auto vocabPaths = options_->get<std::vector<std::string>>("vocabs");
std::vector<int> maxVocabs = options_->get<std::vector<int>>("dim-vocabs");
@@ -163,21 +211,35 @@ public:
trgVocab_ = New<Vocab>(options_, vocabPaths.size() - 1);
trgVocab_->load(vocabPaths.back());
+ // load lexical shortlist
+ if(options_->hasAndNotEmpty("shortlist"))
+ shortlistGenerator_ = New<data::LexicalShortlistGenerator>(
+ options_, srcVocabs_.front(), trgVocab_, 0, 1, vocabPaths.front() == vocabPaths.back());
+
// get device IDs
auto devices = Config::getDevices(options_);
numDevices_ = devices.size();
// initialize scorers
for(auto device : devices) {
- auto graph = New<ExpressionGraph>(true, options_->get<bool>("optimize"));
+ auto graph = New<ExpressionGraph>(true);
+
+ auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
+ graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
+ if (device.type == DeviceType::cpu) {
+ graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
+ }
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
auto scorers = createScorers(options_);
- for(auto scorer : scorers)
+ for(auto scorer : scorers) {
scorer->init(graph);
+ if(shortlistGenerator_)
+ scorer->setShortlistGenerator(shortlistGenerator_);
+ }
scorers_.push_back(scorers);
}
}
@@ -190,7 +252,7 @@ public:
auto printer = New<OutputPrinter>(options_, trgVocab_);
size_t batchId = 0;
- batchGenerator.prepare(false);
+ batchGenerator.prepare();
{
ThreadPool threadPool_(numDevices_, numDevices_);
@@ -206,14 +268,14 @@ public:
scorers = scorers_[id % numDevices_];
}
- auto search = New<Search>(options_, scorers, trgVocab_->getEosId(), trgVocab_->getUnkId());
+ auto search = New<Search>(options_, scorers, trgVocab_);
auto histories = search->search(graph, batch);
for(auto history : histories) {
std::stringstream best1;
std::stringstream bestn;
printer->print(history, best1, bestn);
- collector->add((long)history->GetLineNum(), best1.str(), bestn.str());
+ collector->add((long)history->getLineNum(), best1.str(), bestn.str());
}
};
diff --git a/vs/Marian.sln b/vs/Marian.sln
index c2eb49c9..b98a1246 100755
--- a/vs/Marian.sln
+++ b/vs/Marian.sln
@@ -1,377 +1,25 @@
+
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio 15
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ALL_BUILD", "ALL_BUILD.vcxproj", "{5216F769-E887-369E-AD1E-D6A1F69E834E}"
- ProjectSection(ProjectDependencies) = postProject
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33} = {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA} = {5AF43E07-5917-3D8F-9BF0-B41F698242EA}
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4} = {885D3D2B-7278-30EF-BB1B-50E83D1635C4}
- {3CD61EAE-244E-33AB-8C7D-F5182481E033} = {3CD61EAE-244E-33AB-8C7D-F5182481E033}
- {97131187-E592-3981-886F-222EE20FB669} = {97131187-E592-3981-886F-222EE20FB669}
- {25A05D30-AFC2-3F0E-B475-0B2B81530151} = {25A05D30-AFC2-3F0E-B475-0B2B81530151}
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7} = {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}
- {3784D69C-33A9-33A7-A557-F809EF2F4D34} = {3784D69C-33A9-33A7-A557-F809EF2F4D34}
- {EA3973A2-F92E-3124-9817-81B2458EC8DC} = {EA3973A2-F92E-3124-9817-81B2458EC8DC}
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D} = {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162} = {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- {5857EF98-C87F-3197-A399-F0F9A20913FC} = {5857EF98-C87F-3197-A399-F0F9A20913FC}
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F} = {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}
- {FBB107B9-523B-3094-95CF-A103E2388006} = {FBB107B9-523B-3094-95CF-A103E2388006}
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC} = {5B4A6D26-C638-3350-9E1A-0F987C448DEC}
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F} = {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3} = {1134F859-3DE4-34B1-924F-82CA38D4D4F3}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "INSTALL", "INSTALL.vcxproj", "{9DAF8CA3-052E-3480-A332-34676CAE852B}"
- ProjectSection(ProjectDependencies) = postProject
- {5216F769-E887-369E-AD1E-D6A1F69E834E} = {5216F769-E887-369E-AD1E-D6A1F69E834E}
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "PACKAGE", "PACKAGE.vcxproj", "{3A3C6EA5-65CD-324E-90F4-6B4D70DD5A37}"
- ProjectSection(ProjectDependencies) = postProject
- {5216F769-E887-369E-AD1E-D6A1F69E834E} = {5216F769-E887-369E-AD1E-D6A1F69E834E}
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SQLiteCpp", "src\3rd_party\SQLiteCpp\SQLiteCpp.vcxproj", "{17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ZERO_CHECK", "ZERO_CHECK.vcxproj", "{806A44E1-15D4-3368-B0B9-2A6CC352D505}"
- ProjectSection(ProjectDependencies) = postProject
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "libyaml-cpp", "src\3rd_party\yaml-cpp\libyaml-cpp.vcxproj", "{5AF43E07-5917-3D8F-9BF0-B41F698242EA}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian", "src\marian.vcxproj", "{885D3D2B-7278-30EF-BB1B-50E83D1635C4}"
- ProjectSection(ProjectDependencies) = postProject
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33} = {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA} = {5AF43E07-5917-3D8F-9BF0-B41F698242EA}
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C} = {55A27783-64A4-3AA7-A4B1-49C4B628F18C}
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162} = {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3} = {1134F859-3DE4-34B1-924F-82CA38D4D4F3}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian_conv", "src\marian_conv.vcxproj", "{3CD61EAE-244E-33AB-8C7D-F5182481E033}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4} = {885D3D2B-7278-30EF-BB1B-50E83D1635C4}
- {97131187-E592-3981-886F-222EE20FB669} = {97131187-E592-3981-886F-222EE20FB669}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian_cuda", "src\marian_cuda.vcxproj", "{97131187-E592-3981-886F-222EE20FB669}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian_decoder", "src\marian_decoder.vcxproj", "{25A05D30-AFC2-3F0E-B475-0B2B81530151}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4} = {885D3D2B-7278-30EF-BB1B-50E83D1635C4}
- {97131187-E592-3981-886F-222EE20FB669} = {97131187-E592-3981-886F-222EE20FB669}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian_scorer", "src\marian_scorer.vcxproj", "{8A6B1F60-8E2D-3171-828B-07E732C8E7D7}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4} = {885D3D2B-7278-30EF-BB1B-50E83D1635C4}
- {97131187-E592-3981-886F-222EE20FB669} = {97131187-E592-3981-886F-222EE20FB669}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian_server", "src\marian_server.vcxproj", "{3784D69C-33A9-33A7-A557-F809EF2F4D34}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4} = {885D3D2B-7278-30EF-BB1B-50E83D1635C4}
- {97131187-E592-3981-886F-222EE20FB669} = {97131187-E592-3981-886F-222EE20FB669}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian_train", "src\marian_train.vcxproj", "{EA3973A2-F92E-3124-9817-81B2458EC8DC}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4} = {885D3D2B-7278-30EF-BB1B-50E83D1635C4}
- {97131187-E592-3981-886F-222EE20FB669} = {97131187-E592-3981-886F-222EE20FB669}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian_version", "src\marian_version.vcxproj", "{55A27783-64A4-3AA7-A4B1-49C4B628F18C}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "marian_vocab", "src\marian_vocab.vcxproj", "{36953645-6D01-37E4-ACF7-D3F9BFFCA49D}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4} = {885D3D2B-7278-30EF-BB1B-50E83D1635C4}
- {97131187-E592-3981-886F-222EE20FB669} = {97131187-E592-3981-886F-222EE20FB669}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "pathie-cpp", "src\3rd_party\pathie-cpp\pathie-cpp.vcxproj", "{F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "sentencepiece-static", "src\3rd_party\sentencepiece\src\sentencepiece-static.vcxproj", "{D9D20410-4011-370C-8E15-A6F5C311F337}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "sentencepiece_train-static", "src\3rd_party\sentencepiece\src\sentencepiece_train-static.vcxproj", "{4A20AD5F-7334-31D3-B31D-9AAF53CC6678}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "spm_decode", "src\3rd_party\sentencepiece\src\spm_decode.vcxproj", "{5857EF98-C87F-3197-A399-F0F9A20913FC}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "spm_encode", "src\3rd_party\sentencepiece\src\spm_encode.vcxproj", "{F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "spm_export_vocab", "src\3rd_party\sentencepiece\src\spm_export_vocab.vcxproj", "{FBB107B9-523B-3094-95CF-A103E2388006}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "spm_normalize", "src\3rd_party\sentencepiece\src\spm_normalize.vcxproj", "{5B4A6D26-C638-3350-9E1A-0F987C448DEC}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "spm_train", "src\3rd_party\sentencepiece\src\spm_train.vcxproj", "{11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- {D9D20410-4011-370C-8E15-A6F5C311F337} = {D9D20410-4011-370C-8E15-A6F5C311F337}
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678} = {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}
- EndProjectSection
-EndProject
-Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "zlib", "src\3rd_party\zlib\zlib.vcxproj", "{1134F859-3DE4-34B1-924F-82CA38D4D4F3}"
- ProjectSection(ProjectDependencies) = postProject
- {806A44E1-15D4-3368-B0B9-2A6CC352D505} = {806A44E1-15D4-3368-B0B9-2A6CC352D505}
- EndProjectSection
+VisualStudioVersion = 15.0.28307.902
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Marian", "Marian.vcxproj", "{E2F320FE-0C01-4C80-810C-3A92205A29DC}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|x64 = Debug|x64
Release|x64 = Release|x64
- MinSizeRel|x64 = MinSizeRel|x64
- RelWithDebInfo|x64 = RelWithDebInfo|x64
EndGlobalSection
GlobalSection(ProjectConfigurationPlatforms) = postSolution
- {5216F769-E887-369E-AD1E-D6A1F69E834E}.Debug|x64.ActiveCfg = Debug|x64
- {5216F769-E887-369E-AD1E-D6A1F69E834E}.Debug|x64.Build.0 = Debug|x64
- {5216F769-E887-369E-AD1E-D6A1F69E834E}.Release|x64.ActiveCfg = Release|x64
- {5216F769-E887-369E-AD1E-D6A1F69E834E}.Release|x64.Build.0 = Release|x64
- {5216F769-E887-369E-AD1E-D6A1F69E834E}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {5216F769-E887-369E-AD1E-D6A1F69E834E}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {5216F769-E887-369E-AD1E-D6A1F69E834E}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {5216F769-E887-369E-AD1E-D6A1F69E834E}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {9DAF8CA3-052E-3480-A332-34676CAE852B}.Debug|x64.ActiveCfg = Debug|x64
- {9DAF8CA3-052E-3480-A332-34676CAE852B}.Release|x64.ActiveCfg = Release|x64
- {9DAF8CA3-052E-3480-A332-34676CAE852B}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {9DAF8CA3-052E-3480-A332-34676CAE852B}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {3A3C6EA5-65CD-324E-90F4-6B4D70DD5A37}.Debug|x64.ActiveCfg = Debug|x64
- {3A3C6EA5-65CD-324E-90F4-6B4D70DD5A37}.Release|x64.ActiveCfg = Release|x64
- {3A3C6EA5-65CD-324E-90F4-6B4D70DD5A37}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {3A3C6EA5-65CD-324E-90F4-6B4D70DD5A37}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}.Debug|x64.ActiveCfg = Debug|x64
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}.Debug|x64.Build.0 = Debug|x64
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}.Release|x64.ActiveCfg = Release|x64
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}.Release|x64.Build.0 = Release|x64
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {17E8F84B-76CD-326B-B50A-C4F3C3A8CE33}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {806A44E1-15D4-3368-B0B9-2A6CC352D505}.Debug|x64.ActiveCfg = Debug|x64
- {806A44E1-15D4-3368-B0B9-2A6CC352D505}.Debug|x64.Build.0 = Debug|x64
- {806A44E1-15D4-3368-B0B9-2A6CC352D505}.Release|x64.ActiveCfg = Release|x64
- {806A44E1-15D4-3368-B0B9-2A6CC352D505}.Release|x64.Build.0 = Release|x64
- {806A44E1-15D4-3368-B0B9-2A6CC352D505}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {806A44E1-15D4-3368-B0B9-2A6CC352D505}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {806A44E1-15D4-3368-B0B9-2A6CC352D505}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {806A44E1-15D4-3368-B0B9-2A6CC352D505}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA}.Debug|x64.ActiveCfg = Debug|x64
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA}.Debug|x64.Build.0 = Debug|x64
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA}.Release|x64.ActiveCfg = Release|x64
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA}.Release|x64.Build.0 = Release|x64
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {5AF43E07-5917-3D8F-9BF0-B41F698242EA}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4}.Debug|x64.ActiveCfg = Debug|x64
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4}.Debug|x64.Build.0 = Debug|x64
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4}.Release|x64.ActiveCfg = Release|x64
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4}.Release|x64.Build.0 = Release|x64
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {885D3D2B-7278-30EF-BB1B-50E83D1635C4}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {3CD61EAE-244E-33AB-8C7D-F5182481E033}.Debug|x64.ActiveCfg = Debug|x64
- {3CD61EAE-244E-33AB-8C7D-F5182481E033}.Debug|x64.Build.0 = Debug|x64
- {3CD61EAE-244E-33AB-8C7D-F5182481E033}.Release|x64.ActiveCfg = Release|x64
- {3CD61EAE-244E-33AB-8C7D-F5182481E033}.Release|x64.Build.0 = Release|x64
- {3CD61EAE-244E-33AB-8C7D-F5182481E033}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {3CD61EAE-244E-33AB-8C7D-F5182481E033}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {3CD61EAE-244E-33AB-8C7D-F5182481E033}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {3CD61EAE-244E-33AB-8C7D-F5182481E033}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {97131187-E592-3981-886F-222EE20FB669}.Debug|x64.ActiveCfg = Debug|x64
- {97131187-E592-3981-886F-222EE20FB669}.Debug|x64.Build.0 = Debug|x64
- {97131187-E592-3981-886F-222EE20FB669}.Release|x64.ActiveCfg = Release|x64
- {97131187-E592-3981-886F-222EE20FB669}.Release|x64.Build.0 = Release|x64
- {97131187-E592-3981-886F-222EE20FB669}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {97131187-E592-3981-886F-222EE20FB669}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {97131187-E592-3981-886F-222EE20FB669}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {97131187-E592-3981-886F-222EE20FB669}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {25A05D30-AFC2-3F0E-B475-0B2B81530151}.Debug|x64.ActiveCfg = Debug|x64
- {25A05D30-AFC2-3F0E-B475-0B2B81530151}.Debug|x64.Build.0 = Debug|x64
- {25A05D30-AFC2-3F0E-B475-0B2B81530151}.Release|x64.ActiveCfg = Release|x64
- {25A05D30-AFC2-3F0E-B475-0B2B81530151}.Release|x64.Build.0 = Release|x64
- {25A05D30-AFC2-3F0E-B475-0B2B81530151}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {25A05D30-AFC2-3F0E-B475-0B2B81530151}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {25A05D30-AFC2-3F0E-B475-0B2B81530151}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {25A05D30-AFC2-3F0E-B475-0B2B81530151}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}.Debug|x64.ActiveCfg = Debug|x64
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}.Debug|x64.Build.0 = Debug|x64
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}.Release|x64.ActiveCfg = Release|x64
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}.Release|x64.Build.0 = Release|x64
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {8A6B1F60-8E2D-3171-828B-07E732C8E7D7}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {3784D69C-33A9-33A7-A557-F809EF2F4D34}.Debug|x64.ActiveCfg = Debug|x64
- {3784D69C-33A9-33A7-A557-F809EF2F4D34}.Debug|x64.Build.0 = Debug|x64
- {3784D69C-33A9-33A7-A557-F809EF2F4D34}.Release|x64.ActiveCfg = Release|x64
- {3784D69C-33A9-33A7-A557-F809EF2F4D34}.Release|x64.Build.0 = Release|x64
- {3784D69C-33A9-33A7-A557-F809EF2F4D34}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {3784D69C-33A9-33A7-A557-F809EF2F4D34}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {3784D69C-33A9-33A7-A557-F809EF2F4D34}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {3784D69C-33A9-33A7-A557-F809EF2F4D34}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {EA3973A2-F92E-3124-9817-81B2458EC8DC}.Debug|x64.ActiveCfg = Debug|x64
- {EA3973A2-F92E-3124-9817-81B2458EC8DC}.Debug|x64.Build.0 = Debug|x64
- {EA3973A2-F92E-3124-9817-81B2458EC8DC}.Release|x64.ActiveCfg = Release|x64
- {EA3973A2-F92E-3124-9817-81B2458EC8DC}.Release|x64.Build.0 = Release|x64
- {EA3973A2-F92E-3124-9817-81B2458EC8DC}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {EA3973A2-F92E-3124-9817-81B2458EC8DC}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {EA3973A2-F92E-3124-9817-81B2458EC8DC}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {EA3973A2-F92E-3124-9817-81B2458EC8DC}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C}.Debug|x64.ActiveCfg = Debug|x64
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C}.Debug|x64.Build.0 = Debug|x64
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C}.Release|x64.ActiveCfg = Release|x64
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C}.Release|x64.Build.0 = Release|x64
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {55A27783-64A4-3AA7-A4B1-49C4B628F18C}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}.Debug|x64.ActiveCfg = Debug|x64
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}.Debug|x64.Build.0 = Debug|x64
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}.Release|x64.ActiveCfg = Release|x64
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}.Release|x64.Build.0 = Release|x64
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {36953645-6D01-37E4-ACF7-D3F9BFFCA49D}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}.Debug|x64.ActiveCfg = Debug|x64
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}.Debug|x64.Build.0 = Debug|x64
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}.Release|x64.ActiveCfg = Release|x64
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}.Release|x64.Build.0 = Release|x64
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {F4AD2C38-E6B9-3C4A-A281-4AB7440D6162}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {D9D20410-4011-370C-8E15-A6F5C311F337}.Debug|x64.ActiveCfg = Debug|x64
- {D9D20410-4011-370C-8E15-A6F5C311F337}.Debug|x64.Build.0 = Debug|x64
- {D9D20410-4011-370C-8E15-A6F5C311F337}.Release|x64.ActiveCfg = Release|x64
- {D9D20410-4011-370C-8E15-A6F5C311F337}.Release|x64.Build.0 = Release|x64
- {D9D20410-4011-370C-8E15-A6F5C311F337}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {D9D20410-4011-370C-8E15-A6F5C311F337}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {D9D20410-4011-370C-8E15-A6F5C311F337}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {D9D20410-4011-370C-8E15-A6F5C311F337}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}.Debug|x64.ActiveCfg = Debug|x64
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}.Debug|x64.Build.0 = Debug|x64
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}.Release|x64.ActiveCfg = Release|x64
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}.Release|x64.Build.0 = Release|x64
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {4A20AD5F-7334-31D3-B31D-9AAF53CC6678}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {5857EF98-C87F-3197-A399-F0F9A20913FC}.Debug|x64.ActiveCfg = Debug|x64
- {5857EF98-C87F-3197-A399-F0F9A20913FC}.Debug|x64.Build.0 = Debug|x64
- {5857EF98-C87F-3197-A399-F0F9A20913FC}.Release|x64.ActiveCfg = Release|x64
- {5857EF98-C87F-3197-A399-F0F9A20913FC}.Release|x64.Build.0 = Release|x64
- {5857EF98-C87F-3197-A399-F0F9A20913FC}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {5857EF98-C87F-3197-A399-F0F9A20913FC}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {5857EF98-C87F-3197-A399-F0F9A20913FC}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {5857EF98-C87F-3197-A399-F0F9A20913FC}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}.Debug|x64.ActiveCfg = Debug|x64
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}.Debug|x64.Build.0 = Debug|x64
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}.Release|x64.ActiveCfg = Release|x64
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}.Release|x64.Build.0 = Release|x64
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {F6E7B14E-D9E6-343C-B58D-CA0381A3BB8F}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {FBB107B9-523B-3094-95CF-A103E2388006}.Debug|x64.ActiveCfg = Debug|x64
- {FBB107B9-523B-3094-95CF-A103E2388006}.Debug|x64.Build.0 = Debug|x64
- {FBB107B9-523B-3094-95CF-A103E2388006}.Release|x64.ActiveCfg = Release|x64
- {FBB107B9-523B-3094-95CF-A103E2388006}.Release|x64.Build.0 = Release|x64
- {FBB107B9-523B-3094-95CF-A103E2388006}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {FBB107B9-523B-3094-95CF-A103E2388006}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {FBB107B9-523B-3094-95CF-A103E2388006}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {FBB107B9-523B-3094-95CF-A103E2388006}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC}.Debug|x64.ActiveCfg = Debug|x64
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC}.Debug|x64.Build.0 = Debug|x64
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC}.Release|x64.ActiveCfg = Release|x64
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC}.Release|x64.Build.0 = Release|x64
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {5B4A6D26-C638-3350-9E1A-0F987C448DEC}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}.Debug|x64.ActiveCfg = Debug|x64
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}.Debug|x64.Build.0 = Debug|x64
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}.Release|x64.ActiveCfg = Release|x64
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}.Release|x64.Build.0 = Release|x64
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {11AB9AE9-CF65-341B-B425-9EDFC4E2F22F}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3}.Debug|x64.ActiveCfg = Debug|x64
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3}.Debug|x64.Build.0 = Debug|x64
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3}.Release|x64.ActiveCfg = Release|x64
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3}.Release|x64.Build.0 = Release|x64
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3}.MinSizeRel|x64.ActiveCfg = MinSizeRel|x64
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3}.MinSizeRel|x64.Build.0 = MinSizeRel|x64
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3}.RelWithDebInfo|x64.ActiveCfg = RelWithDebInfo|x64
- {1134F859-3DE4-34B1-924F-82CA38D4D4F3}.RelWithDebInfo|x64.Build.0 = RelWithDebInfo|x64
+ {E2F320FE-0C01-4C80-810C-3A92205A29DC}.Debug|x64.ActiveCfg = Debug|x64
+ {E2F320FE-0C01-4C80-810C-3A92205A29DC}.Debug|x64.Build.0 = Debug|x64
+ {E2F320FE-0C01-4C80-810C-3A92205A29DC}.Release|x64.ActiveCfg = Release|x64
+ {E2F320FE-0C01-4C80-810C-3A92205A29DC}.Release|x64.Build.0 = Release|x64
EndGlobalSection
- GlobalSection(ExtensibilityGlobals) = postSolution
- SolutionGuid = {A73289FB-DB51-3D6F-802E-B474CC102EDA}
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
EndGlobalSection
- GlobalSection(ExtensibilityAddIns) = postSolution
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {3B922907-3384-4D39-9CEB-816BF7BB390D}
EndGlobalSection
EndGlobal
diff --git a/vs/Marian.vcxproj b/vs/Marian.vcxproj
index bebb987b..241aa307 100755
--- a/vs/Marian.vcxproj
+++ b/vs/Marian.vcxproj
@@ -15,40 +15,43 @@
<Keyword>Win32Proj</Keyword>
<RootNamespace>Marian</RootNamespace>
<ProjectName>Marian</ProjectName>
- <WindowsTargetPlatformVersion>8.1</WindowsTargetPlatformVersion>
+ <WindowsTargetPlatformVersion>10.0.17763.0</WindowsTargetPlatformVersion>
+ <CudaToolkitDir Condition="'$(CudaToolkitDir)' == ''">$(CUDA_PATH)</CudaToolkitDir>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
- <PlatformToolset>v140</PlatformToolset>
+ <PlatformToolset>v141</PlatformToolset>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>Application</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
- <PlatformToolset>v140</PlatformToolset>
+ <PlatformToolset>v141</PlatformToolset>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>Unicode</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
- <ImportGroup Label="ExtensionSettings" />
+ <ImportGroup Label="ExtensionSettings">
+ <Import Project="$(CudaToolkitDir)\extras\visual_studio_integration\MSBuildExtensions\CUDA 10.1.props" />
+ </ImportGroup>
<ImportGroup Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<LinkIncremental>true</LinkIncremental>
- <IntDir>$(Platform)\$(Configuration)\Marian\</IntDir>
- <IncludePath>..\src;..\src\3rd_party;%BOOST_INCLUDE_PATH%;%ZLIB_PATH%\include;%MKL_PATH%\include;$(VC_IncludePath);$(WindowsSDK_IncludePath);</IncludePath>
- <LibraryPath>%BOOST_LIB_PATH%;%ZLIB_PATH%\lib;%MKL_PATH%\lib\intel64;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64</LibraryPath>
+ <IntDir>$(SolutionDir)$(Platform)\$(Configuration)\Marian\</IntDir>
+ <IncludePath>$(CudaToolkitIncludeDir);..\src\3rd_party\fbgemm\third_party\googletest\googletest;..\src\3rd_party\fbgemm\third_party\googletest\googletest\include;..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include;..\src\3rd_party\fbgemm\third_party\cpuinfo\src;..\src\3rd_party\fbgemm\third_party\cpuinfo\include;..\src\3rd_party\fbgemm;..\src\3rd_party\fbgemm\third_party\asmjit\src;%MKL_PATH%\include;..\src\3rd_party\fbgemm\include;..\src;..\src\3rd_party;%BOOST_INCLUDE_PATH%;%ZLIB_PATH%\include;$(VC_IncludePath);$(WindowsSDK_IncludePath)</IncludePath>
+ <LibraryPath>$(CudaToolkitLibDir);%BOOST_LIB_PATH%;%ZLIB_PATH%\lib;%MKL_PATH%\lib\intel64;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64</LibraryPath>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<LinkIncremental>false</LinkIncremental>
<ExecutablePath>$(ExecutablePath)</ExecutablePath>
- <IntDir>$(Platform)\$(Configuration)\Marian\</IntDir>
- <IncludePath>..\src;..\src\3rd_party;%BOOST_INCLUDE_PATH%;%ZLIB_PATH%\include;%MKL_PATH%\include;$(VC_IncludePath);$(WindowsSDK_IncludePath);</IncludePath>
- <LibraryPath>%BOOST_LIB_PATH%;%ZLIB_PATH%\lib;%MKL_PATH%\lib\intel64;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64</LibraryPath>
+ <IntDir>$(SolutionDir)$(Platform)\$(Configuration)\Marian\</IntDir>
+ <IncludePath>$(CudaToolkitIncludeDir);..\src\3rd_party\fbgemm\third_party\googletest\googletest;..\src\3rd_party\fbgemm\third_party\googletest\googletest\include;..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include;..\src\3rd_party\fbgemm\third_party\cpuinfo\src;..\src\3rd_party\fbgemm\third_party\cpuinfo\include;..\src\3rd_party\fbgemm;..\src\3rd_party\fbgemm\third_party\asmjit\src;%MKL_PATH%\include;..\src\3rd_party\fbgemm\include;..\src;..\src\3rd_party;%BOOST_INCLUDE_PATH%;%ZLIB_PATH%\include;$(VC_IncludePath);$(WindowsSDK_IncludePath)</IncludePath>
+ <LibraryPath>$(CudaToolkitLibDir);%BOOST_LIB_PATH%;%ZLIB_PATH%\lib;%MKL_PATH%\lib\intel64;$(VC_LibraryPath_x64);$(WindowsSDK_LibraryPath_x64);$(NETFXKitsDir)Lib\um\x64</LibraryPath>
</PropertyGroup>
<ItemDefinitionGroup>
<ClCompile>
@@ -57,6 +60,9 @@
<Link>
<AdditionalLibraryDirectories>$(OutDir);$(SolutionDir)$(Platform)\$(Configuration);$(SolutionDir)$(Platform)\$(Configuration);$(MSMPI_LIB64)</AdditionalLibraryDirectories>
</Link>
+ <CudaCompile>
+ <TargetMachinePlatform>64</TargetMachinePlatform>
+ </CudaCompile>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
@@ -64,22 +70,34 @@
</PrecompiledHeader>
<WarningLevel>Level4</WarningLevel>
<Optimization>Disabled</Optimization>
- <PreprocessorDefinitions>MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
- <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;_DEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <SDLCheck>false</SDLCheck>
<TreatWarningAsError>true</TreatWarningAsError>
- <AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
+ <AdditionalOptions>/bigobj /arch:AVX %(AdditionalOptions)</AdditionalOptions>
<RuntimeLibrary Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">MultiThreadedDebugDLL</RuntimeLibrary>
<DisableSpecificWarnings>4996; 4702</DisableSpecificWarnings>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
<MinimalRebuild>false</MinimalRebuild>
+ <ObjectFileName>$(IntDir)%(RelativeDir)</ObjectFileName>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
- <AdditionalDependencies>zlib.lib; msmpi.lib; mkl_intel_ilp64.lib; mkl_sequential.lib; mkl_core.lib; kernel32.lib; user32.lib; gdi32.lib; winspool.lib; comdlg32.lib; advapi32.lib; shell32.lib; ole32.lib; oleaut32.lib; uuid.lib; odbc32.lib; odbccp32.lib; %(AdditionalDependencies)</AdditionalDependencies>
+ <AdditionalDependencies>cudart_static.lib;cublas.lib;cusparse.lib;curand.lib;zlib.lib;msmpi.lib;mkl_intel_ilp64.lib;mkl_sequential.lib;mkl_core.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;shlwapi.lib;%(AdditionalDependencies)</AdditionalDependencies>
<StackReserveSize>100000000</StackReserveSize>
- <TreatLinkerWarningAsErrors>true</TreatLinkerWarningAsErrors>
+ <TreatLinkerWarningAsErrors>false</TreatLinkerWarningAsErrors>
+ <AdditionalOptions>/ignore:4049 /ignore:4217 %(AdditionalOptions)</AdditionalOptions>
</Link>
+ <CudaCompile>
+ <Include>$(SolutionDir)..\src\;$(SolutionDir)..\src\3rd_party</Include>
+ <CodeGeneration>compute_50,sm_50</CodeGeneration>
+ <KeepDir>$(CudaIntDir)</KeepDir>
+ <CompileOut>$(IntDir)%(Filename)%(Extension).obj</CompileOut>
+ <Warning>W2</Warning>
+ <AdditionalCompilerOptions>
+ </AdditionalCompilerOptions>
+ <Defines>_SCL_SECURE_NO_WARNINGS</Defines>
+ </CudaCompile>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
@@ -89,10 +107,10 @@
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
- <PreprocessorDefinitions>MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
- <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>USE_MKL;ASMJIT_EXPORTS;BOOST_CONFIG_SUPPRESS_OUTDATED_MESSAGE; FBGEMM_EXPORTS; USE_FBGEMM=1; USE_SSE2=1; CUDA_FOUND=1; MKL_FOUND=1; MPI_FOUND=1; BLAS_FOUND=1; MKL_ILP64; WIN32;NDEBUG;_CONSOLE;_LIB;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <SDLCheck>false</SDLCheck>
<FavorSizeOrSpeed>Speed</FavorSizeOrSpeed>
- <AdditionalOptions>/d2Zi+ /bigobj %(AdditionalOptions)</AdditionalOptions>
+ <AdditionalOptions>/d2Zi+ /bigobj /arch:AVX %(AdditionalOptions)</AdditionalOptions>
<TreatWarningAsError>true</TreatWarningAsError>
<RuntimeLibrary Condition="'$(Configuration)|$(Platform)'=='Release|x64'">MultiThreadedDLL</RuntimeLibrary>
<RuntimeLibrary Condition="'$(Configuration)|$(Platform)'=='Release_NoOpt|x64'">MultiThreaded</RuntimeLibrary>
@@ -100,19 +118,693 @@
<OmitFramePointers>true</OmitFramePointers>
<DisableSpecificWarnings>4996; 4702</DisableSpecificWarnings>
<MultiProcessorCompilation>true</MultiProcessorCompilation>
+ <ObjectFileName>$(IntDir)%(RelativeDir)</ObjectFileName>
</ClCompile>
<Link>
<SubSystem>Console</SubSystem>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
- <AdditionalDependencies>zlib.lib; msmpi.lib; mkl_intel_ilp64.lib; mkl_sequential.lib; mkl_core.lib; kernel32.lib; user32.lib; gdi32.lib; winspool.lib; comdlg32.lib; advapi32.lib; shell32.lib; ole32.lib; oleaut32.lib; uuid.lib; odbc32.lib; odbccp32.lib; %(AdditionalDependencies)</AdditionalDependencies>
+ <AdditionalDependencies>cudart_static.lib;cublas.lib;cusparse.lib;curand.lib;zlib.lib;msmpi.lib;mkl_intel_ilp64.lib;mkl_sequential.lib;mkl_core.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;shlwapi.lib;%(AdditionalDependencies)</AdditionalDependencies>
<StackReserveSize>100000000</StackReserveSize>
- <TreatLinkerWarningAsErrors>true</TreatLinkerWarningAsErrors>
+ <TreatLinkerWarningAsErrors>false</TreatLinkerWarningAsErrors>
+ <AdditionalOptions>/ignore:4049 /ignore:4217 %(AdditionalOptions)</AdditionalOptions>
</Link>
+ <CudaCompile>
+ <Include>$(SolutionDir)..\src\;$(SolutionDir)..\src\3rd_party</Include>
+ <CodeGeneration>compute_50,sm_50</CodeGeneration>
+ <Warning>W2</Warning>
+ <AdditionalCompilerOptions>
+ </AdditionalCompilerOptions>
+ <Defines>_SCL_SECURE_NO_WARNINGS</Defines>
+ </CudaCompile>
</ItemDefinitionGroup>
<ItemGroup>
<ClCompile Include="..\src\3rd_party\ExceptionWithCallStack.cpp" />
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\BenchUtils.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\ConvUnifiedBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\Depthwise3DBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\DepthwiseBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\FP16Benchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\GEMMsBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\GEMMsTunableBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\GroupwiseConvRequantizeBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\I8SpmdmBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\Im2ColFusedRequantizeBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\PackedFloatInOutBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\PackedRequantizeAcc16Benchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\PackedRequantizeAcc32Benchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\RequantizeBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\RowOffsetBenchmark.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\codegen_fp16fp32.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\ExecuteKernel.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\ExecuteKernelU8S8.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\Fbgemm.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmConv.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmFP16.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmFP16UKernelsAvx2.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8Depthwise3DAvx2.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8DepthwiseAvx2.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8Spmdm.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16Avx512.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16Avx512VNNI.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32Avx512.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32Avx512VNNI.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GroupwiseConvAcc32Avx2.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\OptimizedKernelsAvx2.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackAMatrix.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithIm2Col.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithQuantRowOffset.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithRowOffset.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackBMatrix.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackDepthwiseConvMatrixAvx2.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackMatrix.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackWeightMatrixForGConv.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackWeightsForConv.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\QuantUtils.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\QuantUtilsAvx2.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\RefImplementations.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\Utils.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\UtilsAvx2.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\UtilsAvx512.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\FP16Test.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\GConvTest.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\I8DepthwiseTest.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\I8SpmdmTest.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\Im2ColFusedRequantizeTest.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\PackedRequantizeAcc16Test.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\PackedRequantizeTest.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\QuantizationHelpers.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\QuantUtilsTest.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\RequantizeOnlyTest.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\TestUtils.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\UniConvTest.cc">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\arch.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\assembler.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\builder.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\callconv.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\codeholder.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\compiler.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\constpool.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\cpuinfo.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\emitter.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\func.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\globals.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\inst.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\jitallocator.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\jitruntime.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\logging.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\operand.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\osutils.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\ralocal.cpp">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rapass.cpp">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rastack.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\string.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\support.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\target.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\type.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\virtmem.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zone.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonehash.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonelist.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonestack.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonetree.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonevector.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86assembler.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86builder.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86callconv.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86compiler.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86features.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instapi.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instdb.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86internal.cpp">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86logging.cpp">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86rapass.cpp">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\src\clog.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\api.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\init.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\descriptor.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\deterministic.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\init.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\info.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\init.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\isa.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\name.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\topology.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\uarch.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\vendor.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows\init.c">
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</TreatWarningAsError>
+ <TreatWarningAsError Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</TreatWarningAsError>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\half_float\HalfPrecisionFloatTest.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\googletest\googletest\src\gtest-all.cc" />
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\googletest\googletest\src\gtest_main.cc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\entry_iterator.cpp" />
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\errors.cpp" />
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\path.cpp" />
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\pathie.cpp" />
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\pathie_ifstream.cpp" />
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\pathie_ofstream.cpp" />
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\temp.cpp" />
+ <ClCompile Include="..\src\3rd_party\phf\phf.cc">
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">TurnOffAllWarnings</WarningLevel>
+ <WarningLevel Condition="'$(Configuration)|$(Platform)'=='Release|x64'">TurnOffAllWarnings</WarningLevel>
+ </ClCompile>
<ClCompile Include="..\src\3rd_party\sentencepiece\src\bpe_model.cc">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
@@ -315,7 +1007,205 @@
</ClCompile>
<ClCompile Include="..\src\3rd_party\yaml-cpp\binary_renamed.cpp" />
<ClCompile Include="..\src\3rd_party\yaml-cpp\yaml-node.cpp" />
+ <ClInclude Include="..\src\3rd_party\any_type.h" />
+ <ClInclude Include="..\src\3rd_party\avx_mathfun.h" />
<ClInclude Include="..\src\3rd_party\ExceptionWithCallStack.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\bench\AlignedVec.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\bench\BenchUtils.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\ConvUtils.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Fbgemm.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmBuild.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmFP16.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmI8DepthwiseAvx2.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmI8Spmdm.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\OutputProcessing-inl.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\PackingTraits-inl.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\QuantUtils.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\QuantUtilsAvx2.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Types.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Utils.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\UtilsAvx2.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\CodeCache.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernel.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernelGeneric.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernelU8S8.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\FbgemmFP16UKernelsAvx2.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\FbgemmI8DepthwiseAvx2-inl.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\GenerateKernel.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\GroupwiseConv.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\OptimizedKernelsAvx2.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\RefImplementations.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\TransposeUtils.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\test\QuantizationHelpers.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\test\TestUtils.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\arch.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\assembler.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\build.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\builder.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\callconv.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\codebufferwriter_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\codeholder.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\compiler.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\constpool.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\cpuinfo.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\datatypes.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\emitter.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\features.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\func.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\globals.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\inst.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\jitallocator.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\jitruntime.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\logging.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\misc_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\operand.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\osutils.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\raassignment_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rabuilders_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\radefs_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\ralocal_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rapass_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rastack_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\string.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\support.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\target.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\type.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\virtmem.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zone.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonehash.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonelist.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonestack.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonestring.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonetree.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonevector.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86assembler.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86builder.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86callconv_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86compiler.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86emitter.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86features.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86globals.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instapi_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instdb.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instdb_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86internal_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86logging_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86opcode_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86rapass_p.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include\clog.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\include\cpuinfo-mock.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\include\cpuinfo.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\common.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\internal-api.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\log.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\utils.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\api.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cpuid.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows\api.h" />
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\googletest\googletest\src\gtest-internal-inl.h" />
+ <ClInclude Include="..\src\3rd_party\half_float\umHalf.h" />
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\collectives.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\all_gather.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\all_reduce.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\broadcast.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\common.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\common_kernel.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\ll_kernel.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\primitives.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\reduce.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\reduce_kernel.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\reduce_scatter.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\bootstrap.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\common_coll.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\core.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\debug.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\enqueue.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\group.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\ibvwrap.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\nccl_net.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\net.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\nvlink.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\nvmlwrap.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\param.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\ring.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\rings.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\shm.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\socket.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\topo.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\transport.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\utils.h">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\entry_iterator.hpp" />
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\errors.hpp" />
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\path.hpp" />
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\pathie.hpp" />
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\pathie_ifstream.hpp" />
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\pathie_ofstream.hpp" />
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\temp.hpp" />
+ <ClInclude Include="..\src\3rd_party\phf\phf.h" />
<ClInclude Include="..\src\3rd_party\sentencepiece\src\bpe_model.h">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
@@ -416,6 +1306,9 @@
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClInclude>
+ <ClInclude Include="..\src\3rd_party\sse_mathfun.h" />
+ <ClInclude Include="..\src\3rd_party\zstr\strict_fstream.hpp" />
+ <ClInclude Include="..\src\3rd_party\zstr\zstr.hpp" />
<ClInclude Include="..\src\command\marian_decoder.cpp">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
@@ -440,11 +1333,17 @@
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClInclude>
+ <ClCompile Include="..\src\common\aliases.cpp" />
<ClCompile Include="..\src\common\binary.cpp" />
<ClCompile Include="..\src\common\cli_helper.cpp" />
<ClCompile Include="..\src\common\cli_wrapper.cpp" />
<ClCompile Include="..\src\common\config_validator.cpp" />
+ <ClCompile Include="..\src\common\fastopt.cpp" />
+ <ClCompile Include="..\src\common\filesystem.cpp" />
+ <ClCompile Include="..\src\common\file_stream.cpp" />
<ClCompile Include="..\src\common\io.cpp" />
+ <ClCompile Include="..\src\common\options.cpp" />
+ <ClCompile Include="..\src\common\types.cpp" />
<ClCompile Include="..\src\common\utils.cpp" />
<ClCompile Include="..\src\common\logging.cpp" />
<ClCompile Include="..\src\common\config.cpp" />
@@ -452,6 +1351,7 @@
<ClCompile Include="..\src\common\version.cpp" />
<ClCompile Include="..\src\data\alignment.cpp" />
<ClCompile Include="..\src\data\default_vocab.cpp" />
+ <ClCompile Include="..\src\data\factored_vocab.cpp" />
<ClCompile Include="..\src\data\sentencepiece_vocab.cpp" />
<ClCompile Include="..\src\data\vocab.cpp" />
<ClCompile Include="..\src\data\corpus_base.cpp" />
@@ -459,15 +1359,29 @@
<ClCompile Include="..\src\data\corpus_nbest.cpp" />
<ClCompile Include="..\src\data\text_input.cpp" />
<ClCompile Include="..\src\3rd_party\cnpy\cnpy.cpp" />
- <ClCompile Include="..\src\3rd_party\svd\svd.cpp" />
+ <ClCompile Include="..\src\examples\iris\helper.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\examples\iris\iris.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\examples\mnist\mnist_ffnn.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\layers\generic.cpp" />
<ClCompile Include="..\src\layers\loss.cpp" />
<ClCompile Include="..\src\layers\weight.cpp" />
<ClCompile Include="..\src\microsoft\quicksand.cpp">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
</ClCompile>
+ <ClCompile Include="..\src\models\transformer_stub.cpp" />
<ClCompile Include="..\src\rescorer\score_collector.cpp" />
<ClCompile Include="..\src\tensors\backend.cpp" />
<ClCompile Include="..\src\tensors\cpu\device.cpp" />
+ <ClCompile Include="..\src\tensors\cpu\fbgemm\packed_gemm.cpp" />
<ClCompile Include="..\src\tensors\cpu\prod.cpp" />
<ClCompile Include="..\src\tensors\cpu\sharp\avx_gemm.cpp" />
<ClCompile Include="..\src\tensors\cpu\sharp\int_gemm.cpp" />
@@ -485,8 +1399,54 @@
<ClCompile Include="..\src\models\model_factory.cpp" />
<ClCompile Include="..\src\models\encoder_decoder.cpp" />
<ClCompile Include="..\src\tensors\rand.cpp" />
+ <ClCompile Include="..\src\tensors\tensor.cpp" />
+ <ClCompile Include="..\src\tests\attention_tests.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\cli_test.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\dropout_test.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\graph_tests.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\logger_test.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\operator_tests.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\pooling_test.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\prod.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\rnn_tests.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\run_tests.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\sqlite_test.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </ClCompile>
<ClCompile Include="..\src\training\communicator.cpp" />
<ClCompile Include="..\src\training\graph_group_multinode_sync.cpp" />
+ <ClCompile Include="..\src\training\scheduler.cpp" />
<ClCompile Include="..\src\translator\history.cpp" />
<ClCompile Include="..\src\translator\output_collector.cpp" />
<ClCompile Include="..\src\translator\nth_element.cpp" />
@@ -533,6 +1493,7 @@
<ClInclude Include="..\src\common\cli_helper.h" />
<ClInclude Include="..\src\common\cli_wrapper.h" />
<ClInclude Include="..\src\common\config_validator.h" />
+ <ClInclude Include="..\src\common\fastopt.h" />
<ClInclude Include="..\src\common\filesystem.h" />
<ClInclude Include="..\src\common\hash.h" />
<ClInclude Include="..\src\common\io.h" />
@@ -540,11 +1501,19 @@
<ClInclude Include="..\src\common\timer.h" />
<ClInclude Include="..\src\common\types.h" />
<ClInclude Include="..\src\common\version.h" />
+ <ClInclude Include="..\src\data\factored_vocab.h" />
+ <ClInclude Include="..\src\data\vocab_base.h" />
+ <ClInclude Include="..\src\examples\mnist\dataset.h" />
+ <ClInclude Include="..\src\examples\mnist\model.h" />
+ <ClInclude Include="..\src\examples\mnist\model_lenet.h" />
+ <ClInclude Include="..\src\examples\mnist\training.h" />
+ <ClInclude Include="..\src\examples\mnist\validator.h" />
+ <ClInclude Include="..\src\functional\approx.h" />
+ <ClInclude Include="..\src\functional\operators.h" />
<ClInclude Include="..\src\layers\loss.h" />
<ClInclude Include="..\src\layers\weight.h" />
<ClInclude Include="..\src\marian.h" />
<ClInclude Include="..\src\3rd_party\catch.hpp" />
- <ClInclude Include="..\src\3rd_party\exception.h" />
<ClInclude Include="..\src\3rd_party\reduce_all.h" />
<ClInclude Include="..\src\3rd_party\threadpool.h" />
<ClInclude Include="..\src\3rd_party\cnpy\cnpy.h" />
@@ -620,8 +1589,6 @@
<ClInclude Include="..\src\3rd_party\spdlog\tests\catch.hpp" />
<ClInclude Include="..\src\3rd_party\spdlog\tests\includes.h" />
<ClInclude Include="..\src\3rd_party\spdlog\tests\utils.h" />
- <ClInclude Include="..\src\3rd_party\svd\defs_and_types.h" />
- <ClInclude Include="..\src\3rd_party\svd\svd.h" />
<ClInclude Include="..\src\3rd_party\yaml-cpp\anchor.h" />
<ClInclude Include="..\src\3rd_party\yaml-cpp\binary.h" />
<ClInclude Include="..\src\3rd_party\yaml-cpp\collectionstack.h" />
@@ -733,19 +1700,21 @@
<ClInclude Include="..\src\layers\word2vec_reader.h" />
<ClInclude Include="..\src\microsoft\quicksand.h" />
<ClInclude Include="..\src\models\amun.h" />
+ <ClInclude Include="..\src\models\bert.h" />
<ClInclude Include="..\src\models\char_s2s.h" />
+ <ClInclude Include="..\src\models\classifier.h" />
<ClInclude Include="..\src\models\costs.h" />
<ClInclude Include="..\src\models\decoder.h" />
<ClInclude Include="..\src\models\encoder.h" />
+ <ClInclude Include="..\src\models\encoder_classifier.h" />
<ClInclude Include="..\src\models\encoder_decoder.h" />
- <ClInclude Include="..\src\models\hardatt.h" />
<ClInclude Include="..\src\models\model_base.h" />
<ClInclude Include="..\src\models\model_factory.h" />
<ClInclude Include="..\src\models\model_task.h" />
<ClInclude Include="..\src\models\nematus.h" />
<ClInclude Include="..\src\models\s2s.h" />
<ClInclude Include="..\src\models\states.h" />
- <ClCompile Include="..\src\models\transformer.h" />
+ <ClInclude Include="..\src\models\transformer.h" />
<ClInclude Include="..\src\models\experimental\lex_probs.h" />
<ClInclude Include="..\src\models\transformer_factory.h" />
<ClInclude Include="..\src\optimizers\clippers.h" />
@@ -760,6 +1729,9 @@
<ClInclude Include="..\src\rnn\types.h" />
<ClInclude Include="..\src\tensors\allocator.h" />
<ClInclude Include="..\src\tensors\backend.h" />
+ <ClInclude Include="..\src\tensors\cpu\fbgemm\expanded_gemm.h" />
+ <ClInclude Include="..\src\tensors\cpu\fbgemm\expression_graph_packable.h" />
+ <ClInclude Include="..\src\tensors\cpu\fbgemm\packed_gemm.h" />
<ClInclude Include="..\src\tensors\cpu\sharp\int_gemm.h" />
<ClInclude Include="..\src\tensors\device.h" />
<ClInclude Include="..\src\tensors\dispatch.h" />
@@ -776,7 +1748,6 @@
<ClInclude Include="..\src\tensors\tensor.h" />
<ClInclude Include="..\src\tensors\tensor_allocator.h" />
<ClInclude Include="..\src\tensors\tensor_operators.h" />
- <ClInclude Include="..\src\tensors\types.h" />
<ClInclude Include="..\src\tensors\cpu\add.h" />
<ClInclude Include="..\src\tensors\cpu\backend.h" />
<ClInclude Include="..\src\tensors\cpu\element.h" />
@@ -806,52 +1777,174 @@
<ClInclude Include="..\src\translator\printer.h" />
<ClInclude Include="..\src\translator\scorers.h" />
<ClInclude Include="..\src\translator\translator.h" />
+ <ClInclude Include="..\src\training\communicator_nccl.h" />
</ItemGroup>
<ItemGroup>
- <None Include="..\src\3rd_party\sentencepiece\src\sentencepiece.proto">
- <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <None Include="..\src\3rd_party\half_float\Readme.md" />
+ <None Include="..\src\3rd_party\half_float\umHalf.inl" />
+ <None Include="..\src\3rd_party\nccl\src\bootstrap.cu">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\3rd_party\sentencepiece\src\sentencepiece_model.proto">
- <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <None Include="..\src\3rd_party\nccl\src\collectives\all_gather.cu">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\tensors\gpu\add.cu">
- <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <None Include="..\src\3rd_party\nccl\src\collectives\all_reduce.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\tensors\gpu\add.inc" />
- <None Include="..\src\tensors\gpu\algorithm.cu">
- <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <None Include="..\src\3rd_party\nccl\src\collectives\broadcast.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\tensors\gpu\cudnn_wrappers.cu">
- <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\all_gather.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\tensors\gpu\device.cu">
- <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\all_reduce.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\tensors\gpu\element.cu">
- <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\broadcast.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\functions.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\tensors\gpu\element.inc" />
- <None Include="..\src\tensors\gpu\prod.cu">
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\Makefile">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\reduce.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\reduce_scatter.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\reduce.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\reduce_scatter.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\init.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\Makefile">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\enqueue.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\group.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\ibvwrap.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\nvmlwrap.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\rings.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\utils.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\nccl.h.in">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\ring.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\net.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\net_ib.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\net_socket.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\p2p.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\shm.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </None>
+ </ItemGroup>
+ <ItemGroup>
+ <CudaCompile Include="..\src\tensors\gpu\add.cu" />
+ <ClInclude Include="..\src\tensors\gpu\add.inc">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <FileType>Document</FileType>
+ </ClInclude>
+ <CudaCompile Include="..\src\tensors\gpu\algorithm.cu" />
+ <CudaCompile Include="..\src\tensors\gpu\cudnn_wrappers.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</ExcludedFromBuild>
+ </CudaCompile>
+ <CudaCompile Include="..\src\tensors\gpu\device.cu" />
+ <CudaCompile Include="..\src\tensors\gpu\element.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</ExcludedFromBuild>
+ </CudaCompile>
+ <ClInclude Include="..\src\tensors\gpu\element.inc">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ <FileType>Document</FileType>
+ </ClInclude>
+ <ClCompile Include="..\src\tensors\gpu\prod.cpp">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</ExcludedFromBuild>
+ <ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">$(IntDir)%(RelativeDir)</ObjectFileName>
+ <ObjectFileName Condition="'$(Configuration)|$(Platform)'=='Release|x64'">$(IntDir)%(RelativeDir)</ObjectFileName>
+ </ClCompile>
+ <CudaCompile Include="..\src\tensors\gpu\sparse.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
+ </CudaCompile>
+ <CudaCompile Include="..\src\tensors\gpu\tensor_operators.cu">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">false</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">false</ExcludedFromBuild>
+ </CudaCompile>
+ <CudaCompile Include="..\src\training\gradient_dropping\gpu\dropper.cu" />
+ <CudaCompile Include="..\src\training\gradient_dropping\gpu\sparse_algorithm.cu" />
+ <CudaCompile Include="..\src\translator\helpers.cu" />
+ <CudaCompile Include="..\src\translator\nth_element.cu" />
+ </ItemGroup>
+ <ItemGroup>
+ <None Include="..\src\tests\README.md" />
+ <None Include="..\src\3rd_party\pathie-cpp\CHANGELOG" />
+ <None Include="..\src\3rd_party\pathie-cpp\LICENSE" />
+ <None Include="..\src\3rd_party\pathie-cpp\README.md" />
+ <None Include="..\src\3rd_party\sentencepiece\src\sentencepiece.proto">
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\tensors\gpu\sparse.cu">
+ <None Include="..\src\3rd_party\sentencepiece\src\sentencepiece_model.proto">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\tensors\gpu\tensor_operators.cu">
+ <None Include="..\src\examples\cmake_install.cmake" />
+ <None Include="..\src\examples\iris\iris.data" />
+ <None Include="..\src\examples\Makefile" />
+ <None Include="..\src\examples\mnist\download.sh" />
+ <None Include="..\src\examples\README.md">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
+ <FileType>Document</FileType>
+ <ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</None>
- <None Include="..\src\training\communicator_nccl.h" />
- <None Include=".editorConfig" />
</ItemGroup>
<ItemGroup>
+ <Text Include="..\src\3rd_party\pathie-cpp\CMakeLists.txt" />
<Text Include="..\src\3rd_party\sentencepiece\src\CMakeLists.txt">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</Text>
+ <Text Include="..\src\examples\CMakeLists.txt" />
+ <Text Include="..\src\tests\CMakeLists.txt" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
+ <Import Project="$(CudaToolkitDir)\extras\visual_studio_integration\MSBuildExtensions\CUDA 10.1.targets" />
</ImportGroup>
</Project> \ No newline at end of file
diff --git a/vs/Marian.vcxproj.filters b/vs/Marian.vcxproj.filters
index 12aa5b4e..a4cbc827 100755
--- a/vs/Marian.vcxproj.filters
+++ b/vs/Marian.vcxproj.filters
@@ -31,9 +31,6 @@
<ClCompile Include="..\src\3rd_party\cnpy\cnpy.cpp">
<Filter>3rd_party\cnpy</Filter>
</ClCompile>
- <ClCompile Include="..\src\3rd_party\svd\svd.cpp">
- <Filter>3rd_party\svd</Filter>
- </ClCompile>
<ClCompile Include="..\src\tensors\backend.cpp">
<Filter>tensors</Filter>
</ClCompile>
@@ -205,9 +202,6 @@
<ClCompile Include="..\src\tensors\cpu\sharp\sse_gemm.cpp">
<Filter>tensors\cpu\sharp</Filter>
</ClCompile>
- <ClCompile Include="..\src\models\transformer.h">
- <Filter>models</Filter>
- </ClCompile>
<ClCompile Include="..\src\common\io.cpp">
<Filter>common</Filter>
</ClCompile>
@@ -421,8 +415,473 @@
<ClCompile Include="..\src\rescorer\score_collector.cpp">
<Filter>rescorer</Filter>
</ClCompile>
- <ClCompile Include="..\src\command\marian_train.cpp">
- <Filter>command</Filter>
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\entry_iterator.cpp">
+ <Filter>3rd_party\pathie-cpp\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\errors.cpp">
+ <Filter>3rd_party\pathie-cpp\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\path.cpp">
+ <Filter>3rd_party\pathie-cpp\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\pathie.cpp">
+ <Filter>3rd_party\pathie-cpp\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\pathie_ifstream.cpp">
+ <Filter>3rd_party\pathie-cpp\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\pathie_ofstream.cpp">
+ <Filter>3rd_party\pathie-cpp\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\pathie-cpp\src\temp.cpp">
+ <Filter>3rd_party\pathie-cpp\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\attention_tests.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\cli_test.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\dropout_test.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\graph_tests.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\logger_test.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\operator_tests.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\pooling_test.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\prod.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\rnn_tests.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\run_tests.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tests\sqlite_test.cpp">
+ <Filter>tests</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\examples\mnist\mnist_ffnn.cpp">
+ <Filter>examples\mnist</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\examples\iris\helper.cpp">
+ <Filter>examples\iris</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\examples\iris\iris.cpp">
+ <Filter>examples\iris</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\models\transformer_stub.cpp">
+ <Filter>models</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\layers\generic.cpp">
+ <Filter>layers</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\data\factored_vocab.cpp">
+ <Filter>data</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tensors\gpu\prod.cpp">
+ <Filter>tensors\gpu</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\ExecuteKernel.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\ExecuteKernelU8S8.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\Fbgemm.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmConv.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmFP16.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmFP16UKernelsAvx2.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8DepthwiseAvx2.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8Spmdm.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16Avx512.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32Avx512.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GroupwiseConvAcc32Avx2.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\OptimizedKernelsAvx2.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackAMatrix.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithIm2Col.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithQuantRowOffset.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackAWithRowOffset.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackBMatrix.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackMatrix.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackWeightMatrixForGConv.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackWeightsForConv.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\QuantUtils.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\QuantUtilsAvx2.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\RefImplementations.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\Utils.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\UtilsAvx2.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\UtilsAvx512.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\api.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\init.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\info.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\init.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\isa.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\name.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\topology.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\uarch.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\vendor.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows\init.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\descriptor.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\cacehe</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\deterministic.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\cacehe</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cache\init.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\cacehe</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\src\clog.c">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\deps\clog\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\common\aliases.cpp">
+ <Filter>common</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\common\filesystem.cpp">
+ <Filter>common</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\arch.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\assembler.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\builder.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\callconv.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\codeholder.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\compiler.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\constpool.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\cpuinfo.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\emitter.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\func.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\globals.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\inst.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\jitallocator.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\jitruntime.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\logging.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\operand.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\osutils.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\ralocal.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rapass.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rastack.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\string.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\support.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\target.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\type.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\virtmem.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zone.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonehash.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonelist.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonestack.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonetree.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonevector.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86assembler.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86builder.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86callconv.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86compiler.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86features.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instapi.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instdb.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86internal.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86logging.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86rapass.cpp">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\codegen_fp16fp32.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\FbgemmI8Depthwise3DAvx2.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC16Avx512VNNI.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\GenerateKernelU8S8S32ACC32Avx512VNNI.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\src\PackDepthwiseConvMatrixAvx2.cc">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\FP16Test.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\GConvTest.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\I8DepthwiseTest.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\I8SpmdmTest.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\Im2ColFusedRequantizeTest.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\PackedRequantizeAcc16Test.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\PackedRequantizeTest.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\QuantizationHelpers.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\QuantUtilsTest.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\RequantizeOnlyTest.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\TestUtils.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\test\UniConvTest.cc">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\BenchUtils.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\ConvUnifiedBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\Depthwise3DBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\DepthwiseBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\FP16Benchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\GEMMsBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\GEMMsTunableBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\GroupwiseConvRequantizeBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\I8SpmdmBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\Im2ColFusedRequantizeBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\PackedFloatInOutBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\PackedRequantizeAcc16Benchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\PackedRequantizeAcc32Benchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\RequantizeBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\bench\RowOffsetBenchmark.cc">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\googletest\googletest\src\gtest_main.cc">
+ <Filter>3rd_party\fbgemm\third_party\googletest\googletest\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\fbgemm\third_party\googletest\googletest\src\gtest-all.cc">
+ <Filter>3rd_party\fbgemm\third_party\googletest\googletest\src</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\common\aliases.cpp">
+ <Filter>common</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\half_float\HalfPrecisionFloatTest.cpp">
+ <Filter>3rd_party\half_float</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tensors\tensor.cpp">
+ <Filter>tensors</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\common\filesystem.cpp">
+ <Filter>common</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\common\file_stream.cpp">
+ <Filter>common</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\training\scheduler.cpp">
+ <Filter>training</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\common\options.cpp">
+ <Filter>common</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\common\types.cpp">
+ <Filter>common</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\common\fastopt.cpp">
+ <Filter>common</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\3rd_party\phf\phf.cc">
+ <Filter>3rd_party\phf</Filter>
+ </ClCompile>
+ <ClCompile Include="..\src\tensors\cpu\fbgemm\packed_gemm.cpp">
+ <Filter>tensors\cpu\fbgemm</Filter>
</ClCompile>
</ItemGroup>
<ItemGroup>
@@ -430,9 +889,6 @@
<ClInclude Include="..\src\3rd_party\catch.hpp">
<Filter>3rd_party</Filter>
</ClInclude>
- <ClInclude Include="..\src\3rd_party\exception.h">
- <Filter>3rd_party</Filter>
- </ClInclude>
<ClInclude Include="..\src\3rd_party\reduce_all.h">
<Filter>3rd_party</Filter>
</ClInclude>
@@ -640,12 +1096,6 @@
<ClInclude Include="..\src\3rd_party\spdlog\tests\utils.h">
<Filter>3rd_party\spdlog\tests</Filter>
</ClInclude>
- <ClInclude Include="..\src\3rd_party\svd\defs_and_types.h">
- <Filter>3rd_party\svd</Filter>
- </ClInclude>
- <ClInclude Include="..\src\3rd_party\svd\svd.h">
- <Filter>3rd_party\svd</Filter>
- </ClInclude>
<ClInclude Include="..\src\3rd_party\yaml-cpp\anchor.h">
<Filter>3rd_party\yaml-cpp</Filter>
</ClInclude>
@@ -988,9 +1438,6 @@
<ClInclude Include="..\src\models\encoder_decoder.h">
<Filter>models</Filter>
</ClInclude>
- <ClInclude Include="..\src\models\hardatt.h">
- <Filter>models</Filter>
- </ClInclude>
<ClInclude Include="..\src\models\model_base.h">
<Filter>models</Filter>
</ClInclude>
@@ -1066,9 +1513,6 @@
<ClInclude Include="..\src\tensors\tensor_operators.h">
<Filter>tensors</Filter>
</ClInclude>
- <ClInclude Include="..\src\tensors\types.h">
- <Filter>tensors</Filter>
- </ClInclude>
<ClInclude Include="..\src\tensors\cpu\add.h">
<Filter>tensors\cpu</Filter>
</ClInclude>
@@ -1345,6 +1789,495 @@
<ClInclude Include="..\src\3rd_party\sentencepiece\src\word_model_trainer.h">
<Filter>3rd_party\sentencepiece\src</Filter>
</ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\collectives.h">
+ <Filter>3rd_party\nccl\src\collectives</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\all_gather.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\all_reduce.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\broadcast.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\common.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\common_kernel.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\ll_kernel.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\primitives.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\reduce.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\reduce_kernel.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\collectives\device\reduce_scatter.h">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\bootstrap.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\common_coll.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\core.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\debug.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\enqueue.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\group.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\ibvwrap.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\nccl_net.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\net.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\nvlink.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\nvmlwrap.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\param.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\ring.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\rings.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\shm.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\socket.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\topo.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\transport.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\nccl\src\include\utils.h">
+ <Filter>3rd_party\nccl\src\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\entry_iterator.hpp">
+ <Filter>3rd_party\pathie-cpp\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\errors.hpp">
+ <Filter>3rd_party\pathie-cpp\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\path.hpp">
+ <Filter>3rd_party\pathie-cpp\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\pathie.hpp">
+ <Filter>3rd_party\pathie-cpp\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\pathie_ifstream.hpp">
+ <Filter>3rd_party\pathie-cpp\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\pathie_ofstream.hpp">
+ <Filter>3rd_party\pathie-cpp\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\pathie-cpp\include\temp.hpp">
+ <Filter>3rd_party\pathie-cpp\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\examples\mnist\dataset.h">
+ <Filter>examples\mnist</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\examples\mnist\model.h">
+ <Filter>examples\mnist</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\examples\mnist\model_lenet.h">
+ <Filter>examples\mnist</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\examples\mnist\training.h">
+ <Filter>examples\mnist</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\examples\mnist\validator.h">
+ <Filter>examples\mnist</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\command\marian_train.cpp">
+ <Filter>command</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\models\transformer.h">
+ <Filter>models</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\models\bert.h">
+ <Filter>models</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\data\vocab_base.h">
+ <Filter>data</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\data\factored_vocab.h">
+ <Filter>data</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\models\classifier.h">
+ <Filter>models</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\models\encoder_classifier.h">
+ <Filter>models</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\models\transformer.h">
+ <Filter>models</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\training\communicator_nccl.h">
+ <Filter>training</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\tensors\gpu\element.inc">
+ <Filter>tensors\gpu</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\tensors\gpu\add.inc">
+ <Filter>tensors\gpu</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\ConvUtils.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Fbgemm.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmBuild.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmFP16.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmI8DepthwiseAvx2.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\FbgemmI8Spmdm.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\OutputProcessing-inl.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\PackingTraits-inl.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\QuantUtils.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\QuantUtilsAvx2.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Types.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\Utils.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\include\fbgemm\UtilsAvx2.h">
+ <Filter>3rd_party\fbgemm\include\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernel.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernelGeneric.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\ExecuteKernelU8S8.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\FbgemmFP16UKernelsAvx2.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\GenerateKernel.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\GroupwiseConv.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\OptimizedKernelsAvx2.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\RefImplementations.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\TransposeUtils.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\include\cpuinfo.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\include\cpuinfo-mock.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\api.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\cpuid.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows\api.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\common.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\internal-api.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\log.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo\utils.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\asmjit.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include\clog.h">
+ <Filter>3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\arch.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\assembler.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\build.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\builder.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\callconv.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\codebufferwriter_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\codeholder.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\compiler.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\constpool.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\cpuinfo.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\datatypes.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\emitter.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\features.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\func.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\globals.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\inst.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\jitallocator.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\jitruntime.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\logging.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\misc_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\operand.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\osutils.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\raassignment_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rabuilders_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\radefs_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\ralocal_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rapass_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\rastack_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\string.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\support.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\target.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\type.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\virtmem.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zone.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonehash.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonelist.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonestack.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonestring.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonetree.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core\zonevector.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\core</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86assembler.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86builder.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86callconv_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86compiler.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86emitter.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86features.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86globals.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instapi_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instdb.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86instdb_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86internal_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86logging_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86opcode_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86operand.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86\x86rapass_p.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\CodeCache.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\src\FbgemmI8DepthwiseAvx2-inl.h">
+ <Filter>3rd_party\fbgemm\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\test\QuantizationHelpers.h">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\test\TestUtils.h">
+ <Filter>3rd_party\fbgemm\test</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\bench\AlignedVec.h">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\bench\BenchUtils.h">
+ <Filter>3rd_party\fbgemm\bench</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\googletest\googletest\src\gtest-internal-inl.h">
+ <Filter>3rd_party\fbgemm\third_party\googletest\googletest\src</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\half_float\umHalf.h">
+ <Filter>3rd_party\half_float</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\any_type.h">
+ <Filter>3rd_party</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\avx_mathfun.h">
+ <Filter>3rd_party</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\sse_mathfun.h">
+ <Filter>3rd_party</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\functional\approx.h">
+ <Filter>functional</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\functional\operators.h">
+ <Filter>functional</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\zstr\strict_fstream.hpp">
+ <Filter>3rd_party</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\zstr\zstr.hpp">
+ <Filter>3rd_party</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\common\fastopt.h">
+ <Filter>common</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\phf\phf.h">
+ <Filter>3rd_party\phf</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\3rd_party\fbgemm\third_party\asmjit\src\asmjit\core.h">
+ <Filter>3rd_party\fbgemm\third_party\asmjit\src\asmjit</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\tensors\cpu\fbgemm\expanded_gemm.h">
+ <Filter>tensors\cpu\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\tensors\cpu\fbgemm\expression_graph_packable.h">
+ <Filter>tensors\cpu\fbgemm</Filter>
+ </ClInclude>
+ <ClInclude Include="..\src\tensors\cpu\fbgemm\packed_gemm.h">
+ <Filter>tensors\cpu\fbgemm</Filter>
+ </ClInclude>
</ItemGroup>
<ItemGroup>
<Filter Include="3rd_party">
@@ -1386,9 +2319,6 @@
<Filter Include="3rd_party\spdlog\tests">
<UniqueIdentifier>{880c8f51-3306-4d80-a682-7242341b0041}</UniqueIdentifier>
</Filter>
- <Filter Include="3rd_party\svd">
- <UniqueIdentifier>{880c8f51-3306-4d80-a682-7242341b0044}</UniqueIdentifier>
- </Filter>
<Filter Include="3rd_party\yaml-cpp">
<UniqueIdentifier>{880c8f51-3306-4d80-a682-7242341b0047}</UniqueIdentifier>
</Filter>
@@ -1482,52 +2412,302 @@
<Filter Include="3rd_party\sentencepiece\src">
<UniqueIdentifier>{638bf0e1-4f83-4b37-9077-2be549d75909}</UniqueIdentifier>
</Filter>
+ <Filter Include="3rd_party\nccl">
+ <UniqueIdentifier>{0ba105eb-79fb-4e2a-8940-f1ecebbcd4fe}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\nccl\src">
+ <UniqueIdentifier>{fbc17f5e-3f10-44a9-b3ad-66ce12573174}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\nccl\src\collectives">
+ <UniqueIdentifier>{c6036c35-5848-4fd5-b1a0-59e2042cbb69}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\nccl\src\misc">
+ <UniqueIdentifier>{7b9a131d-9e0a-4c28-8a51-08232ff2e35e}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\nccl\src\transport">
+ <UniqueIdentifier>{0bd9cca8-660b-46f6-aac6-691fb50245f0}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\nccl\src\include">
+ <UniqueIdentifier>{2beba56f-5dda-4994-bef0-16170b6552b4}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\nccl\src\collectives\device">
+ <UniqueIdentifier>{ac585624-4e66-42cd-8e4e-62cb90029610}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\pathie-cpp">
+ <UniqueIdentifier>{825beb7c-2997-408b-af81-34ab5f14593a}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\pathie-cpp\include">
+ <UniqueIdentifier>{db1dd5a2-f331-495d-9e3b-6dc1c01528ab}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\pathie-cpp\src">
+ <UniqueIdentifier>{5d5ee615-192f-4b7f-bdfd-fb8316ceabc8}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="tests">
+ <UniqueIdentifier>{a86d650a-2268-43d9-9d74-cb17cd6b534b}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm">
+ <UniqueIdentifier>{4bb88f6d-7ddf-41e0-91be-a43dbcd0e9b0}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\include">
+ <UniqueIdentifier>{6c2bef00-97a0-4881-a6f0-ded54b8520bf}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\include\fbgemm">
+ <UniqueIdentifier>{95f7ce7c-c649-4d57-8d2a-d724bd75fe84}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\src">
+ <UniqueIdentifier>{41f7fbeb-2a73-4747-800c-46307cd0b52b}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party">
+ <UniqueIdentifier>{dc2722bc-af78-4923-82cb-9a09cb290fbf}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\asmjit">
+ <UniqueIdentifier>{577ae810-9593-423d-a398-0787252022b4}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo">
+ <UniqueIdentifier>{f97ae984-fe9a-45f6-a3f4-af90875209ba}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\include">
+ <UniqueIdentifier>{4e7efd32-ec9d-4a1f-b454-656ba5c03275}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src">
+ <UniqueIdentifier>{15b8bcc0-2a07-4d39-8e03-18daa0c33d09}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src\x86">
+ <UniqueIdentifier>{ffd4cf44-177f-47a2-870a-438df9ca3be4}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src\x86\cacehe">
+ <UniqueIdentifier>{b600923b-21c1-492a-bfd9-0aa1082ebcd7}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src\x86\windows">
+ <UniqueIdentifier>{79535a0d-1cdc-45a9-89fb-e9c5794ddff5}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\src\cpuinfo">
+ <UniqueIdentifier>{5709c1ff-41f9-4f83-badb-a7a7c98c1fae}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\asmjit\src">
+ <UniqueIdentifier>{a35aa317-6132-4c31-8f9a-8ec68a4b1c39}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\asmjit\src\asmjit">
+ <UniqueIdentifier>{fc12d7c4-41df-48c0-9017-e8f4d7538cf8}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\asmjit\src\asmjit\x86">
+ <UniqueIdentifier>{5818c959-7963-4d8e-9e87-b61f340476c2}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\deps">
+ <UniqueIdentifier>{d4505c8d-5e6e-4baf-8525-dc59ae8b6415}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\deps\clog">
+ <UniqueIdentifier>{fb9777f1-6887-4286-a58c-0956b356a815}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\deps\clog\include">
+ <UniqueIdentifier>{17125bd0-f21b-4e95-a922-690f5665e9b6}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\cpuinfo\deps\clog\src">
+ <UniqueIdentifier>{8fd74b1e-d3c1-4158-ad46-4a447222934e}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\asmjit\src\asmjit\core">
+ <UniqueIdentifier>{b3b34c5f-5b98-436a-b34c-11e2dccb7ea2}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\test">
+ <UniqueIdentifier>{40576dca-07d5-4904-8119-ffbc982451a3}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\bench">
+ <UniqueIdentifier>{9f11c8f1-78f7-47c6-9eac-34cd2c6cd909}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\googletest">
+ <UniqueIdentifier>{75f9df88-0eb1-4d9a-858e-4e0b8fc3aa8a}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\googletest\googletest">
+ <UniqueIdentifier>{9f77e916-1d2f-4c15-9eba-46bcbddd2658}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\fbgemm\third_party\googletest\googletest\src">
+ <UniqueIdentifier>{050ba410-c56a-4607-8401-935f58f598b5}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\half_float">
+ <UniqueIdentifier>{defd3aec-3c56-4d70-a4bb-90ba9003d98d}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="3rd_party\phf">
+ <UniqueIdentifier>{352ac0e9-daed-437a-bc36-fb85ecd037eb}</UniqueIdentifier>
+ </Filter>
+ <Filter Include="tensors\cpu\fbgemm">
+ <UniqueIdentifier>{bf361868-f451-45b8-9695-570d67924972}</UniqueIdentifier>
+ </Filter>
</ItemGroup>
<ItemGroup>
- <None Include=".editorConfig" />
- <None Include="..\src\training\communicator_nccl.h">
- <Filter>training</Filter>
+ <None Include="..\src\3rd_party\nccl\src\bootstrap.cu">
+ <Filter>3rd_party\nccl\src</Filter>
</None>
- <None Include="..\src\tensors\gpu\add.cu">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\nccl\src\init.cu">
+ <Filter>3rd_party\nccl\src</Filter>
</None>
- <None Include="..\src\tensors\gpu\add.inc">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\sentencepiece\src\sentencepiece.proto">
+ <Filter>3rd_party\sentencepiece\src</Filter>
</None>
- <None Include="..\src\tensors\gpu\algorithm.cu">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\sentencepiece\src\sentencepiece_model.proto">
+ <Filter>3rd_party\sentencepiece\src</Filter>
</None>
- <None Include="..\src\tensors\gpu\cudnn_wrappers.cu">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\nccl\src\Makefile">
+ <Filter>3rd_party\nccl\src</Filter>
</None>
- <None Include="..\src\tensors\gpu\device.cu">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\nccl\src\nccl.h.in">
+ <Filter>3rd_party\nccl\src</Filter>
</None>
- <None Include="..\src\tensors\gpu\element.cu">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\nccl\src\ring.cu">
+ <Filter>3rd_party\nccl\src</Filter>
</None>
- <None Include="..\src\tensors\gpu\element.inc">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\nccl\src\transport.cu">
+ <Filter>3rd_party\nccl\src</Filter>
</None>
- <None Include="..\src\tensors\gpu\prod.cu">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\nccl\src\collectives\all_gather.cu">
+ <Filter>3rd_party\nccl\src\collectives</Filter>
</None>
- <None Include="..\src\tensors\gpu\sparse.cu">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\nccl\src\collectives\all_reduce.cu">
+ <Filter>3rd_party\nccl\src\collectives</Filter>
</None>
- <None Include="..\src\tensors\gpu\tensor_operators.cu">
- <Filter>tensors\gpu</Filter>
+ <None Include="..\src\3rd_party\nccl\src\collectives\broadcast.cu">
+ <Filter>3rd_party\nccl\src\collectives</Filter>
</None>
- <None Include="..\src\3rd_party\sentencepiece\src\sentencepiece.proto">
- <Filter>3rd_party\sentencepiece\src</Filter>
+ <None Include="..\src\3rd_party\nccl\src\collectives\reduce.cu">
+ <Filter>3rd_party\nccl\src\collectives</Filter>
</None>
- <None Include="..\src\3rd_party\sentencepiece\src\sentencepiece_model.proto">
- <Filter>3rd_party\sentencepiece\src</Filter>
+ <None Include="..\src\3rd_party\nccl\src\collectives\reduce_scatter.cu">
+ <Filter>3rd_party\nccl\src\collectives</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\all_gather.cu">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\all_reduce.cu">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\broadcast.cu">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\functions.cu">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\Makefile">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\reduce.cu">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\collectives\device\reduce_scatter.cu">
+ <Filter>3rd_party\nccl\src\collectives\device</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\enqueue.cu">
+ <Filter>3rd_party\nccl\src\misc</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\group.cu">
+ <Filter>3rd_party\nccl\src\misc</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\ibvwrap.cu">
+ <Filter>3rd_party\nccl\src\misc</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\nvmlwrap.cu">
+ <Filter>3rd_party\nccl\src\misc</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\rings.cu">
+ <Filter>3rd_party\nccl\src\misc</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\misc\utils.cu">
+ <Filter>3rd_party\nccl\src\misc</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\net.cu">
+ <Filter>3rd_party\nccl\src\transport</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\net_ib.cu">
+ <Filter>3rd_party\nccl\src\transport</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\net_socket.cu">
+ <Filter>3rd_party\nccl\src\transport</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\p2p.cu">
+ <Filter>3rd_party\nccl\src\transport</Filter>
+ </None>
+ <None Include="..\src\3rd_party\nccl\src\transport\shm.cu">
+ <Filter>3rd_party\nccl\src\transport</Filter>
+ </None>
+ <None Include="..\src\3rd_party\pathie-cpp\CHANGELOG">
+ <Filter>3rd_party\pathie-cpp</Filter>
+ </None>
+ <None Include="..\src\3rd_party\pathie-cpp\LICENSE">
+ <Filter>3rd_party\pathie-cpp</Filter>
+ </None>
+ <None Include="..\src\3rd_party\pathie-cpp\README.md">
+ <Filter>3rd_party\pathie-cpp</Filter>
+ </None>
+ <None Include="..\src\tests\README.md">
+ <Filter>tests</Filter>
+ </None>
+ <None Include="..\src\examples\mnist\download.sh">
+ <Filter>examples\mnist</Filter>
+ </None>
+ <None Include="..\src\examples\iris\iris.data">
+ <Filter>examples\iris</Filter>
+ </None>
+ <None Include="..\src\examples\cmake_install.cmake">
+ <Filter>examples</Filter>
+ </None>
+ <None Include="..\src\examples\Makefile">
+ <Filter>examples</Filter>
+ </None>
+ <None Include="..\src\examples\README.md">
+ <Filter>examples</Filter>
+ </None>
+ <None Include="..\src\3rd_party\half_float\Readme.md">
+ <Filter>3rd_party\half_float</Filter>
+ </None>
+ <None Include="..\src\3rd_party\half_float\umHalf.inl">
+ <Filter>3rd_party\half_float</Filter>
</None>
</ItemGroup>
<ItemGroup>
<Text Include="..\src\3rd_party\sentencepiece\src\CMakeLists.txt">
<Filter>3rd_party\sentencepiece\src</Filter>
</Text>
+ <Text Include="..\src\3rd_party\pathie-cpp\CMakeLists.txt">
+ <Filter>3rd_party\pathie-cpp</Filter>
+ </Text>
+ <Text Include="..\src\tests\CMakeLists.txt">
+ <Filter>tests</Filter>
+ </Text>
+ <Text Include="..\src\examples\CMakeLists.txt">
+ <Filter>examples</Filter>
+ </Text>
+ </ItemGroup>
+ <ItemGroup>
+ <CudaCompile Include="..\src\tensors\gpu\add.cu">
+ <Filter>tensors\gpu</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\tensors\gpu\algorithm.cu">
+ <Filter>tensors\gpu</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\tensors\gpu\cudnn_wrappers.cu">
+ <Filter>tensors\gpu</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\tensors\gpu\device.cu">
+ <Filter>tensors\gpu</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\tensors\gpu\element.cu">
+ <Filter>tensors\gpu</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\translator\helpers.cu">
+ <Filter>translator</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\translator\nth_element.cu">
+ <Filter>translator</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\tensors\gpu\tensor_operators.cu">
+ <Filter>tensors\gpu</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\tensors\gpu\sparse.cu">
+ <Filter>tensors\gpu</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\training\gradient_dropping\gpu\dropper.cu">
+ <Filter>training\gradient_dropping\gpu</Filter>
+ </CudaCompile>
+ <CudaCompile Include="..\src\training\gradient_dropping\gpu\sparse_algorithm.cu">
+ <Filter>training\gradient_dropping\gpu</Filter>
+ </CudaCompile>
</ItemGroup>
</Project> \ No newline at end of file