Welcome to mirror list, hosted at ThFree Co, Russian Federation.

ClientCertLoader.cs « RequestProcessing « src « HttpSys « Servers « src - github.com/dotnet/aspnetcore.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 7243ad6dc22f9061c946fcd403ba72bc8e143cbb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Contracts;
using System.Runtime.InteropServices;
using System.Security;
using System.Security.Authentication.ExtendedProtection;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using Microsoft.AspNetCore.HttpSys.Internal;
using Microsoft.Extensions.Logging;

namespace Microsoft.AspNetCore.Server.HttpSys;

// This class is used to load the client certificate on-demand.  Because client certs are optional, all
// failures are handled internally and reported via ClientCertException or ClientCertError.
internal sealed unsafe partial class ClientCertLoader : IAsyncResult, IDisposable
{
    private const uint CertBoblSize = 1500;
    private static readonly IOCompletionCallback IOCallback = new IOCompletionCallback(WaitCallback);
    private static readonly int RequestChannelBindStatusSize =
        Marshal.SizeOf<HttpApiTypes.HTTP_REQUEST_CHANNEL_BIND_STATUS>();

    private SafeNativeOverlapped? _overlapped;
    private byte[]? _backingBuffer;
    private HttpApiTypes.HTTP_SSL_CLIENT_CERT_INFO* _memoryBlob;
    private uint _size;
    private readonly TaskCompletionSource<object?> _tcs;
    private readonly RequestContext _requestContext;

    private int _clientCertError;
    private X509Certificate2? _clientCert;
    private Exception? _clientCertException;
    private readonly CancellationTokenRegistration _cancellationRegistration;

    internal ClientCertLoader(RequestContext requestContext, CancellationToken cancellationToken)
    {
        _requestContext = requestContext;
        _tcs = new TaskCompletionSource<object?>();
        // we will use this overlapped structure to issue async IO to ul
        // the event handle will be put in by the BeginHttpApi2.ERROR_SUCCESS() method
        Reset(CertBoblSize);

        if (cancellationToken.CanBeCanceled)
        {
            _cancellationRegistration = RequestContext.RegisterForCancellation(cancellationToken);
        }
    }

    internal SafeHandle RequestQueueHandle => _requestContext.Server.RequestQueue.Handle;

    internal X509Certificate2? ClientCert
    {
        get
        {
            Contract.Assert(Task.IsCompleted);
            return _clientCert;
        }
    }

    internal int ClientCertError
    {
        get
        {
            Contract.Assert(Task.IsCompleted);
            return _clientCertError;
        }
    }

    internal Exception? ClientCertException
    {
        get
        {
            Contract.Assert(Task.IsCompleted);
            return _clientCertException;
        }
    }

    private RequestContext RequestContext
    {
        get
        {
            return _requestContext;
        }
    }

    private Task Task
    {
        get
        {
            return _tcs.Task;
        }
    }

    private SafeNativeOverlapped? NativeOverlapped
    {
        get
        {
            return _overlapped;
        }
    }

    private HttpApiTypes.HTTP_SSL_CLIENT_CERT_INFO* RequestBlob
    {
        get
        {
            return _memoryBlob;
        }
    }

    private void Reset(uint size)
    {
        if (size == _size)
        {
            return;
        }
        if (_size != 0)
        {
            _overlapped!.Dispose();
        }
        _size = size;
        if (size == 0)
        {
            _overlapped = null;
            _memoryBlob = null;
            _backingBuffer = null;
            return;
        }
        _backingBuffer = new byte[checked((int)size)];
        var boundHandle = RequestContext.Server.RequestQueue.BoundHandle;
        _overlapped = new SafeNativeOverlapped(boundHandle,
            boundHandle.AllocateNativeOverlapped(IOCallback, this, _backingBuffer));
        _memoryBlob = (HttpApiTypes.HTTP_SSL_CLIENT_CERT_INFO*)Marshal.UnsafeAddrOfPinnedArrayElement(_backingBuffer, 0);
    }

