How to get where clause from IQueryable defined as interface

  class Program
{
    static void Main(string[] args)
    {
        var c = new SampleClass<ClassString>();
        c.ClassStrings.Add(new ClassString{ Name1 = "1", Name2 = "1"});
        c.ClassStrings.Add(new ClassString{ Name1 = "2", Name2 = "2"});

        var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2"));

        Console.WriteLine(result);
        Console.ReadLine();
    }
}

public class ClassString
{
    public string Name1 { get; set; }
    public string Name2 { get; set; }
}



public interface ISampleQ
{
    IQueryable<T> Query<T>() where T: class , new();
}
public class SampleClass<X> : ISampleQ
{
    public List<X> ClassStrings { get; private set; }

    public SampleClass()
    {
        ClassStrings = new List<X>();
    }


    public IQueryable<T> Query<T>() where T : class, new()
    {
        //Get the WHERE expression from here.
        return new EnumerableQuery<T>((IEnumerable<T>) ClassStrings);
    }
}

      

I have looked at this solution1 , solution2 and solution3 doesn't seem to apply to my question. Because the where clause is defined outside and it was the interface of the class. How do I get an expression inside a Query method? since no variable passes through.

Purpose, I want to get and insert back to the destination (which is DBContext as IQueryable). Since we have a common interface for us like this ISampleQ.


Added new code examples, but the same scenario:

 internal class Program
{
    private static void Main(string[] args)
    {
        var oracleDbContext = new OracleDbContext();
        var result = oracleDbContext.Query<Person>().Where(person => person.Name.Equals("username"));

        Console.WriteLine();
        Console.ReadLine();
    }
}

public interface IGenericQuery
{
    IQueryable<T> Query<T>() where T : class , new();
}

public class OracleDbContext : IGenericQuery
{
    public OracleDbContext()
    {
        //Will hold all oracle operations here. For brevity, only
        //Query are exposed.
    }

    public IQueryable<T> Query<T>() where T : class, new()
    {
        //Get the where predicate here. Since the where was defined outside of the
        //class. I want to retrieve since the IQueryable<T> is generic to both class
        //OracleDbContext and MssqlDbContext. I want to re-inject the where or add 
        //new expression before calling.
        //
        //For eg.
        //oracleDbContext.Query<T>(where clause from here)
        return null;
    }
}

public class MssqlDbContext : IGenericQuery
{
    public MssqlDbContext()
    {
        //Will hold all MSSQL operations here. For brevity, only
        //Query are exposed.
    }

    public IQueryable<T> Query<T>() where T : class, new()
    {
        //Get the where predicate here.
        return null;
    }
}

public class Person
{
    public int Id { get; set; }
    public int Name { get; set; }
}

      

+3


source to share


1 answer


It's pretty tricky ... Now ... it Queryable.Where()

works like this:

public static IQueryable<TSource> Where<TSource>(this IQueryable<TSource> source, Expression<Func<TSource, bool>> predicate)
{
    return source.Provider.CreateQuery<TSource>(Expression.Call(null, ... 

      

Thus, a Queryable.Where

calls source.Provider.CreateQuery()

that reconfigures a new one IQueryable<>

. So if you want to "see" Where()

while it is being added (and manipulated), you must "be" IQueryable<>.Provider

and have your own CreateQuery()

, so you must create a class that implements IQueryProvider

(and probably a class that implements IQueryable<T>

).

Another way (much easier) is a simple query "converter": a method that takes IQueryable<>

and returns a managed one IQueryable<>

:

var result = c.Query<ClassString>().Where(s => s.Name1.Equals("2")).FixMyQuery();

      

As I said, the full route is quite long:

namespace Utilities
{
    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Collections.ObjectModel;
    using System.Data.Entity;
    using System.Data.Entity.Infrastructure;
    using System.Linq;
    using System.Linq.Expressions;
    using System.Reflection;
    using System.Threading;
    using System.Threading.Tasks;

    public class ProxyDbContext : DbContext
    {
        protected static readonly MethodInfo ProxifySetsMethod = typeof(ProxyDbContext).GetMethod("ProxifySets", BindingFlags.Instance | BindingFlags.NonPublic);

        protected static class ProxyDbContexSetter<TContext> where TContext : ProxyDbContext
        {
            public static readonly Action<TContext> Do = x => { };

            static ProxyDbContexSetter()
            {
                var properties = typeof(TContext).GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy);

                ParameterExpression context = Expression.Parameter(typeof(TContext), "context");

                FieldInfo manipulatorField = typeof(ProxyDbContext).GetField("Manipulator", BindingFlags.Instance | BindingFlags.Public);
                Expression manipulator = Expression.Field(context, manipulatorField);

                var sets = new List<Expression>();

                foreach (PropertyInfo property in properties)
                {
                    if (property.GetMethod == null)
                    {
                        continue;
                    }

                    MethodInfo setMethod = property.SetMethod;
                    if (setMethod != null && !setMethod.IsPublic)
                    {
                        continue;
                    }

                    Type type = property.PropertyType;
                    Type entityType = GetIDbSetTypeArgument(type);

                    if (entityType == null)
                    {
                        continue;
                    }

                    if (!type.IsAssignableFrom(typeof(DbSet<>).MakeGenericType(entityType)))
                    {
                        continue;
                    }

                    Type dbSetType = typeof(DbSet<>).MakeGenericType(entityType);

                    ConstructorInfo constructor = typeof(ProxyDbSet<>)
                        .MakeGenericType(entityType)
                        .GetConstructor(new[] 
                    { 
                        dbSetType, 
                        typeof(Func<bool, Expression, Expression>) 
                    });

                    MemberExpression property2 = Expression.Property(context, property);
                    BinaryExpression assign = Expression.Assign(property2, Expression.New(constructor, Expression.Convert(property2, dbSetType), manipulator));

                    sets.Add(assign);
                }

                Expression<Action<TContext>> lambda = Expression.Lambda<Action<TContext>>(Expression.Block(sets), context);
                Do = lambda.Compile();
            }

            // Gets the T of IDbSetlt;T&gt;
            private static Type GetIDbSetTypeArgument(Type type)
            {
                IEnumerable<Type> interfaces = type.IsInterface ?
                    new[] { type }.Concat(type.GetInterfaces()) :
                    type.GetInterfaces();

                Type argument = (from x in interfaces
                                 where x.IsGenericType
                                 let gt = x.GetGenericTypeDefinition()
                                 where gt == typeof(IDbSet<>)
                                 select x.GetGenericArguments()[0]).SingleOrDefault();
                return argument;
            }
        }

        public readonly Func<bool, Expression, Expression> Manipulator;

        /// <summary>
        /// 
        /// </summary>
        /// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
        /// <param name="resetSets">True to have all the DbSet&lt;TEntity&gt; and IDbSet&lt;TEntity&gt; proxified</param>
        public ProxyDbContext(Func<bool, Expression, Expression> manipulator, bool resetSets = true)
        {
            Manipulator = manipulator;

            if (resetSets)
            {
                ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
            }
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="nameOrConnectionString"></param>
        /// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
        /// <param name="resetSets">True to have all the DbSet&lt;TEntity&gt; and IDbSet&lt;TEntity&gt; proxified</param>
        public ProxyDbContext(string nameOrConnectionString, Func<bool, Expression, Expression> manipulator, bool resetSets = true)
            : base(nameOrConnectionString)
        {
            Manipulator = manipulator;

            if (resetSets)
            {
                ProxifySetsMethod.MakeGenericMethod(GetType()).Invoke(this, null);
            }
        }

        protected void ProxifySets<TContext>() where TContext : ProxyDbContext
        {
            ProxyDbContexSetter<TContext>.Do((TContext)this);
        }

        public override DbSet<TEntity> Set<TEntity>()
        {
            return new ProxyDbSet<TEntity>(base.Set<TEntity>(), Manipulator);
        }

        public override DbSet Set(Type entityType)
        {
            DbSet set = base.Set(entityType);
            ConstructorInfo constructor = typeof(ProxyDbSetNonGeneric<>)
                .MakeGenericType(entityType)
                .GetConstructor(new[] 
                { 
                    typeof(DbSet), 
                    typeof(Func<bool, Expression, Expression>) 
                });

            return (DbSet)constructor.Invoke(new object[] { set, Manipulator });
        }
    }

    /// <summary>
    /// The DbSet, that is implemented as InternalDbSet&lt&gt; by EF.
    /// </summary>
    /// <typeparam name="TEntity"></typeparam>
    public class ProxyDbSetNonGeneric<TEntity> : DbSet, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
    {
        protected readonly DbSet BaseDbSet;
        protected readonly IQueryable<TEntity> ProxyQueryable;

        public readonly Func<bool, Expression, Expression> Manipulator;

        protected readonly FieldInfo InternalSetField = typeof(DbSet).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);

        /// <summary>
        /// 
        /// </summary>
        /// <param name="baseDbSet"></param>
        /// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
        public ProxyDbSetNonGeneric(DbSet baseDbSet, Func<bool, Expression, Expression> manipulator)
        {
            BaseDbSet = baseDbSet;

            IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
            ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);

            ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
            Manipulator = manipulator;

            if (InternalSetField != null)
            {
                InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
            }
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="baseDbSet"></param>
        /// <param name="proxyQueryable"></param>
        /// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
        public ProxyDbSetNonGeneric(DbSet baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
        {
            BaseDbSet = baseDbSet;

            ProxyQueryable = proxyQueryable;
            Manipulator = manipulator;

            if (InternalSetField != null)
            {
                InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
            }
        }

        public override object Add(object entity)
        {
            return BaseDbSet.Add(entity);
        }

        public override IEnumerable AddRange(IEnumerable entities)
        {
            return BaseDbSet.AddRange(entities);
        }

        public override DbQuery AsNoTracking()
        {
            return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsNoTracking()), Manipulator);
        }

        [Obsolete]
        public override DbQuery AsStreaming()
        {
#pragma warning disable 618
            return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
        }

        public override object Attach(object entity)
        {
            return BaseDbSet.Attach(entity);
        }

        public override object Create(Type derivedEntityType)
        {
            return BaseDbSet.Create(derivedEntityType);
        }

        public override object Create()
        {
            return BaseDbSet.Create();
        }

        public override object Find(params object[] keyValues)
        {
            return BaseDbSet.Find(keyValues);
        }

        public override Task<object> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
        {
            return BaseDbSet.FindAsync(cancellationToken, keyValues);
        }

        public override Task<object> FindAsync(params object[] keyValues)
        {
            return BaseDbSet.FindAsync(keyValues);
        }

        public override DbQuery Include(string path)
        {
            return new ProxyDbSetNonGeneric<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, (IQueryable<TEntity>)BaseDbSet.Include(path)), Manipulator);
        }

        public override IList Local
        {
            get
            {
                return BaseDbSet.Local;
            }
        }

        public override object Remove(object entity)
        {
            return BaseDbSet.Remove(entity);
        }

        public override IEnumerable RemoveRange(IEnumerable entities)
        {
            return BaseDbSet.RemoveRange(entities);
        }

        public override DbSqlQuery SqlQuery(string sql, params object[] parameters)
        {
            return BaseDbSet.SqlQuery(sql, parameters);
        }

        IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
        {
            return ProxyQueryable.GetEnumerator();
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return ((IEnumerable)ProxyQueryable).GetEnumerator();
        }

        Type IQueryable.ElementType
        {
            get { return ProxyQueryable.ElementType; }
        }

        Expression IQueryable.Expression
        {
            get { return ProxyQueryable.Expression; }
        }

        IQueryProvider IQueryable.Provider
        {
            get { return ProxyQueryable.Provider; }
        }

        IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
        {
            return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
        }

        IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
        {
            return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
        }

        public override string ToString()
        {
            return ProxyQueryable.ToString();
        }
    }

    public class ProxyDbSet<TEntity> : DbSet<TEntity>, IQueryable<TEntity>, IEnumerable<TEntity>, IDbAsyncEnumerable<TEntity>, IQueryable, IEnumerable, IDbAsyncEnumerable where TEntity : class
    {
        protected readonly DbSet<TEntity> BaseDbSet;
        protected readonly IQueryable<TEntity> ProxyQueryable;

        public readonly Func<bool, Expression, Expression> Manipulator;

        protected readonly FieldInfo InternalSetField = typeof(DbSet<TEntity>).GetField("_internalSet", BindingFlags.Instance | BindingFlags.NonPublic);

        /// <summary>
        /// 
        /// </summary>
        /// <param name="baseDbSet"></param>
        /// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
        public ProxyDbSet(DbSet<TEntity> baseDbSet, Func<bool, Expression, Expression> manipulator)
        {
            BaseDbSet = baseDbSet;

            IQueryProvider provider = ((IQueryable)baseDbSet).Provider;
            ProxyDbProvider proxyDbProvider = new ProxyDbProvider(provider, manipulator);

            ProxyQueryable = proxyDbProvider.CreateQuery<TEntity>(((IQueryable)baseDbSet).Expression);
            Manipulator = manipulator;

            if (InternalSetField != null)
            {
                InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
            }
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="baseDbSet"></param>
        /// <param name="proxyQueryable"></param>
        /// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
        public ProxyDbSet(DbSet<TEntity> baseDbSet, ProxyQueryable<TEntity> proxyQueryable, Func<bool, Expression, Expression> manipulator)
        {
            BaseDbSet = baseDbSet;

            ProxyQueryable = proxyQueryable;
            Manipulator = manipulator;

            if (InternalSetField != null)
            {
                InternalSetField.SetValue(this, InternalSetField.GetValue(baseDbSet));
            }
        }

        public override TEntity Add(TEntity entity)
        {
            return BaseDbSet.Add(entity);
        }

        public override IEnumerable<TEntity> AddRange(IEnumerable<TEntity> entities)
        {
            return BaseDbSet.AddRange(entities);
        }

        public override DbQuery<TEntity> AsNoTracking()
        {
            return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsNoTracking()), Manipulator);
        }

        [Obsolete]
        public override DbQuery<TEntity> AsStreaming()
        {
#pragma warning disable 618
            return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.AsStreaming()), Manipulator);
