C# 如何在执行之前包装实体框架以截获LINQ表达式?

C# 如何在执行之前包装实体框架以截获LINQ表达式?,c#,linq,entity-framework,expression-trees,C#,Linq,Entity Framework,Expression Trees,我想在执行之前重写LINQ表达式的某些部分。我在将重写器注入正确的位置时遇到了问题(事实上) 查看实体框架源代码(在reflector中),它最终归结为IQueryProvider.Execute,它在EF中通过ObjectContext提供内部IQueryProvider提供程序{get;}属性耦合到表达式 因此,我创建了一个包装类(实现IQueryProvider),在调用Execute时重写表达式,然后将其传递给原始提供程序 问题是,Provider后面的字段是private Object

我想在执行之前重写LINQ表达式的某些部分。我在将重写器注入正确的位置时遇到了问题(事实上)

查看实体框架源代码(在reflector中),它最终归结为
IQueryProvider.Execute
,它在EF中通过
ObjectContext
提供
内部IQueryProvider提供程序{get;}
属性耦合到表达式

因此,我创建了一个包装类(实现
IQueryProvider
),在调用Execute时重写表达式,然后将其传递给原始提供程序

问题是,
Provider
后面的字段是
private ObjectQueryProvider\u queryProvider。此
ObjectQueryProvider
是一个内部密封类,这意味着不可能创建提供添加重写的子类

由于ObjectContext的紧密耦合,这种方法使我陷入了死胡同

如何解决这个问题?我看错方向了吗?是否有一种方法可以将自己注入到这个
ObjectQueryProvider


更新:当您使用存储库模式“包装”ObjectContext时,所提供的解决方案都可以工作,但最好使用允许直接使用ObjectContext生成的子类的解决方案。从而与动态数据脚手架保持兼容

我有您需要的源代码,但不知道如何附加文件

以下是一些代码片段(代码片段!我不得不修改这段代码,所以它可能无法编译):

IQueryable:

public class QueryTranslator<T> : IOrderedQueryable<T>
{
    private Expression _expression = null;
    private QueryTranslatorProvider<T> _provider = null;

    public QueryTranslator(IQueryable source)
    {
        _expression = Expression.Constant(this);
        _provider = new QueryTranslatorProvider<T>(source);
    }

    public QueryTranslator(IQueryable source, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        _expression = e;
        _provider = new QueryTranslatorProvider<T>(source);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)_provider.ExecuteEnumerable(this._expression)).GetEnumerator();
    }

    IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return _provider.ExecuteEnumerable(this._expression).GetEnumerator();
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return _expression; }
    }

    public IQueryProvider Provider
    {
        get { return _provider; }
    }
}
公共类查询Translator:IOrderedQueryable
{
私有表达式_Expression=null;
私有查询TranslatorProvider _provider=null;
公共查询Translator(iQueryTable源)
{
_表达式=表达式常数(this);
_provider=新的QueryTranslatorProvider(源);
}
公共查询Translator(IQueryable源,表达式e)
{
如果(e==null)抛出新的ArgumentNullException(“e”);
_表达式=e;
_provider=新的QueryTranslatorProvider(源);
}
公共IEnumerator GetEnumerator()
{
return((IEnumerable)_provider.executenumerable(this._表达式)).GetEnumerator();
}
IEnumerator System.Collections.IEnumerable.GetEnumerator()
{
返回_provider.executenumerable(this._表达式).GetEnumerator();
}
公共类型ElementType
{
获取{return typeof(T);}
}
公开表达
{
获取{return\u expression;}
}
公共IQueryProvider提供程序
{
获取{return\u provider;}
}
}
IQueryProvider:

public class QueryTranslatorProvider<T> : ExpressionTreeTranslator, IQueryProvider
{
    IQueryable _source;

    public QueryTranslatorProvider(IQueryable source)
    {
        if (source == null) throw new ArgumentNullException("source");
        _source = source;
    }

