Welcome to mirror list, hosted at ThFree Co, Russian Federation.

force_render_red.cpp « fuzz « source - github.com/KhronosGroup/SPIRV-Tools.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 3267487ad6f2210cfc971775191400a40a40f12b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
// Copyright (c) 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/fuzz/force_render_red.h"

#include "source/fuzz/fact_manager/fact_manager.h"
#include "source/fuzz/instruction_descriptor.h"
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
#include "source/fuzz/transformation_context.h"
#include "source/fuzz/transformation_replace_constant_with_uniform.h"
#include "source/opt/build_module.h"
#include "source/opt/ir_context.h"
#include "source/opt/types.h"
#include "source/util/make_unique.h"

namespace spvtools {
namespace fuzz {

namespace {

// Helper method to find the fragment shader entry point, complaining if there
// is no shader or if there is no fragment entry point.
opt::Function* FindFragmentShaderEntryPoint(opt::IRContext* ir_context,
                                            MessageConsumer message_consumer) {
  // Check that this is a fragment shader
  bool found_capability_shader = false;
  for (auto& capability : ir_context->capabilities()) {
    assert(capability.opcode() == SpvOpCapability);
    if (capability.GetSingleWordInOperand(0) == SpvCapabilityShader) {
      found_capability_shader = true;
      break;
    }
  }
  if (!found_capability_shader) {
    message_consumer(
        SPV_MSG_ERROR, nullptr, {},
        "Forcing of red rendering requires the Shader capability.");
    return nullptr;
  }

  opt::Instruction* fragment_entry_point = nullptr;
  for (auto& entry_point : ir_context->module()->entry_points()) {
    if (entry_point.GetSingleWordInOperand(0) == SpvExecutionModelFragment) {
      fragment_entry_point = &entry_point;
      break;
    }
  }
  if (fragment_entry_point == nullptr) {
    message_consumer(SPV_MSG_ERROR, nullptr, {},
                     "Forcing of red rendering requires an entry point with "
                     "the Fragment execution model.");
    return nullptr;
  }

  for (auto& function : *ir_context->module()) {
    if (function.result_id() ==
        fragment_entry_point->GetSingleWordInOperand(1)) {
      return &function;
    }
  }
  assert(
      false &&
      "A valid module must have a function associate with each entry point.");
  return nullptr;
}

// Helper method to check that there is a single vec4 output variable and get a
// pointer to it.
opt::Instruction* FindVec4OutputVariable(opt::IRContext* ir_context,
                                         MessageConsumer message_consumer) {
  opt::Instruction* output_variable = nullptr;
  for (auto& inst : ir_context->types_values()) {
    if (inst.opcode() == SpvOpVariable &&
        inst.GetSingleWordInOperand(0) == SpvStorageClassOutput) {
      if (output_variable != nullptr) {
        message_consumer(SPV_MSG_ERROR, nullptr, {},
                         "Only one output variable can be handled at present; "
                         "found multiple.");
        return nullptr;
      }
      output_variable = &inst;
      // Do not break, as we want to check for multiple output variables.
    }
  }
  if (output_variable == nullptr) {
    message_consumer(SPV_MSG_ERROR, nullptr, {},
                     "No output variable to which to write red was found.");
    return nullptr;
  }

  auto output_variable_base_type = ir_context->get_type_mgr()
                                       ->GetType(output_variable->type_id())
                                       ->AsPointer()
                                       ->pointee_type()
                                       ->AsVector();
  if (!output_variable_base_type ||
      output_variable_base_type->element_count() != 4 ||
      !output_variable_base_type->element_type()->AsFloat()) {
    message_consumer(SPV_MSG_ERROR, nullptr, {},
                     "The output variable must have type vec4.");
    return nullptr;
  }

  return output_variable;
}

// Helper to get the ids of float constants 0.0 and 1.0, creating them if
// necessary.
std::pair<uint32_t, uint32_t> FindOrCreateFloatZeroAndOne(
    opt::IRContext* ir_context, opt::analysis::Float* float_type) {
  float one = 1.0;
  uint32_t one_as_uint;
  memcpy(&one_as_uint, &one, sizeof(float));
  std::vector<uint32_t> zero_bytes = {0};
  std::vector<uint32_t> one_bytes = {one_as_uint};
  auto constant_zero = ir_context->get_constant_mgr()->RegisterConstant(
      MakeUnique<opt::analysis::FloatConstant>(float_type, zero_bytes));
  auto constant_one = ir_context->get_constant_mgr()->RegisterConstant(
      MakeUnique<opt::analysis::FloatConstant>(float_type, one_bytes));
  auto constant_zero_id = ir_context->get_constant_mgr()
                              ->GetDefiningInstruction(constant_zero)
                              ->result_id();
  auto constant_one_id = ir_context->get_constant_mgr()
                             ->GetDefiningInstruction(constant_one)
                             ->result_id();
  return std::pair<uint32_t, uint32_t>(constant_zero_id, constant_one_id);
}

std::unique_ptr<TransformationReplaceConstantWithUniform>
MakeConstantUniformReplacement(opt::IRContext* ir_context,
                               const FactManager& fact_manager,
                               uint32_t constant_id,
                               uint32_t greater_than_instruction,
                               uint32_t in_operand_index) {
  return MakeUnique<TransformationReplaceConstantWithUniform>(
      MakeIdUseDescriptor(constant_id,
                          MakeInstructionDescriptor(greater_than_instruction,
                                                    SpvOpFOrdGreaterThan, 0),
                          in_operand_index),
      fact_manager.GetUniformDescriptorsForConstant(constant_id)[0],
      ir_context->TakeNextId(), ir_context->TakeNextId());
}

}  // namespace

bool ForceRenderRed(
    const spv_target_env& target_env, spv_validator_options validator_options,
    const std::vector<uint32_t>& binary_in,
    const spvtools::fuzz::protobufs::FactSequence& initial_facts,
    const MessageConsumer& message_consumer,
    std::vector<uint32_t>* binary_out) {
  spvtools::SpirvTools tools(target_env);
  if (!tools.IsValid()) {
    message_consumer(SPV_MSG_ERROR, nullptr, {},
                     "Failed to create SPIRV-Tools interface; stopping.");
    return false;
  }

  // Initial binary should be valid.
  if (!tools.Validate(&binary_in[0], binary_in.size(), validator_options)) {
    message_consumer(SPV_MSG_ERROR, nullptr, {},
                     "Initial binary is invalid; stopping.");
    return false;
  }

  // Build the module from the input binary.
  std::unique_ptr<opt::IRContext> ir_context = BuildModule(
      target_env, message_consumer, binary_in.data(), binary_in.size());
  assert(ir_context);

  // Set up a fact manager with any given initial facts.
  TransformationContext transformation_context(
      MakeUnique<FactManager>(ir_context.get()), validator_options);
  for (auto& fact : initial_facts.fact()) {
    transformation_context.GetFactManager()->MaybeAddFact(fact);
  }

  auto entry_point_function =
      FindFragmentShaderEntryPoint(ir_context.get(), message_consumer);
  auto output_variable =
      FindVec4OutputVariable(ir_context.get(), message_consumer);
  if (entry_point_function == nullptr || output_variable == nullptr) {
    return false;
  }

  opt::analysis::Float temp_float_type(32);
  opt::analysis::Float* float_type = ir_context->get_type_mgr()
                                         ->GetRegisteredType(&temp_float_type)
                                         ->AsFloat();
  std::pair<uint32_t, uint32_t> zero_one_float_ids =
      FindOrCreateFloatZeroAndOne(ir_context.get(), float_type);

  // Make the new exit block
  auto new_exit_block_id = ir_context->TakeNextId();
  {
    auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
                                              new_exit_block_id,
                                              opt::Instruction::OperandList());
    auto new_exit_block = MakeUnique<opt::BasicBlock>(std::move(label));
    new_exit_block->AddInstruction(MakeUnique<opt::Instruction>(
        ir_context.get(), SpvOpReturn, 0, 0, opt::Instruction::OperandList()));
    entry_point_function->AddBasicBlock(std::move(new_exit_block));
  }