#pragma warning restore 618
        }

        public override TEntity Attach(TEntity entity)
        {
            return BaseDbSet.Attach(entity);
        }

        public override TDerivedEntity Create<TDerivedEntity>()
        {
            return BaseDbSet.Create<TDerivedEntity>();
        }

        public override TEntity Create()
        {
            return BaseDbSet.Create();
        }

        public override TEntity Find(params object[] keyValues)
        {
            return BaseDbSet.Find(keyValues);
        }

        public override Task<TEntity> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
        {
            return BaseDbSet.FindAsync(cancellationToken, keyValues);
        }

        public override Task<TEntity> FindAsync(params object[] keyValues)
        {
            return BaseDbSet.FindAsync(keyValues);
        }

        public override DbQuery<TEntity> Include(string path)
        {
            return new ProxyDbSet<TEntity>(BaseDbSet, new ProxyQueryable<TEntity>((ProxyDbProvider)ProxyQueryable.Provider, BaseDbSet.Include(path)), Manipulator);
        }

        public override ObservableCollection<TEntity> Local
        {
            get
            {
                return BaseDbSet.Local;
            }
        }

        public override TEntity Remove(TEntity entity)
        {
            return BaseDbSet.Remove(entity);
        }

        public override IEnumerable<TEntity> RemoveRange(IEnumerable<TEntity> entities)
        {
            return BaseDbSet.RemoveRange(entities);
        }

        public override DbSqlQuery<TEntity> SqlQuery(string sql, params object[] parameters)
        {
            return BaseDbSet.SqlQuery(sql, parameters);
        }

        IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
        {
            return ProxyQueryable.GetEnumerator();
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return ((IEnumerable)ProxyQueryable).GetEnumerator();
        }

        Type IQueryable.ElementType
        {
            get { return ProxyQueryable.ElementType; }
        }

        Expression IQueryable.Expression
        {
            get { return ProxyQueryable.Expression; }
        }

        IQueryProvider IQueryable.Provider
        {
            get { return ProxyQueryable.Provider; }
        }

        IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
        {
            return ((IDbAsyncEnumerable<TEntity>)ProxyQueryable).GetAsyncEnumerator();
        }

        IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
        {
            return ((IDbAsyncEnumerable)ProxyQueryable).GetAsyncEnumerator();
        }

        public override string ToString()
        {
            return ProxyQueryable.ToString();
        }

        // Note that the operator isn't virtual! If you do:
        // DbSet<Foo> foo = new ProxyDbSet<Foo>(...)
        // DbSet foo2 = (DbSet)foo;
        // Then you'll have a non-proxed DbSet!
        public static implicit operator ProxyDbSetNonGeneric<TEntity>(ProxyDbSet<TEntity> entry)
        {
            return new ProxyDbSetNonGeneric<TEntity>((DbSet)entry.BaseDbSet, entry.Manipulator);
        }
    }

    public class ProxyDbProvider : IQueryProvider, IDbAsyncQueryProvider
    {
        protected readonly IQueryProvider BaseQueryProvider;
        public readonly Func<bool, Expression, Expression> Manipulator;

        /// <summary>
        /// 
        /// </summary>
        /// <param name="baseQueryProvider"></param>
        /// <param name="manipulator">First parameter: true for Execute, false for CreateQuery.</param>
        public ProxyDbProvider(IQueryProvider baseQueryProvider, Func<bool, Expression, Expression> manipulator)
        {
            BaseQueryProvider = baseQueryProvider;
            Manipulator = manipulator;
        }

        public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
        {
            Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;

            IQueryable<TElement> query = BaseQueryProvider.CreateQuery<TElement>(expression2);
            IQueryProvider provider = query.Provider;
            ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);

            return new ProxyQueryable<TElement>(proxy, query);
        }

        protected static readonly MethodInfo CreateQueryNonGenericToGenericMethod = typeof(ProxyDbProvider).GetMethod("CreateQueryNonGenericToGeneric", BindingFlags.Static | BindingFlags.NonPublic);

        public IQueryable CreateQuery(Expression expression)
        {
            Expression expression2 = Manipulator != null ? Manipulator(false, expression) : expression;

            IQueryable query = BaseQueryProvider.CreateQuery(expression2);
            IQueryProvider provider = query.Provider;

            ProxyDbProvider proxy = provider == BaseQueryProvider ? this : new ProxyDbProvider(provider, Manipulator);

            Type entityType = GetIQueryableTypeArgument(query.GetType());

            if (entityType == null)
            {
                return new ProxyQueryable(proxy, query);
            }
            else
            {
                return (IQueryable)CreateQueryNonGenericToGenericMethod.MakeGenericMethod(entityType).Invoke(null, new object[] { proxy, query });
            }
        }

        protected static ProxyQueryable<TElement> CreateQueryNonGenericToGeneric<TElement>(ProxyDbProvider proxy, IQueryable<TElement> query)
        {
            return new ProxyQueryable<TElement>(proxy, query);
        }

        public TResult Execute<TResult>(Expression expression)
        {
            Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
            return BaseQueryProvider.Execute<TResult>(expression2);
        }

        public object Execute(Expression expression)
        {
            Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
            return BaseQueryProvider.Execute(expression2);
        }

        // Gets the T of IQueryablelt;T&gt;
        protected static Type GetIQueryableTypeArgument(Type type)
        {
            IEnumerable<Type> interfaces = type.IsInterface ?
                new[] { type }.Concat(type.GetInterfaces()) :
                type.GetInterfaces();
            Type argument = (from x in interfaces
                             where x.IsGenericType
                             let gt = x.GetGenericTypeDefinition()
                             where gt == typeof(IQueryable<>)
                             select x.GetGenericArguments()[0]).FirstOrDefault();
            return argument;
        }

        public Task<TResult> ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
        {
            var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;

            if (asyncQueryProvider == null)
            {
                throw new NotSupportedException();
            }

            Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
            return asyncQueryProvider.ExecuteAsync<TResult>(expression2, cancellationToken);
        }

        public Task<object> ExecuteAsync(Expression expression, CancellationToken cancellationToken)
        {
            var asyncQueryProvider = BaseQueryProvider as IDbAsyncQueryProvider;

            if (asyncQueryProvider == null)
            {
                throw new NotSupportedException();
            }

            Expression expression2 = Manipulator != null ? Manipulator(true, expression) : expression;
            return asyncQueryProvider.ExecuteAsync(expression2, cancellationToken);
        }
    }

    public class ProxyQueryable : IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
    {
        protected readonly ProxyDbProvider ProxyDbProvider;
        protected readonly IQueryable BaseQueryable;

        public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable baseQueryable)
        {
            ProxyDbProvider = proxyDbProvider;
            BaseQueryable = baseQueryable;
        }

        public IEnumerator GetEnumerator()
        {
            return BaseQueryable.GetEnumerator();
        }

        public Type ElementType
        {
            get { return BaseQueryable.ElementType; }
        }

        public Expression Expression
        {
            get { return BaseQueryable.Expression; }
        }

        public IQueryProvider Provider
        {
            get { return ProxyDbProvider; }
        }

        public override string ToString()
        {
            return BaseQueryable.ToString();
        }

        IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
        {
            var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;

            if (asyncEnumerator == null)
            {
                throw new NotSupportedException();
            }

            return asyncEnumerator.GetAsyncEnumerator();
        }
    }

    public class ProxyQueryable<TElement> : IOrderedQueryable<TElement>, IQueryable<TElement>, IEnumerable<TElement>, IDbAsyncEnumerable<TElement>, IOrderedQueryable, IQueryable, IEnumerable, IDbAsyncEnumerable
    {
        protected readonly ProxyDbProvider ProxyDbProvider;
        protected readonly IQueryable<TElement> BaseQueryable;

        public ProxyQueryable(ProxyDbProvider proxyDbProvider, IQueryable<TElement> baseQueryable)
        {
            ProxyDbProvider = proxyDbProvider;
            BaseQueryable = baseQueryable;
        }

        public IEnumerator<TElement> GetEnumerator()
        {
            return BaseQueryable.GetEnumerator();
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return ((IEnumerable)BaseQueryable).GetEnumerator();
        }

        public Type ElementType
        {
            get { return BaseQueryable.ElementType; }
        }

        public Expression Expression
        {
            get { return BaseQueryable.Expression; }
        }

        public IQueryProvider Provider
        {
            get { return ProxyDbProvider; }
        }

        public override string ToString()
        {
            return BaseQueryable.ToString();
        }

        public IDbAsyncEnumerator<TElement> GetAsyncEnumerator()
        {
            var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable<TElement>;

            if (asyncEnumerator == null)
            {
                throw new NotSupportedException();
            }

            return asyncEnumerator.GetAsyncEnumerator();
        }

        IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
        {
            var asyncEnumerator = BaseQueryable as IDbAsyncEnumerable;

            if (asyncEnumerator == null)
            {
                throw new NotSupportedException();
            }

            return asyncEnumerator.GetAsyncEnumerator();
        }
    }
}

      