    #region IQueryProvider Members

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        return new QueryTranslator<TElement>(_source, expression) as IQueryable<TElement>;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Type elementType = expression.Type.FindElementTypes().First();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
            new object[] { _source, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);

        return _source.Provider.Execute(translated);            
    }

    internal IEnumerable ExecuteEnumerable(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);

        return _source.Provider.CreateQuery(translated);
    }

    #endregion        

    #region Visits
    protected override MethodCallExpression VisitMethodCall(MethodCallExpression m)
    {
        return m;
    }

    protected override Expression VisitUnary(UnaryExpression u)
    {
         return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method);
    }
    #endregion
}
公共类QueryTranslator提供程序:ExpressionTreeTranslator,IQueryProvider
{
IQueryable_源;
公共查询Translator提供程序(iQueryTable源)
{
如果(source==null)抛出新的ArgumentNullException(“source”);
_来源=来源;
}
#区域提供程序成员
公共IQueryable CreateQuery(表达式)
{
如果(表达式==null)抛出新的ArgumentNullException(“表达式”);
将新的QueryTranslator(_source,expression)返回为IQueryable;
}
公共IQueryable CreateQuery(表达式)
{
如果(表达式==null)抛出新的ArgumentNullException(“表达式”);
Type elementType=expression.Type.FindElementTypes().First();
IQueryable结果=(IQueryable)Activator.CreateInstance(typeof(QueryTranslator).MakeGenericType(elementType),
新对象[]{u源,表达式});
返回结果;
}
公共TResult执行(表达式)
{
如果(表达式==null)抛出新的ArgumentNullException(“表达式”);
对象结果=(作为IQueryProvider)。执行(表达式);
返回(TResult)结果;
}
公共对象执行(表达式)
{
如果(表达式==null)抛出新的ArgumentNullException(“表达式”);
Expression translated=this.Visit(Expression);
返回_source.Provider.Execute(已翻译);
}
内部IEnumerable可执行IEnumerable(表达式)
{
如果(表达式==null)抛出新的ArgumentNullException(“表达式”);
Expression translated=this.Visit(Expression);
返回_source.Provider.CreateQuery(已翻译);
}
#端区
#地区访问
受保护的重写MethodCallExpression VisitMethodCall(MethodCallExpression m)
{
返回m;
}
受保护的重写表达式VisitUnary(UnaryU表达式)
{
返回表达式.MakeUnary(u.NodeType,base.Visite(u.Operator),u.Type.ToImplementationType(),u.Method);
}
#端区
}
用法(警告:改编代码!可能无法编译):

private Dictionary\u table=new Dictionary();
公共覆盖IQueryable GetObjectQuery()
{
if(!\u表格容器(类型))
{
_表[type]=新的查询Translator(
_ctx.CreateQuery(“[”+typeof(T.Name+“]”);
}
返回(IQueryable)\表[类型];
}
表达访客/翻译:

编辑:添加了FindElementTypes()。希望所有的方法现在都存在

    /// <summary>
    /// Finds all implemented IEnumerables of the given Type
    /// </summary>
    public static IQueryable<Type> FindIEnumerables(this Type seqType)
    {
        if (seqType == null || seqType == typeof(object) || seqType == typeof(string))
            return new Type[] { }.AsQueryable();

        if (seqType.IsArray || seqType == typeof(IEnumerable))
            return new Type[] { typeof(IEnumerable) }.AsQueryable();

        if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
        {
            return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable();
        }

        var result = new List<Type>();

        foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { }))
        {
            result.AddRange(FindIEnumerables(iface));
        }

        return FindIEnumerables(seqType.BaseType).Union(result);
    }

    /// <summary>
    /// Finds all element types provided by a specified sequence type.
    /// "Element types" are T for IEnumerable&lt;T&gt; and object for IEnumerable.
    /// </summary>
    public static IQueryable<Type> FindElementTypes(this Type seqType)
    {
        return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object));
    }
