Welcome to mirror list, hosted at ThFree Co, Russian Federation.

MessagePackSecurity.cs « MessagePack « Scripts « Assets « MessagePack.UnityClient « src - github.com/aspnet/MessagePack-CSharp.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 2b743afa4a2010c7dfd6d2b153ba3c32465293a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
// Copyright (c) All contributors. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.ExceptionServices;
using MessagePack.Formatters;
using MessagePack.Internal;

namespace MessagePack
{
    /// <summary>
    /// Settings related to security, particularly relevant when deserializing data from untrusted sources.
    /// </summary>
    public class MessagePackSecurity
    {
        /// <summary>
        /// Gets an instance preconfigured with settings that omit all protections. Useful for deserializing fully-trusted and valid msgpack sequences.
        /// </summary>
        public static readonly MessagePackSecurity TrustedData = new MessagePackSecurity();

        /// <summary>
        /// Gets an instance preconfigured with protections applied with reasonable settings for deserializing untrusted msgpack sequences.
        /// </summary>
        public static readonly MessagePackSecurity UntrustedData = new MessagePackSecurity
        {
            HashCollisionResistant = true,
            MaximumObjectGraphDepth = 500,
        };

        private readonly ObjectFallbackEqualityComparer objectFallbackEqualityComparer;

        private MessagePackSecurity()
        {
            this.objectFallbackEqualityComparer = new ObjectFallbackEqualityComparer(this);
        }

        /// <summary>
        /// Initializes a new instance of the <see cref="MessagePackSecurity"/> class
        /// with properties copied from a provided template.
        /// </summary>
        /// <param name="copyFrom">The template to copy from.</param>
        protected MessagePackSecurity(MessagePackSecurity copyFrom)
            : this()
        {
            if (copyFrom is null)
            {
                throw new ArgumentNullException(nameof(copyFrom));
            }

            this.HashCollisionResistant = copyFrom.HashCollisionResistant;
            this.MaximumObjectGraphDepth = copyFrom.MaximumObjectGraphDepth;
        }

        /// <summary>
        /// Gets a value indicating whether data to be deserialized is untrusted and thus should not be allowed to create
        /// dictionaries or other hash-based collections unless the hashed type has a hash collision resistant implementation available.
        /// This can mitigate some denial of service attacks when deserializing untrusted code.
        /// </summary>
        /// <value>
        /// The value is <c>false</c> for <see cref="TrustedData"/> and <c>true</c> for <see cref="UntrustedData"/>.
        /// </value>
        public bool HashCollisionResistant { get; private set; }

        /// <summary>
        /// Gets the maximum depth of an object graph that may be deserialized.
        /// </summary>
        /// <remarks>
        /// <para>
        /// This value can be reduced to avoid a stack overflow that would crash the process when deserializing a msgpack sequence designed to cause deep recursion.
        /// A very short callstack on a thread with 1MB of total stack space might deserialize ~2000 nested arrays before crashing due to a stack overflow.
        /// Since stack space occupied may vary by the kind of object deserialized, a conservative value for this property to defend against stack overflow attacks might be 500.
        /// </para>
        /// </remarks>
        public int MaximumObjectGraphDepth { get; private set; } = int.MaxValue;

        /// <summary>
        /// Gets a copy of these options with the <see cref="MaximumObjectGraphDepth"/> property set to a new value.
        /// </summary>
        /// <param name="maximumObjectGraphDepth">The new value for the <see cref="MaximumObjectGraphDepth"/> property.</param>
        /// <returns>The new instance; or the original if the value is unchanged.</returns>
        public MessagePackSecurity WithMaximumObjectGraphDepth(int maximumObjectGraphDepth)
        {
            if (this.MaximumObjectGraphDepth == maximumObjectGraphDepth)
            {
                return this;
            }

            var clone = this.Clone();
            clone.MaximumObjectGraphDepth = maximumObjectGraphDepth;
            return clone;
        }

        /// <summary>
        /// Gets a copy of these options with the <see cref="HashCollisionResistant"/> property set to a new value.
        /// </summary>
        /// <param name="hashCollisionResistant">The new value for the <see cref="HashCollisionResistant"/> property.</param>
        /// <returns>The new instance; or the original if the value is unchanged.</returns>
        public MessagePackSecurity WithHashCollisionResistant(bool hashCollisionResistant)
        {
            if (this.HashCollisionResistant == hashCollisionResistant)
            {
                return this;
            }

            var clone = this.Clone();
            clone.HashCollisionResistant = hashCollisionResistant;
            return clone;
        }

