From 8735b752c6f7038c0c2c3b5e257f54fbd28e40f8 Mon Sep 17 00:00:00 2001 From: Safia Abdalla Date: Tue, 19 Apr 2022 15:11:36 -0700 Subject: Capture target from delegates and MethodInfo (#41253) * Capture target from delegates and MethodInfo * Address feedback from peer review * Fix up tests and targetFactory invocation * Fix build error in tests * Validate delegate invoked in tests --- .../Http.Extensions/src/RequestDelegateFactory.cs | 25 +++-- .../test/RequestDelegateFactoryTests.cs | 120 +++++++++++++++++++++ 2 files changed, 137 insertions(+), 8 deletions(-) diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 7b25889d0c..edd832b148 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -121,7 +121,9 @@ public static partial class RequestDelegateFactory var factoryContext = CreateFactoryContext(options); - var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext); + Expression> targetFactory = (httpContext) => handler.Target; + + var targetableRequestDelegate = CreateTargetableRequestDelegate(handler.Method, targetExpression, factoryContext, targetFactory); return new RequestDelegateResult(httpContext => targetableRequestDelegate(handler.Target, httpContext), factoryContext.Metadata); } @@ -162,7 +164,7 @@ public static partial class RequestDelegateFactory } var targetExpression = Expression.Convert(TargetExpr, methodInfo.DeclaringType); - var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, targetExpression, factoryContext); + var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, targetExpression, factoryContext, context => targetFactory(context)); return new RequestDelegateResult(httpContext => targetableRequestDelegate(targetFactory(httpContext), httpContext), factoryContext.Metadata); } @@ -187,7 +189,7 @@ public static partial class RequestDelegateFactory return context; } - private static Func CreateTargetableRequestDelegate(MethodInfo methodInfo, Expression? targetExpression, FactoryContext factoryContext) + private static Func CreateTargetableRequestDelegate(MethodInfo methodInfo, Expression? targetExpression, FactoryContext factoryContext, Expression>? targetFactory = null) { // Non void return type @@ -223,7 +225,7 @@ public static partial class RequestDelegateFactory // return type associated with the request to allow for the filter invocation pipeline. if (factoryContext.Filters is { Count: > 0 }) { - var filterPipeline = CreateFilterPipeline(methodInfo, targetExpression, factoryContext); + var filterPipeline = CreateFilterPipeline(methodInfo, targetExpression, factoryContext, targetFactory); Expression>> invokePipeline = (context) => filterPipeline(context); returnType = typeof(ValueTask); // var filterContext = new RouteHandlerInvocationContext(httpContext, new[] { (object)name_local, (object)int_local }); @@ -250,22 +252,29 @@ public static partial class RequestDelegateFactory return HandleRequestBodyAndCompileRequestDelegate(responseWritingMethodCall, factoryContext); } - private static RouteHandlerFilterDelegate CreateFilterPipeline(MethodInfo methodInfo, Expression? target, FactoryContext factoryContext) + private static RouteHandlerFilterDelegate CreateFilterPipeline(MethodInfo methodInfo, Expression? targetExpression, FactoryContext factoryContext, Expression>? targetFactory) { Debug.Assert(factoryContext.Filters is not null); // httpContext.Response.StatusCode >= 400 // ? Task.CompletedTask - // : handler((string)context.Parameters[0], (int)context.Parameters[1]) + // : { + // target = targetFactory(httpContext); + // handler is ((Type)target).MethodName(parameters); + // handler((string)context.Parameters[0], (int)context.Parameters[1]); + // } var filteredInvocation = Expression.Lambda( Expression.Condition( Expression.GreaterThanOrEqual(FilterContextHttpContextStatusCodeExpr, Expression.Constant(400)), CompletedValueTaskExpr, Expression.Block( new[] { TargetExpr }, + targetFactory == null + ? Expression.Empty() + : Expression.Assign(TargetExpr, Expression.Invoke(targetFactory, FilterContextHttpContextExpr)), Expression.Call(WrapObjectAsValueTaskMethod, - target is null + targetExpression is null ? Expression.Call(methodInfo, factoryContext.ContextArgAccess) - : Expression.Call(target, methodInfo, factoryContext.ContextArgAccess)) + : Expression.Call(targetExpression, methodInfo, factoryContext.ContextArgAccess)) )), FilterContextExpr).Compile(); var routeHandlerContext = new RouteHandlerContext( diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 5501077e54..727219b43b 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -4251,6 +4251,126 @@ public class RequestDelegateFactoryTests : LoggedTest Assert.Equal(400, httpContext.Response.StatusCode); } + [Fact] + public async Task RequestDelegateFactory_InvokesFilters_OnDelegateWithTarget() + { + // Arrange + var httpContext = CreateHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create((string name) => $"Hello, {name}!", new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + return await next(context); + } + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + + Assert.Equal(200, httpContext.Response.StatusCode); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("Hello, TestName!", decodedResponseBody); + } + + string GetString(string name) + { + return $"Hello, {name}!"; + } + + [Fact] + public async Task RequestDelegateFactory_InvokesFilters_OnMethodInfoWithNullTargetFactory() + { + // Arrange + var methodInfo = typeof(RequestDelegateFactoryTests).GetMethod( + nameof(GetString), + BindingFlags.NonPublic | BindingFlags.Instance, + new[] { typeof(string) }); + var httpContext = CreateHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + + // Act + var factoryResult = RequestDelegateFactory.Create(methodInfo!, null, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + return await next(context); + } + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + Assert.Equal(200, httpContext.Response.StatusCode); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("Hello, TestName!", decodedResponseBody); + } + + [Fact] + public async Task RequestDelegateFactory_InvokesFilters_OnMethodInfoWithProvidedTargetFactory() + { + // Arrange + var invoked = false; + var methodInfo = typeof(RequestDelegateFactoryTests).GetMethod( + nameof(GetString), + BindingFlags.NonPublic | BindingFlags.Instance, + new[] { typeof(string) }); + var httpContext = CreateHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + httpContext.Request.Query = new QueryCollection(new Dictionary + { + ["name"] = "TestName" + }); + + // Act + Func targetFactory = (context) => + { + invoked = true; + context.Items["invoked"] = true; + return Activator.CreateInstance(methodInfo!.DeclaringType!)!; + }; + var factoryResult = RequestDelegateFactory.Create(methodInfo!, targetFactory, new RequestDelegateFactoryOptions() + { + RouteHandlerFilterFactories = new List>() + { + (routeHandlerContext, next) => async (context) => + { + return await next(context); + } + } + }); + var requestDelegate = factoryResult.RequestDelegate; + await requestDelegate(httpContext); + + // Assert + Assert.Equal(200, httpContext.Response.StatusCode); + Assert.True(invoked); + var invokedInContext = Assert.IsType(httpContext.Items["invoked"]); + Assert.True(invokedInContext); + var decodedResponseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + Assert.Equal("Hello, TestName!", decodedResponseBody); + } + [Fact] public async Task RequestDelegateFactory_CanInvokeSingleEndpointFilter_ThatProvidesCustomErrorMessage() { -- cgit v1.2.3