Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/KhronosGroup/SPIRV-Tools.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorSpencer Fricke <spencerfricke@gmail.com>2022-08-08 21:45:04 +0300
committerGitHub <noreply@github.com>2022-08-08 21:45:04 +0300
commitf20e8d05f5f7fa436ba04f76ce5c4ee07c124742 (patch)
treebfe1fc874a7e934f7c58f6bd0cbc562d1c683fc9 /test
parent5e61ea2098220059e89523f1f47b0bcd8c33b89a (diff)
spirv-val: Add SPV_KHR_ray_tracing instructions (#4871)
Diffstat (limited to 'test')
-rw-r--r--test/val/CMakeLists.txt1
-rw-r--r--test/val/val_ray_tracing.cpp555
2 files changed, 556 insertions, 0 deletions
diff --git a/test/val/CMakeLists.txt b/test/val/CMakeLists.txt
index d02807a7e..e73e0f4af 100644
--- a/test/val/CMakeLists.txt
+++ b/test/val/CMakeLists.txt
@@ -91,6 +91,7 @@ add_spvtools_unittest(TARGET val_fghijklmnop
add_spvtools_unittest(TARGET val_rstuvw
SRCS
val_ray_query.cpp
+ val_ray_tracing.cpp
val_small_type_uses_test.cpp
val_ssa_test.cpp
val_state_test.cpp
diff --git a/test/val/val_ray_tracing.cpp b/test/val/val_ray_tracing.cpp
new file mode 100644
index 000000000..9486777a0
--- /dev/null
+++ b/test/val/val_ray_tracing.cpp
@@ -0,0 +1,555 @@
+// Copyright (c) 2022 The Khronos Group Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests ray tracing instructions from SPV_KHR_ray_tracing.
+
+#include <sstream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "test/val/val_fixtures.h"
+
+namespace spvtools {
+namespace val {
+namespace {
+
+using ::testing::HasSubstr;
+using ::testing::Values;
+
+using ValidateRayTracing = spvtest::ValidateBase<bool>;
+
+TEST_F(ValidateRayTracing, IgnoreIntersectionSuccess) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint AnyHitKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%main = OpFunction %void None %func
+%label = OpLabel
+OpIgnoreIntersectionKHR
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, IgnoreIntersectionExecutionModel) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint CallableKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%main = OpFunction %void None %func
+%label = OpLabel
+OpIgnoreIntersectionKHR
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("OpIgnoreIntersectionKHR requires AnyHitKHR execution model"));
+}
+
+TEST_F(ValidateRayTracing, TerminateRaySuccess) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint AnyHitKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%main = OpFunction %void None %func
+%label = OpLabel
+OpIgnoreIntersectionKHR
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, TerminateRayExecutionModel) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint MissKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%main = OpFunction %void None %func
+%label = OpLabel
+OpTerminateRayKHR
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("OpTerminateRayKHR requires AnyHitKHR execution model"));
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionRaySuccess) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint IntersectionKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float_1 = OpConstant %float 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%bool = OpTypeBool
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %bool %float_1 %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionExecutionModel) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint MissKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float_1 = OpConstant %float 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%bool = OpTypeBool
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %bool %float_1 %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr(
+ "OpReportIntersectionKHR requires IntersectionKHR execution model"));
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionReturnType) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint IntersectionKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float_1 = OpConstant %float 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %uint %float_1 %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("expected Result Type to be bool scalar type"));
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionHit) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpCapability Float64
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint IntersectionKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float64 = OpTypeFloat 64
+%float64_1 = OpConstant %float64 1
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%bool = OpTypeBool
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %bool %float64_1 %uint_1
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Hit must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, ReportIntersectionHitKind) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint IntersectionKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%float = OpTypeFloat 32
+%float_1 = OpConstant %float 1
+%sint = OpTypeInt 32 1
+%sint_1 = OpConstant %sint 1
+%bool = OpTypeBool
+%main = OpFunction %void None %func
+%label = OpLabel
+%report = OpReportIntersectionKHR %bool %float_1 %sint_1
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Hit Kind must be a 32-bit unsigned int scalar"));
+}
+
+TEST_F(ValidateRayTracing, ExecuteCallableSuccess) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint CallableKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%data_ptr = OpTypePointer CallableDataKHR %int
+%data = OpVariable %data_ptr CallableDataKHR
+%inData_ptr = OpTypePointer IncomingCallableDataKHR %int
+%inData = OpVariable %inData_ptr IncomingCallableDataKHR
+%main = OpFunction %void None %func
+%label = OpLabel
+OpExecuteCallableKHR %uint_0 %data
+OpExecuteCallableKHR %uint_0 %inData
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, ExecuteCallableExecutionModel) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint AnyHitKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%data_ptr = OpTypePointer CallableDataKHR %int
+%data = OpVariable %data_ptr CallableDataKHR
+%inData_ptr = OpTypePointer IncomingCallableDataKHR %int
+%inData = OpVariable %inData_ptr IncomingCallableDataKHR
+%main = OpFunction %void None %func
+%label = OpLabel
+OpExecuteCallableKHR %uint_0 %data
+OpExecuteCallableKHR %uint_0 %inData
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("OpExecuteCallableKHR requires RayGenerationKHR, "
+ "ClosestHitKHR, MissKHR and CallableKHR execution models"));
+}
+
+TEST_F(ValidateRayTracing, ExecuteCallableStorageClass) {
+ const std::string body = R"(
+OpCapability RayTracingKHR
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint RayGenerationKHR %main "main"
+OpName %main "main"
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%int = OpTypeInt 32 1
+%uint = OpTypeInt 32 0
+%uint_0 = OpConstant %uint 0
+%data_ptr = OpTypePointer RayPayloadKHR %int
+%data = OpVariable %data_ptr RayPayloadKHR
+%main = OpFunction %void None %func
+%label = OpLabel
+OpExecuteCallableKHR %uint_0 %data
+OpReturn
+OpFunctionEnd
+)";
+
+ CompileSuccessfully(body.c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Callable Data must have storage class CallableDataKHR "
+ "or IncomingCallableDataKHR"));
+}
+
+std::string GenerateRayTraceCode(
+ const std::string& body,
+ const std::string execution_model = "RayGenerationKHR") {
+ std::ostringstream ss;
+ ss << R"(
+OpCapability RayTracingKHR
+OpCapability Float64
+OpExtension "SPV_KHR_ray_tracing"
+OpMemoryModel Logical GLSL450
+OpEntryPoint )"
+ << execution_model << R"( %main "main"
+OpDecorate %top_level_as DescriptorSet 0
+OpDecorate %top_level_as Binding 0
+%void = OpTypeVoid
+%func = OpTypeFunction %void
+%type_as = OpTypeAccelerationStructureKHR
+%as_uc_ptr = OpTypePointer UniformConstant %type_as
+%top_level_as = OpVariable %as_uc_ptr UniformConstant
+%uint = OpTypeInt 32 0
+%uint_1 = OpConstant %uint 1
+%float = OpTypeFloat 32
+%float64 = OpTypeFloat 64
+%f32vec3 = OpTypeVector %float 3
+%f32vec4 = OpTypeVector %float 4
+%float_0 = OpConstant %float 0
+%float64_0 = OpConstant %float64 0
+%v3composite = OpConstantComposite %f32vec3 %float_0 %float_0 %float_0
+%v4composite = OpConstantComposite %f32vec4 %float_0 %float_0 %float_0 %float_0
+%int = OpTypeInt 32 1
+%int_1 = OpConstant %int 1
+%payload_ptr = OpTypePointer RayPayloadKHR %int
+%payload = OpVariable %payload_ptr RayPayloadKHR
+%callable_ptr = OpTypePointer CallableDataKHR %int
+%callable = OpVariable %callable_ptr CallableDataKHR
+%ptr_uint = OpTypePointer Private %uint
+%var_uint = OpVariable %ptr_uint Private
+%ptr_float = OpTypePointer Private %float
+%var_float = OpVariable %ptr_float Private
+%ptr_f32vec3 = OpTypePointer Private %f32vec3
+%var_f32vec3 = OpVariable %ptr_f32vec3 Private
+%main = OpFunction %void None %func
+%label = OpLabel
+)";
+
+ ss << body;
+
+ ss << R"(
+OpReturn
+OpFunctionEnd)";
+ return ss.str();
+}
+
+TEST_F(ValidateRayTracing, TraceRaySuccess) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+
+%_uint = OpLoad %uint %var_uint
+%_float = OpLoad %float %var_float
+%_f32vec3 = OpLoad %f32vec3 %var_f32vec3
+OpTraceRayKHR %as %_uint %_uint %_uint %_uint %_uint %_f32vec3 %_float %_f32vec3 %_float %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
+}
+
+TEST_F(ValidateRayTracing, TraceRayExecutionModel) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body, "CallableKHR").c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("OpTraceRayKHR requires RayGenerationKHR, "
+ "ClosestHitKHR and MissKHR execution models"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayAccelerationStructure) {
+ const std::string body = R"(
+%_uint = OpLoad %uint %var_uint
+OpTraceRayKHR %_uint %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Expected Acceleration Structure to be of type "
+ "OpTypeAccelerationStructureKHR"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayFlags) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %float_0 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray Flags must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayCullMask) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %float_0 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Cull Mask must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRaySbtOffest) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %float_0 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("SBT Offset must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRaySbtStride) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %float_0 %uint_1 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("SBT Stride must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayMissIndex) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %float_0 %v3composite %float_0 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Miss Index must be a 32-bit int scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayOrigin) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %float_0 %float_0 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Ray Origin must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayTMin) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %uint_1 %v3composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray TMin must be a 32-bit float scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayDirection) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v4composite %float_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(
+ getDiagnosticString(),
+ HasSubstr("Ray Direction must be a 32-bit float 3-component vector"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayRayTMax) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float64_0 %payload
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Ray TMax must be a 32-bit float scalar"));
+}
+
+TEST_F(ValidateRayTracing, TraceRayPayload) {
+ const std::string body = R"(
+%as = OpLoad %type_as %top_level_as
+OpTraceRayKHR %as %uint_1 %uint_1 %uint_1 %uint_1 %uint_1 %v3composite %float_0 %v3composite %float_0 %callable
+)";
+
+ CompileSuccessfully(GenerateRayTraceCode(body).c_str());
+ EXPECT_EQ(SPV_ERROR_INVALID_DATA, ValidateInstructions());
+ EXPECT_THAT(getDiagnosticString(),
+ HasSubstr("Payload must have storage class RayPayloadKHR or "
+ "IncomingRayPayloadKHR"));
+}
+
+} // namespace
+} // namespace val
+} // namespace spvtools