Monday, September 29, 2008

Passing Predicates into Compiled Queries

I've recently been looking at generating LINQ predicates on the fly in a mapping layer between a set of business domain entities and a set of related, but different, database entities.  One of the problems that I've encountered is to do with the way in which predicates are handled when using CompiledQueries.

To start off, lets consider the easy non-compiled version.  Here's a method:

// Get an employee by a predicate


static void GetEmployee(Expression<Func<Employee, bool>> predicate)


{


   // Just perform the select, and output the results


   using (DataClasses1DataContext context = new DataClasses1DataContext())


   {


      var results = context.Employees.Where(predicate);


 


      Console.WriteLine("Number of employees: {0}", results.Count());


   }


}




The usage of this is nice and simple, and the sort of thing in LINQ examples all over the web:





GetEmployee(e => e.EmployeeID == 1);




Calling this does exactly what you'd expect.  My next step was to look at how this approach could be used with compiled queries.  I started with a simple method:





// Get an employee by a predicate, using a compiled expression


static void GetEmployeeCompiled(Expression<Func<Employee, bool>> predicate)


{


   // Compile the query


   var compiledQuery =


      CompiledQuery.Compile((DataClasses1DataContext context) => context.Employees.Where(predicate));


 


   // and using the compiled query, output the results.  This crashes :(


   using (DataClasses1DataContext context = new DataClasses1DataContext())


   {


      var results = compiledQuery(context);


 


      Console.WriteLine("Number of employees: {0}", results.Count());


   }


}




Obviously, this is pointless since it just recompiles the query every time.  But let's ignore that small fact - it should, after all, still work.  Alas, it doesn't. 



At the point where the results are enumerated, it explodes with a "NotSupportedException".  Specifically, it fails due to an "Unsupported overload used for query operator 'Where'.".  Looking at the expression tree that is being compiled, in conjunction with some help from Reflector to look at what LINQ is doing under the cover, it can be seen that the issue is down to how the predicate is included in the final query expression. 



Remember, the compiler is not generating executable code here, it is just building a lambda expression.  When it sees the parameter to the Where() method, it has little choice but to be "lift" this variable its own class, and it is a property on this lifted class that is passed as a parameter to the Where() method.  Although the non-compiled version handles this just fine, it causes the CompiledQuery object to barf.  This is just the same as any other query that uses a local variable or parameter.



I've experimented with a number of ways of constructing the query that I'm trying to compile, but all ultimately end up with the same problem.  The solution I've found is a little nasty, but it does work.  If I've missed a cleaner way, then I'd love to hear about it!



Anyhow, the solution.  It is based on the fact that it is the act of "passing" the predicate into the Where() method that is the problem.  So the solution is to not pass in the predicate, but instead pass in some dummy predicate.  Then do some expression tree walking to swap out the dummy predicate for the real one.  The code looks like this:





// Get an employee by a predicate, using a compiled expression


static void GetEmployeeCompiled2(Expression<Func<Employee, bool>> predicate)


{


   // Setup the required query, using a dummy predicate (c => true)


   Expression<Func<DataClasses1DataContext, IEnumerable<Employee>>> compiledExpression =


      context => context.Employees.Where(c => true);


 


   // Dig out the dummy predicate from the expression tree created above


   Expression template = ((UnaryExpression)((MethodCallExpression)(compiledExpression.Body)).Arguments[1]).Operand;


 


   // Swap out the template for the predicate


   compiledExpression = (Expression<Func<DataClasses1DataContext, IEnumerable<Employee>>>) 


                                 ExpressionRewriter.Replace(compiledExpression, template, predicate);


 


   // Compile the query


   var compiledQuery = CompiledQuery.Compile(compiledExpression);


 


   // and using the compiled query, output the results.  This works :)


   using (DataClasses1DataContext context = new DataClasses1DataContext())


   {


      var results = compiledQuery(context);


 


      Console.WriteLine("Number of employees: {0}", results.Count());


   }


}




So the required query is itself stored as an expression, with a dummy predicate (c => true) used to get the correct "shape" of tree.  This predicate is then located and the expression tree is rewritten, swapping out the dummy predicate for the real one.  This new query expression then compiles and executes just fine.



For completeness, the ExpressionRewriter class is defined as:





class ExpressionRewriter : ExpressionVisitor


{


   static public Expression Replace(Expression tree, Expression toReplace, Expression replaceWith)


   {


      ExpressionRewriter rewriter = new ExpressionRewriter(toReplace, replaceWith);


 


      return rewriter.Visit(tree);


   }


 


   private readonly Expression _toReplace;


   private readonly Expression _replaceWith;


 


   private ExpressionRewriter(Expression toReplace, Expression replaceWith)


   {


      _toReplace = toReplace;


      _replaceWith = replaceWith;


   }


 


   protected override Expression Visit(Expression exp)


   {


      if (exp == _toReplace)


      {


         return _replaceWith;


      }


      return base.Visit(exp);


   }


}




where the ExpressionVisitor base class can be found on MSDN

2 comments:

Adrian said...

Great work, helped me a lot

Naveed Iqbal said...

I am this expression
private static Expression>>
compiledExpression = db => (from v in db.UnitComponents_Vs
join u in db.Units.Where(c=>true) on v.UnitID equals u.UnitID
//join c in db.Components.Where(componentFilters) on v.ComponentID equals c.ComponentID
join c in db.Components on v.ComponentID equals c.ComponentID
//join c in db.Components on v.ComponentID equals c.ComponentID
join ct in db.ComponentTypes on c.ComponentTypeID equals ct.ComponentTypeID
//join i in db.Items.Where(itemFilters) on c.ItemID equals i.ItemID
join i in db.Items on c.ItemID equals i.ItemID
select ct).Distinct();


static Expression queryTemplate = ((UnaryExpression)((MethodCallExpression)(compiledExpression.Body)).Arguments[1]).Operand;

unitFilters = u => listProjectIDs.Contains(u.Phase.ProjectID);


compiledExpression = (Expression>>)ExpressionRewriter.Replace(compiledExpression, queryTemplate, unitFilters);


It throws error:
"Index was out of range. Must be non-negative and less than the size of the collection.\r\nParameter name: index"