        /// <summary>
        /// Gets an <see cref="IEqualityComparer{T}"/> that is suitable to use with a hash-based collection.
        /// </summary>
        /// <typeparam name="T">The type of key that will be hashed in the collection.</typeparam>
        /// <returns>The <see cref="IEqualityComparer{T}"/> to use.</returns>
        /// <remarks>
        /// When <see cref="HashCollisionResistant"/> is active, this will be a collision resistant instance which may reject certain key types.
        /// When <see cref="HashCollisionResistant"/> is not active, this will be <see cref="EqualityComparer{T}.Default"/>.
        /// </remarks>
        public IEqualityComparer<T> GetEqualityComparer<T>()
        {
            return this.HashCollisionResistant ? GetHashCollisionResistantEqualityComparer<T>() : EqualityComparer<T>.Default;
        }

        /// <summary>
        /// Gets an <see cref="IEqualityComparer"/> that is suitable to use with a hash-based collection.
        /// </summary>
        /// <returns>The <see cref="IEqualityComparer"/> to use.</returns>
        /// <remarks>
        /// When <see cref="HashCollisionResistant"/> is active, this will be a collision resistant instance which may reject certain key types.
        /// When <see cref="HashCollisionResistant"/> is not active, this will be <see cref="EqualityComparer{T}.Default"/>.
        /// </remarks>
        public IEqualityComparer GetEqualityComparer()
        {
            return this.HashCollisionResistant ? GetHashCollisionResistantEqualityComparer() : EqualityComparer<object>.Default;
        }

        /// <summary>
        /// Returns a hash collision resistant equality comparer.
        /// </summary>
        /// <typeparam name="T">The type of key that will be hashed in the collection.</typeparam>
        /// <returns>A hash collision resistant equality comparer.</returns>
        protected virtual IEqualityComparer<T> GetHashCollisionResistantEqualityComparer<T>()
        {
            IEqualityComparer<T> result = null;
            if (typeof(T).GetTypeInfo().IsEnum)
            {
                Type underlyingType = typeof(T).GetTypeInfo().GetEnumUnderlyingType();
                result =
                    underlyingType == typeof(sbyte) ? CollisionResistantHasher<T>.Instance :
                    underlyingType == typeof(byte) ? CollisionResistantHasher<T>.Instance :
                    underlyingType == typeof(short) ? CollisionResistantHasher<T>.Instance :
                    underlyingType == typeof(ushort) ? CollisionResistantHasher<T>.Instance :
                    underlyingType == typeof(int) ? CollisionResistantHasher<T>.Instance :
                    underlyingType == typeof(uint) ? CollisionResistantHasher<T>.Instance :
                    null;
            }
            else
            {
                // For anything 32-bits and under, our fallback base secure hasher is usually adequate since it makes the hash unpredictable.
                // We should have special implementations for any value that is larger than 32-bits in order to make sure
                // that all the data gets hashed securely rather than trivially and predictably compressed into 32-bits before being hashed.
                // We also have to specially handle some 32-bit types (e.g. float) where multiple in-memory representations should hash to the same value.
                // Any type supported by the PrimitiveObjectFormatter should be added here if supporting it as a key in a collection makes sense.
                result =
                    // 32-bits or smaller:
                    typeof(T) == typeof(bool) ? CollisionResistantHasher<T>.Instance :
                    typeof(T) == typeof(char) ? CollisionResistantHasher<T>.Instance :
                    typeof(T) == typeof(sbyte) ? CollisionResistantHasher<T>.Instance :
                    typeof(T) == typeof(byte) ? CollisionResistantHasher<T>.Instance :
                    typeof(T) == typeof(short) ? CollisionResistantHasher<T>.Instance :
                    typeof(T) == typeof(ushort) ? CollisionResistantHasher<T>.Instance :
                    typeof(T) == typeof(int) ? CollisionResistantHasher<T>.Instance :
                    typeof(T) == typeof(uint) ? CollisionResistantHasher<T>.Instance :

                    // Larger than 32-bits (or otherwise require special handling):
                    typeof(T) == typeof(long) ? (IEqualityComparer<T>)Int64EqualityComparer.Instance :
                    typeof(T) == typeof(ulong) ? (IEqualityComparer<T>)UInt64EqualityComparer.Instance :
                    typeof(T) == typeof(float) ? (IEqualityComparer<T>)SingleEqualityComparer.Instance :
                    typeof(T) == typeof(double) ? (IEqualityComparer<T>)DoubleEqualityComparer.Instance :
                    typeof(T) == typeof(string) ? (IEqualityComparer<T>)StringEqualityComparer.Instance :
                    typeof(T) == typeof(Guid) ? (IEqualityComparer<T>)GuidEqualityComparer.Instance :
                    typeof(T) == typeof(DateTime) ? (IEqualityComparer<T>)DateTimeEqualityComparer.Instance :
                    typeof(T) == typeof(DateTimeOffset) ? (IEqualityComparer<T>)DateTimeOffsetEqualityComparer.Instance :
                    typeof(T) == typeof(object) ? (IEqualityComparer<T>)this.objectFallbackEqualityComparer :
                    null;
            }

            // Any type we don't explicitly whitelist here shouldn't be allowed to use as the key in a hash-based collection since it isn't known to be hash resistant.
            // This method can of course be overridden to add more hash collision resistant type support, or the deserializing party can indicate that the data is Trusted
            // so that this method doesn't even get called.
            return result ?? throw new TypeAccessException($"No hash-resistant equality comparer available for type: {typeof(T)}");
        }