Example manipulator Expression

(this one converts .Where(x => something)

to .Where(x => something && something)

:

namespace My
{
    using System.Linq;
    using System.Linq.Expressions;

    public class MyExpressionManipulator : ExpressionVisitor
    {
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            if (node.Method.DeclaringType == typeof(Queryable) && node.Method.Name == "Where" && node.Arguments.Count == 2)
            {
                // Transforms all the .Where(x => something) in
                // .Where(x => something && something)
                if (node.Arguments[1].NodeType == ExpressionType.Quote)
                {
                    UnaryExpression argument1 = (UnaryExpression)node.Arguments[1]; // Expression.Quote

                    if (argument1.Operand.NodeType == ExpressionType.Lambda)
                    {
                        LambdaExpression argument1lambda = (LambdaExpression)argument1.Operand;

                        // Important: at each step you'll reevalute the
                        // full expression! Try to not replace twice
                        // the expression!
                        // So if you have a query like:
                        // var res = ctx.Where(x => true).Where(x => true).Select(x => 1)
                        // the first time you'll visit
                        //  ctx.Where(x => true)
                        // and you'll obtain
                        //  ctx.Where(x => true && true)
                        // the second time you'll visit
                        //  ctx.Where(x => true && true).Where(x => true)
                        // and you want to obtain
                        //  ctx.Where(x => true && true).Where(x => true && true)
                        // and not
                        //  ctx.Where(x => (true && true) && (true && true)).Where(x => true && true)
                        if (argument1lambda.Body.NodeType != ExpressionType.AndAlso)
                        {
                            var arguments = new Expression[node.Arguments.Count];
                            node.Arguments.CopyTo(arguments, 0);

                            arguments[1] = Expression.Quote(Expression.Lambda(Expression.AndAlso(argument1lambda.Body, argument1lambda.Body), argument1lambda.Parameters));
                            MethodCallExpression node2 = Expression.Call(node.Object, node.Method, arguments);
                            node = node2;
                        }
                    }
                }
            }

            return base.VisitMethodCall(node);
        }
    }
}

      



