2009-10-27 16 views
5

J'essaie de trouver une manière élégante de gérer certains polynômes générés. Voici la situation que nous allons nous concentrer sur (exclusivement) pour cette question:Méthodes générées pour l'évaluation polynomiale

  1. pour est un paramètre dans la génération d'un n e polynôme d'ordre, où n: = ordre + 1.
  2. i est un paramètre entier dans la plage 0..n
  3. Le polynôme a des zéros à x_j, où j = 1.n et j ≠ i (il devrait être clair à ce stade que StackOverflow a besoin d'une nouvelle fonctionnalité ou est présent et je ne le sais pas)
  4. L'évaluateur polynomial es à 1 à x_i.

Étant donné que cet exemple de code particulier génère x_1 .. x_n, je vais expliquer comment ils se trouvent dans le code. Les points sont espacés uniformément x_j = j * elementSize/order, où n = order + 1.

Je génère un Func<double, double> pour évaluer ce polynôme¹.

private static Func<double, double> GeneratePsi(double elementSize, int order, int i) 
{ 
    if (order < 1) 
     throw new ArgumentOutOfRangeException("order", "order must be greater than 0."); 

    if (i < 0) 
     throw new ArgumentOutOfRangeException("i", "i cannot be less than zero."); 
    if (i > order) 
     throw new ArgumentException("i", "i cannot be greater than order"); 

    ParameterExpression xp = Expression.Parameter(typeof(double), "x"); 

    // generate the terms of the factored polynomial in form (x_j - x) 
    List<Expression> factors = new List<Expression>(); 
    for (int j = 0; j <= order; j++) 
    { 
     if (j == i) 
      continue; 

     double p = j * elementSize/order; 
     factors.Add(Expression.Subtract(Expression.Constant(p), xp)); 
    } 

    // evaluate the result at the point x_i to get scaleInv=1.0/scale. 
    double xi = i * elementSize/order; 
    double scaleInv = Enumerable.Range(0, order + 1).Aggregate(0.0, (product, j) => product * (j == i ? 1.0 : (j * elementSize/order - xi))); 

    /* generate an expression to evaluate 
    * (x_0 - x) * (x_1 - x) .. (x_n - x)/(x_i - x) 
    * obviously the term (x_i - x) is cancelled in this result, but included here to make the result clear 
    */ 
    Expression expr = factors.Skip(1).Aggregate(factors[0], Expression.Multiply); 
    // multiplying by scale forces the condition f(x_i)=1 
    expr = Expression.Multiply(Expression.Constant(1.0/scaleInv), expr); 

    Expression<Func<double, double>> lambdaMethod = Expression.Lambda<Func<double, double>>(expr, xp); 
    return lambdaMethod.Compile(); 
} 

Le problème: Je dois aussi évaluer ψ '= dψ/dx. Pour ce faire, je peux réécrire ψ = échelle × (x_0 - x) (x_1 - x) × .. × (x_n - x)/(x_i - x) sous la forme ψ = α_n × x^n + α_n × x^(n-1) + .. + α_1 × x + α_0. Cela donne ψ '= n × α_n × x^(n-1) + (n-1) × α_n × x^(n-2) + .. + 1 × α_1. Pour des raisons de calcul, nous pouvons réécrire la réponse finale sans appel au Math.Pow en écrivant ψ '= x × (x × (x × (..) - β_2) - β_1) - β_0.

Pour faire tout cela « supercherie » (très algèbre de base), je besoin d'un moyen propre à:

  1. Développer un Expression factorisé contenant ConstantExpression et ParameterExpression feuilles et des opérations mathématiques de base (fin en soit BinaryExpression avec le NodeType mis à l'opération) - le résultat ici peut inclure InvocationExpression éléments au MethodInfo pour Math.Pow que nous allons traiter d'une manière spéciale tout au long. Puis je prends la dérivée par rapport à certains ParameterExpression spécifié. Termes dans le résultat où le paramètre de droite à une invocation de Math.Pow était la constante 2 sont remplacés par le ConstantExpression(2) multiplié par ce qui était le côté gauche (l'invocation de Math.Pow(x,1) est supprimée). Les termes du résultat qui deviennent nuls parce qu'ils étaient constants par rapport à x sont supprimés.
  2. Ensuite, factoriser les instances de certains spécifiques ParameterExpression où ils se produisent en tant que paramètre de gauche à un appel de Math.Pow. Lorsque le côté droit de l'invocation devient un ConstantExpression avec la valeur 1, nous remplaçons l'invocation par le seul ParameterExpression lui-même.

¹ A l'avenir, je voudrais la méthode de prendre un ParameterExpression et retourner un Expression qui permet d'évaluer basé sur ce paramètre. De cette façon, je peux agréger des fonctions générées. Je ne suis pas encore là. ² À l'avenir, j'espère pouvoir publier une bibliothèque générale pour travailler avec les expressions LINQ comme mathématiques symboliques.

+4

+1 pour moi de perdre après 5 lignes ... il doit vraiment être une question intelligente;) –

+0

Je, d'autre part, comprendre toutes les maths et ne sait rien à propos de LINQ! On dirait que vous avez déjà vos algorithmes déjà bien établis. Et bonne chance avec cette bibliothèque! – Cascabel

+0

@Jefromi: Je peux générer une arborescence d'expression très bien. Ce que je veux construire est une manière élégante de transformer les arbres, en les traitant comme des expressions de mathématiques symboliques. :) –

Répondre

6

J'ai écrit les bases de plusieurs fonctions mathématiques symboliques en utilisant le type ExpressionVisitor dans .NET 4. Ce n'est pas parfait, mais cela ressemble à la base d'une solution viable.

  • Symbolic est une classe statique publique exposant des méthodes telles que Expand, Simplify et PartialDerivative
  • ExpandVisitor est un type d'aide interne qui étend les expressions
  • SimplifyVisitor est un type d'aide interne qui simplifie les expressions
  • DerivativeVisitor est- un type d'aide interne qui prend la dérivée d'une expression
  • ListPrintVisitor est un Type auxiliaire interne qui convertit un Expression à une notation de préfixe avec une syntaxe Lisp

Symbolic

public static class Symbolic 
{ 
    public static Expression Expand(Expression expression) 
    { 
     return new ExpandVisitor().Visit(expression); 
    } 

    public static Expression Simplify(Expression expression) 
    { 
     return new SimplifyVisitor().Visit(expression); 
    } 

    public static Expression PartialDerivative(Expression expression, ParameterExpression parameter) 
    { 
     bool totalDerivative = false; 
     return new DerivativeVisitor(parameter, totalDerivative).Visit(expression); 
    } 

    public static string ToString(Expression expression) 
    { 
     ConstantExpression result = (ConstantExpression)new ListPrintVisitor().Visit(expression); 
     return result.Value.ToString(); 
    } 
} 

expansibles expressions avec ExpandVisitor

internal class ExpandVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     var left = Visit(node.Left); 
     var right = Visit(node.Right); 

     if (node.NodeType == ExpressionType.Multiply) 
     { 
      Expression[] leftNodes = GetAddedNodes(left).ToArray(); 
      Expression[] rightNodes = GetAddedNodes(right).ToArray(); 
      var result = 
       leftNodes 
       .SelectMany(x => rightNodes.Select(y => Expression.Multiply(x, y))) 
       .Aggregate((sum, term) => Expression.Add(sum, term)); 

      return result; 
     } 

     if (node.Left == left && node.Right == right) 
      return node; 

     return Expression.MakeBinary(node.NodeType, left, right, node.IsLiftedToNull, node.Method, node.Conversion); 
    } 

    /// <summary> 
    /// Treats the <paramref name="node"/> as the sum (or difference) of one or more child nodes and returns the 
    /// the individual addends in the sum. 
    /// </summary> 
    private static IEnumerable<Expression> GetAddedNodes(Expression node) 
    { 
     BinaryExpression binary = node as BinaryExpression; 
     if (binary != null) 
     { 
      switch (binary.NodeType) 
      { 
      case ExpressionType.Add: 
       foreach (var n in GetAddedNodes(binary.Left)) 
        yield return n; 

       foreach (var n in GetAddedNodes(binary.Right)) 
        yield return n; 

       yield break; 

      case ExpressionType.Subtract: 
       foreach (var n in GetAddedNodes(binary.Left)) 
        yield return n; 

       foreach (var n in GetAddedNodes(binary.Right)) 
        yield return Expression.Negate(n); 

       yield break; 

      default: 
       break; 
      } 
     } 

     yield return node; 
    } 
} 