//
///查找给定类型的所有已实现IEnumerable
/// 
公共静态IQueryable FindEnumerables(此类型为seqType)
{
if(seqType==null | | seqType==typeof(object)| | seqType==typeof(string))
返回新类型[]{}.AsQueryable();
if(seqType.IsArray | seqType==typeof(IEnumerable))
返回新类型[]{typeof(IEnumerable)}.AsQueryable();
if(seqType.IsGenericType&&seqType.GetGenericArguments().Length==1&&seqType.GetGenericTypeDefinition()==typeof(IEnumerable))
    /// <summary>
    /// Finds all implemented IEnumerables of the given Type
    /// </summary>
    public static IQueryable<Type> FindIEnumerables(this Type seqType)
    {
        if (seqType == null || seqType == typeof(object) || seqType == typeof(string))
            return new Type[] { }.AsQueryable();

        if (seqType.IsArray || seqType == typeof(IEnumerable))
            return new Type[] { typeof(IEnumerable) }.AsQueryable();

        if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
        {
            return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable();
        }

        var result = new List<Type>();

        foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { }))
        {
            result.AddRange(FindIEnumerables(iface));
        }

        return FindIEnumerables(seqType.BaseType).Union(result);
    }

    /// <summary>
    /// Finds all element types provided by a specified sequence type.
    /// "Element types" are T for IEnumerable&lt;T&gt; and object for IEnumerable.
    /// </summary>
    public static IQueryable<Type> FindElementTypes(this Type seqType)
    {
        return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object));
    }
public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateObjectSet<T>();
    }

    return (IQueryable<T>)_table[type];
}
public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateQuery<T>("[" + GetEntitySetName<T>() + "]"));

    } 
    return (IQueryable<T>)_table[type];
}
public class QueryTranslator<T> : IOrderedQueryable<T>
{
    private Expression expression = null;
    private QueryTranslatorProvider<T> provider = null;

    public QueryTranslator(IQueryable source)
    {
        expression = Expression.Constant(this);
        provider = new QueryTranslatorProvider<T>(source);
    }

    public QueryTranslator(IQueryable source, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        expression = e;
        provider = new QueryTranslatorProvider<T>(source);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)provider.ExecuteEnumerable(this.expression)).GetEnumerator();
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return provider.ExecuteEnumerable(this.expression).GetEnumerator();
    }

    public QueryTranslator<T> Include(String path)
    {
        ObjectQuery<T> possibleObjectQuery = provider.source as ObjectQuery<T>;
        if (possibleObjectQuery != null)
        {
            return new QueryTranslator<T>(possibleObjectQuery.Include(path));
        }
        else
        {
            throw new InvalidOperationException("The Include should only happen at the beginning of a LINQ expression");
        }
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return expression; }
    }

    public IQueryProvider Provider
    {
        get { return provider; }
    }
}

public class QueryTranslatorProvider<T> : ExpressionVisitor, IQueryProvider
{
    internal IQueryable source;

    public QueryTranslatorProvider(IQueryable source)
    {
        if (source == null) throw new ArgumentNullException("source");
        this.source = source;
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        return new QueryTranslator<TElement>(source, expression) as IQueryable<TElement>;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        Type elementType = expression.Type.GetGenericArguments().First();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
            new object[] { source, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);
        return source.Provider.Execute(translated);
    }

    internal IEnumerable ExecuteEnumerable(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);
        return source.Provider.CreateQuery(translated);
    }

    #region Visitors
    protected override Expression VisitConstant(ConstantExpression c)
    {
        // fix up the Expression tree to work with EF again
        if (c.Type == typeof(QueryTranslator<T>))
        {
            return source.Expression;
        }
        else
        {
            return base.VisitConstant(c);
        }
    }
    #endregion
}
public IQueryable<User> List()
{
    return new QueryTranslator<User>(entities.Users).Include("Department");
}