        /// <summary>
        /// Checks the depth of the deserializing graph and increments it by 1.
        /// </summary>
        /// <param name="reader">The reader that is involved in deserialization.</param>
        /// <remarks>
        /// Callers should decrement <see cref="MessagePackReader.Depth"/> after exiting that edge in the graph.
        /// </remarks>
        /// <exception cref="InsufficientExecutionStackException">Thrown if <see cref="MessagePackReader.Depth"/> is already at or exceeds <see cref="MaximumObjectGraphDepth"/>.</exception>
        /// <remarks>
        /// Rather than wrap the body of every <see cref="IMessagePackFormatter{T}.Deserialize"/> method,
        /// this should wrap *calls* to these methods. They need not appear in pure "thunk" methods that simply delegate the deserialization to another formatter.
        /// In this way, we can avoid repeatedly incrementing and decrementing the counter when deserializing each element of a collection.
        /// </remarks>
        public void DepthStep(ref MessagePackReader reader)
        {
            if (reader.Depth >= this.MaximumObjectGraphDepth)
            {
                throw new InsufficientExecutionStackException($"This msgpack sequence has an object graph that exceeds the maximum depth allowed of {MaximumObjectGraphDepth}.");
            }

            reader.Depth++;
        }

        /// <summary>
        /// Returns a hash collision resistant equality comparer.
        /// </summary>
        /// <returns>A hash collision resistant equality comparer.</returns>
        protected virtual IEqualityComparer GetHashCollisionResistantEqualityComparer() => (IEqualityComparer)this.GetHashCollisionResistantEqualityComparer<object>();

        /// <summary>
        /// Creates a new instance that is a copy of this one.
        /// </summary>
        /// <remarks>
        /// Derived types should override this method to instantiate their own derived type.
        /// </remarks>
        protected virtual MessagePackSecurity Clone() => new MessagePackSecurity(this);

        /// <summary>
        /// A hash collision resistant implementation of <see cref="IEqualityComparer{T}"/>.
        /// </summary>
        /// <typeparam name="T">The type of key that will be hashed.</typeparam>
        private class CollisionResistantHasher<T> : IEqualityComparer<T>, IEqualityComparer
        {
            internal static readonly CollisionResistantHasher<T> Instance = new CollisionResistantHasher<T>();

            public bool Equals(T x, T y) => EqualityComparer<T>.Default.Equals(x, y);

            bool IEqualityComparer.Equals(object x, object y) => ((IEqualityComparer)EqualityComparer<T>.Default).Equals(x, y);

            public int GetHashCode(object obj) => this.GetHashCode((T)obj);

            public virtual int GetHashCode(T value) => HashCode.Combine(value);
        }

        /// <summary>
        /// A special hash-resistent equality comparer that defers picking the actual implementation
        /// till it can check the runtime type of each value to be hashed.
        /// </summary>
        private class ObjectFallbackEqualityComparer : IEqualityComparer<object>, IEqualityComparer
        {
            private static readonly Lazy<MethodInfo> GetHashCollisionResistantEqualityComparerOpenGenericMethod = new Lazy<MethodInfo>(() => typeof(MessagePackSecurity).GetTypeInfo().DeclaredMethods.Single(m => m.Name == nameof(MessagePackSecurity.GetHashCollisionResistantEqualityComparer) && m.IsGenericMethod));
            private readonly MessagePackSecurity security;
            private readonly ThreadsafeTypeKeyHashTable<IEqualityComparer> equalityComparerCache = new ThreadsafeTypeKeyHashTable<IEqualityComparer>();

            internal ObjectFallbackEqualityComparer(MessagePackSecurity security)
            {
                this.security = security ?? throw new ArgumentNullException(nameof(security));
            }

            bool IEqualityComparer<object>.Equals(object x, object y) => EqualityComparer<object>.Default.Equals(x, y);