Prendre un dérivé avec DerivativeVisitor

internal class DerivativeVisitor : ExpressionVisitor 
{ 
    private ParameterExpression _parameter; 
    private bool _totalDerivative; 

    public DerivativeVisitor(ParameterExpression parameter, bool totalDerivative) 
    { 
     if (_totalDerivative) 
      throw new NotImplementedException(); 

     _parameter = parameter; 
     _totalDerivative = totalDerivative; 
    } 

    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
     case ExpressionType.Subtract: 
      return Expression.MakeBinary(node.NodeType, Visit(node.Left), Visit(node.Right)); 

     case ExpressionType.Multiply: 
      return Expression.Add(Expression.Multiply(node.Left, Visit(node.Right)), Expression.Multiply(Visit(node.Left), node.Right)); 

     case ExpressionType.Divide: 
      return Expression.Divide(Expression.Subtract(Expression.Multiply(Visit(node.Left), node.Right), Expression.Multiply(node.Left, Visit(node.Right))), Expression.Power(node.Right, Expression.Constant(2))); 

     case ExpressionType.Power: 
      if (node.Right is ConstantExpression) 
      { 
       return Expression.Multiply(node.Right, Expression.Multiply(Visit(node.Left), Expression.Subtract(node.Right, Expression.Constant(1)))); 
      } 
      else if (node.Left is ConstantExpression) 
      { 
       return Expression.Multiply(node, MathExpressions.Log(node.Left)); 
      } 
      else 
      { 
       return Expression.Multiply(node, Expression.Add(
        Expression.Multiply(Visit(node.Left), Expression.Divide(node.Right, node.Left)), 
        Expression.Multiply(Visit(node.Right), MathExpressions.Log(node.Left)) 
        )); 
      } 