    // When you use netsh to configure HTTP.SYS with clientcertnegotiation = enable
    // which means negotiate client certificates, when the client makes the
    // initial SSL connection, the server (HTTP.SYS) requests the client certificate.
    //
    // Some apps may not want to negotiate the client cert at the beginning,
    // perhaps serving the default.htm. In this case the HTTP.SYS is configured
    // with clientcertnegotiation = disabled, which means that the client certificate is
    // optional so initially when SSL is established HTTP.SYS won't ask for client
    // certificate. This works fine for the default.htm in the case above,
    // however, if the app wants to demand a client certificate at a later time
    // perhaps showing "YOUR ORDERS" page, then the server wants to negotiate
    // Client certs. This will in turn makes HTTP.SYS to do the
    // SEC_I_RENOGOTIATE through which the client cert demand is made
    //
    // NOTE: When calling HttpReceiveClientCertificate you can get
    // ERROR_NOT_FOUND - which means the client did not provide the cert
    // If this is important, the server should respond with 403 forbidden
    // HTTP.SYS will not do this for you automatically
    internal Task LoadClientCertificateAsync()
    {
        uint size = CertBoblSize;
        bool retry;
        do
        {
            retry = false;
            uint bytesReceived = 0;

            uint statusCode =
                HttpApi.HttpReceiveClientCertificate(
                    RequestQueueHandle,
                    RequestContext.Request.UConnectionId,
                    (uint)HttpApiTypes.HTTP_FLAGS.NONE,
                    RequestBlob,
                    size,
                    &bytesReceived,
                    NativeOverlapped!);

            if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA)
            {
                HttpApiTypes.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = RequestBlob;
                size = bytesReceived + pClientCertInfo->CertEncodedSize;
                Reset(size);
                retry = true;
            }
            else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_NOT_FOUND)
            {
                // The client did not send a cert.
                Complete(0, null);
            }
            else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS &&
                HttpSysListener.SkipIOCPCallbackOnSuccess)
            {
                IOCompleted(statusCode, bytesReceived);
            }
            else if (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS &&
                statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING)
            {
                // Some other bad error, possible(?) return values are:
                // ERROR_INVALID_HANDLE, ERROR_INSUFFICIENT_BUFFER, ERROR_OPERATION_ABORTED
                // Also ERROR_BAD_DATA if we got it twice or it reported smaller size buffer required.
                Fail(new HttpSysException((int)statusCode));
            }
        }
        while (retry);

        return Task;
    }

    private void Complete(int certErrors, X509Certificate2? cert)
    {
        // May be null
        _clientCert = cert;
        _clientCertError = certErrors;
        Dispose();
        _tcs.TrySetResult(null);
    }

    private void Fail(Exception ex)
    {
        // TODO: Log
        _clientCertException = ex;
        Dispose();
        _tcs.TrySetResult(null);
    }

    private unsafe void IOCompleted(uint errorCode, uint numBytes)
    {
        IOCompleted(this, errorCode, numBytes);
    }

    [SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "Redirected to callback")]
    private static unsafe void IOCompleted(ClientCertLoader asyncResult, uint errorCode, uint numBytes)
    {
        RequestContext requestContext = asyncResult.RequestContext;
        try
        {
            if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA)
            {
                // There is a bug that has existed in http.sys since w2k3.  Bytesreceived will only
                // return the size of the initial cert structure.  To get the full size,
                // we need to add the certificate encoding size as well.

                HttpApiTypes.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = asyncResult.RequestBlob;
                asyncResult.Reset(numBytes + pClientCertInfo->CertEncodedSize);

                uint bytesReceived = 0;
                errorCode =
                    HttpApi.HttpReceiveClientCertificate(
                        requestContext.Server.RequestQueue.Handle,
                        requestContext.Request.UConnectionId,
                        (uint)HttpApiTypes.HTTP_FLAGS.NONE,
                        asyncResult._memoryBlob,
                        asyncResult._size,
                        &bytesReceived,
                        asyncResult._overlapped!);

                if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_IO_PENDING ||
                   (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS && !HttpSysListener.SkipIOCPCallbackOnSuccess))
                {
                    return;
                }
            }

            if (errorCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_NOT_FOUND)
            {
                // The client did not send a cert.
                asyncResult.Complete(0, null);
            }
            else if (errorCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS)
            {
                asyncResult.Fail(new HttpSysException((int)errorCode));
            }
            else
            {
                HttpApiTypes.HTTP_SSL_CLIENT_CERT_INFO* pClientCertInfo = asyncResult._memoryBlob;
                if (pClientCertInfo == null)
                {
                    asyncResult.Complete(0, null);
                }
                else
                {
                    if (pClientCertInfo->pCertEncoded != null)
                    {
                        try
                        {
                            byte[] certEncoded = new byte[pClientCertInfo->CertEncodedSize];
                            Marshal.Copy((IntPtr)pClientCertInfo->pCertEncoded, certEncoded, 0, certEncoded.Length);
                            asyncResult.Complete((int)pClientCertInfo->CertFlags, new X509Certificate2(certEncoded));
                        }
                        catch (CryptographicException exception)
                        {
                            // TODO: Log
                            asyncResult.Fail(exception);
                        }
                        catch (SecurityException exception)
                        {
                            // TODO: Log
                            asyncResult.Fail(exception);
                        }
                    }
                }
            }
        }
        catch (Exception exception)
        {
            asyncResult.Fail(exception);
        }
    }

    private static unsafe void WaitCallback(uint errorCode, uint numBytes, NativeOverlapped* nativeOverlapped)
    {
        var asyncResult = (ClientCertLoader)ThreadPoolBoundHandle.GetNativeOverlappedState(nativeOverlapped)!;
        IOCompleted(asyncResult, errorCode, numBytes);
    }

    public void Dispose()
    {
        Dispose(true);
    }

    private void Dispose(bool disposing)
    {
        if (disposing)
        {
            _cancellationRegistration.Dispose();
            if (_overlapped != null)
            {
                _memoryBlob = null;
                _overlapped.Dispose();
            }
        }
    }

    public object? AsyncState
    {
        get { return _tcs.Task.AsyncState; }
    }

    public WaitHandle AsyncWaitHandle
    {
        get { return ((IAsyncResult)_tcs.Task).AsyncWaitHandle; }
    }

    public bool CompletedSynchronously
    {
        get { return ((IAsyncResult)_tcs.Task).CompletedSynchronously; }
    }

    public bool IsCompleted
    {
        get { return _tcs.Task.IsCompleted; }
    }

    internal static unsafe ChannelBinding? GetChannelBindingFromTls(RequestQueue requestQueue, ulong connectionId, ILogger logger)
    {
        // +128 since a CBT is usually <128 thus we need to call HRCC just once. If the CBT
        // is >128 we will get ERROR_MORE_DATA and call again
        int size = RequestChannelBindStatusSize + 128;

        Debug.Assert(size >= 0);

        byte[]? blob = null;
        SafeLocalFreeChannelBinding? token = null;

        uint bytesReceived = 0; ;
        uint statusCode;

        do
        {
            blob = new byte[size];
            fixed (byte* blobPtr = blob)
            {
                // Http.sys team: ServiceName will always be null if
                // HTTP_RECEIVE_SECURE_CHANNEL_TOKEN flag is set.
                statusCode = HttpApi.HttpReceiveClientCertificate(
                    requestQueue.Handle,
                    connectionId,
                    (uint)HttpApiTypes.HTTP_FLAGS.HTTP_RECEIVE_SECURE_CHANNEL_TOKEN,
                    blobPtr,
                    (uint)size,
                    &bytesReceived,
                    SafeNativeOverlapped.Zero);

                if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS)
                {
                    int tokenOffset = GetTokenOffsetFromBlob((IntPtr)blobPtr);
                    int tokenSize = GetTokenSizeFromBlob((IntPtr)blobPtr);
                    Debug.Assert(tokenSize < Int32.MaxValue);

                    token = SafeLocalFreeChannelBinding.LocalAlloc(tokenSize);

                    Marshal.Copy(blob, tokenOffset, token.DangerousGetHandle(), tokenSize);
                }
                else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_MORE_DATA)
                {
                    int tokenSize = GetTokenSizeFromBlob((IntPtr)blobPtr);
                    Debug.Assert(tokenSize < Int32.MaxValue);

                    size = RequestChannelBindStatusSize + tokenSize;
                }
                else if (statusCode == UnsafeNclNativeMethods.ErrorCodes.ERROR_INVALID_PARAMETER)
                {
                    Log.ChannelBindingUnsupported(logger);
                    return null; // old schannel library which doesn't support CBT
                }
                else
                {
                    // It's up to the consumer to fail if the missing ChannelBinding matters to them.
                    Log.ChannelBindingMissing(logger, new HttpSysException((int)statusCode));
                    break;
                }
            }
        }
        while (statusCode != UnsafeNclNativeMethods.ErrorCodes.ERROR_SUCCESS);

        return token;
    }

    private static int GetTokenOffsetFromBlob(IntPtr blob)
    {
        Debug.Assert(blob != IntPtr.Zero);
        IntPtr tokenPointer = Marshal.ReadIntPtr(blob, (int)Marshal.OffsetOf<HttpApiTypes.HTTP_REQUEST_CHANNEL_BIND_STATUS>("ChannelToken"));
        Debug.Assert(tokenPointer != IntPtr.Zero);
        return (int)IntPtrHelper.Subtract(tokenPointer, blob);
    }

    private static int GetTokenSizeFromBlob(IntPtr blob)
    {
        Debug.Assert(blob != IntPtr.Zero);
        return Marshal.ReadInt32(blob, (int)Marshal.OffsetOf<HttpApiTypes.HTTP_REQUEST_CHANNEL_BIND_STATUS>("ChannelTokenSize"));
    }
}