            bool IEqualityComparer.Equals(object x, object y) => ((IEqualityComparer)EqualityComparer<object>.Default).Equals(x, y);

            public int GetHashCode(object value)
            {
                if (value is null)
                {
                    return 0;
                }

                Type valueType = value.GetType();

                // Take care to avoid recursion.
                if (valueType == typeof(object))
                {
                    // We can trust object.GetHashCode() to be collision resistant.
                    return value.GetHashCode();
                }

                if (!equalityComparerCache.TryGetValue(valueType, out IEqualityComparer equalityComparer))
                {
                    try
                    {
                        equalityComparer = (IEqualityComparer)GetHashCollisionResistantEqualityComparerOpenGenericMethod.Value.MakeGenericMethod(valueType).Invoke(this.security, Array.Empty<object>());
                    }
                    catch (TargetInvocationException ex)
                    {
                        ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
                    }

                    equalityComparerCache.TryAdd(valueType, equalityComparer);
                }

                return equalityComparer.GetHashCode(value);
            }
        }

        private class UInt64EqualityComparer : CollisionResistantHasher<ulong>
        {
            internal static new readonly UInt64EqualityComparer Instance = new UInt64EqualityComparer();

            public override int GetHashCode(ulong value) => HashCode.Combine((uint)(value >> 32), unchecked((uint)value));
        }

        private class Int64EqualityComparer : CollisionResistantHasher<long>
        {
            internal static new readonly Int64EqualityComparer Instance = new Int64EqualityComparer();

            public override int GetHashCode(long value) => HashCode.Combine((int)(value >> 32), unchecked((int)value));
        }

        private class SingleEqualityComparer : CollisionResistantHasher<float>
        {
            internal static new readonly SingleEqualityComparer Instance = new SingleEqualityComparer();

            public override unsafe int GetHashCode(float value)
            {
                // Special check for 0.0 so that the hash of 0.0 and -0.0 will equal.
                if (value == 0.0f)
                {
                    return HashCode.Combine(0);
                }

                // Standardize on the binary representation of NaN prior to hashing.
                if (float.IsNaN(value))
                {
                    value = float.NaN;
                }

                int l = *(int*)&value;
                return l;
            }
        }

        private class DoubleEqualityComparer : CollisionResistantHasher<double>
        {
            internal static new readonly DoubleEqualityComparer Instance = new DoubleEqualityComparer();

            public override unsafe int GetHashCode(double value)
            {
                // Special check for 0.0 so that the hash of 0.0 and -0.0 will equal.
                if (value == 0.0)
                {
                    return HashCode.Combine(0);
                }

                // Standardize on the binary representation of NaN prior to hashing.
                if (double.IsNaN(value))
                {
                    value = double.NaN;
                }

                long l = *(long*)&value;
                return HashCode.Combine((int)(l >> 32), unchecked((int)l));
            }
        }

        private class GuidEqualityComparer : CollisionResistantHasher<Guid>
        {
            internal static new readonly GuidEqualityComparer Instance = new GuidEqualityComparer();

            public override unsafe int GetHashCode(Guid value)
            {
                var hash = default(HashCode);
                int* pGuid = (int*)&value;
                for (int i = 0; i < sizeof(Guid) / sizeof(int); i++)
                {
                    hash.Add(pGuid[i]);
                }

                return hash.ToHashCode();
            }
        }

        private class StringEqualityComparer : CollisionResistantHasher<string>
        {
            internal static new readonly StringEqualityComparer Instance = new StringEqualityComparer();

            public override int GetHashCode(string value)
            {
#if NETCOREAPP
                // .NET Core already has a secure string hashing function. Just use it.
                return value?.GetHashCode() ?? 0;
#else
                var hash = default(HashCode);
                for (int i = 0; i < value.Length; i++)
                {
                    hash.Add(value[i]);
                }

                return hash.ToHashCode();
#endif
            }
        }

        private class DateTimeEqualityComparer : CollisionResistantHasher<DateTime>
        {
            internal static new readonly DateTimeEqualityComparer Instance = new DateTimeEqualityComparer();

            public override unsafe int GetHashCode(DateTime value) => HashCode.Combine((int)(value.Ticks >> 32), unchecked((int)value.Ticks), value.Kind);
        }

        private class DateTimeOffsetEqualityComparer : CollisionResistantHasher<DateTimeOffset>
        {
            internal static new readonly DateTimeOffsetEqualityComparer Instance = new DateTimeOffsetEqualityComparer();

            public override unsafe int GetHashCode(DateTimeOffset value) => HashCode.Combine((int)(value.UtcTicks >> 32), unchecked((int)value.UtcTicks));
        }
    }
}