  // Make the new entry block
  {
    auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
                                              ir_context->TakeNextId(),
                                              opt::Instruction::OperandList());
    auto new_entry_block = MakeUnique<opt::BasicBlock>(std::move(label));

    // Make an instruction to construct vec4(1.0, 0.0, 0.0, 1.0), representing
    // the colour red.
    opt::Operand zero_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.first}};
    opt::Operand one_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.second}};
    opt::Instruction::OperandList op_composite_construct_operands = {
        one_float, zero_float, zero_float, one_float};
    auto temp_vec4 = opt::analysis::Vector(float_type, 4);
    auto vec4_id = ir_context->get_type_mgr()->GetId(&temp_vec4);
    auto red = MakeUnique<opt::Instruction>(
        ir_context.get(), SpvOpCompositeConstruct, vec4_id,
        ir_context->TakeNextId(), op_composite_construct_operands);
    auto red_id = red->result_id();
    new_entry_block->AddInstruction(std::move(red));

    // Make an instruction to store red into the output color.
    opt::Operand variable_to_store_into = {SPV_OPERAND_TYPE_ID,
                                           {output_variable->result_id()}};
    opt::Operand value_to_be_stored = {SPV_OPERAND_TYPE_ID, {red_id}};
    opt::Instruction::OperandList op_store_operands = {variable_to_store_into,
                                                       value_to_be_stored};
    new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
        ir_context.get(), SpvOpStore, 0, 0, op_store_operands));

    // We are going to attempt to construct 'false' as an expression of the form
    // 'literal1 > literal2'. If we succeed, we will later replace each literal
    // with a uniform of the same value - we can only do that replacement once
    // we have added the entry block to the module.
    std::unique_ptr<TransformationReplaceConstantWithUniform>
        first_greater_then_operand_replacement = nullptr;
    std::unique_ptr<TransformationReplaceConstantWithUniform>
        second_greater_then_operand_replacement = nullptr;
    uint32_t id_guaranteed_to_be_false = 0;

    opt::analysis::Bool temp_bool_type;
    opt::analysis::Bool* registered_bool_type =
        ir_context->get_type_mgr()
            ->GetRegisteredType(&temp_bool_type)
            ->AsBool();

    auto float_type_id = ir_context->get_type_mgr()->GetId(float_type);
    auto types_for_which_uniforms_are_known =
        transformation_context.GetFactManager()
            ->GetTypesForWhichUniformValuesAreKnown();

    // Check whether we have any float uniforms.
    if (std::find(types_for_which_uniforms_are_known.begin(),
                  types_for_which_uniforms_are_known.end(),
                  float_type_id) != types_for_which_uniforms_are_known.end()) {
      // We have at least one float uniform; let's see whether we have at least
      // two.
      auto available_constants =
          transformation_context.GetFactManager()
              ->GetConstantsAvailableFromUniformsForType(float_type_id);
      if (available_constants.size() > 1) {
        // Grab the float constants associated with the first two known float
        // uniforms.
        auto first_constant =
            ir_context->get_constant_mgr()
                ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
                    available_constants[0]))
                ->AsFloatConstant();
        auto second_constant =
            ir_context->get_constant_mgr()
                ->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
                    available_constants[1]))
                ->AsFloatConstant();

        // Now work out which of the two constants is larger than the other.
        uint32_t larger_constant_index = 0;
        uint32_t smaller_constant_index = 0;
        if (first_constant->GetFloat() > second_constant->GetFloat()) {
          larger_constant_index = 0;
          smaller_constant_index = 1;
        } else if (first_constant->GetFloat() < second_constant->GetFloat()) {
          larger_constant_index = 1;
          smaller_constant_index = 0;
        }

        // Only proceed with these constants if they have turned out to be
        // distinct.
        if (larger_constant_index != smaller_constant_index) {
          // We are in a position to create 'false' as 'literal1 > literal2', so
          // reserve an id for this computation; this id will end up being
          // guaranteed to be 'false'.
          id_guaranteed_to_be_false = ir_context->TakeNextId();

          auto smaller_constant = available_constants[smaller_constant_index];
          auto larger_constant = available_constants[larger_constant_index];

          opt::Instruction::OperandList greater_than_operands = {
              {SPV_OPERAND_TYPE_ID, {smaller_constant}},
              {SPV_OPERAND_TYPE_ID, {larger_constant}}};
          new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
              ir_context.get(), SpvOpFOrdGreaterThan,
              ir_context->get_type_mgr()->GetId(registered_bool_type),
              id_guaranteed_to_be_false, greater_than_operands));

          first_greater_then_operand_replacement =
              MakeConstantUniformReplacement(
                  ir_context.get(), *transformation_context.GetFactManager(),
                  smaller_constant, id_guaranteed_to_be_false, 0);
          second_greater_then_operand_replacement =
              MakeConstantUniformReplacement(
                  ir_context.get(), *transformation_context.GetFactManager(),
                  larger_constant, id_guaranteed_to_be_false, 1);
        }
      }
    }

    if (id_guaranteed_to_be_false == 0) {
      auto constant_false = ir_context->get_constant_mgr()->RegisterConstant(
          MakeUnique<opt::analysis::BoolConstant>(registered_bool_type, false));
      id_guaranteed_to_be_false = ir_context->get_constant_mgr()
                                      ->GetDefiningInstruction(constant_false)
                                      ->result_id();
    }

    opt::Operand false_condition = {SPV_OPERAND_TYPE_ID,
                                    {id_guaranteed_to_be_false}};
    opt::Operand then_block = {SPV_OPERAND_TYPE_ID,
                               {entry_point_function->entry()->id()}};
    opt::Operand else_block = {SPV_OPERAND_TYPE_ID, {new_exit_block_id}};
    opt::Instruction::OperandList op_branch_conditional_operands = {
        false_condition, then_block, else_block};
    new_entry_block->AddInstruction(
        MakeUnique<opt::Instruction>(ir_context.get(), SpvOpBranchConditional,
                                     0, 0, op_branch_conditional_operands));

    entry_point_function->InsertBasicBlockBefore(
        std::move(new_entry_block), entry_point_function->entry().get());

    for (auto& replacement : {first_greater_then_operand_replacement.get(),
                              second_greater_then_operand_replacement.get()}) {
      if (replacement) {
        assert(replacement->IsApplicable(ir_context.get(),
                                         transformation_context));
        replacement->Apply(ir_context.get(), &transformation_context);
      }
    }
  }

  // Write out the module as a binary.
  ir_context->module()->ToBinary(binary_out, false);
  return true;
}

}  // namespace fuzz
}  // namespace spvtools