Now ... How to use it? The best way is to output your context (in this case Model1) not from DbContext

, but from ProxyDbContext

, for example:

public partial class Model1 : ProxyDbContext
{
    public Model1()
        : base("name=Model1", Manipulate)
    {
    }

    /// <summary>
    /// 
    /// </summary>
    /// <param name="executing">true: the returned Expression will be executed directly, false: the returned expression will be returned as IQueryable&lt&gt.</param>
    /// <param name="expression"></param>
    /// <returns></returns>
    private static Expression Manipulate(bool executing, Expression expression)
    {
        // See the annotation about reexecuting the same visitor
        // multiple times in MyExpressionManipulator().Visit .
        // By executing the visitor only on executing == true,
        // and simply return expression; on executing == false,
        // you have the guarantee that an expression won't be
        // manipulated multiple times.
        // As written now, the expression will be manipulated
        // multiple times.
        return new MyExpressionManipulator().Visit(expression);
    }

    // Some tables
    public virtual DbSet<Parent> Parent { get; set; }
    public virtual IDbSet<Child> Child { get; set; }

      

Then it's very transparent:

// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
    // Your query
    var res = ctx.Parent.Where(x => x.Id > 100);
    // The query is automatically manipulated by your Manipulate method
}

      

another way to do it without subclassing fromProxyDbContext

:

// Where Model1: class Model1 : ProxyDbContext {}
using (var ctx = new Model1())
{
    Func<Expression, Expression> manipulator = new MyExpressionManipulator().Visit;
    ctx.Parent = new ProxyDbSet<Parent>(ctx.Parent, manipulator);
    ctx.Child = new ProxyDbSet<Child>(ctx.Child, manipulator);

    // Your query
    var res = ctx.Parent.Where(x => x.Id > 100);
}

      

ProxyDbContext<>

replaces DbSet<>

/ IDbSet<>

that are present in your context with ProxyDbSet<>

.

In the second example, this action is taken explicitly, but note that you can create a method to do this, or create a factory for your context (a static method that returns a context with a different DbSet<>

"proxy"), or you can put a proxy in the constructor of your context (because that the "initial" initialization DbSet<>

happens in the constructor DbContext

and your context's constructor body is done after that), or you can create multiple subclasses of your context, each with a constructor that proxies in a different way ...

Note that the first method (subclassing ProxyDbContext<>

) "fixes" the Set<>

/ methods Set

that you would otherwise have to fix yourself by copying the overloads of these two methods from ProxyDbContext<>

.

+4


source







All Articles