     default: 
      throw new NotImplementedException(); 
     } 
    } 

    protected override Expression VisitConstant(ConstantExpression node) 
    { 
     return MathExpressions.Zero; 
    } 

    protected override Expression VisitInvocation(InvocationExpression node) 
    { 
     MemberExpression memberExpression = node.Expression as MemberExpression; 
     if (memberExpression != null) 
     { 
      var member = memberExpression.Member; 
      if (member.DeclaringType != typeof(Math)) 
       throw new NotImplementedException(); 

      switch (member.Name) 
      { 
      case "Log": 
       return Expression.Divide(Visit(node.Expression), node.Expression); 

      case "Log10": 
       return Expression.Divide(Visit(node.Expression), Expression.Multiply(Expression.Constant(Math.Log(10)), node.Expression)); 

      case "Exp": 
      case "Sin": 
      case "Cos": 
      default: 
       throw new NotImplementedException(); 
      } 
     } 

     throw new NotImplementedException(); 
    } 

    protected override Expression VisitParameter(ParameterExpression node) 
    { 
     if (node == _parameter) 
      return MathExpressions.One; 

     return MathExpressions.Zero; 
    } 
} 

expressions avec simplificatrices SimplifyVisitor

internal class SimplifyVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     var left = Visit(node.Left); 
     var right = Visit(node.Right); 

     ConstantExpression leftConstant = left as ConstantExpression; 
     ConstantExpression rightConstant = right as ConstantExpression; 
     if (leftConstant != null && rightConstant != null 
      && (leftConstant.Value is double) && (rightConstant.Value is double)) 
     { 
      double leftValue = (double)leftConstant.Value; 
      double rightValue = (double)rightConstant.Value; 

      switch (node.NodeType) 
      { 
      case ExpressionType.Add: 
       return Expression.Constant(leftValue + rightValue); 
      case ExpressionType.Subtract: 
       return Expression.Constant(leftValue - rightValue); 
      case ExpressionType.Multiply: 
       return Expression.Constant(leftValue * rightValue); 
      case ExpressionType.Divide: 
       return Expression.Constant(leftValue/rightValue); 
      default: 
       throw new NotImplementedException(); 
      } 
     } 

     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
      if (IsZero(left)) 
       return right; 
      if (IsZero(right)) 
       return left; 
      break; 

     case ExpressionType.Subtract: 
      if (IsZero(left)) 
       return Expression.Negate(right); 
      if (IsZero(right)) 
       return left; 
      break; 

     case ExpressionType.Multiply: 
      if (IsZero(left) || IsZero(right)) 
       return MathExpressions.Zero; 
      if (IsOne(left)) 
       return right; 
      if (IsOne(right)) 
       return left; 
      break; 

     case ExpressionType.Divide: 
      if (IsZero(right)) 
       throw new DivideByZeroException(); 
      if (IsZero(left)) 
       return MathExpressions.Zero; 
      if (IsOne(right)) 
       return left; 
      break; 

     default: 
      throw new NotImplementedException(); 
     } 

     return Expression.MakeBinary(node.NodeType, left, right); 
    } 

    protected override Expression VisitUnary(UnaryExpression node) 
    { 
     var operand = Visit(node.Operand); 

     ConstantExpression operandConstant = operand as ConstantExpression; 
     if (operandConstant != null && (operandConstant.Value is double)) 
     { 
      double operandValue = (double)operandConstant.Value; 

      switch (node.NodeType) 
      { 
      case ExpressionType.Negate: 
       if (operandValue == 0.0) 
        return MathExpressions.Zero; 

       return Expression.Constant(-operandValue); 

      default: 
       throw new NotImplementedException(); 
      } 
     } 

     switch (node.NodeType) 
     { 
     case ExpressionType.Negate: 
      if (operand.NodeType == ExpressionType.Negate) 
      { 
       return ((UnaryExpression)operand).Operand; 
      } 

      break; 

     default: 
      throw new NotImplementedException(); 
     } 

     return Expression.MakeUnary(node.NodeType, operand, node.Type); 
    } 

    private static bool IsZero(Expression expression) 
    { 
     ConstantExpression constant = expression as ConstantExpression; 
     if (constant != null) 
     { 
      if (constant.Value.Equals(0.0)) 
       return true; 
     } 

     return false; 
    } 

    private static bool IsOne(Expression expression) 
    { 
     ConstantExpression constant = expression as ConstantExpression; 
     if (constant != null) 
     { 
      if (constant.Value.Equals(1.0)) 
       return true; 
     } 

     return false; 
    } 
} 

expressions de mise en forme pour l'affichage avec ListPrintVisitor

internal class ListPrintVisitor : ExpressionVisitor 
{ 
    protected override Expression VisitBinary(BinaryExpression node) 
    { 
     string op = null; 

     switch (node.NodeType) 
     { 
     case ExpressionType.Add: 
      op = "+"; 
      break; 
     case ExpressionType.Subtract: 
      op = "-"; 
      break; 
     case ExpressionType.Multiply: 
      op = "*"; 
      break; 
     case ExpressionType.Divide: 
      op = "/"; 
      break; 
     default: 
      throw new NotImplementedException(); 
     } 

     var left = Visit(node.Left); 
     var right = Visit(node.Right); 
     string result = string.Format("({0} {1} {2})", op, ((ConstantExpression)left).Value, ((ConstantExpression)right).Value); 
     return Expression.Constant(result); 
    } 

    protected override Expression VisitConstant(ConstantExpression node) 
    { 
     if (node.Value is string) 
      return node; 

     return Expression.Constant(node.Value.ToString()); 
    } 

    protected override Expression VisitParameter(ParameterExpression node) 
    { 
     return Expression.Constant(node.Name); 
    } 
} 

Test des résultats

[TestMethod] 
public void BasicSymbolicTest() 
{ 
    ParameterExpression x = Expression.Parameter(typeof(double), "x"); 
    Expression linear = Expression.Add(Expression.Constant(3.0), x); 
    Assert.AreEqual("(+ 3 x)", Symbolic.ToString(linear)); 

    Expression quadratic = Expression.Multiply(linear, Expression.Add(Expression.Constant(2.0), x)); 
    Assert.AreEqual("(* (+ 3 x) (+ 2 x))", Symbolic.ToString(quadratic)); 

    Expression expanded = Symbolic.Expand(quadratic); 
    Assert.AreEqual("(+ (+ (+ (* 3 2) (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(expanded)); 
    Assert.AreEqual("(+ (+ (+ 6 (* 3 x)) (* x 2)) (* x x))", Symbolic.ToString(Symbolic.Simplify(expanded))); 

    Expression derivative = Symbolic.PartialDerivative(expanded, x); 
    Assert.AreEqual("(+ (+ (+ (+ (* 3 0) (* 0 2)) (+ (* 3 1) (* 0 x))) (+ (* x 0) (* 1 2))) (+ (* x 1) (* 1 x)))", Symbolic.ToString(derivative)); 

    Expression simplified = Symbolic.Simplify(derivative); 
    Assert.AreEqual("(+ 5 (+ x x))", Symbolic.ToString(simplified)); 
}